FlashAttention:让大模型训练快三倍的“拼菜师傅“
和一个做推荐系统的朋友吃饭他问我“我训练千问模型Attention层特别慢听说FlashAttention能加速但我不懂CUDA这玩意儿到底是怎么快的”我想了一下跟他说“你把大模型训练想象成一个超大的餐厅厨房。每次做一道菜处理一个batch厨师GPU/NPU要做三件事切菜QK^T矩阵乘、调味Softmax、翻炒乘V。传统做法是切完菜放到盘子里写HBM再从盘子拿起来调味调完味又放盘子再拿起来翻炒——来来回回跑好多趟。”“FlashAttention是什么它是一个拼菜师傅把切菜、调味、翻炒三步合并成一步在灶台上直接完成中间不用来回跑厨房和餐厅。”朋友眼睛亮了“所以快的原因是不用来回跑”“对。专业术语叫IO-aware——不是算力不够是搬运数据太费时间。”传统 Attention 的来回跑问题要理解 FlashAttention先得知道传统 Attention 是怎么工作的。假设你有一个句子128个token每个token用512维向量表示。Attention 要计算每个token和所有其他token的关系得到一个128×128的注意力矩阵。传统实现分三步# 传统 Attention 实现简化版importtorchdeftraditional_attention(Q,K,V):# 第一步计算 QK^T得到注意力得分矩阵# 大小batch × heads × seq_len × seq_lenscorestorch.matmul(Q,K.transpose(-2,-1))/math.sqrt(d_k)# ⚠️ 这里 scores 要写回 HBM显存占用 seq_len × seq_len 空间# 第二步Softmax 归一化attn_weightstorch.softmax(scores,dim-1)# ⚠️ 这里要读 scores从 HBM 读再写 attn_weights写回 HBM# 第三步乘 V得到输出outputtorch.matmul(attn_weights,V)# ⚠️ 这里要读 attn_weights从 HBM 读returnoutput# 问题三步都有 HBM 读写来回搬运数据占用了 60% 以上的时间# 算力矩阵乘只占了不到 40%这三步每一步都要把中间结果写到 HBMHigh Bandwidth Memory显存下一步再读出来。就像那个餐厅比喻——切完菜放盘子再从盘子拿起来调味。当 seq_len 是 4096 的时候那个注意力矩阵的大小是 4096×4096×2 bytesfloat16 32MB。看着不大但这是每个头、每个 batch 都要存的。32 heads × 4 batch 128 份总共 4GB——就存个中间结果。FlashAttention 的灶台合并策略FlashAttention 的核心思路别把中间结果写回 HBM在灶台上直接搞定。具体做法是把 K 和 V 按小块tile读入 UBUnified Buffer昇腾NPU 上的高速片上存储在 UB 里完成一个 tile 的 QK^T → Softmax → 乘 V 完整计算然后把结果累积到输出里。# FlashAttention 的灶台合并思路伪代码defflash_attention_npu(Q,K,V,tile_size128):# Q: (batch, heads, seq_len, dim)# K, V: (batch, heads, seq_len, dim)outputtorch.zeros_like(Q)# 把 K 和 V 按 tile 分块# 每次只把一块 K_tile 和 V_tile 读到 UB 上foriinrange(0,seq_len,tile_size):K_tileK[:,:,i:itile_size,:]# 从 HBM 读一小块 KV_tileV[:,:,i:itile_size,:]# 从 HBM 读一小块 V# 在 UB 上计算 QK^T这块很小UB 放得下scores_tiletorch.matmul(Q,K_tile.transpose(-2,-1))# 在 UB 上做 Softmax不写回 HBMattn_tiletorch.softmax(scores_tile,dim-1)# 在 UB 上乘 V_tile不写回 HBMoutputtorch.matmul(attn_tile,V_tile)# 只有 output 的最终结果才写回 HBMreturnoutput# 优势中间结果scores_tile, attn_tile一直留在 UB 上不写 HBM# HBM 访存量从 34GB 降到 6GBseq_len4096, batch4, heads32这个策略在 GPU 上已经很快了但在昇腾NPU 上还能更快——因为昇腾NPU 的 UB 比 GPU 的 shared memory 大256KB vs 通常 64~164KB可以放更大的 tile减少循环次数。昇腾NPU 上的 FlashAttentionops-transformer 的实现ops-transformer 是昇腾CANN 社区的开源仓库里面有针对昇腾NPU 高度优化的 FlashAttention 实现。关键点ops-transformer 的 FlashAttention 不是简单的算法移植而是针对达芬奇架构做了深度优化Cube 和 Vector 并行达芬奇架构有两套计算单元——Cube 做矩阵乘QK^T 和 PVVector 做逐元素运算Softmax。ops-transformer 的实现让这两步 pipeline 起来一边算矩阵乘一边算 Softmax不浪费时间。异步数据搬运在当前 tile 计算的同时预加载下一个 tile 的 K 和 V 到 UB。这样计算单元就不会等数据。Tiling 策略自动调优不同 seq_len 和 dim 的最优 tile 大小不一样。ops-transformer 的 tiling 策略会根据输入形状自动选择最优分块大小。用代码验证 ops-transformer 的 FlashAttention 效果importtorchimporttorch_npu# 确保 torch-npu 已安装昇腾NPU 的 PyTorch 后端# pip install torch-npu2.1.0 (版本号以 CANN 为准)# 构造输入batch,heads,seq_len,dim4,32,4096,64Qtorch.randn(batch,heads,seq_len,dim,dtypetorch.float16).npu()Ktorch.randn(batch,heads,seq_len,dim,dtypetorch.float16).npu()Vtorch.randn(batch,heads,seq_len,dim,dtypetorch.float16).npu()# 方法1PyTorch 原生 Attention逐算子路径无融合withtorch.no_grad():output_nativetorch.nn.functional.scaled_dot_product_attention(Q,K,V,is_causalTrue)torch.npu.synchronize()# 方法2ops-transformer 的 FlashAttention融合算子# 需要先编译安装 ops-transformer# git clone https://atomgit.com/cann/ops-transformer# cd ops-transformer mkdir build cd build# cmake .. make -j make installfromflash_attention_opsimportflash_attention_npuwithtorch.no_grad():output_faflash_attention_npu(Q,K,V,causalTrue)torch.npu.synchronize()# 对比结果误差应该在 1e-3 以内max_err(output_native.cpu().float()-output_fa.cpu().float()).abs().max().item()print(fPyTorch 原生 vs FlashAttention 最大误差:{max_err:.6f})print(误差 1e-3正确性验证通过ifmax_err1e-3else误差过大检查实现)# 性能对比用 torch_npu.profiler 抓 tracefromtorch_npu.profilerimportprofile,ProfilerActivitywithprofile(activities[ProfilerActivity.NPU],export_namenative_attention.json):output_nativetorch.nn.functional.scaled_dot_product_attention(Q,K,V,is_causalTrue)torch.npu.synchronize()withprofile(activities[ProfilerActivity.NPU],export_nameflash_attention.json):output_faflash_attention_npu(Q,K,V,causalTrue)torch.npu.synchronize()# 在 Profiler GUI 里看# - native_attention.json有三个大色块MatMul / Softmax / MatMul每个色块前后都有 HBM 读写的小色块# - flash_attention.json只有一个大的 FlashAttentionKernel 色块HBM 读写少很多怎么确认 FlashAttention 真的生效了光看代码不够得用 Profiler 抓 trace 确认。# 第一步跑一次训练抓 Profiler tracepython train.py --use-flash-attention --profiler-output trace.json# 第二步在昇腾 CANN 的 Profiler GUI 里打开 trace.json# 看 Attention 层对应的色块# - 如果看到 MatMul、Softmax、MatMul 三个独立色块 → FlashAttention 没生效# - 如果看到一个 FlashAttentionKernel 色块 → 生效了# 第三步看 HBM 访存量# 在 Profiler GUI 的 Memory 标签页# - 传统 AttentionHBM 访存量 ~34GBseq_len4096# - FlashAttentionHBM 访存量 ~6GB节省 82%如果 FlashAttention 没生效检查一下框架适配层配置PyTorch 的scaled_dot_product_attention是否路由到了 ops-transformer 的实现需要安装 torch-npu 并正确配置GE 融合规则CANN 的 GE 图引擎是否识别到了 MatMul→Softmax→MatMul 的融合模式查看 GE 的融合日志输入形状FlashAttention 对 seq_len 有要求通常是 2 的幂次方比如 512、1024、2048、4096如果碰到问题可以去 atomgit 上的 Discussions 区提问社区响应很快。相关仓库https://atomgit.com/cann/ops-transformerhttps://atomgit.com/cann/cann-learning-hubhttps://atomgit.com/cann/cann-samples