烹饪第一个U-Net进行图像分割

news2025/1/9 1:04:48

今天我们将学习如何准备计算机视觉中最重要的网络之一:U-Net。如果你没有代码和数据集也没关系,可以分别通过下面两个链接进行访问:

代码:

https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation?source=post_page-----e812e37e9cd0--------------------------------

Kaggle的MRI分割数据集:

https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation?source=post_page-----e812e37e9cd0--------------------------------

主要步骤:

1. 数据集的探索

2. 数据集和Dataloader类的创建

3. 架构的创建

4. 检查损失(DICE和二元交叉熵)

5. 结果

数据集的探索

我们得到了一组(255 x 255)的MRI扫描的2D图像,以及它们相应的必须将每个像素分类为0(健康)或1(肿瘤)。

这里有一些例子:

8f4f6fa5dd3eee62cf8537f20a64c9ca.jpeg

第一行:肿瘤,第二行:健康主题

数据集和Dataloader类

这是涉及神经网络的每个项目中都会找到的一步。

数据集类

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


class BrainMriDataset(Dataset):
    def __init__(self, df, transforms):
        # df contains the paths to all files
        self.df = df
        # transforms is the set of data augmentation operations we use
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image = cv2.imread(self.df.iloc[idx, 1])
        mask = cv2.imread(self.df.iloc[idx, 2], 0)
        
        augmented = self.transforms(image=image, 
                                    mask=mask)
 
        image = augmented['image'] # Dimension (3, 255, 255)
        mask = augmented['mask']   # Dimension (255, 255)


        # We notice that the image has one more dimension (3 color channels), so we have to one one "artificial" dimension to the mask to match it
        mask = np.expand_dims(mask, axis=0) # Dimension (1, 255, 255)
        
        return image, mask

数据加载器

既然我们已经创建了Dataset类来重新整形张量,我们首先需要定义训练集(用于训练模型),验证集(用于监控训练并避免过拟合),以及测试集,最终评估模型在未见数据上的性能。

# Split df into train_df and val_df
train_df, val_df = train_test_split(df, stratify=df.diagnosis, test_size=0.1)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)


# Split train_df into train_df and test_df
train_df, test_df = train_test_split(train_df, stratify=train_df.diagnosis, test_size=0.15)
train_df = train_df.reset_index(drop=True)


train_dataset = BrainMriDataset(train_df, transforms=transforms)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)


val_dataset = BrainMriDataset(val_df, transforms=transforms)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)


test_dataset = BrainMriDataset(test_df, transforms=transforms)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

U-Net架构

e528761129f5514c2493d1871166da60.jpeg

U-Net架构是用于图像分割任务的强大模型,是卷积神经网络(CNN)的一种类型,其名称来自其U形状的结构。U-Net最初由Olaf Ronneberger等人在2015年的论文中首次开发,标题为“U-Net:用于生物医学图像分割的卷积网络”。

其结构涉及编码(降采样)路径和解码(上采样)路径。U-Net至今仍然是一个非常成功的模型,其成功来自两个主要因素:

1. 对称结构(U形状)

2. 前向连接(图片上的灰色箭头)

前向连接的主要思想是,随着我们在层中越来越深入,我们会失去有关原始图像的一些信息。然而,我们的任务是对图像进行分割,我们需要精确的图像来对每个像素进行分类。这就是为什么我们在对称解码器层的每一层中重新注入图像的原因。以下是通过Pytorch实现的代码:

train_dataset = BrainMriDataset(train_df, transforms=transforms)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)


val_dataset = BrainMriDataset(val_df, transforms=transforms)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)


test_dataset = BrainMriDataset(test_df, transforms=transforms)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)




