经典语义分割(一)利用pytorch复现全卷积神经网络FCN

news2025/1/9 15:49:49

经典语义分割(一)利用pytorch复现全卷积神经网络FCN

这里选择B站up主[霹雳吧啦Wz]根据pytorch官方torchvision模块中实现的FCN源码。

Github连接:FCN源码

1 FCN模型搭建

1.1 FCN网络图

  • pytorch官方实现的FCN网络图,如下所示。
  • 在这里插入图片描述

1.2 backbone

  • FCN原文中的backbone是VGG,这里pytorch官方采用了resnet作为FCN的backbone。
    • ResNet的前两层跟GoogLeNet中的⼀样:
      • 在输出通道数为64、步幅为2的7 × 7卷积层后,接步幅为2的3 × 3的最大汇聚层。
      • 不同之处在于ResNet每个卷积层后增加了批量规范化层。
    • GoogLeNet在后面接了4个由Inception块组成的模块。ResNet后接4个由残差块。
      • ResNet则使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。
      • 第1个模块(layer1)由于之前已经使用了步幅为2的最大汇聚层,所以无须减小高和宽。
      • 原生的ResNet在之后的每个模块(layer2、layer3、layer4)在第⼀个残差块里将上一个模块的通道数翻倍,并将高和宽减半。
      • 不过,在这里和原生的ResNet不同的是,layer3和layer4使用了空洞卷积,并且高宽不减半。
# /fcn/src/backbone.py
import torch
import torch.nn as nn
from torchinfo import summary

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        '''
        1、ResNet的前两层
           ResNet的前两层跟GoogLeNet中的⼀样:
              在输出通道数为64、步幅为2的7 × 7卷积层后,接步幅为2的3 × 3的最⼤汇聚层。
              不同之处在于ResNet每个卷积层后增加了批量规范化层。
        '''
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.inplanes, kernel_size=7, stride=2, padding=3,bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        '''
        2、ResNet后接4个由残差块
            GoogLeNet在后⾯接了4个由Inception块组成的模块。
            ResNet则使⽤4个由残差块组成的模块,每个模块使⽤若⼲个同样输出通道数的残差块。
            第⼀个模块(layer1)由于之前已经使⽤了步幅为2的最⼤汇聚层,所以⽆须减⼩⾼和宽。
            之后的每个模块(layer2、layer3、layer4)在第⼀个残差块⾥将上⼀个模块的通道数翻倍,并将⾼和宽减半。
            
            不过,在这里和原生的ResNet不同的是,layer3和layer4使用了空洞卷积,并且高宽不减半。
        '''
        self.layer1 = self._make_layer(block, 64,  layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            # layer3和layer4使用了空洞卷积,高宽不减半,因此设置stride = 1
            self.dilation *= stride
            stride = 1
        # layer2、layer3和layer4的stride=2,满足
        # layer1的stride=1,但是inplanes(64) != planes * block.expansion(64×4),因此也满足
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )
        # 对于每个layer,只有第1个Bottleneck需要downsample
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        # 对于每个layer,从第2个Bottleneck开始,就不需要downsample
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)


def _resnet(block, layers, **kwargs):
    model = ResNet(block, layers, **kwargs)
    return model


