从零实现Self-Attention用Python拆解Transformer核心引擎当你第一次接触Transformer模型时那些复杂的矩阵运算和注意力分数计算可能让人望而生畏。但事实上自注意力机制的核心思想出奇地简单——它不过是在告诉模型在处理当前这个词时应该重点关注序列中的哪些其他词。今天我们将抛开那些令人头疼的数学公式直接用Python代码构建一个功能完整的Self-Attention模块。1. 准备工作理解输入输出的张量结构在开始编码之前我们需要明确Self-Attention的输入输出规范。假设我们有一个包含3个单词的序列每个单词用维度为4的向量表示那么输入张量的形状就是(3,4)。这个张量将经历一系列线性变换最终产生具有相同长度的输出序列。import torch import torch.nn as nn import torch.nn.functional as F # 示例输入3个token每个token的embedding维度为4 x torch.randn(3, 4) # 形状(batch_size, seq_len, embed_dim)在Transformer中每个输入向量会被转换成三种不同的表示Query(Q)当前词想要查询其他词信息的问题Key(K)其他词提供的可用于匹配的关键词Value(V)实际要被提取的内容信息embed_dim 4 Wq nn.Linear(embed_dim, embed_dim, biasFalse) Wk nn.Linear(embed_dim, embed_dim, biasFalse) Wv nn.Linear(embed_dim, embed_dim, biasFalse) Q Wq(x) # 查询向量 K Wk(x) # 键向量 V Wv(x) # 值向量2. 注意力分数计算SoftMax背后的故事注意力机制的核心在于计算每个词对其他所有词的关注程度。这个过程可以分为三个关键步骤相似度计算通过点积衡量Query和Key的匹配程度缩放处理防止点积结果过大导致SoftMax梯度消失概率转换使用SoftMax将分数转换为概率分布def scaled_dot_product_attention(Q, K, V): dim_k K.size(-1) scores torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(dim_k)) weights F.softmax(scores, dim-1) return torch.matmul(weights, V) output scaled_dot_product_attention(Q, K, V)注意缩放因子√dₖdₖ是Key的维度的引入至关重要。没有它点积的值会随着维度增加而变大导致SoftMax函数在某些维度上梯度接近于零。让我们用一个具体的数值例子来说明假设 Q [1, 0, 1] K [1, 1, 0] 未缩放的点积分数 Q·K 1*1 0*1 1*0 1 缩放后的分数 1/√3 ≈ 0.583. 多头注意力并行化的注意力视角单一注意力机制有一个明显局限——它只能学习一种类型的词语关系。多头注意力通过并行运行多组注意力机制让模型能够同时关注不同方面的信息。头数优点缺点1头计算简单表达能力有限4头平衡性能需要更多参数8头强大表征能力计算成本高实现多头注意力的关键在于将Q、K、V分割到多个子空间class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.embed_dim embed_dim self.num_heads num_heads self.head_dim embed_dim // num_heads self.Wq nn.Linear(embed_dim, embed_dim) self.Wk nn.Linear(embed_dim, embed_dim) self.Wv nn.Linear(embed_dim, embed_dim) self.Wo nn.Linear(embed_dim, embed_dim) def split_heads(self, x): batch_size x.size(0) return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) def forward(self, Q, K, V): batch_size Q.size(0) Q self.split_heads(self.Wq(Q)) K self.split_heads(self.Wk(K)) V self.split_heads(self.Wv(V)) attn_output scaled_dot_product_attention(Q, K, V) attn_output attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim) return self.Wo(attn_output)4. 位置编码弥补序列顺序信息的缺失纯Self-Attention有一个致命缺陷——它对输入序列的顺序不敏感。无论词序如何打乱只要词集合相同输出就相同。位置编码通过为每个位置添加独特的信号来解决这个问题。Transformer使用正弦余弦函数生成位置编码class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len5000): super().__init__() pe torch.zeros(max_len, d_model) position torch.arange(0, max_len, dtypetorch.float).unsqueeze(1) div_term torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] torch.sin(position * div_term) pe[:, 1::2] torch.cos(position * div_term) pe pe.unsqueeze(0) self.register_buffer(pe, pe) def forward(self, x): return x self.pe[:, :x.size(1)]为什么选择这种特定的编码方式因为它允许模型轻松学习相对位置关系。对于任意固定偏移量kPE(posk)可以表示为PE(pos)的线性函数这使得模型能够捕捉到距离k的概念。5. 完整Self-Attention模块实现现在我们将所有组件组合起来构建一个完整的Self-Attention模块class SelfAttentionBlock(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.norm1 nn.LayerNorm(embed_dim) self.attn MultiHeadAttention(embed_dim, num_heads) self.norm2 nn.LayerNorm(embed_dim) self.ffn nn.Sequential( nn.Linear(embed_dim, 4*embed_dim), nn.ReLU(), nn.Linear(4*embed_dim, embed_dim) ) self.pos_encoder PositionalEncoding(embed_dim) def forward(self, x): x self.pos_encoder(x) attn_output self.attn(x, x, x) x self.norm1(x attn_output) ffn_output self.ffn(x) x self.norm2(x ffn_output) return x这个实现包含了Transformer中的几个关键设计残差连接Add层归一化Norm前馈网络FFN多头注意力Multi-Head Attention6. 与PyTorch官方实现对比验证为了验证我们的实现是否正确我们可以将其与PyTorch内置的nn.MultiheadAttention进行比较# 我们的实现 our_attn MultiHeadAttention(embed_dim512, num_heads8) our_output our_attn(Q, K, V) # PyTorch官方实现 pytorch_attn nn.MultiheadAttention(embed_dim512, num_heads8) pytorch_output, _ pytorch_attn(Q, K, V) # 比较两者差异 print(最大差异:, torch.max(torch.abs(our_output - pytorch_output)))在实际项目中我发现在某些情况下直接使用官方实现可能更高效但自己实现的好处是你可以完全控制每一行代码的行为这对调试和理解模型内部运作非常有帮助。7. 实际应用示例文本分类任务让我们看看如何将自注意力机制应用到一个简单的文本分类任务中class TextClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, num_heads, num_classes): super().__init__() self.embedding nn.Embedding(vocab_size, embed_dim) self.attn_block SelfAttentionBlock(embed_dim, num_heads) self.classifier nn.Linear(embed_dim, num_classes) def forward(self, x): x self.embedding(x) # (batch, seq_len, embed_dim) x self.attn_block(x) x x.mean(dim1) # 全局平均池化 return self.classifier(x)这个简单的分类器已经能够捕捉文本中的长距离依赖关系相比传统的RNN模型它的并行计算效率要高得多。在实现过程中我发现注意力权重可视化是一个强大的调试工具。通过观察模型在不同层、不同头上关注的内容可以直观理解模型的工作机制def plot_attention(attention_weights, sentence): fig, ax plt.subplots() cax ax.matshow(attention_weights, cmapviridis) fig.colorbar(cax) ax.set_xticks(range(len(sentence))) ax.set_yticks(range(len(sentence))) ax.set_xticklabels(sentence, rotation90) ax.set_yticklabels(sentence) plt.show()自注意力机制最令人着迷的地方在于它的通用性。同样的代码架构只需调整输入数据的类型就可以应用于图像、音频甚至结构化数据。这种统一性正是Transformer模型能够在多个领域取得突破的关键所在。