深度学习--图像分割UNet介绍及代码分析

news2025/1/11 12:46:43

UNet介绍

  • 参考
  • UNet网络介绍
    • 整体架构
    • UNet过程
      • 输入
      • 编码器(下采样)
      • 中间特征表示
      • 解码器(上采样)
      • 输出
  • 代码详解
    • unetUP和Unet关系
    • 上采样模块——unetUp
    • 用于图像分割的卷积神经网络(CNN)架构模块——Unet
      • 类的定义
      • 初始化方法
      • 上采样模块
      • 额外的上采样卷积层(仅用于ResNet50)
      • 最终卷积层
      • 前向传播方法
      • 冻结和解冻骨干网络
  • 完整代码

参考

U-Net: Convolutional Networks for Biomedical Image Segmentation 输入文章名自行查询
参考博客

UNet网络介绍

整体架构

在这里插入图片描述

U-net 架构(以最低分辨率为 32x32 像素为例),每个蓝色框对应一个多通道特征图。通道数显示在框的顶部。x-y 大小位于框的左下边缘。白框表示复制的特征图。箭头表示不同的操作。
深蓝色箭头:利用3×3的卷积核对图片进行卷积后,通过ReLU激活函数输出特征通道;
灰色箭头:对左边下采样过程中的图片进行裁剪复制;
红色箭头:通过最大池化对图片进行下采样,池化核大小为2×2;
绿色箭头:反卷积,对图像进行上采样,卷积核大小为2×2;
青色箭头:使用1×1的卷积核对图片进行卷积。
因为网络形状像U,故被称为U-net
参考博客(网络结构很清晰,推荐)

UNet过程

输入

U-Net 的输入是一幅单通道的图像,通常大小为 572x572 像素,由于在不断valid卷积过程中,会使得图片越来越小,为了避免数据丢失,在图像输入前都需要进行镜像扩大。

编码器(下采样)

  • U-Net 的编码器部分输入图像通过卷积层进行特征提取,这些卷积层通常使用 3x3 的卷积核,逐步提取图像特征并缩小空间维度。
  • 然后,通过池化层(通常是最大池化)将图像的空间维度减小,例如从 572x572 缩小到 286x286。
  • 这个过程会重复多次,每次都会减小图像的空间维度和增加特征通道数。

中间特征表示

  • 在编码器的最后一层,我们获得了一个中间特征表示,通常是一个高维的特征张量。
  • 这个特征表示包含了图像的抽象特征,可以用于后续的分割任务。

解码器(上采样)

  • U-Net 的解码器部分将中间特征表示还原到原始的空间维度,并逐步增加分辨率。
  • 首先,通过上采样操作将特征张量的空间维度扩大,例如从 286x286 扩大到 572x572。
  • 然后,通过卷积层进行特征融合,将低级和高级特征结合起来。
  • 最后,输出通道数为 64 的卷积层将特征映射到最终的分割结果。

输出

  • U-Net 的输出是一个分割图像,大小与输入图像相同(通常为 572x572 像素),这幅分割图像被分成不同的区域,其中不同区域被分配不同的标签或类别。
  • 分割图像中的每个像素都被分类到不同的类别中,即可以准确地知道图像中的每个像素属于哪个结构或区域。这个分割图像可以用于识别生物医学图像中的不同结构,例如肿瘤、器官等。

代码详解

unetUP和Unet关系

  1. unetUp:

    • unetUp 是一个自定义的 PyTorch 模块(nn.Module),用于实现 U-Net 模型中的上采样部分。
    • 它接受两个输入特征张量 inputs1inputs2,并将它们进行上采样、特征融合和卷积操作,最终输出一个特征张量。
    • 在 U-Net 中,unetUp 负责将低分辨率的特征图上采样到与高分辨率特征图相同的尺寸,以便进行特征融合。
  2. Unet:

    • Unet 是整个 U-Net 模型的主体部分,它由多个 unetUp 模块组成。
    • 根据选择的 backbone(可以是 VGG 或 ResNet-50),Unet 使用不同的主干网络提取特征。
    • Unet 的前向传播过程包括多次特征融合,上采样和卷积操作,最终生成语义分割结果。

