import numpy as np
import matplotlib.pyplot as plt
= np.random.default_rng(42) rng
A Blog Entry on Bayesian Computation by an Applied Mathematician
$$
$$
1 Introduction
1.1 A Time-Agnostic Learning Framework
The absorbing process, a.k.a. masked diffusion, has a unique characteristic as a forward process in a discrete denoising diffusion model; it offers a time-agnostic learning to unmask training framework.
When the state space is \(E^d\) where \(E=\{0,\cdots,K-1\}\) is finite, a current practice is to learn a neural network \(p_\theta\) based on a loss given by \[ \mathcal{L}(\theta):=\int^1_0\frac{\dot{\alpha}_t}{1-\alpha_t}\operatorname{E}\left[\sum_{i=1}^d\log p_\theta(X_0^i|X_t)\right]\,dt, \tag{1}\] where \(\alpha_t\) is a noising schedule, determining the convergence speed of the forward process.1
The expectation in (1) is exactly a cross-entropy loss. Therefore, the loss (1) can be understood as a weighted cross-entropy loss, weighted by the noising schedule \(\alpha_t\).
Note that \(p_\theta\) predicts the true state \(x_0\), based on the current state \(x_t\), some components of which might be masked. Hence we called this framework learning to unmask.
Note also that \(p_\theta\) doesn’t take \(t\) as an argument (Shi et al., 2024), (Ou et al., 2025), which we call the time-agnostic property, following (Zheng et al., 2025).
1.2 Choice of \(\alpha_t\)
This ‘learning to unmask’ task might be very hard when \(t\) is near \(1\), since most of the \(x_t^i\)’s are still masked.
For instance, if we choose a linear schedule \[ \alpha_t=1-t\qquad(t\in[0,1]) \] the scaler \(\frac{\dot{\alpha}_t}{1-\alpha_t}=-t^{-1}\) before the expectation in (1) puts less weight on large \(t\approx1\), while puts much more weight on small \(t\approx0\), where most of the \(x_t^i\)’s should be already unmasked.
Hence, selecting the schedule \(\alpha_t\) to make learning easier can be very effective, for example by reducing the variance of gradient estimator in a SGD algorithm. Actually, this is a part of the technique how (Arriola et al., 2025) achieved their remarkable empirical results.
This flexibility of \(\alpha_t\) is why the loss (1) is considered as a potential competitor against the current dominant autoregressive models. In fact, one work (Chao et al., 2025), still under review, claimed their masked diffusion model surpassed the autoregressive model on the task of language modeling, achieving an evaluation perplexity of 15.36 on the OpenWebText dataset.2
We briefly discuss their trick and related promising techniques to improve the model, before programming toy examples in Section 2 and Section 3 to deepen our understanding in absorbing forward process.
1.3 A Gibbs Sampler Take
One problem about the loss (1) is that \(p_\theta\) predicts the unmasked complete sequence in a product form: \[ p_\theta(x_0|x_t)=\prod_{i=1}^d p_\theta(x_0^i|x_t). \]
This should cause no problem when we unmask one component at a time, since it will be a form of ancestral sampling based on disintegration property.
However, when unmasking two or more components simultaneously, for example when the number of steps is less than \(d\), the product form assumption will simply introduce a bias, as the data distribution is by no means of product form on \(E^d\).
Here, to recover correctness asymptotically, analogy with Gibbs sampling becomes very important.
For example, predictor-corrector technique can be readily employed to mitigate this bias, as discussed in (Lezama et al., 2023), (S. Zhao et al., 2024), (L. Zhao et al., 2024), (Gat et al., 2024), (Wang et al., 2025).
We demonstrate this strategy in Section 3.5.
1.4 Intermediate States
As we mentioned earlier, unmasking can be a very hard task, as closely investigated in (Kim et al., 2025, sec. 3).
To alleviate this problem, (Chao et al., 2025) introduced intermediate states by re-encoding the token in a base-\(2\) encoding, such as \[ 5\mapsto 101. \] The right-hand side needs three steps to be completely unmasked, while the left-hand side only needs one jump.
Therefore, unmasking can be easier to learn, compared to the original token encoding.
However, this is not the only advantage of intermediate states. (Chao et al., 2025) were able to construct a full predictor \(p_\theta\) without the product form assumption on each token.
This approach might act as a block Gibbs sampler and make the convergence faster, as we will discuss later in Section 3.6.
1.5 State-Dependent Rate
As a function of \(t\mapsto\alpha_t\), different choices for \(\alpha_t\) seem to make little impact on the total performance of the model, as we observe in our toy example in Section 2.4.
However, if we allow \(\alpha_t\) to depend on the state as in (Shi et al., 2024, sec. 6), I believe the masked diffusion model will start to show its real potential over currently dominant autoregressive framework.
A problem arises when one tries to learn \(\alpha_t\) at the same time, for example, by including a corresponding term into the loss (1). This will lead to amplified variances of the gradient estimates and unstable training, as reported in (Shi et al., 2024) and (Arriola et al., 2025).
The idea of exploiting state-dependent rate is already very common in sampling time (Peng et al., 2025), (Liu et al., 2025), (Kim et al., 2025), (Rout et al., 2025), determining which token to unmask next during the backward sampling, a bit like Monte Carlo tree search in reinforcement learning.
2 Demo
To demonstrate what an absorbing process does, we carry out generation from a toy data distribution \(\pi_{\text{data}}\) on \(5=\{0,1,2,3,4\}\), by running an exact reverse kernel of the absorbing (masked) forward process.
Therefore, no neural network training will be involved. A 2d example, which is much more interesting, in Section 3 will basically process in parallel.
2.1 Setup
= np.array([0.40, 0.30, 0.18, 0.10, 0.02], dtype=float) p_data
We will represent the MASK as \(-1\). The state space is then \(E:=5 \cup \{-1\}\).
= -1 MASK
An important design choice in the forward process is the noising schedule \(\alpha_t\), which can be interpreted as survival probability and satisfy the following relatioship with the jump intensity \(\beta_t\): \[ \alpha_t=\exp\left(-\int^t_0\beta_s\,ds\right),\qquad t\in[0,1]. \]
Let us keep it simple and set \(\alpha_t=t\). To achive this, we need to set \[ \beta_t=\frac{1}{1-t},\qquad t\in[0,1), \] which is clearly diverging as \(t\to1\). This is to ensure the process to converge in finite time.
= 10 # number of steps
T = np.linspace(1.00, 0.00, T+1) alpha
2.2 The Backward Transition Kernel
In this setting, the backward transition kernel \(p(x_{t-1} | x_t)\) satisfies \[ \operatorname{P}[X_{t-1}=-1|X_t=-1]=\frac{1 - \alpha_{t-1}}{1 - \alpha_t}. \] In the other cases, the unmasked values \(x_{t-1}\) should be determined according to \(\pi_{\text{data}}\), which is unavailable in a real setting, of course.
= (alpha[:-1] - alpha[1:]) / (1.0 - alpha[1:]) # length T p_unmask
def reverse_sample(num_samples: int, p_unmask: np.ndarray):
"""
Start from x_T = MASK for all samples, apply the exact reverse transitions down to t=0.
Returns x_0 samples in 5 = {0,1,...,4}.
"""
= np.full(num_samples, MASK, dtype=int)
x_t = np.empty((T+1, num_samples), dtype=int)
hist 0] = x_t.copy()
hist[for t in range(T, 0, -1):
= np.where(x_t == MASK)[0] # masked indices
idx_mask if idx_mask.size > 0:
= rng.random(idx_mask.size)
u = idx_mask[u < p_unmask[t-1]] # indices that are going to be unmasked
unmask_now if unmask_now.size > 0:
= rng.choice(5, size=unmask_now.size, p=p_data)
cats = cats
x_t[unmask_now] -t+1] = x_t.copy()
hist[T
# At t=0, all remaining MASKs (if any) must have already unmasked earlier with probability 1,
# but numerically we ensure no MASK remains:
assert np.all(x_t != MASK), "Some samples remained MASK at t=0, which should not happen."
return x_t, hist
2.2.1 A Note on Alternative Sampling Strategies
Note that we need not to obey this exact backward transition kernel to sample from the data distribution.
For example, remasking (Lezama et al., 2023), (S. Zhao et al., 2024), (Gat et al., 2024), (Wang et al., 2025), a form of predictor-corrector sampling, can be incorporated to improve sample quality, mitigating numerical errors, as we will see in Section 3.5.
Recently, sampling time path planning (Peng et al., 2025), (Liu et al., 2025), (Kim et al., 2025), (Rout et al., 2025) has been proposed to improve sample quality and model log-likelihood, which lay out of the scope of this post.
2.3 Honest Sampling
= 100_000 # size of sample to get
N = reverse_sample(N, p_unmask) x0_samples, hist
We first make sure our implementation is correct by checking the empirical distribution of the samples generated agrees with the true distribution.
Code (tap me)
= np.bincount(x0_samples, minlength=5).astype(float)
counts = counts / counts.sum()
p_emp
print("Toy data marginal p_data:", p_data.round(4))
print("Empirical p after reverse sampling:", p_emp.round(4))
# ---------- Bar chart: p_data vs empirical ----------
= np.arange(5)
xs = 0.4
width =(6,3))
plt.figure(figsize- width/2, p_data, width=width, label="true p_data")
plt.bar(xs + width/2, p_emp, width=width, label="empirical (reverse)")
plt.bar(xs "Reverse samples match the data marginal")
plt.title("category id")
plt.xlabel("probability")
plt.ylabel(
plt.legend()
plt.tight_layout() plt.show()
Toy data marginal p_data: [0.4 0.3 0.18 0.1 0.02]
Empirical p after reverse sampling: [0.3979 0.3002 0.1828 0.0992 0.0198]
Perfect! Making sure everything is working, we plot 1000 sample paths from the reverse process.
Code (tap me)
= min(1000, hist.shape[1])
n_samples_to_plot
plt.figure()
for i in range(n_samples_to_plot):
range(hist.shape[0]), hist[:, i], alpha=0.5, linewidth=0.8)
plt.plot(
'Time step')
plt.xlabel('State')
plt.ylabel(f'Sample trajectories (first {n_samples_to_plot} samples)')
plt.title(True, alpha=0.3)
plt.grid( plt.show()
We see a relatively equal number of jumps per step:
= np.zeros(T)
jump_counts for i in range(10):
= sum(hist[i] != hist[i+1])
jump_counts[i] print(jump_counts)
[ 9916. 10017. 10001. 10143. 10062. 10195. 9815. 10005. 9859. 9987.]
This is because we set \(\alpha_t=1-t\) to be linear.
Code (tap me)
# ---------- Plot schedule α_t (survival probability) ----------
=(5,3))
plt.figure(figsizerange(T+1), alpha, marker="o")
plt.plot(r"Survival probability $\alpha_t$")
plt.title("t")
plt.xlabel(r"$\alpha_t$")
plt.ylabel(
plt.tight_layout() plt.show()
2.4 Choice of \(\alpha_t\)
\(\alpha_t\) controls the convergence rate of the forward process.
We change \(\alpha_t\) to see the impact on the sampling accuracy. (There should be no influence as long as the exact backward kernel is used.)
Let us change \(\alpha_t\) to be an exponential schedule:
Code (tap me)
= np.exp(np.linspace(0.00, -10.00, T+1))
alpha_exp = (alpha_exp[:-1] - alpha_exp[1:]) / (1.0 - alpha_exp[1:])
p_unmask_exp
=(5,3))
plt.figure(figsizerange(T+1), alpha_exp, marker="o")
plt.plot(r"Survival probability $\alpha_t$")
plt.title("t")
plt.xlabel(r"$\alpha_t$")
plt.ylabel(
plt.tight_layout() plt.show()
In this way, most of the unmasking events should occur in the very last step of the reverse process.
= reverse_sample(N, p_unmask_exp) x0_exp, hist_exp
Code (tap me)
= min(1000, hist_exp.shape[1])
n_samples_to_plot
plt.figure()
for i in range(n_samples_to_plot):
range(hist_exp.shape[0]), hist_exp[:, i], alpha=0.5, linewidth=0.8)
plt.plot(
'Time step')
plt.xlabel('State')
plt.ylabel(f'Sample trajectories (first {n_samples_to_plot} samples)')
plt.title(True, alpha=0.3)
plt.grid( plt.show()
We see many jumps happen in the latter half.
Practically speaking, this is certainly not what we want.
We spend almost half of the computational time (up to the 6th step) in simulating the phantom jumps which just do not happen. The same concern was raised by (Chao et al., 2025).
However, the accuracy is same, as the exact kernel is used to simulate, if the computational cost might be different.
Code (tap me)
def calc_l1_kl(x0_samples, split = 10):
= np.array_split(x0_samples, split)
chunks = np.array([np.bincount(chunk, minlength=5).astype(float) for chunk in chunks])
counts
= counts / counts.sum(axis=1)[0]
p_emp = np.abs(p_emp - p_data).sum(axis=1).mean()
l1 = np.abs(p_emp - p_data).sum(axis=1).var()
l1_var = (np.where(p_emp > 0, p_emp * np.log(p_emp / p_data), 0)).sum(axis=1).mean()
kl = (np.where(p_emp > 0, p_emp * np.log(p_emp / p_data), 0)).sum(axis=1).var()
kl_var return l1, l1_var, kl, kl_var
= calc_l1_kl(x0_samples)
l1, l1_var, kl, kl_var print("Linear Schedule: L1 distance:", round(l1, 6), " ± ", round(l1_var, 6), " KL(p_emp || p_data):", round(kl, 6), " ± ", round(kl_var, 6))
= calc_l1_kl(x0_exp)
l1_exp, l1_exp_var, kl_exp, kl_exp_var print("Exponential Schedule: L1 distance:", round(l1_exp, 6), " ± ", round(l1_exp_var, 6), " KL(p_emp || p_data):", round(kl_exp, 6), " ± ", round(kl_exp_var, 6))
Linear Schedule: L1 distance: 0.0147 ± 3e-05 KL(p_emp || p_data): 0.000196 ± 0.0
Exponential Schedule: L1 distance: 0.0148 ± 4.9e-05 KL(p_emp || p_data): 0.000232 ± 0.0
3 2D Example
3.1 Setup
We consider a highly correlated distribution, whose support is degenerated on the diagonal element on \(5^2\).
Code (tap me)
= 5
K = -1
MASK
# Base marginal for a single site
= np.array([0.40, 0.30, 0.18, 0.10, 0.02], dtype=float)
p_single /= p_single.sum()
p_single
# Build correlated joint with same-parity constraint
= np.zeros((K, K), dtype=float)
W for i in range(K):
for j in range(K):
if (i % 2) == (j % 2):
= p_single[i] * p_single[j]
W[i, j] = W / W.sum()
pi_joint = pi_joint.sum(axis=1)
pi_x = pi_joint.sum(axis=0)
pi_y
# Conditionals
= np.zeros((K, K), dtype=float) # [j, i]
cond_x_given_y = np.zeros((K, K), dtype=float) # [i, j]
cond_y_given_x for j in range(K):
= pi_joint[:, j]; s = col.sum()
col if s > 0:
= col / s
cond_x_given_y[j, :] for i in range(K):
= pi_joint[i, :]; s = row.sum()
row if s > 0:
= row / s
cond_y_given_x[i, :]
= plt.figure(figsize=(8, 3.4))
fig
# Heatmap
= plt.subplot(1, 2, 1)
ax1 = ax1.imshow(pi_joint, cmap='viridis', aspect='equal')
im 'Y')
ax1.set_xlabel('X')
ax1.set_ylabel('Joint Probability Distribution (Heatmap)')
ax1.set_title(range(K))
ax1.set_xticks(range(K))
ax1.set_yticks(
# Value annotation
for i in range(K):
for j in range(K):
f'{pi_joint[i, j]:.3f}',
ax1.text(j, i, ='center', va='center', color='white', fontsize=8)
ha
=ax1)
plt.colorbar(im, ax
# 3D bar plot
= plt.subplot(1, 2, 2, projection='3d')
ax2 = np.arange(K)
x = np.arange(K)
y = np.meshgrid(x, y)
X, Y = pi_joint
Z
ax2.bar3d(X.ravel(), Y.ravel(), np.zeros_like(Z.ravel()), 0.8, 0.8, Z.ravel(), alpha=0.8, cmap='viridis')
'Y')
ax2.set_xlabel('X')
ax2.set_ylabel('Probability')
ax2.set_zlabel('Joint Probability Distribution (3D)')
ax2.set_title(
ax2.set_xticks(x)
ax2.set_yticks(y)
plt.tight_layout() plt.show()
3.2 The Backward Transition Kernel
We will first consider, again, linear schedule:
= 10
T = np.linspace(1.0, 0.0, T + 1)
alpha = (alpha[:-1] - alpha[1:]) / (1.0 - alpha[1:])
p_unmask = np.clip(p_unmask, 0.0, 1.0) p_unmask
Code (tap me)
=(5, 3))
plt.figure(figsizerange(T+1), alpha, marker="o")
plt.plot(r"Survival probability $\alpha_t$")
plt.title("t")
plt.xlabel(r"$\alpha_t$")
plt.ylabel(
plt.tight_layout() plt.show()
The code for backward sampling is basically the same, except for the number of if
branch is now four, rather than just one.
Code (definition of reverse_sample_pairs)
def reverse_sample_pairs(num_samples: int, p_unmask: np.ndarray, T: int):
= np.full(num_samples, MASK, dtype=int)
x1 = np.full(num_samples, MASK, dtype=int)
x2 = np.empty((T + 1, num_samples), dtype=int); hist1[0] = x1
hist1 = np.empty((T + 1, num_samples), dtype=int); hist2[0] = x2
hist2
for t in range(T, 0, -1):
= p_unmask[t-1]
p
# both masked
= (x1 == MASK) & (x2 == MASK)
both = np.where(both)[0]
idx if idx.size > 0:
= rng.random(idx.size) < p
um1 = rng.random(idx.size) < p
um2
= idx[um1 & um2]
idx_both if idx_both.size > 0:
= pi_joint.ravel()
flat = rng.choice(K*K, size=idx_both.size, p=flat)
choices = choices // K; ys = choices % K
xs = xs; x2[idx_both] = ys
x1[idx_both]
= idx[um1 & (~um2)]
idx_only1 if idx_only1.size > 0:
= rng.choice(K, size=idx_only1.size, p=pi_x)
x1[idx_only1]
= idx[(~um1) & um2]
idx_only2 if idx_only2.size > 0:
= rng.choice(K, size=idx_only2.size, p=pi_y)
x2[idx_only2]
# x1 masked, x2 revealed
= np.where((x1 == MASK) & (x2 != MASK))[0]
idx_b1 if idx_b1.size > 0:
= rng.random(idx_b1.size) < p
will = idx_b1[will]
idx_now if idx_now.size > 0:
= x2[idx_now]
y_vals for val in np.unique(y_vals):
= (y_vals == val); n = m.sum()
m = rng.choice(K, size=n, p=cond_x_given_y[val, :])
x1[idx_now[m]]
# x2 masked, x1 revealed
= np.where((x2 == MASK) & (x1 != MASK))[0]
idx_b2 if idx_b2.size > 0:
= rng.random(idx_b2.size) < p
will = idx_b2[will]
idx_now if idx_now.size > 0:
= x1[idx_now]
x_vals for val in np.unique(x_vals):
= (x_vals == val); n = m.sum()
m = rng.choice(K, size=n, p=cond_y_given_x[val, :])
x2[idx_now[m]]
- t + 1] = x1; hist2[T - t + 1] = x2
hist1[T
assert np.all(x1 != MASK) and np.all(x2 != MASK)
return np.stack([x1, x2], axis=1), hist1, hist2
3.3 Honest Sampling
Again, using the exact backward kernel, we are able to reproduce the true joint distribution.
Code (tap me)
# Run
= 100_000
N = reverse_sample_pairs(N, p_unmask, T)
pairs, h1, h2
# Empirical joint
= np.zeros((K, K), dtype=float)
counts for a, b in pairs:
+= 1.0
counts[a, b] = counts / counts.sum()
pi_emp
= plt.subplots(1, 2, figsize=(8, 3.4))
fig, ax = ax[0].imshow(pi_joint, origin="lower", aspect="equal")
im0 0].set_title("True joint π_data")
ax[0].set_xlabel("x2"); ax[0].set_ylabel("x1")
ax[for i in range(K):
for j in range(K):
0].text(j, i, f'{pi_joint[i, j]:.3f}',
ax[='center', va='center', color='white', fontsize=8)
ha=ax[0], fraction=0.046, pad=0.04)
fig.colorbar(im0, ax
= ax[1].imshow(pi_emp, origin="lower", aspect="equal")
im1 1].set_title("Empirical joint (reverse)")
ax[1].set_xlabel("x2"); ax[1].set_ylabel("x1")
ax[for i in range(K):
for j in range(K):
1].text(j, i, f'{pi_emp[i, j]:.3f}',
ax[='center', va='center', color='white', fontsize=8)
ha=ax[1], fraction=0.046, pad=0.04)
fig.colorbar(im1, ax
plt.tight_layout() plt.show()
Code (tap me)
def l1_kl(pairs, split=10):
= np.array_split(pairs, split)
chunks = [], []
l1, kl for chunk in chunks:
= np.zeros((K, K), dtype=float)
counts for a, b in chunk:
+= 1.0
counts[a, b] = counts / counts.sum()
pi_emp
= 1e-12
eps abs(pi_emp - pi_joint).sum())
l1.append(np.= (pi_emp > 0) & (pi_joint > 0)
nz * np.log((pi_emp[nz] + eps) / pi_joint[nz])).sum())
kl.append((pi_emp[nz] return l1, kl
= l1_kl(pairs)
l1, kl
print("L1 distance:", round(np.mean(l1), 6), " ± ", round(np.var(l1), 6), " KL(emp || true):", f"{np.mean(kl):.6e} ± {np.var(kl):.6e}")
L1 distance: 0.024162 ± 4.6e-05 KL(emp || true): 7.286359e-04 ± 7.462192e-08
Note that since the survival rate \(\alpha_t\) decreases linearly, the number of newly unmasked coordinates per step will decrease exponentially.
Code (tap me)
= []
new_unmasks_per_step for t in range(T):
= (h1[t] == MASK) & (h1[t+1] != MASK)
changed1 = (h2[t] == MASK) & (h2[t+1] != MASK)
changed2 sum() + changed2.sum())
new_unmasks_per_step.append(changed1.
=(6, 3))
plt.figure(figsizerange(1, T+1), new_unmasks_per_step, marker="o")
plt.plot("Newly unmasked coordinates per step")
plt.title("reverse step (t→t-1)")
plt.xlabel("#coords")
plt.ylabel(
plt.tight_layout() plt.show()
3.3.1 Larger Step Size
What if we employ a large step size?
Actually, the result doesn’t change. Moreover, the accuracy is higher, since it is equivalent to direct sampling from \(\pi_{\text{data}}\).
= 1 # Number of steps
T = np.linspace(1.0, 0.0, T + 1)
alpha = (alpha[:-1] - alpha[1:]) / (1.0 - alpha[1:])
p_unmask = np.clip(p_unmask, 0.0, 1.0) p_unmask
Code (tap me)
# Run
= 100_000
N = reverse_sample_pairs(N, p_unmask, T)
pairs, h1, h2
# Empirical joint
= np.zeros((K, K), dtype=float)
counts for a, b in pairs:
+= 1.0
counts[a, b] = counts / counts.sum()
pi_emp
= plt.subplots(1, 2, figsize=(8, 3.4))
fig, ax = ax[0].imshow(pi_joint, origin="lower", aspect="equal")
im0 0].set_title("True joint π_data")
ax[0].set_xlabel("x2"); ax[0].set_ylabel("x1")
ax[for i in range(K):
for j in range(K):
0].text(j, i, f'{pi_joint[i, j]:.3f}',
ax[='center', va='center', color='white', fontsize=8)
ha=ax[0], fraction=0.046, pad=0.04)
fig.colorbar(im0, ax
= ax[1].imshow(pi_emp, origin="lower", aspect="equal")
im1 1].set_title("Empirical joint (reverse)")
ax[1].set_xlabel("x2"); ax[1].set_ylabel("x1")
ax[for i in range(K):
for j in range(K):
1].text(j, i, f'{pi_emp[i, j]:.3f}',
ax[='center', va='center', color='white', fontsize=8)
ha=ax[1], fraction=0.046, pad=0.04)
fig.colorbar(im1, ax
plt.tight_layout() plt.show()
Code (tap me)
= l1_kl(pairs)
l1, kl
print("L1 distance:", round(np.mean(l1), 6), " ± ", round(np.var(l1), 6), " KL(emp || true):", f"{np.mean(kl):.6e} ± {np.var(kl):.6e}")
L1 distance: 0.021245 ± 5e-05 KL(emp || true): 6.510835e-04 ± 1.258177e-07
3.4 Coordinate-wise Sampling
The joint distribution would be unavailable, even if the learning based on the loss (1) were perfectly done, because of the product form assumption on the neural network predictor \(p_\theta\).3
We mock this situation by replacing the joint distribution in the exact kernel, programmed in Section 3.2, with the product of its marginals.
Code (definition of reverse_sample_incorrect)
def reverse_sample_incorrect(num_samples: int, p_unmask: np.ndarray, T: int):
= np.full(num_samples, MASK, dtype=int)
x1 = np.full(num_samples, MASK, dtype=int)
x2 = np.empty((T + 1, num_samples), dtype=int); hist1[0] = x1
hist1 = np.empty((T + 1, num_samples), dtype=int); hist2[0] = x2
hist2
for t in range(T, 0, -1):
= p_unmask[t-1]
p
# both masked
= (x1 == MASK) & (x2 == MASK)
both = np.where(both)[0]
idx if idx.size > 0:
= rng.random(idx.size) < p
um1 = rng.random(idx.size) < p
um2
= idx[um1 & um2]
idx_both if idx_both.size > 0:
= np.outer(pi_x, pi_y).ravel()
flat = rng.choice(K*K, size=idx_both.size, p=flat)
choices = choices // K; ys = choices % K
xs = xs; x2[idx_both] = ys
x1[idx_both]
= idx[um1 & (~um2)]
idx_only1 if idx_only1.size > 0:
= rng.choice(K, size=idx_only1.size, p=pi_x)
x1[idx_only1]
= idx[(~um1) & um2]
idx_only2 if idx_only2.size > 0:
= rng.choice(K, size=idx_only2.size, p=pi_y)
x2[idx_only2]
# x1 masked, x2 revealed
= np.where((x1 == MASK) & (x2 != MASK))[0]
idx_b1 if idx_b1.size > 0:
= rng.random(idx_b1.size) < p
will = idx_b1[will]
idx_now if idx_now.size > 0:
= x2[idx_now]
y_vals for val in np.unique(y_vals):
= (y_vals == val); n = m.sum()
m = rng.choice(K, size=n, p=cond_x_given_y[val, :])
x1[idx_now[m]]
# x2 masked, x1 revealed
= np.where((x2 == MASK) & (x1 != MASK))[0]
idx_b2 if idx_b2.size > 0:
= rng.random(idx_b2.size) < p
will = idx_b2[will]
idx_now if idx_now.size > 0:
= x1[idx_now]
x_vals for val in np.unique(x_vals):
= (x_vals == val); n = m.sum()
m = rng.choice(K, size=n, p=cond_y_given_x[val, :])
x2[idx_now[m]]
- t + 1] = x1; hist2[T - t + 1] = x2
hist1[T
assert np.all(x1 != MASK) and np.all(x2 != MASK)
return np.stack([x1, x2], axis=1), hist1, hist2
Code (tap me)
= 10 # Number of steps
T = np.linspace(1.0, 0.0, T + 1)
alpha = (alpha[:-1] - alpha[1:]) / (1.0 - alpha[1:])
p_unmask = np.clip(p_unmask, 0.0, 1.0)
p_unmask
= reverse_sample_incorrect(N, p_unmask, T)
pairs, h1, h2
# Empirical joint
= np.zeros((K, K), dtype=float)
counts for a, b in pairs:
+= 1.0
counts[a, b] = counts / counts.sum()
pi_emp
= plt.subplots(1, 2, figsize=(8, 3.4))
fig, ax = ax[0].imshow(pi_joint, origin="lower", aspect="equal")
im0 0].set_title("True joint π_data")
ax[0].set_xlabel("x2"); ax[0].set_ylabel("x1")
ax[for i in range(K):
for j in range(K):
0].text(j, i, f'{pi_joint[i, j]:.3f}',
ax[='center', va='center', color='white', fontsize=8)
ha=ax[0], fraction=0.046, pad=0.04)
fig.colorbar(im0, ax
= ax[1].imshow(pi_emp, origin="lower", aspect="equal")
im1 1].set_title("Empirical joint (reverse)")
ax[1].set_xlabel("x2"); ax[1].set_ylabel("x1")
ax[for i in range(K):
for j in range(K):
1].text(j, i, f'{pi_emp[i, j]:.3f}',
ax[='center', va='center', color='white', fontsize=8)
ha=ax[1], fraction=0.046, pad=0.04)
fig.colorbar(im1, ax
plt.tight_layout() plt.show()
Then, this time the result is not so good.
Code (tap me)
= l1_kl(pairs)
l1, kl
print("L1 distance:", round(np.mean(l1), 6), " ± ", round(np.var(l1), 6), " KL(emp || true):", f"{np.mean(kl):.6e} ± {np.var(kl):.6e}")
L1 distance: 0.089135 ± 1.2e-05 KL(emp || true): -4.100716e-02 ± 6.515707e-06
There is a small deviation from the exact kernel, where \(\ell^1\) distance was \(0.024\pm0.00005\).
We can easily see, in the empirical distribution, some cells are assinged with positive mass, although the true probability is \(0\).
This effect becomes smaller when the number of steps T
is large. For example, setting T=1000
gives us with almost same accuracy as the exact kernel (although we don’t show it here for brevity).
3.4.1 Larger Step Size
The situation gets worse when the step size is large, for example T=1
.
Code (tap me)
= 1 # Number of steps
T = np.linspace(1.0, 0.0, T + 1)
alpha = (alpha[:-1] - alpha[1:]) / (1.0 - alpha[1:])
p_unmask = np.clip(p_unmask, 0.0, 1.0)
p_unmask
= reverse_sample_incorrect(N, p_unmask, T)
pairs, h1, h2
# Empirical joint
= np.zeros((K, K), dtype=float)
counts for a, b in pairs:
+= 1.0
counts[a, b] = counts / counts.sum()
pi_emp
= plt.subplots(1, 2, figsize=(8, 3.4))
fig, ax = ax[0].imshow(pi_joint, origin="lower", aspect="equal")
im0 0].set_title("True joint π_data")
ax[0].set_xlabel("x2"); ax[0].set_ylabel("x1")
ax[for i in range(K):
for j in range(K):
0].text(j, i, f'{pi_joint[i, j]:.3f}',
ax[='center', va='center', color='white', fontsize=8)
ha=ax[0], fraction=0.046, pad=0.04)
fig.colorbar(im0, ax
= ax[1].imshow(pi_emp, origin="lower", aspect="equal")
im1 1].set_title("Empirical joint (reverse)")
ax[1].set_xlabel("x2"); ax[1].set_ylabel("x1")
ax[for i in range(K):
for j in range(K):
1].text(j, i, f'{pi_emp[i, j]:.3f}',
ax[='center', va='center', color='white', fontsize=8)
ha=ax[1], fraction=0.046, pad=0.04)
fig.colorbar(im1, ax
plt.tight_layout() plt.show()
Code (tap me)
= l1_kl(pairs)
l1, kl
print("L1 distance:", round(np.mean(l1), 6), " ± ", round(np.var(l1), 6), " KL(emp || true):", f"{np.mean(kl):.6e} ± {np.var(kl):.6e}")
L1 distance: 0.85124 ± 4.4e-05 KL(emp || true): -2.884951e-01 ± 1.114368e-05
The error is now significant, because the incorrect product kernel is used every time, as we set \(T=1\) meaning unmasking in just one step!
One way to fix this is to set T
as large as possible, making sure no more than one jump occurs simultaneously.
3.5 Corrector Sampling
Another remedy is to utilise Gibbs sampling ideas to correct the bias, again at the expense of computational cost.
To do this, we must identify a Markov kernel that keeps the marginal distribution of \(X_t\) invariant for every timestep \(t\in[0,1]\).
In our case, let us simply add a re-masking step, only to those pairs \((x_t^1,x_t^2)\)’s which are completely unmasked.
As there is some probability of having been unmasked simultaneously, re-masking only one of them, this time \(x_t^1\), will allow us a second chance to arrive at a correct pair \((x_t^{1,\text{corrected}},x_t^2)\).
Code (definition of reverse_sample_correct)
def reverse_sample_corrector(num_samples: int, p_unmask: np.ndarray, T: int):
= np.full(num_samples, MASK, dtype=int)
x1 = np.full(num_samples, MASK, dtype=int)
x2 = np.empty((T + 1, num_samples), dtype=int); hist1[0] = x1
hist1 = np.empty((T + 1, num_samples), dtype=int); hist2[0] = x2
hist2
for t in range(T, 0, -1):
= p_unmask[t-1]
p
# both masked
= (x1 == MASK) & (x2 == MASK)
both = np.where(both)[0]
idx if idx.size > 0:
= rng.random(idx.size) < p
um1 = rng.random(idx.size) < p
um2
= idx[um1 & um2]
idx_both if idx_both.size > 0:
= np.outer(pi_x, pi_y).ravel()
flat = rng.choice(K*K, size=idx_both.size, p=flat)
choices = choices // K; ys = choices % K
xs = xs; x2[idx_both] = ys
x1[idx_both]
= idx[um1 & (~um2)]
idx_only1 if idx_only1.size > 0:
= rng.choice(K, size=idx_only1.size, p=pi_x)
x1[idx_only1]
= idx[(~um1) & um2]
idx_only2 if idx_only2.size > 0:
= rng.choice(K, size=idx_only2.size, p=pi_y)
x2[idx_only2]
# x1 masked, x2 revealed
= np.where((x1 == MASK) & (x2 != MASK))[0]
idx_b1 if idx_b1.size > 0:
= rng.random(idx_b1.size) < p
will = idx_b1[will]
idx_now if idx_now.size > 0:
= x2[idx_now]
y_vals for val in np.unique(y_vals):
= (y_vals == val); n = m.sum()
m = rng.choice(K, size=n, p=cond_x_given_y[val, :])
x1[idx_now[m]]
# x2 masked, x1 revealed
= np.where((x2 == MASK) & (x1 != MASK))[0]
idx_b2 if idx_b2.size > 0:
= rng.random(idx_b2.size) < p
will = idx_b2[will]
idx_now if idx_now.size > 0:
= x1[idx_now]
x_vals for val in np.unique(x_vals):
= (x_vals == val); n = m.sum()
m = rng.choice(K, size=n, p=cond_y_given_x[val, :])
x2[idx_now[m]]
# corrector step
= 1.0 - p # masking probability
q = (x1 != MASK) & (x2 != MASK)
both = np.where(both)[0]
idx if idx.size > 0:
= rng.random(idx.size) < q
will = idx[will]
idx_now if idx_now.size > 0:
= MASK
x1[idx_now] # predictor step
# both masked
= (x1 == MASK) & (x2 == MASK)
both = np.where(both)[0]
idx if idx.size > 0:
= rng.random(idx.size) < p
um1 = rng.random(idx.size) < p
um2
= idx[um1 & um2]
idx_both if idx_both.size > 0:
= np.outer(pi_x, pi_y).ravel()
flat = rng.choice(K*K, size=idx_both.size, p=flat)
choices = choices // K; ys = choices % K
xs = xs; x2[idx_both] = ys
x1[idx_both]
= idx[um1 & (~um2)]
idx_only1 if idx_only1.size > 0:
= rng.choice(K, size=idx_only1.size, p=pi_x)
x1[idx_only1]
= idx[(~um1) & um2]
idx_only2 if idx_only2.size > 0:
= rng.choice(K, size=idx_only2.size, p=pi_y)
x2[idx_only2]
# x1 masked, x2 revealed
= np.where((x1 == MASK) & (x2 != MASK))[0]
idx_b1 if idx_b1.size > 0:
= rng.random(idx_b1.size) < p
will = idx_b1[will]
idx_now if idx_now.size > 0:
= x2[idx_now]
y_vals for val in np.unique(y_vals):
= (y_vals == val); n = m.sum()
m = rng.choice(K, size=n, p=cond_x_given_y[val, :])
x1[idx_now[m]]
# x2 masked, x1 revealed
= np.where((x2 == MASK) & (x1 != MASK))[0]
idx_b2 if idx_b2.size > 0:
= rng.random(idx_b2.size) < p
will = idx_b2[will]
idx_now if idx_now.size > 0:
= x1[idx_now]
x_vals for val in np.unique(x_vals):
= (x_vals == val); n = m.sum()
m = rng.choice(K, size=n, p=cond_y_given_x[val, :])
x2[idx_now[m]]
- t + 1] = x1; hist2[T - t + 1] = x2
hist1[T
assert np.all(x1 != MASK) and np.all(x2 != MASK)
return np.stack([x1, x2], axis=1), hist1, hist2
Code (tap me)
= 10 # Number of steps
T = np.linspace(1.0, 0.0, T + 1)
alpha = (alpha[:-1] - alpha[1:]) / (1.0 - alpha[1:])
p_unmask = np.clip(p_unmask, 0.0, 1.0)
p_unmask
= reverse_sample_corrector(N, p_unmask, T)
pairs, h1, h2
# Empirical joint
= np.zeros((K, K), dtype=float)
counts for a, b in pairs:
+= 1.0
counts[a, b] = counts / counts.sum()
pi_emp
= plt.subplots(1, 2, figsize=(8, 3.4))
fig, ax = ax[0].imshow(pi_joint, origin="lower", aspect="equal")
im0 0].set_title("True joint π_data")
ax[0].set_xlabel("x2"); ax[0].set_ylabel("x1")
ax[for i in range(K):
for j in range(K):
0].text(j, i, f'{pi_joint[i, j]:.3f}',
ax[='center', va='center', color='white', fontsize=8)
ha=ax[0], fraction=0.046, pad=0.04)
fig.colorbar(im0, ax
= ax[1].imshow(pi_emp, origin="lower", aspect="equal")
im1 1].set_title("Empirical joint (reverse)")
ax[1].set_xlabel("x2"); ax[1].set_ylabel("x1")
ax[for i in range(K):
for j in range(K):
1].text(j, i, f'{pi_emp[i, j]:.3f}',
ax[='center', va='center', color='white', fontsize=8)
ha=ax[1], fraction=0.046, pad=0.04)
fig.colorbar(im1, ax
plt.tight_layout() plt.show()
Code (tap me)
= l1_kl(pairs, 10)
l1, kl
print("L1 distance:", round(np.mean(l1), 6), " ± ", round(np.var(l1), 6), " KL(emp || true):", f"{np.mean(kl):.6e} ± {np.var(kl):.6e}")
L1 distance: 0.022518 ± 2e-05 KL(emp || true): 4.694469e-05 ± 8.500292e-08
We see a small improvement from the \(\ell^1\) distance of \(0.0241\). The difference is significant in terms of our variance of an order \(O(10^{-5})\) culculated from \(10\) repeated experiments.
This is the idea behind the predictor-corrector technique discussed, for example, in (Gat et al., 2024), (S. Zhao et al., 2024), (L. Zhao et al., 2024).
We only included one corrector step per reverse step, although more correction will increase accuracy.
This can be regarded as a form of inference time compute scaling (Wang et al., 2025), although it is a fairly bad idea to wait for the predictor-corrector steps to converge.
3.6 Discussion
Let us come back to the technique employed in (Chao et al., 2025), which we mentioned in Section 1.4.
As we have seen, the number of steps \(T\) has to be kept larger than \(d\), to make sure no more than one jump occurs simultaneously.
However, this increases the computational cost, as more time steps must be spent in simulating phantom jumps.
One thing we can do is to fill this blank steps with a informed ‘half-jump’, by introducing a sub-structure in each state \(x\in E\).
This half-jump is a bit more robust than a direct unmasking. Therefore we are able to spend the time up to \(T\) more meaningfully.
Thus, this strategy can be view as another way to mitigate the bias introduced by the incorrect backward kernel.
4 Future Works
Since the absorbing process is favoured only because of its time-agnostic property, its ability should be explained separately with the properties of the process.
4.1 A Reinforcenment Learning Take
The inference step and the sampling step should be decoupled, at least conceptually.
To tune the forward process noising schedule \(\alpha_t(x)\), a reinforcement learning framework will be employed, I believe in the near future.
This is a variant of meta-learning and, in this way, the unmasking network \(p_\theta\) will be able to efficiently learn the correct dependence structure in the data domain.
For example, in language modeling, there is a natural sequential structure, which is partly why autoregressive modeling has been dominant in this domain. However, by learning \(\alpha_t\) in masking process, a much more efficent factorization over texts can be aquired in a data-driven way.
I even think this \(\alpha_t\) can play an important role just as a word embedding does currently.
In the sampling step, a sampling time path planning will greatly enhance sample quality, just as Monte Carlo tree search does in reinforcement learning.
As a conclusion, the flexible framework of masked diffusion models will enable a marriage with reinforcement learning and meta learning, which will be a way to overcome the autoregressive modeling framework, because the latter imposes unnecessary sequential inductive bias into the model.
4.2 A Scaling Analysis
There are two papers, namely (Santos et al., 2023) and (Winkler et al., 2024), which carry out a scaling analysis as \(K\to\infty\), in order to bridge the gap between discrete and continuous state spaces.
In (Santos et al., 2023, sec. 2.4) and its Appendix C, it is proved that a fairly large class of discrete processes will converge to Ito processes, as \(K\to\infty\).
Their discussion based on a formal argument. They proves that the Kramers-Moyal expansion of the Chapman-Kolmogorov equation of the discrete process converges to that of a Ito process.
In other words, they didn’t identify the direct correspondence, for example deriving the exact expression of limiting SDEs.
(Winkler et al., 2024) builds on their analysis, identifying an OU process on \(\mathbb{R}^d\) corresponds to a Ehrenfest process.
This line of research has a lot to do with thermodynamics, and might provide insights into whether images should be modeled discretely or continuously.
Also, the limit \(d\to\infty\) has yet to be explored.
4.3 Concluding Remarks
I believe that masked diffusion modeling can be viewed as a form of Gibbs sampling, with a learned transition kernel.
Many current practices are based on (uniformly) random scan Gibbs, while the autoregressive models are a fixed scan Gibbs.
Most recent improvements are based upon ideas from reinforcement learning and meta learning, where an optimal order to unmask components is pursued.
This point of view might not be only an abstract nonsense. I actually believe that this point of view will be fruitful in the future.
5 関連記事
References
Footnotes
Discrete Flow Matching (Campbell et al., 2024), (Gat et al., 2024), (Shaul et al., 2025) and simplified Masked diffusion (Shi et al., 2024), (Ou et al., 2025), (Zheng et al., 2025) are different frameworks with different ranges, but both lead to the same training objective (1), when applied to the forward masking process.↩︎
In the context of language modeling, the perplexity is defined as \(2^{l}\) where \(l\) is the average log-likelihood of the test set.↩︎
Of course, the exact sampling would have been available, for exmple, if we learned the backward intensity as (Campbell et al., 2022). However, these methods have been marginalized due to suboptimal performance.↩︎