Diffusion LM初探:LLaDA

LLaDa 论文&代码阅读

Github链接

Paper链接

Introduction

Insights

The generative modeling principles:

\[\max_{\theta} \mathbb{E}_{p_{\text{data}}(x)} \log p_{\theta}(x) \Leftrightarrow \min_{\theta} \mathbf{KL}(p_{\text{data}}(x) \| p_{\theta}(x))\]

KL散度是:

\[D_{KL}(P \| Q) = \sum_{x} P(x) \log P(x) - \sum_{x} P(x) \log Q(x)\]

其中\(\sum_{x} P(x) \log P(x)\)是一个常数,因此最小化KL散度等价于最大化$ _{x} P(x) Q(x)$

AR formulation:

\[p_{\theta}(x) = p_{\theta}(x^1) \prod_{i=2}^{L} p_{\theta}(x^i \mid x^1, \dots, x^{i-1})\]

Insight: 是 generative modeling principles 而不是 autoregressive formulation 决定了 LLM的那些属性 (scalability, insturction following, in-context learning)

Contributions

scalable, 可以做in-context learning, SFT 后 instruction- following 能力明显增强, break the reversal curse

reversal curse: 如果模型知道“A是B”,它并不一定能自动推导出“B是A”。这可能是因为自回归模型优化的是 \(\prod_{t=1}^n P(x_t | x_{<t})\) 它偏向记住“从左到右”的知识模式。

Approach

Theory

\[\mathcal{L}(\theta) \triangleq -\mathbb{E}_{t, x_0, x_t} \left[ \frac{1}{t} \sum_{i=1}^{L} \mathbf{1}[x_t^i = \text{M}] \log p_\theta(x_0^i | x_t) \right]\]

\(t \in [0,1]\) uniformly random, 这个序列每一位都有\(t\)的概率被替换成[MASK], 训练目标是预测被MASK掉的token。\(\frac{1}{t}\) 是为了归一化.

我们无法直接优化左边的真正的负对数似然,但可以优化它的上界。

\[-\mathbb{E}_{p_{\text{data}}(x_0)} [\log p_\theta(x_0)] \leq \mathcal{L}(\theta)\]

Training and Inference

Pretraning: learning \(p_{\theta}(x_0)\)

Backbone: Transformer

与自回归模型的区别:LLaDa不用causal mask,而是可以看到整个sequence。因为不是自回归,所以没有KV cache,因此用的是vanilla multihead attention 而不是 gqa(多个head共享同一个K/V)

Dataset: 2.3 Trillion tokens

lr: Warmup-Stable-Decay learning rate scheduler

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def forward_process(input_ids, eps=1e-3):
b, l = input_ids.shape # batch size, sequence length
t = torch.rand(b, device=input_ids.device) #为batch中每一个序列单独采样一个t
p_mask = (1 - eps) * t + eps # 避免t为0导致p_mask为0,eps为一个很小的数
p_mask = p_mask[:, None].repeat(1, l) # (b,) → (b,1) → (b,l)

# 逐元素比较,把小于 p_mask 的位置标记为 True,其余为 False,实现以p_mask的概率随机mask掉输入序列中的token
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
# masked_indices为True的位置被替换成126336,False的位置保持不变
noisy_batch = torch.where(masked_indices, 126336, input_ids)
return noisy_batch, masked_indices, p_mask

# The data is an integer tensor of shape (b, 4096),
# where b represents the batch size and 4096 is the sequence length.
input_ids = batch["input_ids"]

# We set 1% of the pre-training data to a random length that is uniformly sampled from the range [1, 4096].
# The following implementation is not elegant and involves some data waste.
# However, the data waste is minimal, so we ignore it.
if torch.rand(1) < 0.01:
# 随机输出一个范围为[1, seqlength)的整数,作为新的序列长度
random_length = torch.randint(1, input_ids.shape[1] + 1, (1,))
# (b, 4096) → (b, random_length)
input_ids = input_ids[:, :random_length]

noisy_batch, masked_indices, p_mask = forward_process(input_ids)
# (b, l, vocab_size)
logits = model(input_ids=noisy_batch).logits

# Loss的计算
# 只取被 mask 的 token 的 logit 和 input_id 来计算交叉熵损失
# (b, l, vocab_size) -> (num_masked_tokens, vocab_size)
token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
#注意这里的p_mask[masked_indices]就是论文公式里的t
loss = token_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])

Finetuning: learning \(p_{\theta}(r_0 \mid p_0)\)

把response里面的一些词给mask掉,训练模型去预测这些被mask掉的词。

