图像分割——U-Net论文介绍+代码(PyTorch)

news2024/11/23 0:48:30

0、概要

原理大致介绍了一下,后续会不断精进改的更加详细,然后就是代码可以对自己的数据集进行一个训练,还会不断完善,相应其他代码可以私信我。

一、论文内容总结

摘要:人们普遍认为,深度网络成功需要数千样本,在本文中,提出一种网络和训练方法,它使用大量数据增强来有效使用现存的样本,我们的体系结构由一个捕获上下文的收缩路径和能够实现精确定位的对称扩展路径组成。我们证明出这个网络可以使用少量图像进行端到端训练,并且在ISBI挑战赛上优先于先前的最佳方法(滑动窗口卷积)。并且我们的网络速度很快。

1介绍

        目前卷积神经网络的具体用途是用在分类任务上,其中对图像的输出是一个单一的类标签。然而,在许多视觉任务中,特别是在生物医学图像处理中,所期望的输出应该包括定位(每个像素都应该分配一个类标签),另外,医学图像数目不是很多。因此,ciresan等人,在一个滑动窗口设置中训练一个网络,通过在每个像素周围提供一个局部区域(补丁)来预测每个像素的类标签,这个网络可以本地化,并且在当时效果还可以,但是这个网络的也有缺陷,很慢,每个网络必须在每个补丁单独运行,而且由于重叠的补丁,会有很多多余的预测,并且补丁的大小,也决定了预测的这个像素点所结合的上下文或者说是感受野的大小,而这个补丁不能太大也不能太小。所以这就是这个网络所存在问题。

        而我们的网络,建立了一个更好的网络,所谓的全卷积网络,我们修改和扩展了这种体系结构,使它可以在很少的训练图像下,产生更精确的分割,网络结构如下图所示。

        主要思想如下(1)编码器-解码器架构(Encoder-Decoder Structure):U-Net采用了经典的编码器-解码器设计。编码器部分通过一系列的卷积和池化操作对输入图像进行下采样,目的是提取出越来越抽象的特征表示。解码器部分则通过上采样操作(例如转置卷积)逐步将这些特征映射回原始输入的空间维度,以便进行像素级别的预测。

(2)跳跃连接:它允许将编码器路径中的特征图与相应解码器层的特征进行合并。具体来说,在每个上采样步骤之后,会将对应编码器层的输出与解码器的输出拼接在一起。这样做的目的是保留局部的精细结构信息,有助于恢复分割结果中的细节,因为编码器的早期层包含更多空间信息但语义信息较少。

(3)对称性:U-Net的结构在视觉上呈现为“U”形,体现了其编码器和解码器的对称性。这种对称不仅体现在网络结构上,也反映在处理图像信息的方式中,从特征提取到细节恢复的完整流程。

(4)端到端学习与像素级预测:U-Net能够直接在每个像素上进行类别预测,实现了端到端的学习,这对于图像分割任务尤为重要。网络的输出与输入图像大小相同,每个像素都有一个类别标签,适用于精确的图像分割任务。

(5)轻量级和高效性

2、网络结构

从上图能很清晰的清楚结构,结果十分简单。

3、训练

利用输入图像及其相应的分割图,利用随机梯度下降来训练网络,由于当时还未有填充的卷积,因此输出图像比输入图像小了一个恒定的边界宽度。后边的一些解释大家可以代码过程,这里介绍起来不是很清楚。

4、数据增强

当只有少量的训练样本可用的时候,数据增强对于教会网络所需的不变性和鲁棒性是非常重要的,对于显微镜图像,我们主要需要位移和旋转不变性,以及对变形和灰度值变化的鲁棒性,而训练样本的随机弹性变形时训练一个很少标注图像的分割网络的观念概念。因此我们采用了相应的方法进行了数据增强。

二、代码结构+解释

一、工程文件中有一个文件夹叫model,里面含有两个文件夹,一个是unet_model,另一个是unet_parts,这两个用来定义模型结构。

(1)unet_parts.py  主要包含常用的一些块