def resnet50(**kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet101(**kwargs):
    r"""ResNet-101 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)



if __name__ == '__main__':
    net = resnet50(replace_stride_with_dilation=[False, True, True])
    print(net)
    # pip install torchinfo
    # 可以看到网络每一层的输出shape以及网络参数信息
    summary(net, input_size=(1, 3, 480, 480))

1.3 FCN Head

  • 经过backbone后,再通过FCN Head模块。
    • 通过3×3卷积层缩小通道为原来的1/4【2048-512】,再通过一个dropout和一个1×1卷积层
    • 这里1×1卷积层调整特征层的channel为分割类别中的类别个数。
    • layer3中引出的一条FCN Head辅助分类器,是为了防止误差梯度无法传递到网络浅层。
      • 训练的时候是可以使用辅助分类器件的。
      • 最后去预测或者部署到正式环境的时候只用主干的output,不用aux output。
  • 最后经过双线性插值还原特征图大小到原图。
# /fcn/src/fcn_model.py
from collections import OrderedDict

from typing import Dict

import torch
from torch import nn, Tensor
from torch.nn import functional as F
try:
    from .backbone import resnet50, resnet101
except:
    from backbone import resnet50, resnet101


class IntermediateLayerGetter(nn.ModuleDict):
    _version = 2
    __annotations__ = {
        "return_layers": Dict[str, str],
    }

    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
        orig_return_layers = return_layers
        return_layers = {str(k): str(v) for k, v in return_layers.items()}

        # 重新构建backbone,将没有使用到的模块全部删掉
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layers = orig_return_layers

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            # self.return_layers = {'layer4': 'out', 'layer3': 'aux'}
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out


class FCN(nn.Module):
    __constants__ = ['aux_classifier']

    def __init__(self, backbone, classifier, aux_classifier=None):
        super(FCN, self).__init__()
        self.backbone = backbone
        self.classifier = classifier
        self.aux_classifier = aux_classifier

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        input_shape = x.shape[-2:]
        # contract: features is a dict of tensors
        features = self.backbone(x)

        result = OrderedDict()
        x = features["out"]
        x = self.classifier(x)
        # 原论文中虽然使用的是ConvTranspose2d,但权重是冻结的,所以就是一个bilinear插值
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        result["out"] = x

        # FCN Head辅助分类器,是为了防止误差梯度无法传递到网络浅层
        if self.aux_classifier is not None:
            x = features["aux"]
            x = self.aux_classifier(x)
            # 原论文中虽然使用的是ConvTranspose2d,但权重是冻结的,所以就是一个bilinear插值
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
            result["aux"] = x

        return result


class FCNHead(nn.Sequential):
    def __init__(self, in_channels, channels):
        # 通过3×3卷积层缩小通道为原来的1/4【2048-512】,再通过一个dropout和一个1×1卷积层
        inter_channels = in_channels // 4
        layers = [
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(inter_channels, channels, 1) # 这里1×1卷积层调整特征层的channel为分割类别中的类别个数
        ]

        super(FCNHead, self).__init__(*layers)


def fcn_resnet50(aux, num_classes=21, pretrain_backbone=False):
    # 'resnet50_imagenet': 'https://download.pytorch.org/models/resnet50-0676ba61.pth'
    # 'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth'
    backbone = resnet50(replace_stride_with_dilation=[False, True, True])

    if pretrain_backbone:
        # 载入resnet50 backbone预训练权重
        backbone.load_state_dict(torch.load("resnet50.pth", map_location='cpu'))

    out_inplanes = 2048
    aux_inplanes = 1024

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    # backbone经过前向传播的结果为OrderedDict()
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    # why using aux: https://github.com/pytorch/vision/issues/4292
    if aux:
        aux_classifier = FCNHead(aux_inplanes, num_classes)

    classifier = FCNHead(out_inplanes, num_classes)

    model = FCN(backbone, classifier, aux_classifier)

    return model


def fcn_resnet101(aux, num_classes=21, pretrain_backbone=False):
    # 'resnet101_imagenet': 'https://download.pytorch.org/models/resnet101-63fe2227.pth'
    # 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth'
    backbone = resnet101(replace_stride_with_dilation=[False, True, True])

    if pretrain_backbone:
        # 载入resnet101 backbone预训练权重
        backbone.load_state_dict(torch.load("resnet101.pth", map_location='cpu'))

    out_inplanes = 2048
    aux_inplanes = 1024

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    # why using aux: https://github.com/pytorch/vision/issues/4292
    if aux:
        aux_classifier = FCNHead(aux_inplanes, num_classes)

    classifier = FCNHead(out_inplanes, num_classes)

    model = FCN(backbone, classifier, aux_classifier)

    return model


if __name__ == '__main__':
    model = fcn_resnet50(aux=True, num_classes=21)
    print(model)
    x = torch.randn(size=(1, 3, 480, 480))
    print(model(x)['out'].shape)
    print(model(x)['aux'].shape)

2 损失函数的计算

2.1 VOC的标注详解

在这里插入图片描述

  • 这张图片大致可以分为四部分,一部分是黑色背景,一部分是粉红色的人,一部分是大红色的飞机,还有一部分是白色的神秘物体。

    • 图片的背景,它是黑色的,背景类别为0,因此在调色板中0所对应的RGB值为[0,0,0],为黑色。

    • pascal_voc_classes.json中"person": 15,可知人用数字15表示,而在palette.json中,"15": [192, 128, 128]可知15对应的RGB为粉红色,因此粉红色的是人。

    • 同理,可知飞机"aeroplane": 1在调色板中对应的颜色为大红色。

    • 这个白色的神秘物体其实也是一个小飞机,但很难分辨,故标注时用白色像素给隐藏起来了,最后白色对应的像素也不会参与损失计算。如果你足够细心的话,你会发现在人和飞机的边缘其实都是存在一圈白色的像素的,这是为了更好的区分不同类别对应的像素。同样,这里的白色也不会参与损失计算。

  • 我们可以用程序来看看标注图像中是否有白色像素。

    from PIL import Image
    import numpy as np
    img = Image.open('D:\\VOCdevkit\\VOC2007\\SegmentationClass\\2007_000032.png')
    img_np = np.array(img)
    

    在这里插入图片描述

    • 可以看到地下的像素是1,表示飞机(大红色),上面的像素为0,表示背景(黑色),中间的像素为255,这就对应着飞机周围的白色像素。
    • 我们可以看一下255对应的RGB值, [224,224,192]表示的RGB颜色为白色。
    • 这里的255需要注意,后面计算损失时白色部分不计算正是通过忽略这个值实现的。

2.2 交叉熵损失cross_entropy

l o s s ( x , c l a s s ) = − l o g ( e x [ c l a s s ] ∑ j e x [ j ] ) = − x [ c l a s s ] + l o g ( ∑ j e x [ j ] ) 举个例子:假设输入 x = [ 0.1 , 0.2 , 0.3 ] ,标签 c l a s s = 1 l o s s ( x , c l a s s ) = − x [ c l a s s ] + l o g ( ∑ j e x [ j ] ) = − 0.2 + l o g ( e x [ 0 ] + e x [ 1 ] + e x [ 2 ] ) = − 0.2 + l o g ( e 0.1 + e 0.2 + e 0.3 ) loss(x,class)=-log(\frac{e^{x[class]}}{\sum\limits_{j} e^{x[j]}})=-x[class]+log(\sum\limits_{j} e^{x[j]})\\ 举个例子:假设输入x=[0.1,0.2,0.3],标签class=1 \\ loss(x,class)=-x[class]+log(\sum\limits_{j} e^{x[j]})=-0.2 +log( e^{x[0]} + e^{x[1]} + e^{x[2]}) \\ = -0.2 +log( e^{0.1} + e^{0.2} + e^{0.3}) loss(x,class)=log(jex[j]ex[class])=x[class]+log(jex[j])举个例子:假设输入x=[0.1,0.2,0.3],标签class=1loss(x,class)=x[class]+log(jex[j])=0.2+log(ex[0]+ex[1]+ex[2])=0.2+log(e0.1+e0.2+e0.3)

我们可以用程序进行验证:

import torch
import numpy as np
import math

# 官方实现
input = torch.tensor([[0.1, 0.2, 0.3],
                      [0.1, 0.2, 0.3],
                      [0.1, 0.2, 0.3]])
target = torch.tensor([0, 1, 2])
loss = torch.nn.functional.cross_entropy(input, target)
print('官方计算 loss = ', loss.numpy())

# 自己计算
res0 = -0.1 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res1 = -0.2 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res2 = -0.3 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res = (res0 + res1 + res2) / 3
print('自己计算 loss = %.7f ' % res)
# 仅精度有差别,所以这证明了我们的计算方式是没有错的。
官方计算 loss = 1.1019429
自己计算 loss = 1.1019428 

FCN在计算损失是会忽略白色的像素,其就对应着标签中的255。

忽略白色像素的损失其实很简单,只要在函数调用时传入ignore_index并指定对应的值即可。

如对本例来说,现我打算忽略target中标签为2的数据,即不让其参与损失计算,我们来看看如何使用cross_entropy函数来实现。

import torch
import numpy as np
import math

# 官方实现
input = torch.tensor([[0.1, 0.2, 0.3],
                      [0.1, 0.2, 0.3],
                      [0.1, 0.2, 0.3]])
target = torch.tensor([0, 1, 2])
loss = torch.nn.functional.cross_entropy(input, target,  ignore_index=2)
print('官方计算 loss = ', loss.numpy())

# 自己计算
res0 = -0.1 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res1 = -0.2 + np.log(math.exp(0.1) + math.exp(0.2) + math.exp(0.3))
res = (res0 + res1 ) / 2
print('自己计算 loss = %.6f ' % res)
官方计算 loss = 1.151943
自己计算 loss = 1.151943 

2.3 FCN中损失计算过程

  • 程序中输入cross_entropy函数中的x通常是4维的tensor,即[N,C,H,W],这时候训练损失是怎么计算的呢?我们以x的维度为[1,2,2,2]为例讲解

  • 我们手动计算时候,会将数据按通道方向展开,然后分别计算cross_entropy,最后求平均(如下图所示)

在这里插入图片描述

import torch
import numpy as np
import math

# 1、官方计算
input = torch.tensor([[[[0.1, 0.2],
                        [0.3, 0.4]],

                       [[0.5, 0.6],
                        [0.7, 0.8]]]])    #shape(1 2 2 2 )

target = torch.tensor([[[0, 1],
                        [0, 1]]])

loss = torch.nn.functional.cross_entropy(input, target)
print('官方计算 loss = ', loss.numpy())


# 2、自己计算
res0 = -0.1 + np.log(math.exp(0.1) + math.exp(0.5))
res1 = -0.6 + np.log(math.exp(0.2) + math.exp(0.6))
res2 = -0.3 + np.log(math.exp(0.3) + math.exp(0.7))
res3 = -0.8 + np.log(math.exp(0.4) + math.exp(0.8))
res = (res0 + res1 + res2 + res3)/4
print('自己计算 loss = %.8f ' % res)
官方计算 loss = 0.71301526
自己计算 loss = 0.71301525 
  • 如果,我们此时忽略target=0

在这里插入图片描述

import torch
import numpy as np
import math

# 1、官方计算
input = torch.tensor([[[[0.1, 0.2],
                        [0.3, 0.4]],

                       [[0.5, 0.6],
                        [0.7, 0.8]]]])    #shape(1 2 2 2 )

target = torch.tensor([[[0, 1],
                        [0, 1]]])

loss = torch.nn.functional.cross_entropy(input, target , ignore_index=0)
print('官方计算 loss = ', loss.numpy())


# 2、自己计算
res1 = -0.6 + np.log(math.exp(0.2) + math.exp(0.6))
res3 = -0.8 + np.log(math.exp(0.4) + math.exp(0.8))
res = ( res1  + res3)/2
print('自己计算 loss = %.7f ' % res)
官方计算 loss =  0.5130153
自己计算 loss =  0.5130153 

2.4 FCN中损失代码

  • 通过上面讲解,我们就很容易理解FCN的损失计算了。这里忽略了255像素,不让其参与到损失的计算中。
  • 如果辅助分类器存在,给予较小的损失权重。
# fcn/train_utils/train_and_eval.py
def criterion(inputs, target):
    losses = {}
    for name, x in inputs.items():
        # 忽略target中值为255的像素,255的像素是目标边缘或者padding填充
        losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)

    if len(losses) == 1:
        return losses['out']

    return losses['out'] + 0.5 * losses['aux']

3 VOC数据集的读取及数据预处理

我们自定义VOCSegmentation类,继承pytorch提供的torch.utils.data.Dataset类,主要实现__getitem__函数。再利用pytorch提供的Dataloader,就可以通过调用__getitem__函数来批量读取VOC数据集图片和标签了。

VOCSegmentation类的初始化部分,如下方的代码所示:

# fcn/my_dataset.py
class VOCSegmentation(data.Dataset):
    def __init__(self, voc_root, year="2007", transforms=None, txt_name: str = "train.txt"):
        super(VOCSegmentation, self).__init__()
        assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
        root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
        assert os.path.exists(root), "path '{}' does not exist.".format(root)
        image_dir = os.path.join(root, 'JPEGImages')
        mask_dir = os.path.join(root, 'SegmentationClass')

        txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)
        assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)
        with open(os.path.join(txt_path), "r") as f:
            file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]

        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
        assert (len(self.images) == len(self.masks))
        self.transforms = transforms
  • 首先我们需要获取输入(image)和标签(target)的路径。

    • voc_root是我们应该传入VOCdevkit所在的文件夹。

    • 最终self.image和self.masks里存储的就是我们输入和标签的路径了。

  • 接着我们对输入图片和标签进行transformer预处理(代码如下)

    • 训练集采用了随机缩放、水平翻转、随机裁剪、toTensor和Normalize。
    • 验证集仅使用了随机缩放、toTensor和Normalize。
    • crop_size设置为480,即训练图片都会裁剪到480*480大小,而验证时没有使用随机裁剪方法,因此验证集的图片尺寸是不一致的, 需要进行进一步的处理
# fcn/train.py
class SegmentationPresetTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        min_size = int(0.5 * base_size)
        max_size = int(2.0 * base_size)

        trans = [T.RandomResize(min_size, max_size)]
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))
        trans.extend([
            T.RandomCrop(crop_size),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
        self.transforms = T.Compose(trans)

    def __call__(self, img, target):
        return self.transforms(img, target)


class SegmentationPresetEval:
    def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transforms = T.Compose([
            T.RandomResize(base_size, base_size),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])

    def __call__(self, img, target):
        return self.transforms(img, target)


def get_transform(train):
    base_size = 520
    crop_size = 480

    return SegmentationPresetTrain(base_size, crop_size) if train else SegmentationPresetEval(base_size)
  • 预处理代码完成后,就可以实现__getitem__以及__len__方法。

     # fcn/my_dataset.py
        def __getitem__(self, index):
            """
            Args:
                index (int): Index
    
            Returns:
                tuple: (image, target) where target is the image segmentation.
            """
            img = Image.open(self.images[index]).convert('RGB')
            target = Image.open(self.masks[index])
    
            if self.transforms is not None:
                img, target = self.transforms(img, target)
    
            return img, target
    
        def __len__(self):
            return len(self.images)
        
        
        @staticmethod
        def collate_fn(batch):
            images, targets = list(zip(*batch))
            batched_imgs = cat_list(images, fill_value=0)
            batched_targets = cat_list(targets, fill_value=255)
            return batched_imgs, batched_targets
    
  • 在VOCSegmentation类中,还实现了DataLoader中需要的collate_fn。

    • 在collate_fn中,接受一个List类型数据,其中每个元素是一个Tuple2类型,包括了image和target。
    • 在collate_fn中调用cat_list方法,对验证集图片尺寸是不一致进行处理。
    # fcn/my_dataset.py
    def cat_list(images, fill_value=0):
        # 计算该batch数据中,channel, h, w的最大值
        max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
        batch_shape = (len(images),) + max_size
        batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
        for img, pad_img in zip(images, batched_imgs):
            pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
        return batched_imgs
    
  • 最后就可以调用Dataloader批量获取数据了。

# fcn/train.py   
    # VOCdevkit -> VOC2007 -> ImageSets -> Segmentation -> train.txt
    train_dataset = VOCSegmentation(args.data_path,
                                    year="2007",
                                    transforms=get_transform(train=True),
                                    txt_name="train.txt")

    # VOCdevkit -> VOC2007 -> ImageSets -> Segmentation -> val.txt
    val_dataset = VOCSegmentation(args.data_path,
                                  year="2007",
                                  transforms=get_transform(train=False),
                                  txt_name="val.txt")

    num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             collate_fn=val_dataset.collate_fn)

4 模型训练及测试

4.1 模型训练

  • 代码在 fcn/train.py 中。

  • 先利用Dataset和DataLoader批量获取数据。

  • 然后创建FCN网络模型,可以加载在COCO数据集上的预训练权重。

    def create_model(aux, num_classes, pretrain=True):
        model = fcn_resnet50(aux=aux, num_classes=num_classes)
    
        if pretrain:
            weights_dict = torch.load("./fcn_resnet50_coco.pth", map_location='cpu')
    
            if num_classes != 21:
                # 官方提供的预训练权重是21类(包括背景)
                # 如果训练自己的数据集,将和类别相关的权重删除,防止权重shape不一致报错
                for k in list(weights_dict.keys()):
                    if "classifier.4" in k:
                        del weights_dict[k]
    
            missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
            if len(missing_keys) != 0 or len(unexpected_keys) != 0:
                print("missing_keys: ", missing_keys)
                print("unexpected_keys: ", unexpected_keys)
    
        return model
    
  • 设置SGD优化器

    # 设置优化器
    optimizer = torch.optim.SGD(
            params_to_optimize,
            lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
        )
    
  • 设置学习率更新策略。

     # 创建学习率更新策略,这里是每个step更新一次(不是每个epoch)
        lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True)
    
    # fcn/train_utils/train_and_eval.py
    def create_lr_scheduler(optimizer,
                            num_step: int,
                            epochs: int,
                            warmup=True,
                            warmup_epochs=1,
                            warmup_factor=1e-3):
        assert num_step > 0 and epochs > 0
        if warmup is False:
            warmup_epochs = 0
    
        def f(x):
            """
            根据step数返回一个学习率倍率因子,
            注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
            """
            if warmup is True and x <= (warmup_epochs * num_step):
                alpha = float(x) / (warmup_epochs * num_step)
                # warmup过程中lr倍率因子从warmup_factor -> 1
                return warmup_factor * (1 - alpha) + alpha
            else:
                # warmup后lr倍率因子从1 -> 0
                # 参考deeplab_v2: Learning rate policy
                return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
    
        return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
    
  • 训练代码如下,可以代码调试。

    for epoch in range(args.start_epoch, args.epochs):
            mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch,
                                            lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)
            # 测试
            confmat = evaluate(model, val_loader, device=device, num_classes=num_classes)
            val_info = str(confmat)
            print(val_info)
            # write into txt
            with open(results_file, "a") as f:
                # 记录每个epoch对应的train_loss、lr以及验证集各指标
                train_info = f"[epoch: {epoch}]\n" \
                             f"train_loss: {mean_loss:.4f}\n" \
                             f"lr: {lr:.6f}\n"
                f.write(train_info + val_info + "\n\n")
    
            save_file = {"model": model.state_dict(),
                         "optimizer": optimizer.state_dict(),
                         "lr_scheduler": lr_scheduler.state_dict(),
                         "epoch": epoch,
                         "args": args}
            if args.amp:
                save_file["scaler"] = scaler.state_dict()
            torch.save(save_file, "save_weights/model_{}.pth".format(epoch))
    

4.2 模型测试

在 train_and_val.py 文件中的 evaluate 函数代码如下:

  • 创建 ConfusionMatrix 混淆矩阵
  • 使用 for 循环遍历 data_loader 得到 image 和 target 信息,并将其指给对应的设备当中
  • 再将 image 图像输入到 model 模型中进行预测,得到 output 输出(只使用主分支上的输出)
  • 调用 update 方法时,在计算每一批数据预测结果与真实结果对比的过程中,将 target 和 output.argmax(1) 进行 flatten 处理
    • output.argmax(1) 中的 1 是指在 channel 维度,而 argmax 方法用于 将每个像素预测值最大的类别作为其预测类别(如下图所示) 。
    • 在这里插入图片描述
# fcn/train_utils/train_and_eval.py
def evaluate(model, data_loader, device, num_classes):
    model.eval()
    confmat = utils.ConfusionMatrix(num_classes)
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    with torch.no_grad():
        for image, target in metric_logger.log_every(data_loader, 100, header):
            image, target = image.to(device), target.to(device)
            output = model(image)
            output = output['out']

            confmat.update(target.flatten(), output.argmax(1).flatten())

        confmat.reduce_from_all_processes()

    return confmat

ConfusionMatrix 类代码如下:

  • ConfusionMatrix 类中的 update 函数传入了真实标签 a 和预测标签 b 等参数,代码的具体解析:

    • 这里的 num_classes 是指包含了背景的类别个数。
    • 如果 self.mat 是 None ,就使用 torch.zeros 创建一个全零矩阵作为混淆矩阵,大小为 n x n ,用于记录真实标签和预测标签之间的关系。
    • 通过检查真实标签 a 中的元素是否属于有效类别范围 [ 0 , N ) 来寻找属于目标类别的像素索引。
    • 根据像素的真实类别 a [ k ] 和预测类别 b [ k ] 计算类别索引 inds ,用于统计真实类别为 a [ k ] 被预测成 b [ k ] 的像素个数。
    • 使用 torch.bincount 统计类别索引 inds 在 [ 0 , n**2 ) 内的出现次数,并将结果重塑成 ( n , n ) 的矩阵形状,统计数据累加到混淆矩阵中。
      在这里插入图片描述
  • ConfusionMatrix 类中的 compute 函数计算常见的语义分割评价指标。

    • 语义分割评价指标主要包括 Pixel Accuracy ( Global Accuracy )、mean Accuracy、mean IoU 等:
      • Pixel Accuracy = 类别预测正确的像素个数总和 ÷ 图片的总像素个数
      • mean Accuracy = 对每个类别的 Accuracy 求平均值
      • mean IoU = 对每个类别的 IoU 求平均值
class ConfusionMatrix(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, a, b):
        n = self.num_classes
        if self.mat is None:
            # 创建混淆矩阵
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
        with torch.no_grad():
            # 寻找GT中为目标的像素索引(例如:255就不是目标的像素索引)
            k = (a >= 0) & (a < n)
            # 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙)
            inds = n * a[k].to(torch.int64) + b[k]
            self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)

    def reset(self):
        if self.mat is not None:
            self.mat.zero_()

    def compute(self):
        h = self.mat.float()
        # 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)
        acc_global = torch.diag(h).sum() / h.sum()
        # 计算每个类别的准确率
        acc = torch.diag(h) / h.sum(1)
        # 计算每个类别预测与真实目标的iou
        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return acc_global, acc, iu

    def reduce_from_all_processes(self):
        if not torch.distributed.is_available():
            return
        if not torch.distributed.is_initialized():
            return
        torch.distributed.barrier()
        torch.distributed.all_reduce(self.mat)

    def __str__(self):
        acc_global, acc, iu = self.compute()
        return (
            'global correct: {:.1f}\n'
            'average row correct: {}\n'
            'IoU: {}\n'
            'mean IoU: {:.1f}').format(
                acc_global.item() * 100,
                ['{:.1f}'.format(i) for i in (acc * 100).tolist()],
                ['{:.1f}'.format(i) for i in (iu * 100).tolist()],
                iu.mean().item() * 100)

4.3 模型预测

  • 模型输出为1×c×h×w,因为这是预测,故batch=1,这里使用的是VOC数据,故这里的c=num_class=21。【包含一个背景类】
  • 首先我们会取输出中每个像素在21个通道中的最大值,如第一个像素在21个通道的最大值在通道0上取得。这个通道对应的索引是0,在VOC中是背景类,故这个像素所属类别为背景。其它像素同理。
  • 在这里插入图片描述
    # fcn/predict.py
    model.eval()  # 进入验证模式
    with torch.no_grad():
        # init model
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        t_start = time_synchronized()
        output = model(img.to(device))
        t_end = time_synchronized()
        print("inference time: {}".format(t_end - t_start))
        # 在输出中的chanel维度求最大值对应的类别索引
        prediction = output['out'].argmax(1).squeeze(0)
        prediction = prediction.to("cpu").numpy().astype(np.uint8)
        mask = Image.fromarray(prediction)
        mask.putpalette(pallette)
        mask.save("test_result.png")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1486728.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

斐波那契数列模型---使用最小花费爬楼梯

746. 使用最小花费爬楼梯 - 力扣&#xff08;LeetCode&#xff09; 1、状态表示&#xff1a; 题目意思即&#xff1a;cost[i]代表从第i层向上爬1阶或者2阶&#xff0c;需要花费多少力气。如cost[0]&#xff0c;代表从第0阶爬到第1阶或者第2阶需要cost[0]的力气。 一共有cost.…

Java - List集合与Array数组的相互转换

一、List 转 Array 使用集合转数组的方法&#xff0c;必须使用集合的 toArray(T[] array)&#xff0c;传入的是类型完全一样的数组&#xff0c;大小就是 list.size() public static void main(String[] args) throws Exception {List<String> list new ArrayList<S…

梯度下降算法(带你 原理 实践)

目录 一、引言 二、梯度下降算法的原理 三、梯度下降算法的实现 四、梯度下降算法的优缺点 优点&#xff1a; 缺点&#xff1a; 五、梯度下降算法的改进策略 1 随机梯度下降&#xff08;Stochastic Gradient Descent, SGD&#xff09; 2 批量梯度下降&#xff08;Batch…

【解读】工信部数据安全能力提升实施方案

近日&#xff0c;工信部印发《工业领域数据安全能力提升实施方案&#xff08;2024-2026年&#xff09;》&#xff0c;提出到2026年底&#xff0c;我国工业领域数据安全保障体系基本建立&#xff0c;基本实现各工业行业规上企业数据安全要求宣贯全覆盖。数据安全保护意识普遍提高…

vue api封装

api封装 由于一个项目里api是很多的&#xff0c;随处都在调&#xff0c;如果按照之前的写法&#xff0c;在每个组件中去调api&#xff0c;一旦api有改动&#xff0c;遍地都要去改&#xff0c;所以api应该也要封装一下&#xff0c;将api的调用封装在函数中&#xff0c;将函数集…

sql 行列互换

在SQL中进行行列互换可以使用PIVOT函数。下面是一个示例查询及其对应的结果&#xff1a; 创建测试表格 CREATE TABLE test_table (id INT PRIMARY KEY,name VARCHAR(50),category VARCHAR(50) );向测试表格插入数据 INSERT INTO test_table VALUES (1, A, Category A); INSE…

Go语言必知必会100问题-15 缺少代码文档

缺少代码文档 文档&#xff08;代码注释&#xff09;是编码的一个重要方面&#xff0c;它可以降低客户端使用API的复杂度&#xff0c;也有助于项目维护。在Go语言中&#xff0c;我们应该遵循一些规则使得我们的代码更地道。下面一起来看看这些规则。 每个可导出的元素必须添加…

YOLOv9有效提点|加入MobileViT 、SK 、Double Attention Networks、CoTAttention等几十种注意力机制(五)

专栏介绍&#xff1a;YOLOv9改进系列 | 包含深度学习最新创新&#xff0c;主力高效涨点&#xff01;&#xff01;&#xff01; 一、本文介绍 本文只有代码及注意力模块简介&#xff0c;YOLOv9中的添加教程&#xff1a;可以看这篇文章。 YOLOv9有效提点|加入SE、CBAM、ECA、SimA…

JVM相关问题

JVM相关问题 一、Java继承时父子类的初始化顺序是怎样的&#xff1f;二、JVM类加载的双亲委派模型&#xff1f;三、JDK为什么要设计双亲委派模型&#xff0c;有什么好处&#xff1f;四、可以打破JVM双亲委派模型吗&#xff1f;如何打破JVM双亲委派模型&#xff1f;五、什么是内…

【数据结构】前缀树的模拟实现

目录 1、什么是前缀树&#xff1f; 2、模拟实现 2.1、前缀树节点结构 2.2、字符串的添加 2.3、字符串的查寻 2.3.1、查询树中有多少个以字符串"pre"作为前缀的字符串 2.3.2、查询某个字符串被添加过多少次 2.4、字符串的删除 3、完整代码 1、什么是前缀树&…

(资源篇)2025届暑假实习春招全攻略路线

绝对的全攻略&#xff0c;资源完善程度绝对的全网唯一。 觉得有帮助的&#xff1a;随手一键三连关注就是对up主最大的激励。 绝对的宝藏up主&#xff01;&#xff01;&#xff01;&#xff0c;up主每天都会进行更新视频&#xff0c;算法视频or校招信息or八股讲解。 【暴躁老…

数字化转型导师坚鹏:如何制定证券公司数字化转型年度培训规划

如何制定与实施证券公司数字化转型年度培训规划 ——以推动证券公司数字化转型战略落地为核心&#xff0c;实现知行果合一 课程背景&#xff1a; 很多证券公司都在开展数字化转型培训工作&#xff0c;目前存在以下问题急需解决&#xff1a; 缺少针对性的证券公司数字化转型…

账单怎么记账软件下载,佳易王账单记账汇总统计管理系统软件教程

账单怎么记账软件下载&#xff0c;佳易王账单记账汇总统计管理系统软件教程 一、前言 以下软件以 佳易王账单记账汇总统计管理系统软件V17.0为例说明 软件文件下载可以点击最下方官网卡片——软件下载——试用版软件下载 软件特色&#xff1a; 1、功能实用&#xff0c;操作…

第二天 Kubernetes落地实践之旅

第二天 Kubernetes落地实践之旅 本章学习kubernetes的架构及工作流程&#xff0c;重点介绍如何使用Workload管理业务应用的生命周期&#xff0c;实现服务不中断的滚动更新&#xff0c;通过服务发现和集群内负载均衡来实现集群内部的服务间访问&#xff0c;并通过ingress实现外…

one4all 排坑记录

one4all 排坑记录 任务踩坑回顾动作踩坑动作踩坑动作新一步测试Habitat-sim 测试habitat-lab继续ONE4ALL 任务 看了《One-4-All: Neural Potential Fields for Embodied Navigation》这篇论文&#xff0c;感觉挺有意思&#xff0c;他也开源了代码。视觉语言导航是我一直想做的…

CSS_实现三角形和聊天气泡框

如何用css画出一个三角形 1、第一步 写一个正常的盒子模型&#xff0c;先给个正方形的div&#xff0c;便于观察&#xff0c;给div设置宽高和背景颜色 <body><div class"box"></div> </body> <style>.box {width: 100px;height: 100px…

第三百七十九回

文章目录 1. 概念介绍2. 使用方法3. 代码与效果3.1 示例代码3.2 运行效果 4. 内容总结 013pickers2.gif 我们在上一章回中介绍了"如何实现Numberpicker"相关的内容&#xff0c;本章回中将介绍wheelChoose组件.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. 概念…

sql 注入 之sqli-labs/less-5 双注入,也称:报错注入

该关卡返回正确或者错误页面,还有错误的代码&#xff0c;所以可以使用报错注入。报错注入的方式&#xff1a; updatexml 函数注入&#xff1a; mysql5.1.5 版本以上支持该函数&#xff0c;返回数据限制32位 模板&#xff1a;select * from user where id1 and (updatexml(&q…

MySQL:开始深入其数据(三)DQL的后续

上一章学习mysql语句里的where和join,这一章我们开始分析group by ,having,order by,limit语句。 three,too,one,go! 文章目录 重温select语法having:order by:limit 重温select语法 SELECT [ALL | DISTINCT] { * | table.* | [ table.field1 [ as alias1] [, table.field2 [a…

[通用] iPad 用于 Windows 扩展屏解决方案 Moonlight + Sunshine + Easy Virtual Display

文章目录 前言推流端 Sunshine 安装设置接收端安装 Moonlight安装虚拟屏幕软件 Easy Virtual Display 前言 上期博客讲了如何利用原生的 NVIDIA’s GameStream 传输协议实现 iPad 当作 Windows 副屏&#xff0c;对于非N卡用户&#xff0c;有一个软件 Sunshine 可以代替 Nvidia…