保姆级教程:手把手带你用PyTorch复现SAM图像分割(附源码解读)
从零实现SAM图像分割PyTorch实战与核心模块解析引言图像分割一直是计算机视觉领域的核心任务之一而Meta AI发布的Segment Anything ModelSAM彻底改变了这一领域的游戏规则。作为一名长期从事计算机视觉开发的工程师我第一次接触SAM时就被它的通用性和强大性能所震撼。与传统的分割模型不同SAM不需要针对特定数据集进行微调就能在各种场景下表现出色这得益于它在1100万张图像上训练的庞大知识体系。本文将带您从零开始使用PyTorch框架完整复现SAM的核心功能。不同于简单地调用官方API我们会深入模型内部逐层解析其工作原理并实现一个精简但功能完整的Demo。您将学习到如何正确配置SAM运行环境模型权重加载与初始化技巧图像预处理的标准流程视觉Transformer在分割任务中的独特应用提示编码与掩码解码的协同工作机制无论您是希望深入理解SAM内部原理的研究者还是需要在项目中集成图像分割功能的开发者这篇实战指南都将为您提供清晰的实现路径。我们选择中等规模的sam_vit_b_01ec64.pth模型作为示例它在精度和计算资源之间提供了良好的平衡。1. 环境配置与模型加载1.1 基础环境搭建在开始之前我们需要准备一个兼容的Python环境。推荐使用Python 3.8和PyTorch 1.12版本这是经过验证能够稳定运行SAM的配置。以下是创建conda环境的命令conda create -n sam_env python3.8 conda activate sam_env pip install torch torchvision torchaudio pip install opencv-python matplotlib对于GPU加速请确保安装了对应CUDA版本的PyTorch。可以通过nvidia-smi查看CUDA版本然后从PyTorch官网获取匹配的安装命令。1.2 模型权重处理SAM提供了三种规模的预训练模型模型类型参数量文件大小适用场景vit_h636M2.4GB高精度需求vit_l308M1.2GB平衡场景vit_b91M357MB快速实验我们选择vit_b版本进行演示下载后需要验证文件完整性import torch from torchvision.models import resnet50 model_path sam_vit_b_01ec64.pth state_dict torch.load(model_path) print(f模型包含的参数数量: {len(state_dict)})注意首次加载模型时PyTorch会进行格式转换可能需要几分钟时间。建议在代码中添加缓存检查机制避免重复转换。1.3 模型架构初始化SAM的核心由三个组件构成图像编码器基于Vision Transformer的编码网络提示编码器处理点、框等交互输入的编码器掩码解码器生成最终分割结果的解码网络以下是简化的模型初始化代码from segment_anything.modeling import Sam def build_sam(checkpointNone): model Sam( image_encoderImageEncoderViT( depth12, embed_dim768, img_size1024, mlp_ratio4, num_heads12, patch_size16, qkv_biasTrue, use_rel_posTrue, window_size14, ), prompt_encoderPromptEncoder( embed_dim256, image_embedding_size(64, 64), input_image_size(1024, 1024), ), mask_decoderMaskDecoder( num_multimask_outputs3, transformer_dim256, iou_head_depth3, iou_head_hidden_dim256, ), ) if checkpoint is not None: model.load_state_dict(checkpoint) return model2. 图像预处理流程详解2.1 标准化处理SAM要求输入图像满足特定格式分辨率调整为1024x1024像素值归一化到[0,1]范围使用ImageNet风格的均值和标准差进行标准化以下是完整的预处理代码import cv2 import numpy as np import torch from torchvision.transforms import Normalize def preprocess_image(image_path): # 读取图像并转换为RGB image cv2.imread(image_path) image cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 调整大小并保持长宽比 h, w image.shape[:2] scale 1024 / max(h, w) new_h, new_w int(h * scale), int(w * scale) image cv2.resize(image, (new_w, new_h)) # 填充至1024x1024 top (1024 - new_h) // 2 bottom 1024 - new_h - top left (1024 - new_w) // 2 right 1024 - new_w - left image cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value(0, 0, 0)) # 转换为PyTorch张量并标准化 image torch.from_numpy(image).permute(2, 0, 1).float() mean torch.tensor([123.675, 116.28, 103.53]) / 255 std torch.tensor([58.395, 57.12, 57.375]) / 255 normalize Normalize(meanmean, stdstd) image normalize(image) return image.unsqueeze(0) # 添加batch维度2.2 预处理可视化为了确保预处理步骤正确执行建议添加可视化检查import matplotlib.pyplot as plt def show_preprocess_result(original, processed): plt.figure(figsize(12, 6)) plt.subplot(1, 2, 1) plt.title(Original Image) plt.imshow(original) plt.axis(off) plt.subplot(1, 2, 2) processed_img processed.squeeze().permute(1, 2, 0).numpy() processed_img (processed_img - processed_img.min()) / (processed_img.max() - processed_img.min()) plt.title(Processed Image) plt.imshow(processed_img) plt.axis(off) plt.show()3. 图像编码器实现3.1 Patch Embedding机制Vision Transformer的第一步是将图像分割为固定大小的patch。在SAM中使用16x16的patch大小class PatchEmbed(nn.Module): def __init__(self, kernel_size16, stride16, padding0): super().__init__() self.proj nn.Conv2d( in_channels3, out_channels768, kernel_sizekernel_size, stridestride, paddingpadding, ) def forward(self, x): x self.proj(x) # (B, 768, 64, 64) x x.permute(0, 2, 3, 1) # (B, 64, 64, 768) return x关键参数解析输入3通道1024x1024图像卷积核16x16步长16输出64x64的768维特征图1024/16643.2 窗口注意力机制SAM采用了窗口化的注意力计算大幅降低了计算复杂度class WindowAttention(nn.Module): def __init__(self, dim, num_heads8, window_size14): super().__init__() self.dim dim self.num_heads num_heads self.window_size window_size self.qkv nn.Linear(dim, dim * 3) self.proj nn.Linear(dim, dim) def forward(self, x): B, H, W, C x.shape x window_partition(x, self.window_size) # (num_windows*B, window_size, window_size, C) qkv self.qkv(x).reshape(-1, self.window_size * self.window_size, 3, self.num_heads, C // self.num_heads) q, k, v qkv.unbind(2) # 每个形状为 (num_windows*B, num_heads, window_size*window_size, head_dim) attn (q k.transpose(-2, -1)) * (C ** -0.5) attn attn.softmax(dim-1) x (attn v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C) x window_unpartition(x, self.window_size, (H, W)) x self.proj(x) return x提示窗口注意力是SAM高效处理高分辨率图像的关键相比全局注意力计算复杂度从O(n²)降至O(n)4. 提示编码与掩码解码4.1 交互式提示处理SAM支持多种交互方式包括点、框和掩码。以下是点提示的编码实现class PointEncoder(nn.Module): def __init__(self, embed_dim256): super().__init__() self.position_embed nn.Parameter(torch.zeros(1, 2, embed_dim)) self.label_embed nn.Embedding(3, embed_dim) # 0背景, 1前景, 2未指定 def forward(self, points, labels): points: (B, N, 2) 坐标归一化到[0,1] labels: (B, N) 0/1/2 point_embed self.position_embed[:, :2] # (1, 2, embed_dim) point_embed point_embed.unsqueeze(2) # (1, 2, 1, embed_dim) # 坐标编码 coords_embed points.unsqueeze(-1) * point_embed # (B, N, 2, embed_dim) coords_embed coords_embed.sum(dim2) # (B, N, embed_dim) # 标签编码 label_embed self.label_embed(labels) # (B, N, embed_dim) # 组合编码 point_embed coords_embed label_embed return point_embed4.2 掩码解码流程掩码解码器整合图像特征和提示信息生成最终分割结果class MaskDecoder(nn.Module): def __init__(self, transformer_dim256, num_multimask_outputs3): super().__init__() self.transformer TwoWayTransformer(depth2, embedding_dimtransformer_dim) self.output_upscaling nn.Sequential( nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size2, stride2), nn.LayerNorm(transformer_dim // 4), nn.GELU(), nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size2, stride2), nn.GELU(), ) self.output_hypernetworks nn.ModuleList([ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(num_multimask_outputs) ]) def forward(self, image_embedding, prompt_embedding): # Transformer融合特征 tokens torch.cat([ image_embedding.flatten(1, 2), # (B, H*W, C) prompt_embedding # (B, N, C) ], dim1) tokens self.transformer(tokens) # 生成掩码 mask_tokens tokens[:, :image_embedding.shape[1]*image_embedding.shape[2]] mask_tokens mask_tokens.view_as(image_embedding) # 上采样 upscaled_embedding self.output_upscaling(mask_tokens.permute(0, 3, 1, 2)) masks torch.stack([ h(upscaled_embedding) for h in self.output_hypernetworks ], dim1) return masks5. 完整推理流程实现5.1 端到端推理函数整合所有组件实现完整的推理流程def predict_masks(image, points, point_labels, model): # 图像预处理 input_image preprocess_image(image) # 图像编码 image_embedding model.image_encoder(input_image) # 提示编码 points torch.as_tensor(points, dtypetorch.float32) point_labels torch.as_tensor(point_labels, dtypetorch.int64) prompt_embedding model.prompt_encoder(points, point_labels) # 掩码解码 masks model.mask_decoder(image_embedding, prompt_embedding) # 后处理 masks masks 0 # 转换为二值掩码 return masks5.2 交互式演示示例使用matplotlib实现简单的交互界面def interactive_demo(image_path, model): image cv2.imread(image_path) image cv2.cvtColor(image, cv2.COLOR_BGR2RGB) fig, ax plt.subplots() ax.imshow(image) points [] labels [] def onclick(event): if event.inaxes ! ax: return # 左键前景右键背景 label 1 if event.button 1 else 0 points.append([event.xdata, event.ydata]) labels.append(label) # 绘制点标记 color green if label 1 else red ax.scatter(event.xdata, event.ydata, ccolor, s50) fig.canvas.draw() # 执行预测 if len(points) 0: masks predict_masks(image_path, [points], [labels], model) for i in range(masks.shape[1]): mask masks[0, i].cpu().numpy() ax.imshow(mask, alpha0.3 * mask, cmapBlues) fig.canvas.draw() fig.canvas.mpl_connect(button_press_event, onclick) plt.show()在实际项目中我发现合理设置提示点的位置对分割结果影响很大。通常建议在物体边界附近放置前景和背景点这样模型能更准确地理解分割边界。对于复杂场景可以尝试多次点击逐步优化分割结果。