pytorch零基础实现语义分割项目(三)——语义分割模型(U-net和deeplavb3+)

news2024/9/27 23:23:47

文章目录

  • 项目列表
  • 前言
  • U-net
    • 模型概况
      • 下采样过程
      • 上采样过程
    • 模型代码
      • 上采样代码
      • U-net模型构建
  • deeplabv3+
    • 模型概况
    • 模型代码
      • resNet
      • ASPP
      • deeplabv3+模型构建
  • 结尾

项目列表

语义分割项目(一)——数据概况及预处理

语义分割项目(二)——标签转换与数据加载

语义分割项目(三)——语义分割模型(U-net和deeplavb3+)


前言

在前两篇中我们完成了针对数据的处理和加载,本篇文章中我们将介绍两个常见的语义分割模型U-Net和deeplabv3+来完成航拍语义分割项目

U-net

模型概况

首先来看U-net的模型结构
在这里插入图片描述

下采样过程

模型首先是进行多次下采样,每次下采样后的结果保留一份以供后边合并

上采样过程

进行多次下采样后,对最后一层下采样的结果进行上采样并与上一层下采样的结果通道合并,再对通道合并后的结果进行通道调整,再次进行上采样和合并操作直到图像宽高与原图像宽高相同为止,最后再把通道数调整到与分类的类别数相同即可

模型代码

在本文中我们使用vgg16来作为backbone,vgg16前向传递的过程就是下采样的过程,正好vgg16和图中相符,也是进行了5次下采样,所以我们只需对下采样过程进行编写即可。

上采样代码

向上采用的过程我们采用的是转置卷积(有些地方也翻译为反卷积),输入值是待上采样的值(设为A)和他上一层下采样得到的值(设为B),操作过程就是先对A进行上采样,然后进行线性插值使得A的高宽变为和B高宽相同,然后合并他们的通道,随后进行两次卷积操作并传入激活函数获得上采样的结果

class up_sample(nn.Module):
    def __init__(self, channel1, channel2):
        super(up_sample, self).__init__()
        self.up = nn.ConvTranspose2d(channel1, channel1, kernel_size=2, stride=2, bias=False)
        self.conv1  = nn.Conv2d(channel1 + channel2, channel2, kernel_size = 3, padding = 1)
        self.conv2  = nn.Conv2d(channel2, channel2, kernel_size = 3, padding = 1)
        self.relu   = nn.ReLU(inplace = True)
        
    def forward(self, input1, input2):
        if input1.shape[-2:] != input2.shape[-2:]:
            input1 = self.up(input1)
        
        image_size = input2.shape[-2:]
        input1 = F.interpolate(input1, image_size)
        outputs = torch.cat([input1, input2], dim=1)
        outputs = self.conv1(outputs)
        outputs = self.conv2(outputs)
        outputs = self.relu(outputs)
        return  outputs

U-net模型构建

由于我们是初试语义分割,所以我们希望能够迅速的获得实验结果,所以我们使用torch中预训练好的vgg16作为backbone,如果对于自己手动搭建更为感兴趣,可以参考下面这篇博客尝试——常用线性CNN模型的结构与训练

如果上文的过程能够理解,下面的代码就不难理解了,我们截取vgg16的五个下采样过程,然后进行上采样,最后调整通道数即可

class U_net(nn.Module):
    def __init__(self, backbone='VGG16', channels=[64, 128, 256, 512, 512], out_channel=6):
        super(U_net, self).__init__()
        if backbone == 'VGG16':
            vgg16 = models.vgg16(pretrained=True)
            backbone = list(vgg16.children())[0]
            self.b1 = nn.Sequential(*list(backbone.children())[:5])
            self.b2 = nn.Sequential(*list(backbone.children())[5:10])
            self.b3 = nn.Sequential(*list(backbone.children())[10:17])
            self.b4 = nn.Sequential(*list(backbone.children())[17:24])
            self.b5 = nn.Sequential(*list(backbone.children())[24:])
        
       
        self.up_sample2 = up_sample(channels[1], channels[0])
        self.up_sample3 = up_sample(channels[2], channels[1])
        self.up_sample4 = up_sample(channels[3], channels[2])
        self.up_sample5 = up_sample(channels[4], channels[3])
        
        self.uplayer = nn.ConvTranspose2d(channels[0], out_channel, kernel_size=2, stride=2, bias=False)
        
    def forward(self, X):
        X1 = self.b1(X)
        X2 = self.b2(X1)
        X3 = self.b3(X2)
        X4 = self.b4(X3)
        X5 = self.b5(X4)
        
        output = self.up_sample5(X5, X4)
        output = self.up_sample4(output, X3)
        output = self.up_sample3(output, X2)
        output = self.up_sample2(output, X1)
        output = self.uplayer(output)
        return output