class UNet(nn.Module):


    def __init__(self):
        super().__init__()


        # Define convolutional layers
        # These are used in the "down" path of the U-Net,
        # where the image is successively downsampled
        self.conv_down1 = double_conv(3, 64)
        self.conv_down2 = double_conv(64, 128)
        self.conv_down3 = double_conv(128, 256)
        self.conv_down4 = double_conv(256, 512)


        # Define max pooling layer for downsampling
        self.maxpool = nn.MaxPool2d(2)


        # Define upsampling layer
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)


        # Define convolutional layers
        # These are used in the "up" path of the U-Net,
        # where the image is successively upsampled
        self.conv_up3 = double_conv(256 + 512, 256)
        self.conv_up2 = double_conv(128 + 256, 128)
        self.conv_up1 = double_conv(128 + 64, 64)


        # Define final convolution to output correct number of classes
        # 1 because there are only two classes (tumor or not tumor)
        self.last_conv = nn.Conv2d(64, 1, kernel_size=1)


    def forward(self, x):
        # Forward pass through the network


        # Down path
        conv1 = self.conv_down1(x)
        x = self.maxpool(conv1)
        conv2 = self.conv_down2(x)
        x = self.maxpool(conv2)
        conv3 = self.conv_down3(x)
        x = self.maxpool(conv3)
        x = self.conv_down4(x)


        # Up path
        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)
        x = self.conv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.conv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)
        x = self.conv_up1(x)


        # Final output
        out = self.last_conv(x)
        out = torch.sigmoid(out)


        return out

损失和评估标准

与每个神经网络一样,都有一个目标函数,一种损失,我们通过梯度下降最小化它。我们还引入了评估标准,它帮助我们训练模型(如果它在连续的3个时期中没有改善,那么我们停止训练,因为模型正在过拟合)。从这一段中有两个主要要点:

1. 损失函数是两个损失函数的组合(DICE损失,二元交叉熵)

2. 评估函数是DICE分数,不要与DICE损失混淆

DICE损失:

c5b7a1b8fafaa01bdbfe5f549725a7ce.jpeg

DICE损失

备注:我们添加了一个平滑参数(epsilon)以避免除以零。

二元交叉熵损失:

9a4538b4bda0db62bfa673b13131def8.jpeg

BCE

于是,我们的总损失是:

9da84b914878f2bd0934f20ef2aae91c.jpeg

让我们一起实现它:

def dice_coef_loss(inputs, target):
    smooth = 1.0
    intersection = 2.0 * ((target * inputs).sum()) + smooth
    union = target.sum() + inputs.sum() + smooth


    return 1 - (intersection / union)




def bce_dice_loss(inputs, target):
    inputs = inputs.float()
    target = target.float()
    
    dicescore = dice_coef_loss(inputs, target)
    bcescore = nn.BCELoss()
    bceloss = bcescore(inputs, target)


    return bceloss + dicescore

评估标准(Dice系数):

我们使用的评估函数是DICE分数。它在0到1之间,1是最好的。

0be7e283b0e50c64a2fd1a935d03b85a.jpeg

Dice分数的图示

其数学实现如下:

3f035ea9fdf85f0e09abd6293c5bd257.jpeg

def dice_coef_metric(inputs, target):
    intersection = 2.0 * (target * inputs).sum()
    union = target.sum() + inputs.sum()
    if target.sum() == 0 and inputs.sum() == 0:
        return 1.0


    return intersection / union

训练循环

def train_model(model_name, model, train_loader, val_loader, train_loss, optimizer, lr_scheduler, num_epochs):  
    
    print(model_name)
    loss_history = []
    train_history = []
    val_history = []


    for epoch in range(num_epochs):
        model.train()  # Enter train mode
        
        # We store the training loss and dice scores
        losses = []
        train_iou = []
                
        if lr_scheduler:
            warmup_factor = 1.0 / 100
            warmup_iters = min(100, len(train_loader) - 1)
            lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
        
        # Add tqdm to the loop (to visualize progress)
        for i_step, (data, target) in enumerate(tqdm(train_loader, desc=f"Training epoch {epoch+1}/{num_epochs}")):
            data = data.to(device)
            target = target.to(device)
                      
            outputs = model(data)
            
            out_cut = np.copy(outputs.data.cpu().numpy())


            # If the score is less than a threshold (0.5), the prediction is 0, otherwise its 1
            out_cut[np.nonzero(out_cut < 0.5)] = 0.0
            out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
            
            train_dice = dice_coef_metric(out_cut, target.data.cpu().numpy())
            
            loss = train_loss(outputs, target)
            
            losses.append(loss.item())
            train_iou.append(train_dice)


            # Reset the gradients
            optimizer.zero_grad()
            # Perform backpropagation to compute gradients
            loss.backward()
            # Update the parameters with the computed gradients
            optimizer.step()
    
            if lr_scheduler:
                lr_scheduler.step()
        
        val_mean_iou = compute_iou(model, val_loader)
        
        loss_history.append(np.array(losses).mean())
        train_history.append(np.array(train_iou).mean())
        val_history.append(val_mean_iou)
        
        print("Epoch [%d]" % (epoch))
        print("Mean loss on train:", np.array(losses).mean(), 
              "\nMean DICE on train:", np.array(train_iou).mean(), 
              "\nMean DICE on validation:", val_mean_iou)
        
    return loss_history, train_history, val_history

