LLM 模型图模式改造指南
LLM 模型图模式改造指南【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills本文档专门针对LLM 推理模型的图模式适配提炼核心改造思路和最佳实践。适用场景GPT、Llama、DeepSeek、Qwen 等 Transformer 架构的大语言模型推理。核心原则图模式适配的本质将动态变化的东西提取为模型输入模型内部尽量保证静态。图编译后会生成静态计算图任何动态行为都可能导致图中断graph break或重编译。改造的关键是识别并隔离动态因素动态因素问题表现解决思路内存地址变化Guard 失败、重编译预分配固定大小原地更新Shape 变化图中断、多次编译固定 shape 或通过参数控制Python 控制流Graph Break使用 Tensor 操作或模式参数.item()调用强制 Graph Break保持 Tensor 或外部传入LM 模型图模式适配工作流重要对于 LLM 推理模型必须区分 prefill 和 decode 阶段通常只对 decode 做图模式优化。LLM 图模式适配流程 │ ├─→ 1. 区分 Prefill/Decode │ ├── Prefill: 保持 Eager 模式序列长度变化大 │ ├── Decode: 图模式优化固定单 token 输入 │ └── 添加 is_prefill 参数 │ ├─→ 2. 改造模型代码 │ ├── 预分配 KV Cache │ ├── 位置等变化参数提取出来作为输入传入 │ └── 见下方各模块改造指南 │ ├─→ 3. 配置图模式 │ └── 见推荐配置章节 │ └─→ 4. 验证性能 ├── 正常 → 完成 └── 性能劣化 → 定位重编译问题重编译问题定位与解决关键如果图模式性能劣化必须定位是否发生了重编译定位重编译# 开启重编译日志 torch._logging.set_logs(recompilesTrue) # 运行模型 output compiled_model(input) # 如果发生重编译会打印类似 # [recompiles] Recompiling function func_name for reason: reason重编译解决方案dynamicFalse: 检测到重编译 │ └─→ 分析重编译原因 ├── 固定 shape 但仍重编译 → dynamicFalse skip_guard_eval_unsafeTrue └── 输入 shape 变化 → dynamicTrue模块改造指南1. KV Cache 模块改造改造目标消除 KV Cache 动态扩展导致的 shape 变化实现固定大小 cache 的原地更新。核心思路预分配策略在模型初始化时分配固定大小的 cache原地更新原则使用原地更新算子写入新值避免重新分配有效长度控制通过参数控制实际参与计算的长度返回优化图模式下不返回 KV cache已原地更新问题模式 vs 改造模式# 问题模式动态扩展 KV cache key torch.cat([past_key, new_key], dim1) # shape 变化 # 改造模式固定大小预分配 原地更新 # 1. 初始化时预分配 def _init_kv_cache(self, batch_size, max_seq_len, device): cache_shape (batch_size, 1, max_seq_len, head_dim) self.kv_cache torch.zeros(cache_shape, dtypedtype, devicedevice) # 2. forward 中原地更新 def forward(self, ..., kv_len, past_key_value): torch_npu.scatter_update_(past_key_cache, kv_len, new_key_states, dim-2)常见问题问题现象根因解决方案每次 decode 触发重编译torch.cat扩展 KV cache预分配固定大小原地更新内存占用过大预分配浪费结合 PagedAttention 按 block 管理返回 KV cache 开销大图模式下返回大量 tensor已原地更新无需返回2. Rotary Embedding 模块改造改造目标消除动态计算实现静态图 cos/sin 查询。如果已经使用了融合算子没有触发静态图的限制则无需改造否则需要进行改造核心思路预计算 cos/sin初始化时计算所有位置值并缓存索引查询通过index_select或切片获取外层计算优化在模型外层统一计算传入各层实现示意def forward(self, x, kv_len, is_prefillTrue): if is_prefill: cos self.cos_cached[:seq_len] # prefill切片 else: cos torch.index_select(self.cos_cached, dim0, indexkv_len.view(-1)) # decode索引 return cos.to(x.dtype), sin.to(x.dtype)3. Attention 模块改造改造目标使 Attention 计算图模式友好支持 Flash Attention 等融合算子。核心思路使用 NPU 原生融合算子优先使用 NPU 提供的融合 attention 算子有效长度参数化通过参数控制避免大规模 attention mask区分 prefill/decode使用模式参数选择不同计算路径提示可使用model-infer-fusionskill 获取融合算子指导。4. Buffer/Parameter 模块改造改造目标避免 buffer/parameter 地址变化触发 guard 失败。核心思路预分配策略初始化时分配最大可能大小原地更新原则使用copy_()、fill_()等原地操作只读访问通过index_select、切片等只读方式访问5. 动态信息外部化设计改造目标将动态变化的信息从模型内部移到输入参数。动态信息内部计算外部传入位置索引position_ids torch.arange(seq_len)作为参数传入序列长度seq_len hidden_states.size(1)actual_seq_lengths参数写入位置内部计算kv_lenkv_len参数模式切换内部判断is_prefill参数forward 签名设计参考def forward( self, input_ids: torch.LongTensor, # 位置相关Tensor 形式支持图追踪 position_ids: Optional[torch.LongTensor] None, kv_len: Optional[torch.IntTensor] None, # KV 写入位置 # 序列长度List[int] 传给 NPU 算子 actual_seq_lengths_kv: Optional[List[int]] None, actual_seq_lengths_q: Optional[List[int]] None, # 模式控制 is_prefill: bool False, # KV Cache past_key_values: Optional[Tuple[torch.Tensor]] None, # 预计算的 cos/sin避免重复计算 cos: Optional[torch.Tensor] None, sin: Optional[torch.Tensor] None, ... ): pass注意事项不要在模型 forward 中使用.item()——将 Tensor 转换为 Python 标量会触发 Graph Break。# 错误写法 - 会导致 Graph Break max_pos_id position_ids.max().item() 1 # 正确写法 - 使用静态参数或预计算 max_pos_id MAX_SEQ_LEN # 作为常量传入推荐配置npugraph_ex 后端推荐用于 LLM Decodeimport torch import torch_npu model YourModel().npu() opt_model torch.compile( model, backendnpugraph_ex, fullgraphTrue, dynamicFalse, # LLM decode 固定 shape options{ # FX图优化 inplace_pass: True, input_inplace_pass: True, pattern_fusion_pass: True, # 内存优化 reuse_graph_pool_in_same_fx: True, clone_input: True, clone_output: False, # 性能优化 remove_noop_ops: True, } )GE 图模式import torch import torch_npu import torchair from torchair import patch_for_hcom patch_for_hcom() # 集合通信入图有 TP/EP 并行时需调用 config torchair.CompilerConfig() # 根据需要配置 inference_config, ge_config 等 npu_backend torchair.get_npu_backend(compiler_configconfig) opt_model torch.compile(model, backendnpu_backend)区分 Prefill/Decode 实践指南核心思想为模型添加独立的prefill()和decode()方法通过is_prefill参数区分执行路径。代码示例# 模型层 class MyModelForCausalLM(nn.Module): def forward(self, input_ids, position_ids, past_key_values, is_prefillFalse, **kwargs): # is_prefill 控制不同执行路径 if is_prefill: # Prefill 专属SP all-gather、取最后 token logits pass else: # Decode 专属多流并行、原地更新 KV cache pass return logits def prefill(self, **kwargs): return self.forward(is_prefillTrue, **kwargs) def decode(self, **kwargs): return self.forward(is_prefillFalse, **kwargs) # Runner 层 class MyRunner: def model_inference(self, model_inputs, is_prefillFalse): if is_prefill: return self.model.prefill(**model_inputs) else: return self.model.decode(**model_inputs) # 适合图模式Prefill vs Decode 关键差异组件PrefillDecode输入变长序列提示词固定数量 token图模式通常不图化适合dynamicFalse图模式相关文档npugraph_ex 详细指南npugraph_ex-guide.mdGE 图模式详细指南ge-graph-guide.md【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考