图像分割Unet算法及其Pytorch实现

news2025/1/21 5:50:20

文章目录

    • 简介
    • 实现
    • 数据集
    • 训练
    • 预测

简介

UNet是一种用于图像分割的神经网络,由于这个算法前后两个部分在处理上比较对称,类似一个U形,如下图所示,故称之为Unet,论文链接:U-Net: Convolutional Networks for Biomedical Image Segmentation,全文仅8页。

在这里插入图片描述

从此图可以看出,左边的基础操作是两次 3 × 3 3\times3 3×3卷积后池化,连续4次,图像从 572 × 572 572\times572 572×572变成 32 × 32 32\times32 32×32。右侧则调转过来,以两次 3 × 3 3\times3 3×3卷积核一个 2 × 2 2\times2 2×2上采样卷积作为一组,再来四次,最后恢复成 388 × 388 388\times388 388×388的图像。

实现

整理一下上图,其计算顺序依次是

  1. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2池化
  2. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2池化
  3. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2池化
  4. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2池化
  5. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2上采样,拼接4的结果
  6. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2上采样,拼接3的结果
  7. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2上采样,拼接2的结果
  8. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2上采样,拼接1的结果
  9. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 1 × 1 1\times1 1×1卷积

由于两次 3 × 3 3\times3 3×3卷积累计出现多次,故而先将其封装成类,便于后续调用

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, inSize, outSize):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inSize, outSize, kernel_size=3, padding=1),
            nn.BatchNorm2d(outSize),
            nn.ReLU(inplace=True),
            nn.Conv2d(outSize, outSize, kernel_size=3, padding=1),
            nn.BatchNorm2d(outSize),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

然后分别实现其降采样、上采样以及最终的输出过程,其中降采样没什么好说的,就是两次卷积一次池化,最终输出的 1 × 1 1\times1 1×1卷积当然就更简单了,二者一并实现如下

class Down(nn.Module):
    def __init__(self, inSize, outSize):
        super().__init__()
        self.conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(inSize, outSize))

    def forward(self, x):
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, inSize, outSize):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(inSize, outSize, 1)

    def forward(self, x):
        return self.conv(x)

上采样过程相对来说复杂一点,多了一个拼接操作,故而其forward函数中,除了需要输入被卷积的数据之外,还要输入U形中,与之对应的那部分计算结果

class Up(nn.Module):
    def __init__(self, inSize, outSize):
        super().__init__()

        self.up = nn.UpsamplingBilinear2d(scale_factor=2)
        self.conv = DoubleConv(inSize, outSize)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

最后,将这几个组分拼接成一个UNet

class UNet(nn.Module):
    def __init__(self, nChannel, nClass):
        super(UNet, self).__init__()
        self.inc = DoubleConv(nChannel, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, nClass)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        return self.outc(x)

数据集

在具体训练之前,需要准备数据集,其中图像存放在image文件夹中,标签存放在label文件夹中,同名的图像和标签文件一一对应。

from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset

class ImgData(Dataset):
    def __init__(self, data_path):
        self.path = data_path
        self.imgForder = os.path.join(data_path, "image")

    # 加载图像
    def loadImg(self, path):
        img = np.array(Image.open(path))
        return img.reshape(1, *img.shape)

    # 根据index读取图片
    def __getitem__(self, index):
        pImg = os.path.join(self.path, f"image\{index}.png")
        pLabel = os.path.join(self.path, f"label\{index}.png")
        image = self.loadImg(pImg)
        label = self.loadImg(pLabel)
        # 数据标签归一化
        if label.max() > 1:
            label = label / 255
        # 随机翻转图像,增加训练样本
        flipCode = np.random.randint(3)
        if flipCode!=0:
            image = np.flip(image, flipCode).copy()
            label = np.flip(label, flipCode).copy()
        return image, label

    def __len__(self):
        # 返回训练集大小
        return len(os.listdir(self.imgForder))

训练

接下来就是激动人心的训练过程了,UNet采用RMSprop优化算法和BCEWithLogits损失函数,训练函数如下

from torch.utils.data import DataLoader
from torch import optim
import torch.nn as nn

def train(net, device, path, epochs=40, bSize=1, lr=0.00001):
    igmData = ImgData(path)
    train_loader = DataLoader(igmData, bSize, shuffle=True)
    # 优化算法
    optimizer = optim.RMSprop(net.parameters(),
            lr=lr, weight_decay=1e-8, momentum=0.9)

    criterion = nn.BCEWithLogitsLoss()      # 损失函数
    bestLoss = float('inf')                # 最佳loss,初始化为无穷大

    # 训练epochs次
    for epoch in range(epochs):
        net.train()     # 训练模式
        for image, label in train_loader:
            optimizer.zero_grad()
            # 将数据拷贝到device中
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)

            pred = net(image)   # 使用网络参数,输出预测结果
            loss = criterion(pred, label)   # 计算损失
            # 保存loss最小的网络参数
            if loss < bestLoss:
                bestLoss = loss
                torch.save(net.state_dict(), 'best_model.pth')

            loss.backward() # 更新参数
            optimizer.step()

        print(epoch, 'Loss/train', loss.item())

