别再死记硬背了!用Python+PyTorch手把手图解自注意力机制(附完整代码)
别再死记硬背了用PythonPyTorch手把手图解自注意力机制附完整代码理解自注意力机制最有效的方式不是背诵公式而是亲手实现它。本文将带你用PyTorch从零构建一个可交互的自注意力模块并通过动态可视化揭示其核心计算逻辑。无论你是准备面试的开发者还是正在学习Transformer架构的研究者这套代码实验都能让你真正掌握注意力的本质。1. 环境准备与数据建模我们先构建一个极简的文本处理场景输入4个单词的嵌入向量模拟Transformer中的单头自注意力计算。这里使用PyTorch的自动微分功能避免手动计算矩阵导数。import torch import torch.nn as nn import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation # 模拟输入4个单词的嵌入向量维度64 tokens [deep, learning, is, fun] embed_dim 64 x torch.randn(4, embed_dim) # 形状[序列长度, 嵌入维度]定义可训练的权重矩阵实际项目中这些参数会自动学习class SelfAttention(nn.Module): def __init__(self, embed_dim): super().__init__() self.W_q nn.Linear(embed_dim, embed_dim, biasFalse) self.W_k nn.Linear(embed_dim, embed_dim, biasFalse) self.W_v nn.Linear(embed_dim, embed_dim, biasFalse) def forward(self, x): Q self.W_q(x) # 查询向量 K self.W_k(x) # 键向量 V self.W_v(x) # 值向量 return Q, K, V2. 动态计算注意力分数自注意力的核心是计算单词间的关联程度。我们通过查询-键点积得到原始分数然后用softmax归一化def compute_attention(Q, K): scores torch.matmul(Q, K.transpose(0, 1)) # 点积运算 scores scores / (embed_dim ** 0.5) # 缩放防止梯度消失 attn_weights torch.softmax(scores, dim-1) return attn_weights # 实例化并计算 attn_layer SelfAttention(embed_dim) Q, K, V attn_layer(x) attn_weights compute_attention(Q, K)用热力图实时显示注意力矩阵的变化fig, ax plt.subplots() im ax.imshow(attn_weights.detach().numpy(), cmapviridis) def update(i): # 模拟训练过程中权重更新 with torch.no_grad(): attn_layer.W_q.weight 0.01 * torch.randn_like(attn_layer.W_q.weight) Q, K, V attn_layer(x) im.set_data(compute_attention(Q, K).detach().numpy()) return [im] ani FuncAnimation(fig, update, frames20, interval500) plt.colorbar(im) plt.show()这段代码会生成一个动态图展示随着权重矩阵更新各单词间注意力分布的变化过程。你会直观看到某些单词组合如deep和learning逐渐形成强关联。3. 权重聚合与输出生成获得注意力权重后我们需要用它加权求和值向量def weighted_sum(attn_weights, V): return torch.matmul(attn_weights, V) # 形状[序列长度, 嵌入维度] output weighted_sum(attn_weights, V)为了验证效果可以对比输入输出向量的相似度cos nn.CosineSimilarity(dim1) print(输入输出相似度:, cos(x, output))典型输出可能显示输入输出相似度: tensor([0.3124, 0.2897, 0.2568, 0.3012])4. 扩展为多头注意力单头注意力只能捕捉一种模式的关系。实际Transformer使用多头机制class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads8): super().__init__() self.head_dim embed_dim // num_heads self.W_o nn.Linear(embed_dim, embed_dim) # 输出投影 def split_heads(self, x): return x.view(x.size(0), -1, self.head_dim) def forward(self, x): Q, K, V attn_layer(x) Q self.split_heads(Q) # [序列长度, 头数, 头维度] K self.split_heads(K) V self.split_heads(V) # 各头独立计算 attn_outputs [] for i in range(Q.size(1)): attn compute_attention(Q[:,i], K[:,i]) attn_outputs.append(weighted_sum(attn, V[:,i])) # 拼接并投影 combined torch.cat(attn_outputs, dim1) return self.W_o(combined)关键改进点查询/键/值被分割到不同子空间每个头独立计算注意力最终结果通过线性层融合5. 可视化技巧进阶使用NetworkX库绘制动态注意力图import networkx as nx def draw_attention_graph(weights, tokens): G nx.DiGraph() G.add_nodes_from(tokens) for i, src in enumerate(tokens): for j, dst in enumerate(tokens): G.add_edge(src, dst, weightweights[i,j].item()) pos nx.circular_layout(G) nx.draw(G, pos, with_labelsTrue, edge_color[G[u][v][weight] for u,v in G.edges()], width[2*G[u][v][weight] for u,v in G.edges()])调用示例draw_attention_graph(attn_weights, tokens)这会生成带权重的有向图边的粗细和颜色深度反映注意力强度。通过对比不同层的注意力图可以直观理解Transformer如何构建层级表征。