deeplabv3+

模型概况

其中DCNN指的是深度卷积神经网络,本文中我们将采用resnet18,deeplabv3+的主要思路是从深度卷积神经网络中间的输出结果中先取一个结果经过卷积后放到解码器中(其实这里的解码器只是一种说法,实际就是取出备用,假设经过操作后的结果为A),然后从DCNN末尾的输出经过5种卷积组合分别得到5个值,并将这些值通道合并,然后对于合并后的结果进行通道数调整,调整后的结果再进行上采样使得结果的高和宽与A的高和宽相同,然后再将他们通道数合并,随后经过卷积和上采样得到与原图高宽相同的结果
在这里插入图片描述

模型代码

resNet

关于resNet的构建请参照常用结构化CNN模型构建在这里我们不再赘述,只放出代码

class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)

        else:
            self.conv3 = None

        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)

        Y += X
        return F.relu(Y)


def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))

    return blk


def resNet18(in_channels):
    b1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=3),
                       nn.BatchNorm2d(64), nn.ReLU(),
                       nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                       )

    b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))

    b3 = nn.Sequential(*resnet_block(64, 128, 2))

    b4 = nn.Sequential(*resnet_block(128, 256, 2))

    b5 = nn.Sequential(*resnet_block(256, 512, 2))

    net = nn.Sequential(b1, b2, b3, b4, b5)

    return net

ASPP

五种不同卷积组合的过程被称为ASPP,ASPP代码如下

class ASPP(nn.Module):
    def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
        super(ASPP, self).__init__()
        self.branch1 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch2 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch3 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch4 = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        self.branch5 = nn.Sequential(
#                 nn.AdaptiveAvgPool2d((1, 1)),  # (b, c, r, c)->(b, c, 1, 1)
                nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=True), # (b, c_out, 1, 1)
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True)
        )

        self.conv_cat = nn.Sequential(
                nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=True),
                nn.BatchNorm2d(dim_out, momentum=bn_mom),
                nn.ReLU(inplace=True),
        )
        

    def forward(self, x):
        [b, c, row, col] = x.size()
        
        conv1x1 = self.branch1(x)
        conv3x3_1 = self.branch2(x)
        conv3x3_2 = self.branch3(x)
        conv3x3_3 = self.branch4(x)
        
        global_feature = torch.mean(x, 2, True)
        global_feature = torch.mean(global_feature, 3, True)
        global_feature = self.branch5(global_feature)
        global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
        feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
        result = self.conv_cat(feature_cat)
        return result

deeplabv3+模型构建

class deeplabv3(nn.Module):
    def __init__(self, in_channels, num_classes, bonenet='resNet18'):
        super(deeplabv3, self).__init__()
        if bonenet == 'resNet18':
            self.bonenet = 'resNet18'
            self.layers = resNet18(3)
            low_level_channels = 64
            high_level_channels = 512
            low_out_channels = 64
            high_out_channels = 256
            
        self.shortcut_conv = nn.Sequential(
            nn.Conv2d(low_level_channels, low_out_channels, 1),
            nn.BatchNorm2d(low_out_channels),
            nn.ReLU(inplace=True)
        )
        
        self.aspp = ASPP(high_level_channels, high_out_channels)
        self.cat_conv = nn.Sequential(
            nn.Conv2d(256+64, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Dropout(0.1),
        )
        self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)
        if self.bonenet == 'resNet18':
            self.up_sample = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=2, padding=2, stride=2, bias=False)
        
        
    def forward(self, X):
        img_size = X.shape[-2:]
        if self.bonenet == 'resNet18':
            X = self.layers[:2](X)
            
        short_cut = self.shortcut_conv(X)
        if self.bonenet == 'resNet18':
            X = self.layers[2:](X)
            
        aspp = self.aspp(X)
        aspp = F.interpolate(aspp, size=(short_cut.shape[-2], short_cut.shape[-1]), mode='bilinear', align_corners=True)
        
        concat = self.cat_conv(torch.cat([aspp, short_cut], dim=1))
        ans = self.cls_conv(concat)
        ans = self.up_sample(ans)