上采样模块——unetUp

这段代码定义了一个名为 unetUp 的类,它是一个用于UNet架构中的上采样模块。这个模块的作用是将低分辨率特征图上采样并与高分辨率特征图结合,以生成更高分辨率的输出。

# 定义unetUp类,unetUp类继承自nn.Module,是PyTorch中所有神经网络模块的基类。
class unetUp(nn.Module):
# 初始化方法
in_size 和 out_size 是输入和输出通道的数量。
self.conv1 和 self.conv2 是两个二维卷积层,卷积核大小为3,填充为1。
self.up 是一个最近邻插值的上采样层,放大倍数为2。
self.relu 是一个ReLU激活函数。

    def __init__(self, in_size, out_size):
        super(unetUp, self).__init__()
        self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)
        self.up = nn.UpsamplingNearest2d(scale_factor=2)
        self.relu = nn.ReLU(inplace=True)
# 前向传播方法
# inputs1 和 inputs2 是前向传播时的输入张量。
# torch.cat([inputs1, self.up(inputs2)], 1) 将 inputs1 和上采样后的 inputs2 在通道维度上拼接。
# outputs = self.conv1(outputs) 对拼接后的张量进行第一次卷积。
# outputs = self.relu(outputs) 对卷积结果应用ReLU激活函数。
# outputs = self.conv2(outputs) 对激活后的张量进行第二次卷积。
# outputs = self.relu(outputs) 再次应用ReLU激活函数。最终返回处理后的 outputs。
# 这段代码的主要功能是将低分辨率特征图上采样并与高分辨率特征图结合,经过两次卷积和激活函数处理后,生成更高分辨率的输出特征图。

    def forward(self, inputs1, inputs2):
        outputs = torch.cat([inputs1,self.up(inputs2)],1)
        outputs = self.conv1(outputs)
        outputs = self.relu(outputs)
        outputs = self.conv2(outputs)
        outputs = self.relu(outputs)
        return outputs

用于图像分割的卷积神经网络(CNN)架构模块——Unet

下面这段代码定义了一个名为 Unet 的类,它是一个用于图像分割的卷积神经网络(CNN)架构。这个类可以使用不同的骨干网络(backbone),如VGG16或ResNet50,并包含上采样模块以生成高分辨率的输出。以下是对代码的详细解释:

类的定义

class Unet(nn.Module):

Unet 类继承自 nn.Module,这是PyTorch中所有神经网络模块的基类。

初始化方法

def __init__(self, num_classes=21, pretrained=False, backbone='vgg'):
    super(Unet, self).__init__()
    if backbone == 'vgg':
        self.vgg = VGG16(pretrained=pretrained)
        in_filters = [192, 384, 768, 1024]
    elif backbone == "resnet50":
        self.resnet = resnet50(pretrained=pretrained)
        in_filters = [192, 512, 1024, 3072]
    else:
        raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone))
    out_filters = [64, 128, 256, 512]
  • num_classes 是输出类别的数量。
  • pretrained 指示是否使用预训练的权重。
  • backbone 指定使用的骨干网络,可以是VGG16或ResNet50。
  • 根据选择的骨干网络,初始化相应的网络并设置输入过滤器的数量。
  • in_filters(输入通道数):在卷积神经网络(CNN)中,in_filters 表示输入图像的通道数或特征图的数量。在输入层,如果是灰度图片,那就只有一个 feature map;如果是彩色图片,一般就是 3 个 feature map(对应红、绿、蓝通道)。在其他层,每个卷积核(也称为过滤器)与上一层的每个 feature map 做卷积,产生下一层的一个 feature map。因此,如果有 N 个卷积核,下一层就会产生 N 个 feature map。
  • out_filters(输出通道数):在卷积神经网络中,out_filters 表示卷积核的数量或输出的特征图数量。卷积核的个数决定了下一层的 feature map 数量。每个卷积核可以提取一种特征,并生成一个新的特征图。在多层卷积网络中,下一层的卷积核的通道数等于上一层的 feature map 数量。如果通道数不相等,就无法继续进行卷积操作。