结果

让我们在一个带有肿瘤的主题上评估我们的模型:

326c72e0a3dc7e3cf4b716108a4a58e6.jpeg

结果看起来相当不错!我们可以看到模型明显学到了关于图像结构的一些有用信息。然而,它可能可以更好地细化分割,这可以通过我们将很快讨论的更先进的技术来实现。U-Net至今仍然广泛使用,但有一个著名的模型达到了最先进的性能,称为nn-UNet。

·  END  ·

HAPPY LIFE

0ce980e3d287c8e1d8cf10ae1f388a0f.png

本文仅供学习交流使用,如有侵权请联系作者删除

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1449390.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言】常见字符串函数的功能与模拟实现

目录 1.strlen() 模拟实现strlen() 2.strcpy() 模拟实现strcpy() 3.strcat() 模拟实现strcat() 4.strcmp() 模拟实现strcmp() 5.strncpy() 模拟实现strncpy() 6.strncat() 模拟实现strncat() 7.strncmp() 模拟实现strncmp() 8.strstr() 模拟实现strstr() 9.str…

第二十九回 施恩三入死囚牢 武松大闹飞云浦-分布式版本控制系统Git使用

武松要蒋门神答应三件事&#xff1a;离开快活林、东西都归还施恩&#xff0c;公开对施恩赔礼道歉&#xff0c;不许在孟州住。蒋门神不得已都答应了&#xff0c;灰溜溜地离开了孟州城。 一个月之后&#xff0c;天气转凉&#xff0c;张都监调武松到孟州城&#xff0c;做了他的亲…

vue3-应用规模化-路由和状态

客户端 vs. 服务端路由 服务端路由指的是服务器根据用户访问的 URL 路径返回不同的响应结果。当我们在一个传统的服务端渲染的 web 应用中点击一个链接时&#xff0c;浏览器会从服务端获得全新的 HTML&#xff0c;然后重新加载整个页面。 然而&#xff0c;在单页面应用中&…

CSS设置盒子阴影

语法 box-shadow: *h-shadow v-shadow blur spread color* inset; 注释: box-shadow向框添加一个或多个阴影. 该属性是由逗号分隔的阴影列表,每个阴影由2-4个长度值、可选的颜色值及可选的inset关键词来规定。省略长度的值是0。 外阴影 a、给元素右边框和下边框加外阴影——把…

生活篇——华为手机去除负一屏

华为手机去除如下图的恶心负一屏 打开华为的应用市场app 进入&#xff1a;我的-设置-国家/地区&#xff08;改为俄罗斯&#xff09;-进入智慧助手检查更新并更新智慧助手。 然后重复开始的操作&#xff0c;将地区改回中国&#xff0c;这样就没有负一屏了。

python自学...

一、稍微高级一点的。。。 1. 闭包&#xff08;跟js差不多&#xff09; 2. 装饰器 就是spring的aop 3. 多线程

拟合案例2:matlab实现分段函数拟合(分段点未知)及源码

案例介绍: 本案是针对一个分段函数中的参数进行拟合,使用的拟合工具是matlab中的lsqcurvefit或nlinfit。函数形式和待拟合参数如下所示。该案例的特殊之处在于分段点也是待拟合参数,因此如何自定义拟合函数,实现分段点的拟合是本案例最大的难点。本案例提供了三种分段函数…

双端队列,优先队列,单调队列

单调队列 单调队列是指一个队列内部元素具有单调性的数据结构 分为单调递增队列和单调递减队列 单调队列满足三个性质&#xff1a; 单调队列也是队列&#xff0c;满足先进先出单调队列必须满足从队头到队尾的单调性排在队列前面的元素比排在队列后面的元素要先进队 代码实现上…

excel统计分析——多组数据的秩和检验

单因素资料不完全满足方差的基本假定时&#xff0c;可进行数据转换后再进行方差分析&#xff0c;但有时数据转换后仍不满足方差分析的基本假定&#xff0c;就只能进行秩和检验了。 多组数据秩和检验的主要方法为Kruskal-Wallis检验&#xff0c;也称为Kruskal-Wallis秩和方差分析…

分布式文件系统 SpringBoot+FastDFS+Vue.js【二】