#         ans = F.interpolate(ans, size=(img_size[0], img_size[1]), mode='bilinear', align_corners=True)
        return ans

结尾

本篇中我们完成了对于U-net和deeplabv3+的构建,在下一篇中我们将构建用于训练的损失函数与训练过程,完成使用航拍语义分割

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/354240.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

简单的组合拳

前言:在最近的wxb举行hw中,同事让我帮他看看一些后台登录站点。尝试了未授权,弱口令皆无果,要么不存在弱口令,要么有验证码,没办法绕过。本文章仅提供一个思路,在hw中更多时候并不推荐尝试这种思…

如何正确使用 钳位二极管

在电路设计中,经常遇到需要IO保护的场景,比如ADC采样,GPIO接收电平信号等。 常见的保护方法有分压,限幅,限流等。本次我们讨论限幅方法中的 钳位二极管。 我们以BAT54S为例,它的符号是这样的, 而在很多手册里,我们可以看到,一般是这样使用的: 因此,我设计了简化…

第五章.与学习相关技巧—正则化,超参数

第五章.与学习相关技巧 5.4 正则化&超参数 在机器学习中,过拟合是一个很常见的问题。过拟合指的是只能拟合训练数据,但不能很好的拟合不包含在训练数据中的其他数据状态。 1.发生过拟合的原因 模型拥有大量参数,表现力强。训练数据少。…

使用梯度下降的线性回归(Matlab代码实现)

目录 💥1 概述 📚2 运行结果 🎉3 参考文献 👨‍💻4 Matlab代码 💥1 概述 梯度下降法,是一种基于搜索的最优化方法,最用是最小化一个损失函数。梯度下降是迭代法的一种,可以用于求…

【办公类-16-05-04】“2022下学期 大班运动场地分配表-跳过节日循环排序”(python 排班表系列)

样例展示:跳过节日的运动场地循环排序表(8个班级8组内容 下学期一共20周)背景需求:上学期做过一次大班运动场地安排,跳过节日。2023.2下学期运动场地排班(跳过节日)又来了。一、场地器械微调二、…

哪里可以找到免费的 PDF 阅读编辑器?7 个免费 PDF 阅读编辑器分享

如果您曾经需要编辑 PDF,您可能会发现很难找到免费的 PDF 编辑器。幸运的是,您可以使用在线资源来编辑该文档,而无需为软件付费。 在本文中,我将介绍七种不同的 PDF 编辑器,它们至少可以让您免费编辑几个文件。我通过…

目标检测笔记(八):自适应缩放技术Letterbox完整代码和结果展示

文章目录自适应缩放技术Letterbox介绍自适应缩放技术Letterbox流程自适应缩放Letterbox代码运行结果自适应缩放技术Letterbox介绍 由于数据集中存在多种不同和长宽比的样本图,传统的图片缩放方法按照固定尺寸来进行缩放会造成图片扭曲变形的问题。自适应缩放技术通…

Qt COM组件导出源文件

文章目录摘要dumpcpp.exe注册COM组件COM 组件转CPP参考关键字: Qt、 COM、 组件、 源文件、 dumpcpp摘要 由于厂家提供的库不是纯净C库,是基于COM组件开的库,在和厂家友好交流无果下,只能研究下Qt 如何调用,好在Qt 的…

rt-thread pwm 多通道

一通道pwm参考 https://blog.csdn.net/yangshengwei230612/article/details/128738351?spm1001.2014.3001.5501 以下主要是多通道与一通道的区别 芯片 stm32f407rgt6 1、配置PWM设备驱动相关宏定义 添加PWM宏定义 #define BSP_USING_PWM8 #define BSP_USING_PWM8_CH1 #d…