上采样模块

# upsampling
self.up_concat4 = unetUp(in_filters[3], out_filters[3])
self.up_concat3 = unetUp(in_filters[2], out_filters[2])
self.up_concat2 = unetUp(in_filters[1], out_filters[1])
self.up_concat1 = unetUp(in_filters[0], out_filters[0])
  • 定义四个上采样模块,每个模块将低分辨率特征图上采样到更高分辨率并与高分辨率特征图结合。
  • self.up_concat4, self.up_concat3, self.up_concat2 , self.up_concat1 是上采样操作的一部分。它们分别将不同层的特征图级联在一起,以获得更丰富的特征表示

额外的上采样卷积层(仅用于ResNet50)

if backbone == 'resnet50':
    self.up_conv = nn.Sequential(
    # 使用双线性插值(nn.UpsamplingBilinear2d)将特征图的大小放大两倍
        nn.UpsamplingBilinear2d(scale_factor=2), 
        # 通过两个卷积层对特征图进行处理,以获得更好的特征表示
        nn.Conv2d(out_filters[0], out_filters[0], kernel_size=3, padding=1),
         # 使用 nn.ReLU() 激活函数来确保非线性变换
        nn.ReLU(),
        nn.Conv2d(out_filters[0], out_filters[0], kernel_size=3, padding=1),
        nn.ReLU(),
    )
else:
    self.up_conv = None
  • 如果使用ResNet50作为骨干网络,定义一个额外的上采样卷积层。

最终卷积层

# self.final 是一个卷积层,用于生成最终的输出。它将高分辨率的特征图映射到类别数(num_classes)
self.final = nn.Conv2d(out_filters[0], num_classes, 1)
  • 定义一个最终的卷积层,将输出通道数转换为类别数。

前向传播方法

定义模型的前向传播过程,将输入数据通过网络的各层进行计算,最终生成输出。

def forward(self, inputs):
# 根据 self.backbone 的值,选择不同的模型(VGG 或 ResNet-50)进行前向传播。
    	# 通过卷积层和池化层对输入数据进行处理,得到特征图 feat1、feat2、feat3、feat4 和 feat5
    if self.backbone == "vgg":
        [feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs)
    elif self.backbone == "resnet50":
        [feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs)
# 通过上采样操作将这些特征图进行级联,得到更高分辨率的特征图 up4、up3、up2 和 up1
    up4 = self.up_concat4(feat4, feat5)
    up3 = self.up_concat3(feat3, up4)
    up2 = self.up_concat2(feat2, up3)
    up1 = self.up_concat1(feat1, up2)
# 如果存在上采样卷积层 self.up_conv,则对 up1 进行进一步处理
    if self.up_conv != None:
        up1 = self.up_conv(up1)
# 通过 self.final 层获得最终的输出
    final = self.final(up1)
    
    return final
  • 根据选择的骨干网络,获取不同层的特征图。
  • 使用上采样模块逐层上采样并结合特征图。
  • 如果定义了额外的上采样卷积层,则应用该层。
  • 最终通过一个卷积层生成输出。

冻结和解冻骨干网络

def freeze_backbone(self):
    if self.backbone == "vgg":
        for param in self.vgg.parameters():
            param.requires_grad = False
    elif self.backbone == "resnet50":
        for param in self.resnet.parameters():
            param.requires_grad = False