""" Parts of the U-Net model """
"""https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# 导入torch相关库

class DoubleConv(nn.Module): # 继承pytorch中的nn.Moudle类,该类用于构建神经网络中的双卷积块,利用两次连续的卷积操作增强特征表示能力,
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):  # 初始化参数,设置输入特征图参数和输出而整体参数
        super().__init__() # 调用父类的初始化方法,继承父类必要步骤
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),  # padding=1来保持输出尺寸和输入相同
            nn.BatchNorm2d(out_channels),  # 批量归一化层(BN),加速训练过程,提高模型的稳定性和泛化能力。这里针对的是 out_channels 个通道。
            nn.ReLU(inplace=True), # 应用ReLU激活函数,非线性地增加网络的表达能力
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)  # 返回处理后的输出章

#  旨在通过连续的卷积和非线性提取更高级别的特征表示
class Down(nn.Module):  # 下采样模块
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),  # 最大池化层
            DoubleConv(in_channels, out_channels)  # 卷积或者是下采样
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):  # 上采样模块
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):  # 初始化参数,选择上采样方式(双线性插值或转置卷积)、定义内部组件。这里选用的是双线性插值来进行上采样
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)# 缩放因子为2,模式是对齐角落的选项
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels) # 用于进一步处理上采样后的特征

    def forward(self, x1, x2):  # 上采样、尺寸调整以及特征融合
        x1 = self.up(x1)  # 上采样特征图
        # input is CHW
        # 计算x1和x2在高度和宽度上的插值
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])    # 对x1进行填充


        x = torch.cat([x2, x1], dim=1) # 沿着维度进行拼接,实现特征融合
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # 1*1卷积核降维

    def forward(self, x):
        return self.conv(x)

(2)unet的网络结构,这里相对于原版的有一些更改的地方,代码也很简单

""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""
import torch.nn as nn
import torch.nn.functional as F
from unet_parts import *
# 导入相关库
# 定义了一个完整的U-Net模型
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

if __name__ == '__main__':
    net = UNet(n_channels=3, n_classes=1)
    print(net)

二、数据集设定以及图像增强代码

第一个文件夹主要是对数据集进行处理的一个脚本,第二个就是数据集的一个样式或者说规则,在训练过程中,主要相关的代码就是utils中的dataset.py这个脚本,主要作用是根据data路径,然后对数据集进行预处理,翻转这些操作,代码如下所示

import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import random

class ISBI_Loader(Dataset):
    def __init__(self, data_path):
        # 初始化函数,读取所有data_path下的图片
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, 'Training_Images/*.jpg')) # 查找指定路径下的所有JPEG图片文件
        # 表示在data_path路径下的Training_Images文件夹中寻找扩展名为.jpg的所有文件,glob.glob函数会遍历这个路径并且返回一个包含所有匹配文件路径的列表

    def augment(self, image, flipCode):
        # 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转
        flip = cv2.flip(image, flipCode)
        return flip
        
    def __getitem__(self, index):
        # 根据index读取图片
        image_path = self.imgs_path[index]
        # 根据image_path生成label_path
        label_path = image_path.replace('Training_Images', 'Training_Labels')
        label_path = label_path.replace('.jpg', '.png') # todo 更新标签文件的逻辑
        # 生成对应标签图像的路径

        # 读取训练图片和标签图片
        # print(image_path)
        # print(label_path)
        image = cv2.imread(image_path)  # 读进来后就是numpy数组了
        label = cv2.imread(label_path)
        image = cv2.resize(image, (512, 512))
        label = cv2.resize(label, (512, 512), interpolation=cv2.INTER_NEAREST)
        # 对于label的图像处理时候,明确采用最近邻插值方法来处理尺寸变化,确保标签图像在缩放过程中类别标签不发生模糊,保持其原有的清晰界限。
        # 将数据转为单通道的图片
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)  # BGR转成二值图
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)

        # 处理标签,将像素值为255的改为1
        if label.max() > 1:
            label = label / 255
        # 随机进行数据增强,为2时不做处理,即
        flipCode = random.choice([-1, 0, 1, 2])
        if flipCode != 2:
            image = self.augment(image, flipCode)
            label = self.augment(label, flipCode)
        image = image.reshape(1, image.shape[0], image.shape[1])
        label = label.reshape(1, label.shape[0], label.shape[1])
        return image, label

    def __len__(self):
        # 返回训练集大小
        return len(self.imgs_path)


if __name__ == "__main__":
    isbi_dataset = ISBI_Loader("data/train/")
    print("数据个数:", len(isbi_dataset))
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=2,
                                               shuffle=True)
    for image, label in train_loader:
        print(image.shape)

三、训练代码

这部分就是训练的一整个过程。大

from model.unet_model import UNet
from utils.dataset import ISBI_Loader
from torch import optim
import torch.nn as nn
import torch
from tqdm import tqdm


def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
    # 加载训练集
    isbi_dataset = ISBI_Loader(data_path)
    per_epoch_num = len(isbi_dataset) / batch_size
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    # 定义RMSprop算法
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # 定义Loss算法
    criterion = nn.BCEWithLogitsLoss()
    # best_loss统计,初始化为正无穷
    best_loss = float('inf')
    # 训练epochs次
    with tqdm(total=epochs*per_epoch_num) as pbar:
        for epoch in range(epochs):
            # 训练模式
            net.train()
            # 按照batch_size开始训练
            for image, label in train_loader:
                optimizer.zero_grad()
                # 将数据拷贝到device中
                image = image.to(device=device, dtype=torch.float32)
                label = label.to(device=device, dtype=torch.float32)
                # 使用网络参数,输出预测结果
                pred = net(image)
                # 计算loss
                loss = criterion(pred, label)
                # print('{}/{}:Loss/train'.format(epoch + 1, epochs), loss.item())
                # 保存loss值最小的网络参数
                if loss < best_loss:
                    best_loss = loss
                    torch.save(net.state_dict(), 'best_model.pth')
                # 更新参数
                loss.backward()
                optimizer.step()
                pbar.update(1)


if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道1,分类为1。
    net = UNet(n_channels=1, n_classes=1)  # todo edit input_channels n_classes
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 指定训练集地址,开始训练
    data_path = r"D:\新建文件夹 (3)\VOCdevkit3000\VOCdevkit\VOC2007\data" # todo 修改为你本地的数据集位置
    train_net(net, device, data_path, epochs=50, batch_size=4)

四、总结

大致可能有些粗劣的介绍了U-Net的相关原理,以及代码,给出的代码可以训练,如果有需要完整工程文件的可以私信我。有错误的地方希望批评指正,感谢感谢

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1831288.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HarmonyOS开发日记 :自定义节点,实现 UI 组件 动态创建、更新

引言 UI动态操作包含组件的动态创建、卸载、更新等相关操作。 通过组件预创建&#xff0c;可以满足开发者在非build生命周期中进行组件创建&#xff0c;创建后的组件可以进行属性设置、布局计算等操作。之后在页面加载时进行使用&#xff0c;可以极大提升页面响应速度。 UI …

中国500米分辨率年最大LAI数据集(2000-2020)

叶面积指数是生态系统的一个重要结构参数&#xff0c;用来反映植物叶面数量、冠层结构变化、植物群落生命活力及其环境效应&#xff0c;为植物冠层表面物质和能量交换的描述提供结构化的定量信息&#xff0c;并在生态系统碳积累、植被生产力和土壤、植物、大气间相互作用的能量…

[Qt的学习日常]--常用控件2

前言 作者&#xff1a;小蜗牛向前冲 名言&#xff1a;我可以接受失败&#xff0c;但我不能接受放弃 如果觉的博主的文章还不错的话&#xff0c;还请点赞&#xff0c;收藏&#xff0c;关注&#x1f440;支持博主。如果发现有问题的地方欢迎❀大家在评论区指正 目录 一、widget的…

如何使用视频文案提取帮手将手机上视频里的声音转成文字?

在自媒体短视频日益增加的时候不少自媒体创作者如何将视频转文字的需求日益增加。本次将给大家分享一款针对广大职场青年用户群体的视频转文字工具&#xff0c;旨在为用户提供高效、准确的视频转文字服务。 如何将手机上的视频转成文字呢 视频转文字工具具有转换速度快&#…

11.无代码爬虫八爪鱼采集器抓取网站信息的实操案例——选择目标网站、提取标题、发布时间、评论内容、作者昵称、点赞数量等字段

首先&#xff0c;多数情况下免费版本的功能&#xff0c;已经可以满足绝大多数采集需求&#xff0c;想了解八爪鱼采集器版本区别的详情&#xff0c;请访问这篇帖子&#xff1a; https://blog.csdn.net/cctv1123/article/details/139581468 八爪鱼采集器免费版和个人版、团队版下…

【SpringBoot】SpringBoot:打造现代化微服务架构

文章目录 引言微服务架构概述什么是微服务架构微服务的优势 使用SpringBoot构建微服务创建SpringBoot微服务项目示例&#xff1a;创建订单服务 配置数据库创建实体类和Repository创建服务层和控制器 微服务间通信使用RestTemplate进行同步通信示例&#xff1a;调用用户服务 使用…

hadoop/hive/DBeaver启动流程

hadoop 启动 cd到指定目录下 cd /opt/module/hadoop-3.3.0/sbin/启动文件 ./start-all.shjps一下&#xff0c;查看显示的内容 应该显示以下内容 NameNode SecondaryNameNode DataNode ResourceManager NodeManager如果缺少namenode&#xff0c;那么执行 rm -rf /tmp/hadoo…

数据可视化实验二:回归分析、判别分析与聚类分析

目录 一、使用回归分析方法分析某病毒是否与温度呈线性关系 1.1 代码实现 1.2 线性回归结果 1.3 相关系数验证 二、使用判别分析方法预测某病毒在一定的温度下是否可以存活&#xff0c;分别使用三种判别方法&#xff0c;包括Fish判别、贝叶斯判别、LDA 2.1 数据集展示&am…

超越中心化:Web3如何塑造未来数字生态

随着技术的不断发展&#xff0c;人们对于网络和数字生态的期望也在不断提升。传统的中心化互联网模式虽然带来了便利&#xff0c;但也暴露出了诸多问题&#xff0c;比如数据滥用、信息泄露、权力集中等。在这样的背景下&#xff0c;Web3技术应运而生&#xff0c;旨在打破传统中…

帕金森运动小贴士,壁纸里的健康密码

&#x1f31f; 在这个快节奏的时代&#xff0c;我们越来越关注身体的健康。今天&#xff0c;我想和大家分享一份特别的小贴士&#xff0c;它藏在一张精致的小红书壁纸里&#xff0c;是关于帕金森病的运动建议。帕金森病是一种常见的神经系统疾病&#xff0c;适当的运动对于缓解…

Excel 常用技巧(六)

Microsoft Excel 是微软为 Windows、macOS、Android 和 iOS 开发的电子表格软件&#xff0c;可以用来制作电子表格、完成许多复杂的数据运算&#xff0c;进行数据的分析和预测&#xff0c;并且具有强大的制作图表的功能。由于 Excel 具有十分友好的人机界面和强大的计算功能&am…

Oracle--服务器结构详解

一、Oracle服务器主要组成 实例&#xff08;系统全局区SGA、后台进程&#xff09;数据库程序全局区&#xff08;PGA&#xff09;前台进程 二、系统全局区SGA 1.高速数据缓冲区 用来存放Oracle系统最近访问过的数据块&#xff0c;经常或者最近被访问的数据块会被放置到高速数据…

【Win】识别Hyper-V虚拟机第一代与第二代及其差异

Hyper-V作为微软强大的虚拟化平台&#xff0c;允许用户创建虚拟机并安装各种操作系统。但您是否知道Hyper-V虚拟机分为第一代和第二代&#xff0c;并且它们之间存在一些关键差异&#xff1f;本文将指导您如何识别您的虚拟机属于哪一代&#xff0c;并详细解释两者之间的主要区别…

C#结合JS 修改解决 KindEditor 弹出层问题

目录 问题现象 原因分析 范例运行环境 解决问题 修改 kindeditor.js C# 服务端更新 小结 问题现象 KindEditor 是一款出色的富文本HTML在线编辑器&#xff0c;关于编辑器的详细介绍可参考我的文章《C# 将 TextBox 绑定为 KindEditor 富文本》&#xff0c;这里我们讲述在…

cad怎么转成pdf文件?方法很简单!

cad怎么转成pdf文件&#xff1f;在数字化时代&#xff0c;CAD图纸的转换与共享已成为日常工作中的常态。无论是建筑设计师、工程师还是学生&#xff0c;都可能遇到需要将CAD文件转换为PDF格式的需求。本文将为您推荐三款高效的CAD转PDF软件&#xff0c;让您轻松实现文件格式的转…

GPRS抄表技术是什么?

1.GPRS抄表技术概述 GPRS(GeneralPacketRadioService)抄表是一种基于移动通信网络的远程抄表技术&#xff0c;它利用GPRS网络进行数据传输&#xff0c;实现了对水、电、气等公用事业表计的实时、远程读取。这项技术的出现&#xff0c;极大地提升了公用事业管理的效率和准确性&…

apollo配置中心入门实践

说明&#xff1a; &#xff08;如果微服务开发没有严格统一的代码开发规范&#xff0c;不建议采用apollo&#xff0c;否则只会更浪费时间在一堆配置上&#xff09; 通常情况下&#xff0c;我们无论是但模块开发&#xff0c;还是微服务多模块开发&#xff0c;都采用springboot…

想要做好短视频?这5大关键点你知道吗?沈阳短视频剪辑培训

在新媒体运营中&#xff0c;短视频已成为抓住观众注意力的重要工具。制作成功的短视频需要细心规划和精确执行。今天小编就围绕做好短视频的五大关键点&#xff0c;为大家进行详细解析&#xff0c;帮助您提升视频的吸引力和效果。 做好短视频的5大关键点 01内容策划&#xff1…

docker通过容器id查看运行命令;Portainer监控管理docker容器

1、docker通过容器id查看运行命令 参考&#xff1a;https://blog.csdn.net/a772304419/article/details/138732138 docker inspect 运行镜像id“Cmd”: [ “–model”, “/qwen-7b”, “–port”, “10860”, “–max-model-len”, “4096”, “–trust-remote-code”, “–t…

【CMU 15-445】Proj3 Query Execution

Query Execution 通关记录Task1 Access Method ExecutorsSeqScanInsertUpdateDeleteIndexScanOptimizing SeqScan to IndexScan Task2 Aggregation & Join ExecutorsAggregationNextedLoopJoin Task3 HashJoin Executor and OptimizationHashJoinOptimizing NestedLoopJoin…