别只盯着model.load_state_dict!PyTorch保存与加载checkpoint时,优化器(optimizer)的那些‘坑’与正确姿势
别只盯着model.load_state_dictPyTorch保存与加载checkpoint时优化器(optimizer)的那些‘坑’与正确姿势在PyTorch训练过程中我们常常会遇到需要中断并恢复训练的情况。这时候checkpoint的保存与加载就显得尤为重要。然而许多开发者往往只关注模型权重的正确加载而忽略了优化器状态的匹配问题。本文将深入探讨PyTorch中checkpoint保存与加载的完整工作流特别是优化器状态的那些坑与正确姿势。1. 为什么优化器状态同样重要当我们使用PyTorch进行模型训练时优化器不仅仅保存了当前的参数值还保存了许多关键的训练状态信息。这些信息对于恢复训练至关重要动量参数如SGD中的momentum、Adam中的m和v自适应学习率参数如Adam优化器中的exp_avg和exp_avg_sq学习率调度状态如ReduceLROnPlateau中的最佳loss记录参数分组信息不同参数组可能有不同的学习率策略忽视这些状态的正确保存和加载可能导致训练过程出现以下问题训练曲线不连续loss突然跳变收敛速度变慢需要重新热身模型性能下降特别是对于自适应优化器学习率调度失效# 一个典型的优化器state_dict结构示例 optimizer torch.optim.Adam(model.parameters(), lr0.001) print(optimizer.state_dict()) # 输出示例 { state: { 0: {step: 100, exp_avg: ..., exp_avg_sq: ...}, 1: {step: 100, exp_avg: ..., exp_avg_sq: ...}, ... }, param_groups: [ {lr: 0.001, betas: (0.9, 0.999), eps: 1e-08, ...}, ... ] }2. 常见的checkpoint保存策略对比在PyTorch中我们通常有三种主要的checkpoint保存策略每种策略都有其适用场景和潜在问题2.1 只保存模型权重这是最简单的策略只保存模型的state_dicttorch.save({ model_state_dict: model.state_dict(), }, checkpoint.pth)优点文件体积小加载简单兼容性强缺点无法恢复训练状态需要重新初始化优化器丢失所有训练历史信息2.2 保存模型和优化器状态更完整的保存方式包含模型和优化器状态torch.save({ model_state_dict: model.state_dict(), optimizer_state_dict: optimizer.state_dict(), epoch: epoch, loss: loss, }, checkpoint.pth)优点可以恢复训练状态保持训练连续性保留优化器内部状态缺点文件体积较大对模型结构变化敏感可能出现参数组不匹配问题2.3 保存完整训练状态最全面的保存方式包含训练所需的所有信息torch.save({ model_state_dict: model.state_dict(), optimizer_state_dict: optimizer.state_dict(), scheduler_state_dict: scheduler.state_dict() if scheduler else None, epoch: epoch, loss: loss, best_metric: best_metric, config: training_config, }, checkpoint.pth)优点可以完全恢复训练环境保持所有训练状态便于实验复现缺点文件体积最大对代码版本和依赖敏感迁移性较差3. 优化器状态加载的常见坑与解决方案在实际应用中优化器状态加载可能会遇到各种问题。下面我们来看几个典型的坑及其解决方案3.1 参数组不匹配问题这是最常见的错误之一表现为ValueError: loaded state dict contains a parameter group that doesnt match the size of optimizers group原因分析模型结构发生了变化如增减了层参数分组方式不同优化器类型不同解决方案检查模型结构一致性# 打印当前模型和checkpoint的参数名 print(set(model.state_dict().keys())) print(set(checkpoint[model_state_dict].keys()))参数组对齐def align_optimizer_state(optimizer, checkpoint): # 获取当前优化器的参数组结构 current_param_groups optimizer.state_dict()[param_groups] # 获取checkpoint中的参数组结构 checkpoint_param_groups checkpoint[optimizer_state_dict][param_groups] # 对齐参数组 aligned_state_dict { state: {}, param_groups: current_param_groups } # 复制匹配的状态 for param_id, state in checkpoint[optimizer_state_dict][state].items(): if str(param_id) in [str(p) for p in optimizer.param_groups[0][params]]: aligned_state_dict[state][param_id] state return aligned_state_dict3.2 优化器类型不匹配问题当尝试加载不同类型的优化器状态时如从SGD加载到Adam会导致各种隐式问题。解决方案def load_optimizer_safely(optimizer, checkpoint): if type(optimizer).__name__ ! type(checkpoint[optimizer]).__name__: print(fWarning: Optimizer type mismatch! fCurrent: {type(optimizer).__name__}, fCheckpoint: {type(checkpoint[optimizer]).__name__}) return False try: optimizer.load_state_dict(checkpoint[optimizer_state_dict]) return True except ValueError as e: print(fFailed to load optimizer state: {str(e)}) return False3.3 模型结构变化导致的状态不匹配当模型结构发生变化时直接加载优化器状态可能会导致各种问题。解决方案def load_with_model_changes(model, optimizer, checkpoint): # 加载模型权重忽略不匹配的键 model.load_state_dict(checkpoint[model_state_dict], strictFalse) # 获取当前模型和checkpoint的参数映射 current_params {id(p): p for p in model.parameters()} checkpoint_params {int(k): v for k, v in checkpoint[optimizer_state_dict][state].items()} # 构建新的优化器状态 new_state_dict { state: {}, param_groups: optimizer.state_dict()[param_groups] } # 只保留仍然存在的参数状态 for param_id, state in checkpoint_params.items(): if param_id in current_params: new_state_dict[state][id(current_params[param_id])] state optimizer.load_state_dict(new_state_dict)4. 健壮的checkpoint工具函数实现基于上述分析我们可以实现一个健壮的checkpoint保存和加载工具函数4.1 保存checkpointdef save_checkpoint(model, optimizer, epoch, loss, best_metric, config, filename, schedulerNone): checkpoint { model_state_dict: model.state_dict(), optimizer_state_dict: optimizer.state_dict(), scheduler_state_dict: scheduler.state_dict() if scheduler else None, epoch: epoch, loss: loss, best_metric: best_metric, config: config, timestamp: datetime.datetime.now().isoformat(), git_hash: get_git_revision_hash() if is_under_git() else None, environment: { python_version: sys.version, torch_version: torch.__version__, cuda_version: torch.version.cuda if torch.cuda.is_available() else None, } } # 确保目录存在 os.makedirs(os.path.dirname(filename), exist_okTrue) # 保存到临时文件再重命名避免写入过程中断导致文件损坏 temp_filename filename .tmp torch.save(checkpoint, temp_filename) os.replace(temp_filename, filename) # 同时保存一份JSON元数据 metadata { k: v for k, v in checkpoint.items() if k not in [model_state_dict, optimizer_state_dict, scheduler_state_dict] } with open(filename .meta.json, w) as f: json.dump(metadata, f, indent2)4.2 加载checkpointdef load_checkpoint(model, optimizer, filename, schedulerNone, devicecuda): if not os.path.exists(filename): raise FileNotFoundError(fCheckpoint file {filename} not found) # 加载checkpoint checkpoint torch.load(filename, map_locationdevice) # 验证基本结构 required_keys [model_state_dict, optimizer_state_dict, epoch] for key in required_keys: if key not in checkpoint: raise ValueError(fInvalid checkpoint: missing key {key}) # 加载模型状态 model.load_state_dict(checkpoint[model_state_dict]) # 尝试加载优化器状态 try: optimizer.load_state_dict(checkpoint[optimizer_state_dict]) except ValueError as e: print(fWarning: Failed to load optimizer state directly: {str(e)}) print(Attempting to align optimizer states...) # 尝试对齐优化器状态 aligned_state align_optimizer_state(optimizer, checkpoint) optimizer.load_state_dict(aligned_state) # 加载学习率调度器状态 if scheduler is not None and scheduler_state_dict in checkpoint: try: scheduler.load_state_dict(checkpoint[scheduler_state_dict]) except ValueError as e: print(fWarning: Failed to load scheduler state: {str(e)}) # 返回其他元数据 return { epoch: checkpoint.get(epoch, 0), loss: checkpoint.get(loss, float(inf)), best_metric: checkpoint.get(best_metric, None), config: checkpoint.get(config, {}), environment: checkpoint.get(environment, {}), timestamp: checkpoint.get(timestamp, unknown) }4.3 检查checkpoint兼容性def check_checkpoint_compatibility(model, checkpoint): # 检查模型参数 model_keys set(model.state_dict().keys()) checkpoint_keys set(checkpoint[model_state_dict].keys()) # 计算差异 extra_in_model model_keys - checkpoint_keys extra_in_checkpoint checkpoint_keys - model_keys # 检查优化器类型 optimizer_type None if optimizer_state_dict in checkpoint: optimizer_type checkpoint[optimizer_state_dict].get(param_groups, [{}])[0].get(name, unknown) return { model: { matched_params: len(model_keys checkpoint_keys), extra_in_model: list(extra_in_model), extra_in_checkpoint: list(extra_in_checkpoint), compatible: len(extra_in_checkpoint) 0 }, optimizer: { type: optimizer_type, present: optimizer_state_dict in checkpoint }, scheduler: { present: scheduler_state_dict in checkpoint } }5. 实际应用中的最佳实践基于多年PyTorch使用经验我总结了以下checkpoint管理的最佳实践版本控制在checkpoint文件名中包含关键信息modelname_epoch_valacc_timestamp.pth保存完整的训练配置和超参数记录git commit hash以便复现验证加载保存后立即验证能否正确加载定期验证旧checkpoint的加载能力异常处理处理文件损坏情况提供降级加载选项性能优化使用torch.save(..., _use_new_zipfile_serializationTrue)加速大模型保存考虑异步保存策略减少训练中断存储管理实现checkpoint轮转策略考虑压缩长期存储的checkpoint# 一个完整的训练循环示例 def train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs, checkpoint_dir): best_metric 0 start_epoch 0 # 尝试从checkpoint恢复 checkpoint_files sorted(glob.glob(os.path.join(checkpoint_dir, *.pth))) if checkpoint_files: latest_checkpoint checkpoint_files[-1] print(fResuming from checkpoint: {latest_checkpoint}) resume_data load_checkpoint(model, optimizer, latest_checkpoint, scheduler) start_epoch resume_data[epoch] 1 best_metric resume_data[best_metric] for epoch in range(start_epoch, num_epochs): # 训练循环 model.train() for batch in train_loader: # 训练步骤... pass # 验证循环 model.eval() val_metric evaluate(model, val_loader) # 更新学习率 if scheduler: scheduler.step(val_metric) # 保存checkpoint if val_metric best_metric: best_metric val_metric checkpoint_name fmodel_{epoch:03d}_{val_metric:.4f}.pth save_checkpoint( modelmodel, optimizeroptimizer, schedulerscheduler, epochepoch, lossval_loss, best_metricbest_metric, configtrain_config, filenameos.path.join(checkpoint_dir, checkpoint_name) ) # 保留最好的3个checkpoint cleanup_checkpoints(checkpoint_dir, keep3)