\[-\mathbb{E}_{t, p_0, r_0, r_t} \left[ \frac{1}{t} \sum_{i=1}^{L'} \mathbf{1}[r_t^i = \text{M}] \log p_\theta(r_0^i | p_0, r_t) \right]\]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
input_ids, prompt_lengths = batch["input_ids"], batch["prompt_lengths"]
# 先对整个序列进行一次整体的加噪
noisy_batch, _, p_mask = forward_process(input_ids)

# Do not add noise to the prompt
token_positions = torch.arange(noisy_batch.shape[1], device=noisy_batch.device).expand(noisy_batch.size(0), noisy_batch.size(1))
prompt_mask = (token_positions < prompt_length.unsqueeze(1))
#把prompt部分的token恢复成原来的input_id,response部分的token保持加噪状态
noisy_batch[prompt_mask] = input_ids[prompt_mask]

# Calculate the answer length (including the padded <EOS> tokens)
prompt_mask = prompt_mask.to(torch.int64)
# 1 - prompt_mask之后,prompt位置为0,response位置为1
answer_lengths = torch.sum((1 - prompt_mask), dim=-1, keepdim=True)
# (b,1) → (b,l),每一行都是answer_length的值,方便后续计算loss时进行归一化
answer_lengths = answer_length.repeat(1, noisy_batch.shape[1])

masked_indices = (noisy_batch == 126336)

logits = model(input_ids=noisy_batch).logits

token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
ce_loss = torch.sum(token_loss / answer_lengths[masked_indices]) / input_ids.shape[0]

Inference

块与块之间是autoregessive的生成,块内是去噪生成。

引入的remasking机制非常关键,它允许低置信度 token 被重新 mask, 下一步模型有机会重新预测,提高生成质量, 而且可以平滑过渡。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def get_num_transfer_tokens(mask_index, steps):
'''
在反向过程中,区间 [0, 1] 被均匀离散化为 steps 个间隔。
此外,由于 LLaDA 采用线性噪声调度,
每一步转换的预期 token 数量应该是一致的。
该函数用于预先计算每一步需要转换的 token 数量。

参数:
mask_index: 布尔张量,标记哪些位置是mask
steps: 总采样步数

返回:
num_transfer_tokens: 形状为 (batch_size, steps),每步要转换的token数量
'''

mask_num = mask_index.sum(dim=1, keepdim=True)
base = mask_num // steps
# 需要额外分配的token数
remainder = mask_num % steps
# 初始化 [batch_size, steps] tensor,每步的基础转换token数为base
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
# 将余数均匀分配到前remainder步(每步+1)
for i in range(mask_num.size(0)):
num_transfer_tokens[i, :remainder[i]] += 1
return num_transfer_tokens

完整的inference核心代码,重点关注remasking机制的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0., remasking='low_confidence', mask_id=126336):
'''
掩码扩散模型的生成函数
参数:
model: 掩码预测模型
prompt: 形状为 (batch_size, prompt_length) 的张量,输入提示
attention_mask: 注意力掩码
steps: 采样步数,小于或等于 gen_length
gen_length: 生成答案的长度
block_length: 块长度,小于或等于 gen_length。如果小于 gen_length,表示使用半自回归重掩码
temperature: 分类分布采样温度,0表示确定性采样
cfg_scale: 无监督分类器引导缩放因子
remasking: 重掩码策略。'low_confidence' 或 'random'
mask_id: [MASK] token 的 ID 是 126336
logits_eos_inf: 是否将 EOS token 的 logits 设置为 -inf。详见 LLaDA 附录 B.4
confidence_eos_eot_inf: 是否将 EOS 和 EoT token 的置信度设置为 -inf。详见 LLaDA 附录 B.4
'''
# 1. 初始化:创建全为mask_id的序列
x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
# 将prompt部分复制到x中,生成部分保持为mask_id
x[:, :prompt.shape[1]] = prompt.clone()
# 扩展attention_mask,把生成部分设置成1,确保生成部分被注意到
if attention_mask is not None:
attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)
prompt_index = (x != mask_id)

# 2. 计算块数和每块的步数
assert gen_length % block_length == 0
num_blocks = gen_length // block_length # 总块数
assert steps % num_blocks == 0
steps = steps // num_blocks # 每块的步数

# 3. 分块生成(半自回归方式)
for num_block in range(num_blocks):
# prompt.shape[1]: 跳过prompt部分
# start: prompt.shape[1] + num_block * block_length
# end: prompt.shape[1] + (num_block + 1) * block_length
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
# 预计算每步要转换的token数量
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)

# 4. 迭代去噪过程
for i in range(steps):
mask_index = (x == mask_id)
logits = model(x, attention_mask=attention_mask).logits

# 添加Gumbel噪声并采样
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1) # 预测的token, 形状: (b, l)

# 5. 计算置信度
if remasking == 'low_confidence':
# 使用softmax概率作为置信度
p = F.softmax(logits, dim=-1)
# gather函数根据x0中的索引从p中提取对应位置的概率值,得到每个位置的置信度
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # 形状: (b, l)
elif remasking == 'random':
# 使用随机值作为置信度
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
raise NotImplementedError(remasking)

# 将当前块之后的位置置信度设为-inf, 屏蔽未来位置
x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
# 只在mask位置更新预测值
x0 = torch.where(mask_index, x0, x)
# 只在mask位置保留置信度
confidence = torch.where(mask_index, x0_p, -np.inf)

# 6. 重掩码:选择置信度最高的num_transfer_tokens个位置进行更新
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) # 初始全False
for j in range(confidence.shape[0]):
# 选择置信度最高的k个位置
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) # (k,)
transfer_index[j, select_index] = True #把被选择的标记为True
# 只更新选中的位置,其余保持mask状态
x[transfer_index] = x0[transfer_index]

return x

Diffusion LM初探:LLaDA
http://example.com/2026/02/24/post-2/
作者
瑾瑜當年
发布于
2026年2月24日
许可协议