本文最后更新于 2026-02-24T22:29:45+08:00
LLaDa 论文&代码阅读
Github链接
Paper链接
Introduction
Insights
The generative modeling principles:
\[\max_{\theta}
\mathbb{E}_{p_{\text{data}}(x)} \log p_{\theta}(x) \Leftrightarrow
\min_{\theta} \mathbf{KL}(p_{\text{data}}(x) \|
p_{\theta}(x))\]
KL散度是:
\[D_{KL}(P \| Q) = \sum_{x} P(x) \log P(x)
- \sum_{x} P(x) \log Q(x)\]
其中\(\sum_{x} P(x) \log
P(x)\) 是一个常数,因此最小化KL散度等价于最大化$ _{x} P(x)
Q(x)$
AR formulation:
\[p_{\theta}(x) = p_{\theta}(x^1)
\prod_{i=2}^{L} p_{\theta}(x^i \mid x^1, \dots, x^{i-1})\]
Insight: 是 generative modeling principles 而不是 autoregressive
formulation 决定了 LLM的那些属性 (scalability, insturction following,
in-context learning)
Contributions
scalable, 可以做in-context learning, SFT 后 instruction- following
能力明显增强, break the reversal curse
reversal curse:
如果模型知道“A是B”,它并不一定能自动推导出“B是A”。这可能是因为自回归模型优化的是
\(\prod_{t=1}^n P(x_t | x_{<t})\)
它偏向记住“从左到右”的知识模式。
Approach
Theory
\[\mathcal{L}(\theta) \triangleq
-\mathbb{E}_{t, x_0, x_t} \left[ \frac{1}{t} \sum_{i=1}^{L}
\mathbf{1}[x_t^i = \text{M}] \log p_\theta(x_0^i | x_t)
\right]\]
\(t \in [0,1]\) uniformly random,
这个序列每一位都有\(t\) 的概率被替换成[MASK],
训练目标是预测被MASK掉的token。\(\frac{1}{t}\) 是为了归一化.
我们无法直接优化左边的真正的负对数似然,但可以优化它的上界。
\[-\mathbb{E}_{p_{\text{data}}(x_0)} [\log
p_\theta(x_0)] \leq \mathcal{L}(\theta)\]
Training and Inference
Pretraning: learning \(p_{\theta}(x_0)\)
Backbone: Transformer
与自回归模型的区别:LLaDa不用causal
mask,而是可以看到整个sequence。因为不是自回归,所以没有KV
cache,因此用的是vanilla multihead attention 而不是
gqa(多个head共享同一个K/V)
Dataset: 2.3 Trillion tokens
lr: Warmup-Stable-Decay learning rate scheduler
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 def forward_process (input_ids, eps=1e-3 ): b, l = input_ids.shape t = torch.rand(b, device=input_ids.device) p_mask = (1 - eps) * t + eps p_mask = p_mask[:, None ].repeat(1 , l) masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask noisy_batch = torch.where(masked_indices, 126336 , input_ids) return noisy_batch, masked_indices, p_mask input_ids = batch["input_ids" ]if torch.rand(1 ) < 0.01 : random_length = torch.randint(1 , input_ids.shape[1 ] + 1 , (1 ,)) input_ids = input_ids[:, :random_length] noisy_batch, masked_indices, p_mask = forward_process(input_ids) logits = model(input_ids=noisy_batch).logits token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none' ) / p_mask[masked_indices] loss = token_loss.sum () / (input_ids.shape[0 ] * input_ids.shape[1 ])
Finetuning: learning
\(p_{\theta}(r_0 \mid p_0)\)
把response里面的一些词给mask掉,训练模型去预测这些被mask掉的词。
\[-\mathbb{E}_{t, p_0, r_0, r_t} \left[
\frac{1}{t} \sum_{i=1}^{L'} \mathbf{1}[r_t^i = \text{M}] \log
p_\theta(r_0^i | p_0, r_t) \right]\]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 input_ids, prompt_lengths = batch["input_ids" ], batch["prompt_lengths" ] noisy_batch, _, p_mask = forward_process(input_ids) token_positions = torch.arange(noisy_batch.shape[1 ], device=noisy_batch.device).expand(noisy_batch.size(0 ), noisy_batch.size(1 )) prompt_mask = (token_positions < prompt_length.unsqueeze(1 )) noisy_batch[prompt_mask] = input_ids[prompt_mask] prompt_mask = prompt_mask.to(torch.int64) answer_lengths = torch.sum ((1 - prompt_mask), dim=-1 , keepdim=True ) answer_lengths = answer_length.repeat(1 , noisy_batch.shape[1 ]) masked_indices = (noisy_batch == 126336 ) logits = model(input_ids=noisy_batch).logits token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none' ) / p_mask[masked_indices] ce_loss = torch.sum (token_loss / answer_lengths[masked_indices]) / input_ids.shape[0 ]
Inference
块与块之间是autoregessive的生成,块内是去噪生成。
引入的remasking机制非常关键,它允许低置信度 token 被重新 mask,
下一步模型有机会重新预测,提高生成质量, 而且可以平滑过渡。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 def get_num_transfer_tokens (mask_index, steps ): ''' 在反向过程中,区间 [0, 1] 被均匀离散化为 steps 个间隔。 此外,由于 LLaDA 采用线性噪声调度, 每一步转换的预期 token 数量应该是一致的。 该函数用于预先计算每一步需要转换的 token 数量。 参数: mask_index: 布尔张量,标记哪些位置是mask steps: 总采样步数 返回: num_transfer_tokens: 形状为 (batch_size, steps),每步要转换的token数量 ''' mask_num = mask_index.sum (dim=1 , keepdim=True ) base = mask_num // steps remainder = mask_num % steps num_transfer_tokens = torch.zeros(mask_num.size(0 ), steps, device=mask_index.device, dtype=torch.int64) + base for i in range (mask_num.size(0 )): num_transfer_tokens[i, :remainder[i]] += 1 return num_transfer_tokens
完整的inference核心代码,重点关注remasking机制的实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 def generate (model, prompt, attention_mask=None , steps=128 , gen_length=128 , block_length=128 , temperature=0. , remasking='low_confidence' , mask_id=126336 ): ''' 掩码扩散模型的生成函数 参数: model: 掩码预测模型 prompt: 形状为 (batch_size, prompt_length) 的张量,输入提示 attention_mask: 注意力掩码 steps: 采样步数,小于或等于 gen_length gen_length: 生成答案的长度 block_length: 块长度,小于或等于 gen_length。如果小于 gen_length,表示使用半自回归重掩码 temperature: 分类分布采样温度,0表示确定性采样 cfg_scale: 无监督分类器引导缩放因子 remasking: 重掩码策略。'low_confidence' 或 'random' mask_id: [MASK] token 的 ID 是 126336 logits_eos_inf: 是否将 EOS token 的 logits 设置为 -inf。详见 LLaDA 附录 B.4 confidence_eos_eot_inf: 是否将 EOS 和 EoT token 的置信度设置为 -inf。详见 LLaDA 附录 B.4 ''' x = torch.full((prompt.shape[0 ], prompt.shape[1 ] + gen_length), mask_id, dtype=torch.long).to(model.device) x[:, :prompt.shape[1 ]] = prompt.clone() if attention_mask is not None : attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0 ], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1 ) prompt_index = (x != mask_id) assert gen_length % block_length == 0 num_blocks = gen_length // block_length assert steps % num_blocks == 0 steps = steps // num_blocks for num_block in range (num_blocks): block_mask_index = (x[:, prompt.shape[1 ] + num_block * block_length: prompt.shape[1 ] + (num_block + 1 ) * block_length:] == mask_id) num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) for i in range (steps): mask_index = (x == mask_id) logits = model(x, attention_mask=attention_mask).logits logits_with_noise = add_gumbel_noise(logits, temperature=temperature) x0 = torch.argmax(logits_with_noise, dim=-1 ) if remasking == 'low_confidence' : p = F.softmax(logits, dim=-1 ) x0_p = torch.squeeze( torch.gather(p, dim=-1 , index=torch.unsqueeze(x0, -1 )), -1 ) elif remasking == 'random' : x0_p = torch.rand((x0.shape[0 ], x0.shape[1 ]), device=x0.device) else : raise NotImplementedError(remasking) x0_p[:, prompt.shape[1 ] + (num_block + 1 ) * block_length:] = -np.inf x0 = torch.where(mask_index, x0, x) confidence = torch.where(mask_index, x0_p, -np.inf) transfer_index = torch.zeros_like(x0, dtype=torch.bool , device=x0.device) for j in range (confidence.shape[0 ]): _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) transfer_index[j, select_index] = True x[transfer_index] = x0[transfer_index] return x