分布式文件系统 SpringBootFastDFSVue.js【二】 六、实现上传功能并展示数据6.1.创建数据库6.2.创建spring boot项目fastDFS-java6.3.引入依赖6.3.fastdfs-client配置文件6.4.跨域配置GlobalCrosConfig.java6.5.创建模型--实体类6.5.1.FastDfsFile.java6.5.2.FastDfsFileType.j…

__attribute__ ---Compile

Section for attribute attribute_&#xff1f;嵌入式C代码属性怎么定义 https://www.elecfans.com/d/2269222.html section 属性的主要作用是&#xff1a;在程序编译时&#xff0c;将一个函数或者变量放到指定的段&#xff0c;即指定的section 中。 一个可执行文件注意由代…

STM32——菜单(二级菜单)

文章目录 一.补充二. 二级菜单代码 简介&#xff1a;首先在我的51 I2C里面有OLED详细讲解&#xff0c;本期代码从51OLED基础上移植过来的&#xff0c;可以先看完那篇文章&#xff0c;在看这个&#xff0c;然后按键我是用的定时器扫描不会堵塞程序,可以翻开我的文章有单独的定时…

BUGKU-WEB 矛盾

题目描述 进入场景看看&#xff1a; 代码如下&#xff1a; $num$_GET[num]; if(!is_numeric($num)) { echo $num; if($num1) echo flag{**********}; }解题思路 需要读懂一下这段PHP代码的意思明显是一道get相关的题目&#xff0c;需要提供一个num的参数,然后需要传入一个不…

【数据结构】顺序栈和链式栈的简单实现和解析(C语言版)

数据结构——栈的简单解析和实现 一、概念二、入栈&#xff08;push&#xff09;三、出栈&#xff08;pop&#xff09;四、顺序栈简单实现 &#xff08;1&#xff09;进栈操作&#xff08;2&#xff09;出栈操作 一、概念 本篇所讲解的栈和队列属于逻辑结构上的划分。逻辑结构…

GPDB - 高可用 - FTS机制(一):探测成功

GPDB - 高可用 - FTS机制&#xff08;一&#xff09;&#xff1a;探测成功 作为GreenPlum高可用的核心功能&#xff0c;FTS&#xff08;Fault Tolerance Server&#xff09;进程负责故障检测。该进程是master上的一个子进程&#xff0c;可以快速检测到primary或者mirror是否宕机…

PyTorch深度学习快速入门教程 - 【小土堆学习笔记】

小土堆Pytorch视频教程链接 声明&#xff1a; 博主本人技术力不高&#xff0c;这篇博客可能会因为个人水平问题出现一些错误&#xff0c;但作为小白&#xff0c;还是希望能写下一些碰到的坑&#xff0c;尽力帮到其他小白 1 环境配置 1.1 pycharm pycharm建议使用2020的&…

【C语言】指针的进阶篇,深入理解指针和数组,函数之间的关系

欢迎来CILMY23的博客喔&#xff0c;本期系列为【C语言】指针的进阶篇&#xff0c;深入理解指针和数组&#xff0c;函数之间的关系&#xff0c;图文讲解其他指针类型以及指针和数组&#xff0c;函数之间的关系&#xff0c;带大家更深刻理解指针&#xff0c;以及数组指针&#xf…

LeetCode Python - 16.最接近的三数之和

目录 题目答案运行结果 题目 给你一个长度为 n 的整数数组 nums 和 一个目标值 target。请你从 nums 中选出三个整数&#xff0c;使它们的和与 target 最接近。 返回这三个数的和。 假定每组输入只存在恰好一个解。 示例 1&#xff1a; 输入&#xff1a;nums [-1,2,1,-4],…

Vulnhub靶场 DC-6

目录 一、环境搭建 二、主机发现 三、漏洞复现 1、wpscan工具 2、后台识别 dirsearch 3、爆破密码 4、rce漏洞利用 activity monitor 5、rce写shell 6、新线索 账户 7、提权 8、拿取flag 四、总结 一、环境搭建 Vulnhub靶机下载&#xff1a; 官网地址&#xff1a…

凭证获取:Linux凭证获取

目录 shadow文件详解 1.Shadow文件的组成与作用 2.破解Shadow文件内容 利用strace记录密码 shadow文件 1.Shadow文件的组成与作用 在Linux中&#xff0c;/etc/shadow文件又被称为“影子文件”&#xff0c;主要用于存储Linux文件系统中的用户凭据信息。该文件只能由root权…