金字塔场景解析网络PSPNet:打通全局上下文,屠榜语义分割三大基准
论文基本信息标题Pyramid Scene Parsing Network会议CVPR 2017单位香港中文大学、商汤科技代码https://github.com/hszhao/PSPNet论文https://arxiv.org/pdf/1612.01105.pdf前言在语义分割领域FCN虽然实现了端到端像素级预测但在复杂场景中经常闹笑话把河里的船认成汽车、把楼认成摩天楼、小目标枕头直接忽略。核心原因就是缺乏全局上下文信息。PSPNet横空出世提出金字塔池化模块PPM把不同尺度、不同区域的全局特征融合进来再搭配深度监督损失稳定训练直接拿下2016年ImageNet场景解析、PASCAL VOC 2012、Cityscapes三个榜单第一单模型mIoU高达85.4%成为语义分割史上的里程碑模型。图 1.ADE20K 数据集中复杂场景的示例图。一、FCN做场景解析的三大致命问题作者在ADE20K数据集上分析FCN预测结果总结出三大通病关系不匹配缺乏场景常识汽车不会出现在河里FCN却把船认成汽车类别混淆墙/房子/大楼/摩天楼外观相似FCN会给同一个物体标多个类小目标忽略路灯、招牌、枕头等不显眼物体FCN容易分类错误图 2.我们在 ADE20K [43] 数据集上观察到的场景解析问题。第一行显示了关系不匹配的问题——汽车很少会浮在水面上而船只则不然。第二行展示了混淆的类别其中“建筑物”类别很容易被误认为是“摩天大楼”。第三行展示了不显眼的类别。在这个例子中枕头在颜色和质地方面与床单非常相似。这些不显眼的物体很容易被全卷式网络FCN错误分类。图片分析FCN预测结果错误百出PSPNet借助全局上下文完美修正精准识别物体类别与边界。二、核心创新金字塔池化模块PPM2.1 设计思路CNN理论感受野远大于实际感受野无法有效捕捉全局信息。全局平均池化又过于简单会丢失空间关系。金字塔池化模块对特征图做四种不同尺度的池化融合全局局部不同区域的上下文信息。2.2 结构与计算流程输入ResNet空洞卷积输出的特征图尺寸为输入图像的1/8四层金字塔池化bin大小分别为1×1、2×2、3×3、6×6每层经过1×1卷积降维减少通道数上采样到原特征图尺寸拼接原始特征四层金字塔特征得到融合全局信息的最终特征图 3.我们所提出的 PSPNet 的概述。给定一个输入图像a我们首先使用卷积神经网络CNN获取最后一层卷积层的特征图b然后应用金字塔解析模块来获取不同的子区域表示接着经过上采样和拼接层来形成最终的特征表示c该表示包含了局部和全局的上下文信息。最后该表示被输入到卷积层中以获得最终的像素级预测d。图片分析输入图像→CNN提取特征→金字塔池化融合多尺度全局特征→卷积输出像素级预测。2.3 数学表达最终特征FfinalF_{final}Ffinal由原始特征F0F_{0}F0与金字塔各层特征F1,F2,F3,F4F_{1},F_{2},F_{3},F_{4}F1,F2,F3,F4拼接而成FfinalConcat(F0,F1,F2,F3,F4)F_{final} Concat(F_{0}, F_{1}, F_{2}, F_{3}, F_{4})FfinalConcat(F0,F1,F2,F3,F4)F0F_{0}F0CNN主干输出的原始特征F1−F4F_{1}−F_{4}F1−F4金字塔四层池化降维上采样后的特征ConcatConcatConcat通道维度拼接通俗解释把不同尺度的特征叠在一起让模型同时看到全局和局部三、训练技巧深度监督损失深层ResNet训练困难作者提出辅助损失分支在ResNet第4阶段后添加辅助分类器主损失辅助损失共同优化网络辅助损失权重设为0.4测试时丢弃辅助分支不影响推理速度图 4.ResNet101 中辅助损失的示例。每个蓝色方框代表一个残差块。辅助损失是在第 4b22 个残差块之后添加的。图片分析蓝色块为残差块res4b22后添加辅助损失帮助梯度回传让深层网络更好收敛。损失函数LossLossmain0.4×LossauxLoss Loss_{main} 0.4 × Loss_{aux}LossLossmain0.4×LossauxLossmainLoss_{main}Lossmain主分支分割损失LossauxLoss_{aux}Lossaux辅助分支分割损失0.4辅助损失权重实验最优值四、核心代码实现PyTorch4.1 金字塔池化模块PPMimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFclassPyramidPooling(nn.Module):def__init__(self,in_channels,pool_sizes,out_channels):super().__init__()# 1×1卷积降维self.convsnn.ModuleList()forpool_sizeinpool_sizes:self.convs.append(nn.Sequential(nn.AdaptiveAvgPool2d(pool_size),nn.Conv2d(in_channels,out_channels,1,biasFalse),nn.BatchNorm2d(out_channels),nn.ReLU(inplaceTrue)))defforward(self,x):features[x]h,wx.shape[2:]forconvinself.convs:outconv(x)# 上采样到原尺寸outF.interpolate(out,size(h,w),modebilinear,align_cornersFalse)features.append(out)# 拼接所有特征returntorch.cat(features,dim1)# 初始化PPM输入通道2048池化尺度[1,2,3,6]降维到512ppmPyramidPooling(in_channels2048,pool_sizes[1,2,3,6],out_channels512)4.2 PSPNet主干classPSPNet(nn.Module):def__init__(self,num_classes):super().__init__()# 假设使用ResNet101作为主干self.resnetresnet101(pretrainedTrue)# 金字塔池化2048→4×512拼接后共204820484096通道self.ppmPyramidPooling(2048,[1,2,3,6],512)# 最终分类卷积self.cls_convnn.Sequential(nn.Conv2d(4096,512,3,padding1,biasFalse),nn.BatchNorm2d(512),nn.ReLU(inplaceTrue),nn.Dropout(0.1),nn.Conv2d(512,num_classes,1))# 辅助损失分支self.aux_convnn.Sequential(nn.Conv2d(1024,256,3,padding1,biasFalse),nn.BatchNorm2d(256),nn.ReLU(inplaceTrue),nn.Dropout(0.1),nn.Conv2d(256,num_classes,1))defforward(self,x):# 主干前向xself.resnet.conv1(x)xself.resnet.bn1(x)xself.resnet.relu(x)xself.resnet.maxpool(x)xself.resnet.layer1(x)xself.resnet.layer2(x)xself.resnet.layer3(x)auxself.resnet.layer4(x)# 辅助分支特征xself.resnet.layer5(aux)# 主分支特征# 金字塔池化xself.ppm(x)outself.cls_conv(x)outF.interpolate(out,sizex.shape[2:],modebilinear,align_cornersFalse)ifself.training:aux_outself.aux_conv(aux)aux_outF.interpolate(aux_out,sizex.shape[2:],modebilinear,align_cornersFalse)returnout,aux_outreturnout五、实验结果三大榜单屠榜验证5.1 金字塔池化消融实验方法Mean IoU(%)Pixel Acc.(%)ResNet50-Baseline37.2378.01ResNet50全局池化40.0779.52ResNet50PPM(AVG)41.6880.04表格1 出处论文表1表格分析单纯全局池化提升有限四层金字塔池化比基线高4.45% mIoU效果显著平均池化效果优于最大池化5.2 辅助损失消融实验损失权重αMean IoU(%)Pixel Acc.(%)无辅助损失35.8277.07α0.337.0177.87α0.437.2378.01表格2 出处论文表2表格分析辅助损失有效提升精度权重α0.4时效果最优5.3 三大数据集SOTAPASCAL VOC 2012单模型mIoU85.4%超越所有同期方法CityscapesmIoU80.2%大幅领先DeepLab、FCNImageNet场景解析冠军单模型超多数集成模型图 7.对 PASCAL VOC 2012 数据集的视觉效果进行了改进。PSPNet 能够生成更准确、更详细的结果。图片分析基线把牛认成马和狗PSPNet精准修正对小物体、遮挡物体识别更准确。六、全文总结核心贡献金字塔池化模块PPM融合多尺度全局上下文解决FCN缺乏场景信息的问题深度监督损失稳定训练深度ResNet加速收敛、提升精度工程化实现公开完整代码与模型语义分割落地标配三榜第一验证方法通用性与强性能核心逻辑空洞卷积扩大感受野 → 金字塔池化抓取全局上下文 → 深度监督稳定训练 → 精准像素级预测。PSPNet证明了全局上下文在场景解析中的重要性后续的DeepLabv3、SegFormer等模型都借鉴了多尺度上下文融合思想至今仍是学习语义分割必读的经典算法。