def unfreeze_backbone(self):
    if self.backbone == "vgg":
        for param in self.vgg.parameters():
            param.requires_grad = True
    elif self.backbone == "resnet50":
        for param in self.resnet.parameters():
            param.requires_grad = True
  • freeze_backbone 方法用于冻结骨干网络的参数,使其在训练过程中不更新。
  • unfreeze_backbone 方法用于解冻骨干网络的参数,使其在训练过程中可以更新。

冻结或解冻神经网络模型的特定层–更好地进行迁移学习或微调

  1. 迁移学习

    • 在迁移学习中,使用一个预训练的神经网络模型(通常在大规模数据集上进行训练)来解决新的任务。
    • 通过冻结模型的底层层(例如卷积层),可以保留其在原始任务上学到的特征表示,然后在新任务上进行微调。
    • 这样做有助于避免在新任务上过拟合,并且可以加快训练速度。
  2. 微调

    • 微调是指在预训练模型的基础上继续训练,以适应新任务的特定数据。
    • 解冻底层层,允许其权重在新任务上进行调整,以更好地适应新数据。
    • 通常,我们只微调模型的一部分,而不是整个模型,以避免丢失预训练模型的有用特征。

完整代码

import torch
import torch.nn as nn

from nets.resnet import resnet50
from nets.vgg import VGG16

class unetUp(nn.Module):
    def __init__(self, in_size, out_size):
        super(unetUp, self).__init__()
        self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)
        self.up = nn.UpsamplingNearest2d(scale_factor=2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, inputs1, inputs2):
        outputs = torch.cat([inputs1,self.up(inputs2)],1)
        outputs = self.conv1(outputs)
        outputs = self.relu(outputs)
        outputs = self.conv2(outputs)
        outputs = self.relu(outputs)
        return outputs
