从零到一:用PyTorch手撕Transformer(附完整代码与调试技巧)
1. 为什么需要手写Transformer第一次接触Transformer时你可能会有这样的疑问现在有这么多现成的深度学习框架如HuggingFace的transformers库为什么还要从零开始实现呢这里我分享一个真实案例去年我们团队在部署一个翻译模型时发现直接调用预训练模型在长文本翻译时会出现内存泄漏。由于不熟悉底层实现调试花了整整两周。而亲手实现过Transformer的同事仅用两天就定位到是注意力矩阵的内存管理问题。手写Transformer的三大核心价值维度魔术真正理解(batch_size, seq_len, d_model)这些张量在每一层的变换过程。比如多头注意力中d_k的缩放操作只有亲手实现过才会明白为什么需要除以$\sqrt{d_k}$调试能力当模型输出异常时能快速定位是mask机制问题还是残差连接问题。我曾遇到过一个bug解码时总重复生成相同词汇最后发现是解码器自注意力mask未正确设置定制魔改想给注意力加个稀疏约束想尝试新型位置编码只有掌握底层实现才能灵活改造模型结构下面这张表对比了不同学习方式的收益学习方式理论理解调试能力改造灵活性时间成本直接调用API★★☆★☆☆★☆☆低阅读论文★★★★☆☆★★☆中手写实现★★★★★★★★★高2. 环境准备与数据预处理2.1 极简PyTorch环境推荐使用conda创建纯净环境conda create -n transformer python3.8 conda activate transformer pip install torch1.12.1cu113 torchvision0.13.1cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install numpy matplotlib tqdm避坑指南CUDA版本要与PyTorch对应可通过nvidia-smi查看建议固定PyTorch版本不同版本间张量操作可能有细微差异如果遇到RuntimeError: CUDA out of memory尝试减小batch_size或使用梯度累积2.2 玩具级数据集构建为了聚焦模型实现我们构造一个极简的德语→英语翻译数据集# 特殊标记 PAD 0 # 填充标记 SOS 1 # 句子开始 EOS 2 # 句子结束 sentences [ # 德语 输入 输出 [ich mochte ein bier, S i want a beer, i want a beer E], [ich mochte ein cola, S i want a coke, i want a coke E] ] vocab { de: {P: PAD, ich: 3, mochte: 4, ein: 5, bier: 6, cola: 7}, en: {P: PAD, S: SOS, E: EOS, i: 8, want: 9, a: 10, beer: 11, coke: 12} }数据处理技巧序列填充使用torch.nn.utils.rnn.pad_sequence自动处理不等长序列批量生成DataLoader的collate_fn参数可以自定义批次组装逻辑设备转移用.to(device)统一管理数据位置完整的数据管道实现class TranslationDataset(Dataset): def __init__(self, sentences, vocab): self.enc_inputs [] self.dec_inputs [] self.dec_outputs [] for de, en_in, en_out in sentences: self.enc_inputs.append([vocab[de][word] for word in de.split()]) self.dec_inputs.append([vocab[en][word] for word in en_in.split()]) self.dec_outputs.append([vocab[en][word] for word in en_out.split()]) def __getitem__(self, idx): return ( torch.LongTensor(self.enc_inputs[idx]), torch.LongTensor(self.dec_inputs[idx]), torch.LongTensor(self.dec_outputs[idx]) ) def collate_fn(batch): enc_inputs [item[0] for item in batch] dec_inputs [item[1] for item in batch] dec_outputs [item[2] for item in batch] return ( pad_sequence(enc_inputs, batch_firstTrue, padding_valuePAD), pad_sequence(dec_inputs, batch_firstTrue, padding_valuePAD), pad_sequence(dec_outputs, batch_firstTrue, padding_valuePAD) ) dataset TranslationDataset(sentences, vocab) loader DataLoader(dataset, batch_size2, collate_fncollate_fn)3. Transformer核心组件实现3.1 位置编码的数学之美Transformer的位置编码采用正弦余弦函数其精妙之处在于相对位置信息通过三角函数特性任意位置的编码都能表示为其他位置的线性组合可扩展性即使遇到比训练时更长的序列也能生成合理的编码class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len5000): super().__init__() pe torch.zeros(max_len, d_model) position torch.arange(0, max_len).unsqueeze(1) div_term torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) pe[:, 0::2] torch.sin(position * div_term) # 偶数位置 pe[:, 1::2] torch.cos(position * div_term) # 奇数位置 self.register_buffer(pe, pe) def forward(self, x): return x self.pe[:x.size(1)].unsqueeze(0) # 自动广播到batch维度调试技巧可视化位置编码用plt.imshow(pe.numpy())检查是否呈现棋盘格模式数值检查确保相邻位置的编码差异适中太大或太小都会影响训练3.2 注意力机制的三大核心3.2.1 缩放点积注意力def scaled_dot_product_attention(Q, K, V, maskNone): # Q/K/V形状: (batch_size, n_heads, seq_len, d_k) scores torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.size(-1)) if mask is not None: scores scores.masked_fill(mask 0, -1e9) # 用极小值填充被mask的位置 attn F.softmax(scores, dim-1) return torch.matmul(attn, V), attn关键点缩放因子$\sqrt{d_k}$防止点积过大导致softmax梯度消失mask操作要在softmax之前完成3.2.2 多头注意力class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() assert d_model % n_heads 0 self.d_k d_model // n_heads self.n_heads n_heads self.W_Q nn.Linear(d_model, d_model) self.W_K nn.Linear(d_model, d_model) self.W_V nn.Linear(d_model, d_model) self.W_O nn.Linear(d_model, d_model) def forward(self, Q, K, V, mask): # 线性变换 分头 Q self.W_Q(Q).view(-1, Q.size(1), self.n_heads, self.d_k).transpose(1, 2) K self.W_K(K).view(-1, K.size(1), self.n_heads, self.d_k).transpose(1, 2) V self.W_V(V).view(-1, V.size(1), self.n_heads, self.d_k).transpose(1, 2) # 计算注意力 if mask is not None: mask mask.unsqueeze(1) # 广播到所有头 x, attn scaled_dot_product_attention(Q, K, V, mask) # 合并多头 x x.transpose(1, 2).contiguous().view(-1, x.size(2), self.n_heads * self.d_k) return self.W_O(x)维度变换解析输入形状(batch_size, seq_len, d_model)线性变换后(batch_size, seq_len, d_model)分头操作(batch_size, seq_len, n_heads, d_k)→ 转置为(batch_size, n_heads, seq_len, d_k)注意力计算后保持形状不变合并输出(batch_size, seq_len, d_model)3.2.3 掩码机制Transformer使用两种掩码填充掩码避免注意力机制处理填充符号序列掩码防止解码器看到未来信息def create_masks(enc_input, dec_input): # 编码器掩码仅padding enc_mask (enc_input ! PAD).unsqueeze(1).unsqueeze(2) # 解码器掩码padding future dec_pad_mask (dec_input ! PAD).unsqueeze(1).unsqueeze(2) seq_len dec_input.size(1) dec_seq_mask torch.tril(torch.ones(seq_len, seq_len)).bool().to(dec_input.device) dec_mask dec_pad_mask dec_seq_mask return enc_mask, dec_mask4. 模型训练与调试技巧4.1 学习率调度策略Transformer使用带预热warmup的学习率调度class TransformerOptimizer: def __init__(self, optimizer, d_model, warmup_steps4000): self.optimizer optimizer self.d_model d_model self.warmup_steps warmup_steps self.current_step 0 def step(self): self.current_step 1 lr self.d_model ** -0.5 * min(self.current_step ** -0.5, self.current_step * self.warmup_steps ** -1.5) for param_group in self.optimizer.param_groups: param_group[lr] lr self.optimizer.step()训练曲线解读初期学习率线性增长避免冷启动中期随步数平方根衰减后期稳定在小学习率微调4.2 梯度裁剪与损失函数criterion nn.CrossEntropyLoss(ignore_indexPAD) optimizer TransformerOptimizer( torch.optim.Adam(model.parameters(), betas(0.9, 0.98), eps1e-9), d_model512 ) def train_step(batch): enc_input, dec_input, dec_output batch enc_mask, dec_mask create_masks(enc_input, dec_input) pred model(enc_input, dec_input, enc_mask, dec_mask) loss criterion(pred.view(-1, pred.size(-1)), dec_output.view(-1)) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm1.0) optimizer.step() return loss.item()常见问题排查NaN损失检查注意力分数是否在softmax前被正确mask梯度爆炸调小学习率或增强梯度裁剪欠拟合增加模型深度或检查数据预处理5. 完整模型组装5.1 编码器实现class EncoderLayer(nn.Module): def __init__(self, d_model, n_heads, d_ff, dropout0.1): super().__init__() self.self_attn MultiHeadAttention(d_model, n_heads) self.ffn nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model) ) self.norm1 nn.LayerNorm(d_model) self.norm2 nn.LayerNorm(d_model) self.dropout nn.Dropout(dropout) def forward(self, x, mask): attn_output self.self_attn(x, x, x, mask) x self.norm1(x self.dropout(attn_output)) ffn_output self.ffn(x) return self.norm2(x self.dropout(ffn_output)) class Encoder(nn.Module): def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, dropout0.1): super().__init__() self.embedding nn.Embedding(vocab_size, d_model) self.pos_encoding PositionalEncoding(d_model) self.layers nn.ModuleList([ EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) ]) def forward(self, x, mask): x self.pos_encoding(self.embedding(x)) for layer in self.layers: x layer(x, mask) return x5.2 解码器实现class DecoderLayer(nn.Module): def __init__(self, d_model, n_heads, d_ff, dropout0.1): super().__init__() self.self_attn MultiHeadAttention(d_model, n_heads) self.enc_attn MultiHeadAttention(d_model, n_heads) self.ffn nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model) ) self.norm1 nn.LayerNorm(d_model) self.norm2 nn.LayerNorm(d_model) self.norm3 nn.LayerNorm(d_model) self.dropout nn.Dropout(dropout) def forward(self, x, enc_output, src_mask, tgt_mask): # 自注意力带未来信息mask attn_output self.self_attn(x, x, x, tgt_mask) x self.norm1(x self.dropout(attn_output)) # 编码器-解码器注意力 attn_output self.enc_attn(x, enc_output, enc_output, src_mask) x self.norm2(x self.dropout(attn_output)) ffn_output self.ffn(x) return self.norm3(x self.dropout(ffn_output)) class Decoder(nn.Module): def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, dropout0.1): super().__init__() self.embedding nn.Embedding(vocab_size, d_model) self.pos_encoding PositionalEncoding(d_model) self.layers nn.ModuleList([ DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) ]) def forward(self, x, enc_output, src_mask, tgt_mask): x self.pos_encoding(self.embedding(x)) for layer in self.layers: x layer(x, enc_output, src_mask, tgt_mask) return x5.3 Transformer完整架构class Transformer(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model512, n_layers6, n_heads8, d_ff2048, dropout0.1): super().__init__() self.encoder Encoder(src_vocab_size, d_model, n_layers, n_heads, d_ff, dropout) self.decoder Decoder(tgt_vocab_size, d_model, n_layers, n_heads, d_ff, dropout) self.fc nn.Linear(d_model, tgt_vocab_size) def forward(self, src, tgt, src_mask, tgt_mask): enc_output self.encoder(src, src_mask) dec_output self.decoder(tgt, enc_output, src_mask, tgt_mask) return self.fc(dec_output)6. 模型部署与推理6.1 贪婪解码实现def greedy_decode(model, src, src_mask, max_len20, start_symbolSOS): memory model.encoder(src, src_mask) ys torch.ones(1, 1).fill_(start_symbol).type_as(src) for _ in range(max_len-1): tgt_mask create_decoder_mask(ys) out model.decoder(ys, memory, src_mask, tgt_mask) prob model.fc(out[:, -1]) next_word prob.argmax(dim-1) ys torch.cat([ys, next_word.unsqueeze(0)], dim1) if next_word EOS: break return ys6.2 可视化注意力权重def plot_attention(attention, src_sentence, tgt_sentence): fig plt.figure(figsize(10, 10)) ax fig.add_subplot(111) cax ax.matshow(attention.numpy(), cmapbone) fig.colorbar(cax) ax.set_xticklabels([] src_sentence.split(), rotation90) ax.set_yticklabels([] tgt_sentence.split()) return fig