FlashAttention的量化:INT8和FP8怎么用才不影响精度?
FlashAttention的量化INT8和FP8怎么用才不影响精度某团队在昇腾NPU上跑Llama-2-7B想进一步提升推理速度。他们看到FlashAttention已经用了FP16想再进一步——做INT8量化。他们把权重转成了INT8把KV Cache也设成了INT8结果模型输出的全是乱码。问题出在FlashAttention的量化方式不对。FlashAttention不是简单地把FP16换成INT8就行——Attention的计算需要在高精度下做Softmax归一化只有特定的中间结果才能量化。今天把FlashAttention量化的正确姿势讲清楚哪些地方可以量化哪些地方不能量化以及怎么做到INT8量化还不影响精度。先打个比方翻译中的信息损耗想象翻译一段中文文章整句翻译FP16→INT8把整句话一次性翻译成英文。如果原句很长翻译误差会累积最后整句意思都变了。逐句翻译摘要混合精度先把每句话翻译成英文可以量化再对整段做摘要必须高精度。这样既省了工作量又保留了关键信息。FlashAttention的量化也是这个道理不是所有地方都能量化。QKV投影可以量化逐token计算但Softmax的中间结果不能量化需要高精度归一化。FlashAttention的量化可以怎么分FlashAttention的计算流程分三层每层的量化策略不同第一层QKV投影可以量化QKV投影是对每个token独立做的输入是隐藏状态输出是Q、K、V向量。这层的计算可以量化误差只影响单个token的向量不会在token之间传播。QKV投影量化策略 输入FP16隐藏状态动态范围小 权重INT8大部分值在[-127, 127]之间 输出FP16Q、K、V向量计算需要高精度 量化公式W_int8 round(W_fp16 / scale) 反量化公式W_fp16 W_int8 × scale第二层Attention Score不能量化QK^T之后得到的注意力分数矩阵是Softmax的输入。这层绝对不能量化——Softmax是指数运算对数值范围非常敏感。量化后的分数会导致归一化出错模型输出退化。Attention Score绝对不能量化 QK^T的值范围[-11, 11]经scale之后 Softmax需要精确的浮点运算 量化成INT8会导致exp(-11/quant_scale)的值全都变成0或1 结果注意力退化成分布式one-hot第三层KV Cache可以量化KV Cache存的是V向量值向量不参与Softmax计算。V向量经过Softmax加权求和之后才进入下一层。所以KV Cache可以量化但有条件。KV Cache量化条件 1. V向量的值范围要稳定不能有异常值 2. 量化误差不能累积量化-反量化误差要在容限内 3. 要做per-token或per-channel的动态量化不能用静态量化 推荐per-token动态量化 每个token的V向量独立量化自己的scale 这样不同token的值范围不同也能正确处理怎么做INT8 KV Cache量化步骤1确认硬件和算子支持# 检查昇腾NPU是否支持INT8 FlashAttentionpython3-cfrom torch_npu.contrib.functional import npu_flash_attention; print(npu_flash_attention)# 查看torch_npu版本和支持的dtypepython3-cimport torch_npu; print(torch_npu.__version__)步骤2校准KV Cache的量化scaleINT8量化需要先做校准确定每个token的scale。importtorchclassInt8KVCacheQuantizer:INT8 KV Cache量化器def__init__(self,num_layers,num_kv_heads,head_dim,quant_schemeper_token):self.num_layersnum_layers self.num_kv_headsnum_kv_heads self.head_dimhead_dim self.quant_schemequant_scheme# per_token or per_channel# 存储每个token的scaleself.k_scales{}# {layer_idx: tensor [B, num_kv_heads, seq_len, 1]}self.v_scales{}defcalibrate(self,model,calibration_data,num_samples100):校准收集KV Cache的数值分布确定量化scalemodel.eval()withtorch.no_grad():fori,batchinenumerate(calibration_data):ifinum_samples:breakoutputsmodel(input_idsbatch[input_ids],use_cacheTrue,return_dictTrue)# 收集每层的KV Cache统计past_kvoutputs.past_key_valuesforlayer_idx,(k,v)inenumerate(past_kv):# 记录最大绝对值用于确定scaleiflayer_idxnotinself.k_scales:self.k_scales[layer_idx][]self.v_scales[layer_idx][]# per-token量化每个token独立scaleifself.quant_schemeper_token:k_maxk.abs().amax(dim-1,keepdimTrue)# [B, H, S, 1]v_maxv.abs().amax(dim-1,keepdimTrue)else:# per-channel量化每个channel一个scalek_maxk.abs().amax(dim[0,2,3],keepdimTrue)v_maxv.abs().amax(dim[0,2,3],keepdimTrue)self.k_scales[layer_idx].append(k_max)self.v_scales[layer_idx].append(v_max)defcompute_scales(self):根据校准数据计算量化scaleforlayer_idxinself.k_scales:# 取所有样本的最大值确保量化范围覆盖所有数据k_cattorch.cat(self.k_scales[layer_idx],dim2)v_cattorch.cat(self.v_scales[layer_idx],dim2)# scale max / 127INT8的范围是[-127, 127]self.k_scales[layer_idx](k_cat.amax(dim2,keepdimTrue)/127.0).clamp(min1e-6)self.v_scales[layer_idx](v_cat.amax(dim2,keepdimTrue)/127.0).clamp(min1e-6)defquantize_kv(self,k,v,layer_idx):量化KV Cachek_scaleself.k_scales[layer_idx]v_scaleself.v_scales[layer_idx]k_int8torch.clamp(torch.round(k/k_scale),-127,127).to(torch.int8)v_int8torch.clamp(torch.round(v/v_scale),-127,127).to(torch.int8)returnk_int8,v_int8,k_scale,v_scaledefdequantize_kv(self,k_int8,v_int8,k_scale,v_scale):反量化KV Cachek_fp16k_int8.float()*k_scale v_fp16v_int8.float()*v_scalereturnk_fp16,v_fp16步骤3在FlashAttention里用INT8 KV Cachedefflash_attention_with_int8_kv_cache(q,k_int8,v_int8,k_scale,v_scale,head_num,kv_head_num,layer_idx):用INT8 KV Cache做FlashAttention# Step 1: 反量化K和Vk_fp16,v_fp16quantizer.dequantize_kv(k_int8,v_int8,k_scale,v_scale)# Step 2: FlashAttention计算# 注意计算本身还是FP16只是KV Cache存储是INT8outputnpu_flash_attention(q,k_fp16,v_fp16,head_numhead_num,kv_head_numkv_head_num,scale_value1.0/(q.shape[-1]**0.5))returnoutput# 使用outputflash_attention_with_int8_kv_cache(qq,k_int8k_cache_int8[layer_idx],v_int8v_cache_int8[layer_idx],k_scalek_scale[layer_idx],v_scalev_scale[layer_idx],head_num32,kv_head_num32,layer_idx0)FP8量化更激进的方案FP8是比INT8更激进的量化方案——用8位浮点数代替16位浮点数动态范围比INT8大。昇腾910B和更新的NPU支持FP8。FP8格式 E4M34位指数3位尾数范围[-448, 448]适合权重 E5M25位指数2位尾数范围[-57344, 57344]适合激活值 FlashAttention中FP8的用法 QKV投影输出E5M2激活值范围大 KV CacheE4M3或E5M2权重稳定 Attention ScoreFP16不能量化 输出FP16继续往后传defflash_attention_fp8(q,k,v,head_num):FP8版本的FlashAttention# 量化QKV到FP8E5M2q_fp8q.to(torch.float8_e5m2)k_fp8k.to(torch.float8_e5m2)v_fp8v.to(torch.float8_e5m2)# FlashAttention计算输入是FP8# 注意算子内部会把FP8转成更高精度再计算outputnpu_flash_attention(q_fp8,k_fp8,v_fp8,head_numhead_num,input_dtypefloat8_e5m2,# 指定输入dtypeoutput_dtypefloat16# 输出是FP16)returnoutput⚠️ 踩坑预警FP8在昇腾910B上支持但910不支持。910上只能用INT8。在910B上FP8的KV Cache量化能节省50%显存同时精度损失0.5%。量化后的精度验证量化一定会带来精度损失关键是控制损失在可接受范围内。defverify_quantization_accuracy(model,fp16_outputs,# 原始FP16输出ground truthquant_outputs,# 量化后的输出rtol1e-3,atol1e-3):验证量化后的精度# 计算误差abs_diff(fp16_outputs-quant_outputs).abs()rel_diffabs_diff/fp16_outputs.abs().clamp(min1e-6)max_absabs_diff.max().item()max_relrel_diff.max().item()mean_absabs_diff.mean().item()print(f量化精度验证)print(f 最大绝对误差{max_abs:.6f})print(f 最大相对误差{max_rel:.6f})print(f 平均绝对误差{mean_abs:.6f})print(f 容限rtol{rtol}, atol{atol})within_tolerancetorch.allclose(fp16_outputs,quant_outputs,rtolrtol,atolatol)ifwithin_tolerance:print(✅ 量化精度在容限范围内)else:print(❌ 量化精度超出容限需要调整量化方案)print( 建议)print( 1. 增加校准样本数)print( 2. 改用per-token量化而不是per-channel)print( 3. 用BF16代替FP16作为反量化目标)returnwithin_tolerance总结量化方案选择FlashAttention的量化按这个清单选方案量化方案显存节省精度损失适用场景支持硬件FP16baseline0%0%所有场景所有NPUINT8 KV Cache~50%0.5%长序列推理昇腾910INT8 QKV投影~30%1%极致压缩昇腾910BFP8 E5M2激活~50%0.5%高吞吐推理昇腾910BFP8 E4M3权重~60%1%极致压缩昇腾910B优先级建议先开INT8 KV Cache收益最大风险最低再加FP8 E5M2激活如果硬件支持最后考虑INT8 QKV投影精度风险最高代码和文档https://atomgit.com/cann/ops-transformer