PyTorch 2.0+ 实现 Transformer:6层编码器/解码器在 WMT14 数据集上的完整训练流程
PyTorch 2.0 实现 Transformer6层编码器/解码器在 WMT14 数据集上的完整训练流程Transformer 架构自 2017 年提出以来已成为自然语言处理领域的基石模型。本文将深入探讨如何使用 PyTorch 2.0 实现一个完整的 Transformer 模型并在 WMT14 英德翻译数据集上进行训练。不同于简单的玩具实现我们将重点关注工业级的数据流水线构建、混合精度训练和梯度累积等高级技巧。1. 环境准备与数据预处理1.1 安装依赖与配置首先确保已安装 PyTorch 2.0 和必要的依赖库pip install torch torchtext torchdata sacrebleu tensorboard对于 GPU 加速训练建议使用 CUDA 11.7 版本。我们可以通过以下代码检查环境配置import torch print(fPyTorch version: {torch.__version__}) print(fCUDA available: {torch.cuda.is_available()}) print(fCUDA version: {torch.version.cuda})1.2 WMT14 数据集加载WMT14 是机器翻译领域的标准基准数据集包含约 450 万句英德平行语料。我们将使用 torchtext 提供的 API 进行加载from torchtext.datasets import WMT14 from torchtext.data.utils import get_tokenizer SRC_LANGUAGE de TGT_LANGUAGE en # 加载分词器 token_transform { SRC_LANGUAGE: get_tokenizer(spacy, languagede_core_news_sm), TGT_LANGUAGE: get_tokenizer(spacy, languageen_core_web_sm) } # 构建词汇表 def build_vocab(filepaths, tokenizer, min_freq2): counter Counter() for filepath in filepaths: with open(filepath, r, encodingutf-8) as f: for line in f: counter.update(tokenizer(line)) return Vocab(counter, min_freqmin_freq, specials[unk, pad, bos, eos]) train_iter WMT14(splittrain, language_pair(SRC_LANGUAGE, TGT_LANGUAGE)) vocab_transform { SRC_LANGUAGE: build_vocab([train_iter.src], token_transform[SRC_LANGUAGE]), TGT_LANGUAGE: build_vocab([train_iter.tgt], token_transform[TGT_LANGUAGE]) }提示在实际项目中建议预先处理好数据集并保存到本地避免每次训练都重新处理。1.3 数据流水线优化为了高效加载数据我们实现一个自定义的 Dataset 和 DataLoaderfrom torch.utils.data import Dataset, DataLoader class TranslationDataset(Dataset): def __init__(self, src_sentences, tgt_sentences, src_vocab, tgt_vocab): self.src_sentences src_sentences self.tgt_sentences tgt_sentences self.src_vocab src_vocab self.tgt_vocab tgt_vocab def __len__(self): return len(self.src_sentences) def __getitem__(self, idx): src_sentence self.src_sentences[idx] tgt_sentence self.tgt_sentences[idx] src_tensor torch.tensor([self.src_vocab[token] for token in token_transform[SRC_LANGUAGE](src_sentence)]) tgt_tensor torch.tensor([self.tgt_vocab[token] for token in token_transform[TGT_LANGUAGE](tgt_sentence)]) return src_tensor, tgt_tensor def collate_fn(batch): src_batch, tgt_batch zip(*batch) src_batch pad_sequence(src_batch, padding_valuePAD_IDX) tgt_batch pad_sequence(tgt_batch, padding_valuePAD_IDX) return src_batch, tgt_batch train_loader DataLoader(train_dataset, batch_size128, shuffleTrue, collate_fncollate_fn)2. Transformer 模型实现2.1 模型参数配置我们首先定义 Transformer 的核心参数class TransformerConfig: def __init__(self): self.d_model 512 # 嵌入维度 self.nhead 8 # 注意力头数 self.num_encoder_layers 6 # 编码器层数 self.num_decoder_layers 6 # 解码器层数 self.dim_feedforward 2048 # 前馈网络维度 self.dropout 0.1 # Dropout概率 self.activation relu # 激活函数 self.max_seq_length 100 # 最大序列长度 self.src_vocab_size len(vocab_transform[SRC_LANGUAGE]) self.tgt_vocab_size len(vocab_transform[TGT_LANGUAGE])2.2 位置编码实现位置编码是 Transformer 的关键组件用于注入序列的位置信息class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout0.1, max_len5000): super().__init__() self.dropout nn.Dropout(pdropout) position torch.arange(max_len).unsqueeze(1) div_term torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] torch.sin(position * div_term) pe[:, 0, 1::2] torch.cos(position * div_term) self.register_buffer(pe, pe) def forward(self, x): x x self.pe[:x.size(0)] return self.dropout(x)2.3 完整 Transformer 实现基于 PyTorch 的 nn.Transformer 模块我们可以构建完整的模型class TransformerModel(nn.Module): def __init__(self, config): super().__init__() self.config config self.src_tok_emb nn.Embedding(config.src_vocab_size, config.d_model) self.tgt_tok_emb nn.Embedding(config.tgt_vocab_size, config.d_model) self.positional_encoding PositionalEncoding(config.d_model, config.dropout) self.transformer nn.Transformer( d_modelconfig.d_model, nheadconfig.nhead, num_encoder_layersconfig.num_encoder_layers, num_decoder_layersconfig.num_decoder_layers, dim_feedforwardconfig.dim_feedforward, dropoutconfig.dropout, activationconfig.activation ) self.fc_out nn.Linear(config.d_model, config.tgt_vocab_size) def forward(self, src, tgt, src_maskNone, tgt_maskNone, memory_maskNone, src_key_padding_maskNone, tgt_key_padding_maskNone, memory_key_padding_maskNone): src_emb self.positional_encoding(self.src_tok_emb(src)) tgt_emb self.positional_encoding(self.tgt_tok_emb(tgt)) output self.transformer( src_emb, tgt_emb, src_masksrc_mask, tgt_masktgt_mask, memory_maskmemory_mask, src_key_padding_masksrc_key_padding_mask, tgt_key_padding_masktgt_key_padding_mask, memory_key_padding_maskmemory_key_padding_mask ) return self.fc_out(output)3. 训练流程优化3.1 混合精度训练PyTorch 的 AMP (Automatic Mixed Precision) 可以显著加速训练并减少显存占用from torch.cuda.amp import GradScaler, autocast scaler GradScaler() def train_step(model, optimizer, criterion, src, tgt): model.train() optimizer.zero_grad() tgt_input tgt[:-1, :] tgt_output tgt[1:, :] with autocast(): output model(src, tgt_input) loss criterion(output.view(-1, output.size(-1)), tgt_output.view(-1)) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() return loss.item()3.2 梯度累积对于大 batch size 训练可以使用梯度累积技术accumulation_steps 4 def train_epoch(model, optimizer, criterion, train_loader): model.train() total_loss 0 optimizer.zero_grad() for i, (src, tgt) in enumerate(train_loader): src src.to(device) tgt tgt.to(device) tgt_input tgt[:-1, :] tgt_output tgt[1:, :] with autocast(): output model(src, tgt_input) loss criterion(output.view(-1, output.size(-1)), tgt_output.view(-1)) loss loss / accumulation_steps scaler.scale(loss).backward() if (i 1) % accumulation_steps 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() total_loss loss.item() * accumulation_steps return total_loss / len(train_loader)3.3 学习率调度使用余弦退火学习率调度器from torch.optim.lr_scheduler import CosineAnnealingLR optimizer torch.optim.Adam(model.parameters(), lr0.0005, betas(0.9, 0.98), eps1e-9) scheduler CosineAnnealingLR(optimizer, T_max10, eta_min1e-5)4. 评估与结果分析4.1 BLEU 分数评估使用 sacreBLEU 进行自动评估from sacrebleu import corpus_bleu def evaluate(model, val_loader, max_len100): model.eval() translations [] references [] with torch.no_grad(): for src, tgt in val_loader: src src.to(device) tgt tgt.to(device) # 使用贪心解码生成翻译 translation greedy_decode(model, src, max_len) translations.append(translation) references.append(tgt.cpu().numpy()) bleu_score corpus_bleu(translations, references) return bleu_score.score def greedy_decode(model, src, max_len): memory model.encode(src) ys torch.ones(1, 1).fill_(BOS_IDX).type_as(src.data) for i in range(max_len-1): out model.decode(memory, ys) prob model.generator(out[:, -1]) _, next_word torch.max(prob, dim1) next_word next_word.item() ys torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim1) if next_word EOS_IDX: break return ys4.2 训练监控使用 TensorBoard 记录训练过程from torch.utils.tensorboard import SummaryWriter writer SummaryWriter() for epoch in range(30): train_loss train_epoch(model, optimizer, criterion, train_loader) val_bleu evaluate(model, val_loader) writer.add_scalar(Loss/train, train_loss, epoch) writer.add_scalar(BLEU/val, val_bleu, epoch) scheduler.step() print(fEpoch: {epoch1:02d} | Train Loss: {train_loss:.3f} | Val BLEU: {val_bleu:.2f})5. 高级优化技巧5.1 标签平滑标签平滑可以防止模型对训练数据过度自信class LabelSmoothingLoss(nn.Module): def __init__(self, classes, padding_idx, smoothing0.1): super().__init__() self.criterion nn.KLDivLoss(reductionsum) self.padding_idx padding_idx self.confidence 1.0 - smoothing self.smoothing smoothing self.classes classes self.true_dist None def forward(self, x, target): assert x.size(1) self.classes true_dist x.data.clone() true_dist.fill_(self.smoothing / (self.classes - 2)) true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) true_dist[:, self.padding_idx] 0 mask torch.nonzero(target.data self.padding_idx) if mask.dim() 0: true_dist.index_fill_(0, mask.squeeze(), 0.0) self.true_dist true_dist return self.criterion(x, true_dist) criterion LabelSmoothingLoss(len(vocab_transform[TGT_LANGUAGE]), PAD_IDX)5.2 模型并行与数据并行对于大型 Transformer 模型可以使用分布式训练import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP def setup(rank, world_size): dist.init_process_group(nccl, rankrank, world_sizeworld_size) def cleanup(): dist.destroy_process_group() def train_distributed(rank, world_size): setup(rank, world_size) model TransformerModel(config).to(rank) model DDP(model, device_ids[rank]) optimizer torch.optim.Adam(model.parameters()) # 训练循环... cleanup()5.3 模型量化与优化训练完成后可以对模型进行量化以提升推理速度quantized_model torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtypetorch.qint8 )6. 实际部署建议在生产环境中部署 Transformer 模型时建议考虑以下优化使用 TorchScript 导出模型scripted_model torch.jit.script(model) scripted_model.save(transformer_scripted.pt)实现高效的批处理推理from torch.utils.data import DataLoader from concurrent.futures import ThreadPoolExecutor class InferencePipeline: def __init__(self, model_path, batch_size32, max_workers4): self.model torch.jit.load(model_path) self.executor ThreadPoolExecutor(max_workersmax_workers) self.batch_size batch_size def process_batch(self, batch): with torch.no_grad(): return self.model(batch) async def predict(self, inputs): batches [inputs[i:iself.batch_size] for i in range(0, len(inputs), self.batch_size)] results list(self.executor.map(self.process_batch, batches)) return torch.cat(results)