分析 vant4 源码,学会用 vue3 + ts 开发毫秒级渲染的倒计时组件,真是妙啊

2022年11月23日首发于掘金,现在同步到公众号。11. 前言大家好,我是若川。推荐点右上方蓝字若川视野把我的公众号设为星标。我倾力持续组织了一年多源码共读,感兴趣的可以加我微信 lxchuan12 参与。另外,想学源码,极力推…

浙江工商大学2023年硕士研究生 入学考试初试成绩查询通知及说明

根据往年的情况,2023浙江工商大学MBA考试初试成绩可能将于2月21日下午两点公布,为了广大考生可以及时查询到自己的分数,杭州达立易考教育为大家汇总了信息。一、成绩查询考生可以登录中国研究生招生信息网(http://yz.chsi.com.cn/…

MySQL - 介绍

前言 本篇介绍认识MySQL,重装mysql操作 如有错误,请在评论区指正,让我们一起交流,共同进步! 本文开始 1.什么是数据库? 数据库: 一种通过SQL语言操作管理数据的软件; 重装数据库的卸载数据库步骤 : ① 停止MySQL服…

分享96个HTML体育竞技模板,总有一款适合您

分享96个HTML体育竞技模板,总有一款适合您 96个HTML体育竞技模板下载链接:https://pan.baidu.com/s/1k2vJUlbd2Boduuqqa0EWMA?pwdj8ji 提取码:j8ji Python采集代码下载链接:采集代码.zip - 蓝奏云 北京奥运火炬PSD模板 奥运…

CCNP350-401学习笔记(101-150题)

101、Refer to the exhibit SwitchC connects HR and Sales to the Core switch However, business needs require that no traffic from the Finance VLAN traverse this switch. Which command meets this requirement? A. SwitchC(config)#vtp pruning B. SwitchC(config)#…

信息时代企业的核心特征-读《硅谷之谜》

引言 几年前读完《浪潮之巅》上下部之后买的书,后来一直搁置没读,直到最近,每天晚上读一点,才把读完,虽然它说自己是《浪潮之巅》的续集,但是内容其实和《浪潮之巅》关系不大,直接读也没有什么问…

再学C语言38:指针操作

C提供了6种基本的指针操作 示例代码&#xff1a; #include <stdio.h>int main(void) {int arr[5] {1, 2, 3, 4, 5};int * p1, *p2, *p3;p1 arr; // 把一个地址赋给指针p2 &arr[2]; // 把一个地址赋给指针printf("指针指向的地址&#xff0c;指针指向地址中…

Yaklang websocket劫持教程

背景 随着Web应用的发展与动态网页的普及&#xff0c;越来越多的场景需要数据动态刷新功能。在早期时&#xff0c;我们通常使用轮询的方式(即客户端每隔一段时间询问一次服务器)来实现&#xff0c;但是这种实现方式缺点很明显: 大量请求实际上是无效的&#xff0c;这导致了大量…

matlab离散系统仿真分析——电机

目录 1.电机模型 2.数字PID控制 3.MATLAB数字仿真分析 3.1matlab程序 3.2 仿真结果 4. SIMULINK仿真分析 4.1simulink模型 4.2仿真结果 1.电机模型 即&#xff1a; 其中&#xff1a;J 0.0067&#xff1b;B 0.10 2.数字PID控制 首先我们来看一下连续PID&#xff1…

[一键CV] Blazor 拖放上传文件转换格式并推送到浏览器下载

前言 昨天有个小伙伴发了一个老外java编写的小工具给我,功能是转换西班牙邮局快递Coreeos express的单据格式成Amazon格式,他的需求是改一下程序为匹配转换另一个快递公司MRW格式到Amazon格式,然而我堂堂一个Blazor发烧友,怎么可能去反编译人家的java修改呢?必须直接撸一个Bl…

Docker 快速上手学习入门教程

目录 1、docker 的基础概念 2、怎样打包和运行一个应用程序&#xff1f; 3、如何对 docker 中的应用程序进行修改&#xff1f; 4、如何对创建的镜像进行共享&#xff1f; 5、如何使用 volumes 名称对容器中的数据进行存储&#xff1f;// 数据挂载 6、另一种挂载方式&…