class Unet(nn.Module):
    def __init__(self, num_classes = 2, pretrained = False, backbone = 'vgg'):
        super(Unet, self).__init__()
        if backbone == 'vgg':
            self.vgg = VGG16(pretrained=pretrained)
            in_filters = [192, 384, 768, 1024]
        elif backbone == 'resnet50'
            self.resnet = resnet50(pretrained=pretrained)
            in_filters = [192, 512, 1024, 3072]
        else:
            raise ValueError('Unsupported backbone -`{}`, Use vgg, resnet50.'.format(backbone))
        out_filters = [64, 128, 256, 512]
        #???
        self.up_concat4 = unetUp(in_filters[3], out_filters[3])
        self.up_concat3 = unetUp(in_filters[2], out_filters[2])
        self.up_concat2 = unetUp(in_filters[1], out_filters[1])
        self.up_concat1 = unetUp(in_filters[0], out_filters[0])

        if backbone == 'resnet50':
            self.up_conv = nn.Sequential(
                nn.UpsamplingNearest2d(scale_factor=2),
                nn.Conv2d(out_filters[0],out_filters[0], kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(out_filters[0], out_filters[0], kernel_size=3, padding=1),
                nn.ReLU(),
            )
        else:
            self.up_conv = None

        self.final = nn.Conv2d(out_filters[0], num_classes, 1)
        self.backbone = backbone

    def forward(self, inputs):
        if self.backbone == "vgg":
            [feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs)
        elif self.backbone == "resnet50":
            [feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs)

        up4 = self.up_concat4(feat4, feat5)
        up3 = self.up_concat3(feat3, up4)
        up2 = self.up_concat2(feat2, up3)
        up1 = self.up_concat1(feat1, up2)

        if self.up_conv != None:
            up1 = self.up_conv(up1)

        final = self.final(up1)
        return final

    def freeze_backbone(self):
        if self.backbone == "vgg":
            for param in self.vgg.parameters():
                param.requires_grad = False
        elif self.backbone == "resnet50":
            for param in self.resnet.parameters():
                param.requires_grad = False

    def unfreeze_backbone(self):
        if self.backbone == "vgg":
            for param in self.vgg.parameters():
                param.requires_grad = True
        elif self.backbone == "resnet50":
            for param in self.resnet.parameters():
                param.requires_grad = True
                

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1987363.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用 Manim 创建一个二维坐标平面【NumberPlane】

NumberPlane 是 Manim 中用于创建一个二维坐标平面的类。它可以帮助用户在场景中可视化坐标轴、网格线以及其他数学概念。具体来说,它的功能包括: 坐标轴:NumberPlane 提供了 x 轴和 y 轴,通常是中心对称的,允许用户清…

深入探究Python反序列化漏洞:原理剖析与实战复现

在现代应用程序开发中,Python反序列化漏洞已成为一个备受关注的安全问题。反序列化是Python中用于将字节流转换回对象的过程,但如果没有妥善处理,攻击者可以通过精心构造的恶意数据,利用反序列化漏洞执行任意代码,进而…

前端day4-表单标签

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>day4-表单</title> </head> <body&g…

# 基于MongoDB实现商品管理系统(2)

基于MongoDB实现商品管理系统&#xff08;2&#xff09; 基于 mongodb 实现商品管理系统之准备工作 1、案例需求 这里使用的不是前端页面&#xff0c;而是控制台来完成的。 具体的需求如下所示&#xff1a; 运行 查询所有 通过id查询详情 添加 - 通过id删除 2、案例分析 程…

进程创建,进程消亡

虚拟地址&#xff1a;通过虚拟技术&#xff0c;将外部存储设备的一部分空间&#xff0c;划分给系统&#xff0c;作为在内存不足时临时用作数据缓存。当内存耗尽时&#xff0c;电脑就会自动调用硬盘来充当内存&#xff0c;以缓解内存的紧张。 练习: 编写一个代码实现,一个父…

OGG转MP3音频格式转换:6种免费音频转换器推荐

在如今的数字音乐时代&#xff0c;不同音频格式的兼容性问题常常让我们感到困扰。其中&#xff0c;OGG和MP3是两种常见的音频格式&#xff0c;但由于设备和平台的支持问题&#xff0c;我们经常需要将OGG转换为MP3格式。 本文将为您详细介绍OGG和MP3的区别&#xff0c;为什么需要…

Spring Boot集成protobuf快速入门Demo

1.什么是protobuf&#xff1f; Protobuf&#xff08;Protocol Buffers&#xff09;是由 Google 开发的一种轻量级、高效的数据交换格式&#xff0c;它被用于结构化数据的序列化、反序列化和传输。相比于 XML 和 JSON 等文本格式&#xff0c;Protobuf 具有更小的数据体积、更快…

数据结构:队列(含源码)

目录 一、队列的概念和结构 二、队列的实现 头文件 初始化 入队列和出队列 获取队头队尾元素 队列有效数据数及队列判空 队列的销毁 完整源码 dl.h dl.c 一、队列的概念和结构 队列是一种只允许在一端进行插入数据操作&#xff0c;在另一端进行删除数据操作的特殊线性…

重生之我 学习【数据结构之顺序表(SeqList)】

⭐⭐⭐ 新老博友们&#xff0c;感谢各位的阅读观看 期末考试&假期调整暂时的停更了两个多月 没有写博客为大家分享优质内容 还容各位博友多多的理解 美丽的八月重生之我归来 继续为大家分享内容 你我共同加油 一起努力 ⭐⭐⭐ 数据结构将以顺序表、链表、栈区、队列、二叉树…

多米诺和托米诺平铺

有两种形状的瓷砖&#xff1a;一种是2 x 1的多米诺形&#xff0c;另一种是形如L的托米诺形。两种形状都可以旋转。 给定整数 n &#xff0c;返回可以平铺 2 x n 的面板的方法的数量。返回对 10^9 7 取模 的值。 平铺指的是每个正方形都必须有瓷砖覆盖。两个平铺不同&#xff…

maven常用命令与常见问题汇总

文章目录 一、IDEA 下载依赖包源码报错Sources not found for:xxxx二、常用命令1、打包 一、IDEA 下载依赖包源码报错Sources not found for:xxxx 解决方案&#xff1a; 方案1、在 terminal 运行 mvn dependency:resolve -Dclassifiersources 命令 方案2、右键特定的pom文件…

论文概览 |《IJGIS》2024 Vol.38 issue4

本次给大家整理的是《International Journal of Geographical Information Science》杂志2024年第38卷第4期的论文的题目和摘要&#xff0c;一共包括8篇SCI论文&#xff01; 论文1 knowledge-constrained large language model interactable with GIS: enhancing public risk …

笔试题 day1

目录 快速io 统计2的个数 两个数组的交集 点击消除 快速io import java.util.*; import java.io.*;public class Main {public static PrintWriter out new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));public static Read in new Read();publ…