接下来调用训练函数,经过40次训练之后,得到51MB的best_model.pth模型文件,此即最佳测试结果

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(1, 1)
net.to(device=device)
path = "train/"
train(net, device, path)

预测

所谓预测,无非是重新做一次训练,而且不及损失,只需保存被神经网络处理之后的结果即可,下面是预测一张图像的函数,其输入net即为我们训练好的网络,device为设备。

def predictOne(net, device, pRead, pSave):
    img = Image.open(pRead)
    img = np.array(img)
    img = img.reshape(1, 1, *img.shape)

    img = torch.from_numpy(img)
    img = img.to(device=device, dtype=torch.float32)

    pred = net(img)     # 预测
    pred[pred >= 0.5] = 255
    pred[pred < 0.5] = 0

    pred = np.array(pred.data.cpu()[0])[0]
    img = Image.fromarray(pred.astype(np.uint8))
    img.save(pSave)

最后,批量处理预测数据集,test和predict分别是存放测试文件和预测图像的文件夹。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(1, 1)
net.to(device=device)
net.load_state_dict(torch.load('best_model.pth', map_location=device))

net.eval()      # 测试模式
fs = os.listdir('test')
for f in fs:
    pRead = os.path.join('test', f)
    pSave = os.path.join("predict",f)
    predictOne(net, device, pRead, pSave)

预测结果如下,左侧为图像,右侧为标签。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1347811.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Lesson 06 vector类(上)

C&#xff1a;渴望力量吗&#xff0c;少年&#xff1f; 文章目录 一、vector是什么&#xff1f;二、vector的使用1. 构造函数2. vector iterator3. vector 空间增长问题4. vector增删查改 三、vector实际使用 一、vector是什么&#xff1f; vector是表示可变大小数组的序列容器…

考研后SpringBoot复习2—容器底层相关注解

考研后SpringBoot复习2 SpringBoot底层注解学习 与容器功能相关的注解与springboot的底层原理密切相关 组件添加注解configuration Spring Ioc容器部分回顾 包括在配置中注册&#xff0c;开启包扫描和注解驱动开发等需要在进行重新的学习回顾 实例 package com.dzu.boot;imp…

启动gazebo harmonic

ros2 launch ros_gz_sim gz_sim.launch.py gz_version:8 如果不输入gz_version:8,默认就是6&#xff0c;启动的就是默认版本ign版本 左边那个是8&#xff0c;右边那个是6

STC8H系列单片机入门教程之NVC系列语音播报模块(九)

一、模块简述 ● 模组支持3.3V和5V单片机供电系统 ● 标准2.54MM间距排针与外部连接 ● 支持喇叭0.5W/8欧 ● 适合用于超声波距离、电子秤重量、时钟时间、温度、球赛比分等语音播报 二、引脚说明 序号 名称 说明 1 VCC 电源正&#xff08;3.3V-5V&#…

老子的《道德经》透露,不努力反而更成功

人类生而自由&#xff0c;但到处都是枷锁。 永远不要怀疑经过慎思且足够投入的一小群人能否改变这个世界。事实上&#xff0c;只有他们才办得到。 优美灵魂的两个发展方向&#xff1a;崇拜道德的天才&#xff0c;对别人实行道德的判断。 一、道 《道德经》开始的名字是《老子…

2023年度学习总结

想想大一刚开始在CSDN写作&#xff0c;这一坚持&#xff0c;就是我在CSDN的第九个年头&#xff0c;这也是在CSDN最有里程碑的一年&#xff0c;这一年我被评为CSDN的博客专家啦&#xff01;先是被评为Unity开发领域新星创作者&#xff0c;写的关于一部分Unity开发的心得获得大家…

系统功能测试的最好方法

我的新书《Android App开发入门与实战》已于2020年8月由人民邮电出版社出版&#xff0c;欢迎购买。点击进入详情 测试系统功能是软件开发和工程过程中的关键步骤。 它确保系统或软件应用程序按预期运行、满足用户要求并可靠运行。 在这里&#xff0c;我们深入探讨最佳方法&a…

matplotlib绘制柱状图

代码 import matplotlib.pyplot as plt import numpy as np# 数据 categories [denoise, double-digit, 100% 5R] existence [0.9778, 0.9768, 0.9767] non_existence [0.9772, 0.9767, 0.9778]# 设置每组柱状图的宽度 bar_width 0.25# 计算每组柱状图的位置 x np.arange…

react + redux 之 美团案例

1.案例展示 2.环境搭建 克隆项目到本地&#xff08;内置了基础静态组件和模版&#xff09; git clone http://git.itcast.cn/heimaqianduan/redux-meituan.git 安装所有依赖 npm i 启动mock服务&#xff08;内置了json-server&#xff09; npm run serve 启动前端服务 npm…

