用PyTorch逐行复现Transformer:从Harvard NLP的注释代码到你的第一个翻译模型
用PyTorch逐行构建Transformer从理论到德语-英语翻译实战1. 环境准备与数据加载在开始构建Transformer之前我们需要配置合适的开发环境。推荐使用Python 3.8和PyTorch 1.10版本这些组合能提供最佳的兼容性和性能表现。基础环境配置conda create -n transformer python3.8 conda activate transformer pip install torch torchtext spacy matplotlib seaborn python -m spacy download en_core_web_sm python -m spacy download de_core_news_sm对于IWSLT德语-英语数据集我们可以使用torchtext的内置加载功能。这个数据集包含约20万句对是入门级机器翻译任务的理想选择。数据预处理关键步骤文本标准化统一大小写、处理特殊字符分词使用spacy进行语言特定的分词构建词汇表过滤低频词添加特殊标记批处理按长度分组减少填充量from torchtext.data import Field, BucketIterator from torchtext.datasets import IWSLT SRC Field(tokenizespacy, tokenizer_languagede, init_tokensos, eos_tokeneos, lowerTrue) TGT Field(tokenizespacy, tokenizer_languageen, init_tokensos, eos_tokeneos, lowerTrue) train_data, valid_data, test_data IWSLT.splits( exts(.de, .en), fields(SRC, TGT), filter_predlambda x: len(vars(x)[src]) 100 and len(vars(x)[trg]) 100 ) SRC.build_vocab(train_data, min_freq2) TGT.build_vocab(train_data, min_freq2)2. Transformer核心组件实现2.1 多头注意力机制多头注意力是Transformer的核心创新它允许模型同时关注不同表示子空间的信息。我们首先实现缩放点积注意力然后构建完整的多头注意力模块。缩放点积注意力公式 $$ \text{Attention}(Q,K,V) \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$import torch import torch.nn as nn import torch.nn.functional as F import math class ScaledDotProductAttention(nn.Module): def __init__(self, dropout0.1): super().__init__() self.dropout nn.Dropout(dropout) def forward(self, q, k, v, maskNone): d_k q.size(-1) scores torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores scores.masked_fill(mask 0, -1e9) attn F.softmax(scores, dim-1) attn self.dropout(attn) return torch.matmul(attn, v), attn多头注意力实现要点将输入线性投影到h个不同的子空间在每个头上独立计算注意力拼接所有头的输出并通过最终线性变换class MultiHeadAttention(nn.Module): def __init__(self, d_model, h, dropout0.1): super().__init__() assert d_model % h 0 self.d_k d_model // h self.h h self.q_linear nn.Linear(d_model, d_model) self.k_linear nn.Linear(d_model, d_model) self.v_linear nn.Linear(d_model, d_model) self.attention ScaledDotProductAttention(dropout) self.out nn.Linear(d_model, d_model) def forward(self, q, k, v, maskNone): batch_size q.size(0) # 线性投影并分头 q self.q_linear(q).view(batch_size, -1, self.h, self.d_k).transpose(1,2) k self.k_linear(k).view(batch_size, -1, self.h, self.d_k).transpose(1,2) v self.v_linear(v).view(batch_size, -1, self.h, self.d_k).transpose(1,2) # 计算注意力 x, attn self.attention(q, k, v, maskmask) # 拼接并做最终投影 x x.transpose(1,2).contiguous().view(batch_size, -1, self.h * self.d_k) return self.out(x)2.2 位置编码与位置前馈网络Transformer需要显式的位置信息因为其不包含循环或卷积结构。我们使用正弦/余弦位置编码class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len5000): super().__init__() pe torch.zeros(max_len, d_model) position torch.arange(0, max_len).unsqueeze(1) div_term torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0)/d_model)) pe[:, 0::2] torch.sin(position * div_term) pe[:, 1::2] torch.cos(position * div_term) pe pe.unsqueeze(0) self.register_buffer(pe, pe) def forward(self, x): return x self.pe[:, :x.size(1)]位置前馈网络由两个线性变换和ReLU激活组成class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout0.1): super().__init__() self.w1 nn.Linear(d_model, d_ff) self.w2 nn.Linear(d_ff, d_model) self.dropout nn.Dropout(dropout) def forward(self, x): return self.w2(self.dropout(F.relu(self.w1(x))))3. 编码器与解码器结构3.1 编码器层实现编码器层包含自注意力机制和前馈网络每个子层都有残差连接和层归一化class EncoderLayer(nn.Module): def __init__(self, d_model, self_attn, feed_forward, dropout): super().__init__() self.self_attn self_attn self.feed_forward feed_forward self.sublayer nn.ModuleList([ SublayerConnection(d_model, dropout) for _ in range(2) ]) self.size d_model def forward(self, x, mask): x self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) return self.sublayer[1](x, self.feed_forward) class SublayerConnection(nn.Module): def __init__(self, size, dropout): super().__init__() self.norm nn.LayerNorm(size) self.dropout nn.Dropout(dropout) def forward(self, x, sublayer): return x self.dropout(sublayer(self.norm(x)))3.2 解码器层实现解码器层额外包含编码器-解码器注意力机制并实现了未来信息屏蔽class DecoderLayer(nn.Module): def __init__(self, d_model, self_attn, src_attn, feed_forward, dropout): super().__init__() self.self_attn self_attn self.src_attn src_attn self.feed_forward feed_forward self.sublayer nn.ModuleList([ SublayerConnection(d_model, dropout) for _ in range(3) ]) def forward(self, x, memory, src_mask, tgt_mask): m memory x self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) x self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) return self.sublayer[2](x, self.feed_forward)4. 完整模型组装与训练4.1 模型组装将各个组件组合成完整的Transformer模型class Transformer(nn.Module): def __init__(self, src_vocab, tgt_vocab, N6, d_model512, d_ff2048, h8, dropout0.1): super().__init__() attn MultiHeadAttention(d_model, h, dropout) ff PositionwiseFeedForward(d_model, d_ff, dropout) position PositionalEncoding(d_model, dropout) self.encoder Encoder(EncoderLayer(d_model, attn, ff, dropout), N) self.decoder Decoder(DecoderLayer(d_model, attn, attn, ff, dropout), N) self.src_embed nn.Sequential(Embeddings(d_model, src_vocab), position) self.tgt_embed nn.Sequential(Embeddings(d_model, tgt_vocab), position) self.generator Generator(d_model, tgt_vocab) for p in self.parameters(): if p.dim() 1: nn.init.xavier_uniform_(p) def encode(self, src, src_mask): return self.encoder(self.src_embed(src), src_mask) def decode(self, memory, src_mask, tgt, tgt_mask): return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) def forward(self, src, tgt, src_mask, tgt_mask): return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)4.2 训练策略与优化Transformer使用特殊的优化策略和学习率调度class NoamOpt: def __init__(self, model_size, factor, warmup, optimizer): self.optimizer optimizer self._step 0 self.warmup warmup self.factor factor self.model_size model_size self._rate 0 def step(self): self._step 1 rate self.rate() for p in self.optimizer.param_groups: p[lr] rate self._rate rate self.optimizer.step() def rate(self, stepNone): if step is None: step self._step return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) def get_std_opt(model): return NoamOpt(model.src_embed[0].d_model, 2, 4000, torch.optim.Adam(model.parameters(), lr0, betas(0.9, 0.98), eps1e-9))4.3 标签平滑与损失计算标签平滑可以防止模型对预测结果过于自信class LabelSmoothing(nn.Module): def __init__(self, size, padding_idx, smoothing0.0): super().__init__() self.criterion nn.KLDivLoss(reductionsum) self.padding_idx padding_idx self.confidence 1.0 - smoothing self.smoothing smoothing self.size size self.true_dist None def forward(self, x, target): assert x.size(1) self.size true_dist x.data.clone() true_dist.fill_(self.smoothing / (self.size - 2)) true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) true_dist[:, self.padding_idx] 0 mask torch.nonzero(target.data self.padding_idx) if mask.dim() 0: true_dist.index_fill_(0, mask.squeeze(), 0.0) self.true_dist true_dist return self.criterion(x, true_dist.clone().detach())5. 推理与评估5.1 贪婪解码实现def greedy_decode(model, src, src_mask, max_len, start_symbol): memory model.encode(src, src_mask) ys torch.ones(1, 1).fill_(start_symbol).type_as(src.data) for i in range(max_len-1): out model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)) prob model.generator(out[:, -1]) _, next_word torch.max(prob, dim1) next_word next_word.data[0] ys torch.cat([ys, torch.ones(1,1).type_as(src.data).fill_(next_word)], dim1) return ys def subsequent_mask(size): attn_shape (1, size, size) subsequent_mask torch.triu(torch.ones(attn_shape), diagonal1) return subsequent_mask 05.2 模型评估与BLEU计算from torchtext.data.metrics import bleu_score def evaluate(model, iterator, criterion): model.eval() epoch_loss 0 with torch.no_grad(): for i, batch in enumerate(iterator): src batch.src trg batch.trg src_mask (src ! SRC.vocab.stoi[pad]).unsqueeze(-2) trg_mask (trg ! TGT.vocab.stoi[pad]).unsqueeze(-2) output model(src, trg[:,:-1], src_mask, trg_mask[:,:-1,:-1]) loss criterion(output.contiguous().view(-1, output.size(-1)), trg[:,1:].contiguous().view(-1)) epoch_loss loss.item() return epoch_loss / len(iterator) def calculate_bleu(model, iterator, max_len50): model.eval() trgs [] pred_trgs [] with torch.no_grad(): for batch in iterator: src batch.src src_mask (src ! SRC.vocab.stoi[pad]).unsqueeze(-2) pred greedy_decode(model, src, src_mask, max_len, TGT.vocab.stoi[sos]) pred_trg [TGT.vocab.itos[i] for i in pred[0] if i ! TGT.vocab.stoi[eos]] trg [TGT.vocab.itos[i] for i in batch.trg[0] if i ! TGT.vocab.stoi[eos]] pred_trgs.append(pred_trg) trgs.append([trg]) return bleu_score(pred_trgs, trgs)6. 实战技巧与性能优化6.1 批处理与内存管理批处理优化策略动态批处理按序列长度分组减少填充量梯度累积小批量累计后更新模拟大批量效果混合精度训练使用FP16减少显存占用def create_batches(data, batch_size, max_padding0): batches [] for i in range(0, len(data), batch_size): batch data[i:ibatch_size] src_len max(len(x.src) for x in batch) trg_len max(len(x.trg) for x in batch) src_len min(src_len, max_padding) trg_len min(trg_len, max_padding) batches.append((batch, src_len, trg_len)) return batches6.2 多GPU训练实现class MultiGPULossCompute: def __init__(self, generator, criterion, devices, optNone, chunk_size5): self.generator generator self.criterion nn.parallel.replicate(criterion, devicesdevices) self.opt opt self.devices devices self.chunk_size chunk_size def __call__(self, out, targets, normalize): total 0.0 generator nn.parallel.replicate(self.generator, devicesself.devices) out_scatter nn.parallel.scatter(out, target_gpusself.devices) out_grad [[] for _ in out_scatter] targets nn.parallel.scatter(targets, target_gpusself.devices) for i in range(0, out_scatter[0].size(1), self.chunk_size): out_column [[o[:, i:iself.chunk_size]] for o in out_scatter] gen nn.parallel.parallel_apply(generator, out_column) y [(g.contiguous().view(-1, g.size(-1)), t[:, i:iself.chunk_size].contiguous().view(-1)) for g, t in zip(gen, targets)] loss nn.parallel.parallel_apply(self.criterion, y) l nn.parallel.gather(loss, target_deviceself.devices[0]) l l.sum() / normalize total l.item() if self.opt is not None: l.backward() for j, l in enumerate(loss): out_grad[j].append(out_column[j][0].grad.data.clone()) if self.opt is not None: out_grad [torch.cat(og, dim1) for og in out_grad] o1 out o2 nn.parallel.gather(out_grad, target_deviceself.devices[0]) o1.backward(gradiento2) self.opt.step() self.opt.optimizer.zero_grad() return total * normalize6.3 注意力可视化技巧理解模型如何关注输入序列的关键部分import matplotlib.pyplot as plt import seaborn as sns def plot_attention(src, tgt, attention): fig plt.figure(figsize(10,10)) ax fig.add_subplot(111) cax ax.matshow(attention, cmapbone) fig.colorbar(cax) ax.set_xticklabels([] src, rotation90) ax.set_yticklabels([] tgt) ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.show()7. 常见问题与解决方案7.1 训练不稳定问题现象损失值波动大或出现NaN解决方案梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)学习率预热使用NoamOpt调度器检查输入数据确保没有无效值或异常数据7.2 过拟合处理应对策略增加dropout比例0.3-0.5早停策略监控验证集损失标签平滑设置smoothing0.1权重衰减Adam优化器中设置weight_decay0.017.3 长序列处理优化方法限制最大序列长度如100-200实现更高效的内存注意力计算使用相对位置编码替代绝对位置编码class RelativePositionalEncoding(nn.Module): def __init__(self, d_model, max_len5000): super().__init__() self.d_model d_model self.max_len max_len self.pe nn.Parameter(torch.zeros(max_len, d_model)) nn.init.uniform_(self.pe, -0.1, 0.1) def forward(self, x): seq_len x.size(1) if seq_len self.max_len: raise ValueError(fSequence length {seq_len} exceeds maximum {self.max_len}) return x self.pe[:seq_len, :]