瑞_Linux防火墙相关命令_Windows远程连接虚拟机的服务失败_Linux防火墙端口开放

&#x1f64a; 前言&#xff1a;博主在学习使用虚拟机的过程中&#xff0c;常常碰到 Windows 远程连接虚拟机的服务失败的问题。比如想要在主机上连接虚拟机中的 MongoDB 服务的时候&#xff0c;服务器或者虚拟机一般都会默认开启防火墙&#xff0c;则会导致远程连接失败&#…

做一个能和你互动玩耍的智能机器人之七-接入对话和大模型

接入科大迅飞的语音识别&#xff1a; private void printResult(RecognizerResult results) {String text JsonParser2.parseIatResult(results.getResultString());String sn null;// 读取json结果中的sn字段try {JSONObject resultJson new JSONObject(results.getResult…

如何忽略已经提交到 Git 仓库中的文件

文章目录 前言一、确认文件是否已经被提交二、确认 .git 文件存在三、修改 .git/info/exclude 文件四、修改文件名五、提交和推送六、验证总结 前言 在日常开发中&#xff0c;我们常常会遇到这样的情况&#xff1a;不小心将不应追踪的文件提交到了 Git 仓库中&#xff0c;例如…

LabVIEW中的Reverse String函数与字节序转换

在LabVIEW中&#xff0c;数据的字节序&#xff08;也称为端序&#xff09;问题通常出现在数据传输和存储过程中。字节序可以分为大端&#xff08;Big-Endian&#xff09;和小端&#xff08;Little-Endian&#xff09;&#xff0c;它们分别表示高字节存储在低地址和低字节存储在…

培训第二十二天(mysql数据库主从搭建)

上午 1、为mysql添加开机启动chkconfig [rootmysql1 ~]# chkconfig --list //列出系统服务在不同运行级别下的启动状态注&#xff1a;该输出结果只显示 SysV 服务&#xff0c;并不包含原生 systemd 服务。SysV 配置数据可能被原生 systemd 配置覆盖。 要列出 systemd 服务…

2024.8.2(MySQL)

一、mysql 1、下载mysql软件包 [rootmysql ~]# yum -y install wget [rootmysql ~]# wget https://downloads.mysql.com/archives/get/p/23/file/mysql-8.0.33-1.el7.x86_64.rpm-bundle.tar 2、解压 [rootmysql ~]# tar -xf mysql-8.0.33-1.el7.x86_64.rpm-bundle.tar 3、安装…

完美解决浏览器的输入框自动填入时,黄色背景问题,以及图标被遮住问题(最新)

用图说话↓↓↓ 首先用代码解决黄色背景问题&#xff0c;box-shadow颜色设置透明即可 :deep(input:-webkit-autofill) {box-shadow: 0 0 0 1000px transparent !important;/* 浏览器记住密码的底色的颜色 */-webkit-text-fill-color: #fff !important;/* 浏览器记住密码的字的…