【模拟电路】软件Circuit JS

一、模拟电路软件Circuit JS 二、Circuit JS软件配置 三、Circuit JS 软件 常见的快捷键 四、Circuit JS软件基础使用 五、Circuit JS软件使用讲解 欧姆定律电阻的串联和并联电容器的充放电过程电感器和实现理想超导的概念电容阻止电压的突变&#xff0c;电感阻止电流的突变LR…

Linux基础知识学习3

vim编辑器 其分为四种模式 1.普通(命令)模式 2.编辑模式 3.底栏模式 4.可视化模式 vim编辑器被称为编辑器之神&#xff0c;而Emacs更是神之编辑器 普通模式&#xff1a; 1.光标移动 ^ 移动到行首 w 跳到下一个单词的开头…

【完整思路】2023 年中国高校大数据挑战赛 赛题 B DNA 存储中的序列聚类与比对

2023 年中国高校大数据挑战赛 赛题 B DNA 存储中的序列聚类与比对 任务 1.错误率和拷贝数分析&#xff1a;分析“train_reads.txt”和“train_reference.txt”数据集中的错误率&#xff08;插入、删除、替换、链断裂&#xff09;和序列拷贝数。 2.聚类模型开发&#xff1a;开发…

【快速全面掌握 WAMPServer】10.HTTP2.0时代,让 WampServer 开启 SSL 吧!

网管小贾 / sysadm.cc 如今的互联网就是个看脸的时代&#xff0c;颜值似乎成了一切&#xff01; 不信&#xff1f;看看那些直播带货的就知道了&#xff0c;颜值与出货量绝对成正比&#xff01; 而相对于 HTTP 来说&#xff0c;HTTPS 绝对算得上是高颜值的帅哥&#xff0c;即安…

java的参数传递机制概述,方法重载概述,以及相关案例

前言&#xff1a; 学了Java的传递机制&#xff0c;稍微记录一下。循循渐进&#xff0c;daydayup&#xff01; java的参数传递机制概述 1&#xff0c;java的参数传递机制是什么&#xff1f; java的参数传递机制是一种值传递机制。 2&#xff0c;值传递是什么&#xff1f; 值…

NGUI基础-三大基础组件之Panel组件

目录 Panel组件 Panel的作用&#xff1a; 注意&#xff1a; 相关关键参数讲解&#xff1a; Alpha&#xff08;透明度值&#xff09;&#xff1a; Depth&#xff08;深度&#xff09;&#xff1a; Clippinng&#xff08;裁剪&#xff09;&#xff1a; ​编辑 None Tex…

八. 实战:CUDA-BEVFusion部署分析-环境搭建

目录 前言0. 简述1. CUDA-BEVFusion浅析2. CUDA-BEVFusion环境配置2.1 简述2.2 源码下载2.3 模型数据下载2.4 基础软件安装2.5 protobuf安装2.5.1 apt 方式安装2.5.2 源码方式安装 2.6 编译运行2.6.1 配置 environment.sh2.6.2 利用TensorRT构建模型2.6.3 编译运行程序 2.7 拓展…

Baumer工业相机堡盟工业相机如何通过NEOAPI SDK设置相机的图像剪切(ROI)功能(C++)

Baumer工业相机堡盟工业相机如何通过NEOAPI SDK设置相机的图像剪切&#xff08;ROI&#xff09;功能&#xff08;C&#xff09; Baumer工业相机Baumer工业相机的图像剪切&#xff08;ROI&#xff09;功能的技术背景CameraExplorer如何使用图像剪切&#xff08;ROI&#xff09;功…

分库分表之Mycat应用学习五

5 Mycat 离线扩缩容 当我们规划了数据分片&#xff0c;而数据已经超过了单个节点的存储上线&#xff0c;或者需要下线节 点的时候&#xff0c;就需要对数据重新分片。 5.1 Mycat 自带的工具 5.1.1 准备工作 1、mycat 所在环境安装 mysql 客户端程序。 2、mycat 的 lib 目录…

汇川PLC(H5U):定时器指令

一、H5U系列的定时器种类 H5U系列PLC的定时器指令都封装成指令块了&#xff0c;共4种类型&#xff1a;脉冲定时器、接通延时定时器、关断延时定时器、时间累加定时器。 H5U系列PLC的定时器时间基准是1ms&#xff0c;在IN引脚的执行指令有效的时候开始跟新计数器的值。 我们知…

以太网转RS485通讯类库封装

最近选用有人科技的以太网转RS485模块做项目&#xff0c;设备真漂亮&#xff0c;国货之光。调通了通讯的代码&#xff0c;发到网上供大家参考&#xff0c;多多交流。 以下分别是配套的头文件与源文件&#xff1a; /*******************************************************…