unet脑肿瘤分割完整代码

news2025/2/24 6:07:38

U-net脑肿瘤分割完整代码

    • 代码目录
    • 数据集
    • 网络
    • 训练
    • 测试

代码目录

在这里插入图片描述

数据集

在这里插入图片描述
https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation

dataset.py

在这里插入代码片import os
import numpy as np
import glob
from PIL import Image
import cv2
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import matplotlib.pyplot as plt

kaggle_3m='./kaggle_3m/'
dirs=glob.glob(kaggle_3m+'*')
#print(dirs)
#os.listdir('./kaggle_3m\\TCGA_HT_A61B_19991127')
data_img=[]
data_label=[]
for subdir in dirs:
    dirname=subdir.split('\\')[-1]
    for filename in os.listdir(subdir):
        img_path=subdir+'/'+filename #图片的绝对路径
        if 'mask' in img_path:
            data_label.append(img_path)
        else:
            data_img.append(img_path)
#data_img[:5] #前几张图 和标签是否对应
#data_label[:5]
data_imgx=[]
for i in range(len(data_label)):#图片和标签对应
    img_mask=data_label[i]
    img=img_mask[:-9]+'.tif'
    data_imgx.append(img)
#data_imgx
data_newimg=[]
data_newlabel=[]
for i in data_label:#获取只有病灶的数据
    value=np.max(cv2.imread(i))
    try:
        if value>0:
            data_newlabel.append(i)
            i_img=i[:-9]+'.tif'
            data_newimg.append(i_img)
    except:
        pass
#查看结果
#data_newimg[:5]
#data_newlabel[:5]
im=data_newimg[20]
im=Image.open(im)
#im.show(im)
im=data_newlabel[20]
im=Image.open(im)
#im.show(im)
#print("可用数据:")
#print(len(data_newlabel))
#print(len(data_newimg))
#数据转换
train_transformer=transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])
test_transformer=transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
])
class BrainMRIdataset(Dataset):
    def __init__(self, img, mask, transformer):
        self.img = img
        self.mask = mask
        self.transformer = transformer

    def __getitem__(self, index):
        img = self.img[index]
        mask = self.mask[index]

        img_open = Image.open(img)
        img_tensor = self.transformer(img_open)

        mask_open = Image.open(mask)
        mask_tensor = self.transformer(mask_open)

        mask_tensor = torch.squeeze(mask_tensor).type(torch.long)

        return img_tensor, mask_tensor

    def __len__(self):
        return len(self.img)
s=1000#划分训练集和测试集
train_img=data_newimg[:s]
train_label=data_newlabel[:s]
test_img=data_newimg[s:]
test_label=data_newlabel[s:]
#加载数据
train_data=BrainMRIdataset(train_img,train_label,train_transformer)
test_data=BrainMRIdataset(test_img,test_label,test_transformer)

dl_train=DataLoader(train_data,batch_size=4,shuffle=True)
dl_test=DataLoader(test_data,batch_size=4,shuffle=True)

img,label=next(iter(dl_train))
plt.figure(figsize=(12,8))
for i,(img,label) in enumerate(zip(img[:4],label[:4])):
    img=img.permute(1,2,0).numpy()
    label=label.numpy()
    plt.subplot(2,4,i+1)
    plt.imshow(img)
    plt.subplot(2,4,i+5)
    plt.imshow(label)

网络

在这里插入图片描述
model.py


import torch
import torch.nn as nn


class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels,
                                      kernel_size=3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.Conv2d(out_channels, out_channels,
                                      kernel_size=3, padding=1),
                            nn.ReLU(inplace=True)
            )
        self.pool = nn.MaxPool2d(kernel_size=2)
    def forward(self, x, is_pool=True):
        if is_pool:
            x = self.pool(x)
        x = self.conv_relu(x)
        return x


class Upsample(nn.Module):
    def __init__(self, channels):
        super(Upsample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(2 * channels, channels,
                      kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels,
                      kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.upconv_relu = nn.Sequential(
            nn.ConvTranspose2d(channels,
                               channels // 2,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv_relu(x)
        x = self.upconv_relu(x)
        return x


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.down1 = Downsample(3, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.down4 = Downsample(256, 512)
        self.down5 = Downsample(512, 1024)

        self.up = nn.Sequential(
            nn.ConvTranspose2d(1024,
                               512,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.ReLU(inplace=True)
        )

        self.up1 = Upsample(512)
        self.up2 = Upsample(256)
        self.up3 = Upsample(128)

        self.conv_2 = Downsample(128, 64)
        self.last = nn.Conv2d(64, 2, kernel_size=1)

    def forward(self, x):
        x1 = self.down1(x, is_pool=False)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)

        x5 = self.up(x5)

        x5 = torch.cat([x4, x5], dim=1)  # 32*32*1024
        x5 = self.up1(x5)  # 64*64*256)
        x5 = torch.cat([x3, x5], dim=1)  # 64*64*512
        x5 = self.up2(x5)  # 128*128*128
        x5 = torch.cat([x2, x5], dim=1)  # 128*128*256
        x5 = self.up3(x5)  # 256*256*64
        x5 = torch.cat([x1, x5], dim=1)  # 256*256*128

        x5 = self.conv_2(x5, is_pool=False)  # 256*256*64

        x5 = self.last(x5)  # 256*256*3
        return x5

if __name__ == '__main__':
    x = torch.rand([8, 3, 256, 256])
    model = Net()
    y = model(x)

训练

train.py

import torch as t
import torch.nn as nn
from tqdm import tqdm  #进度条
import model
from dataset import *


device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

train_data=BrainMRIdataset(train_img,train_label,train_transformer)
test_data=BrainMRIdataset(test_img,test_label,test_transformer)

dl_train=DataLoader(train_data,batch_size=4,shuffle=True)
dl_test=DataLoader(test_data,batch_size=4,shuffle=True)

model = model.Net()
img,label=next(iter(dl_train))
model=model.to('cuda')
img=img.to('cuda')
pred=model(img)
label=label.to('cuda')
loss_fn=nn.CrossEntropyLoss()#交叉熵损失函数
loss_fn(pred,label)
optimizer=torch.optim.Adam(model.parameters(),lr=0.0001)
def train_epoch(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    epoch_iou = [] #交并比

    net=model.train()
    for x, y in tqdm(testloader):
        x, y = x.to('cuda'), y.to('cuda')
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()

            intersection = torch.logical_and(y, y_pred)
            union = torch.logical_or(y, y_pred)
            batch_iou = torch.sum(intersection) / torch.sum(union)
            epoch_iou.append(batch_iou.item())

    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc = correct / (total * 256 * 256)

    test_correct = 0
    test_total = 0
    test_running_loss = 0
    epoch_test_iou = []

    t.save(net.state_dict(), './Results/weights/unet_weight/{}.pth'.format(epoch))

    model.eval()
    with torch.no_grad():
        for x, y in tqdm(testloader):
            x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()

            intersection = torch.logical_and(y, y_pred)#预测值和真实值之间的交集
            union = torch.logical_or(y, y_pred)#预测值和真实值之间的并集
            batch_iou = torch.sum(intersection) / torch.sum(union)
            epoch_test_iou.append(batch_iou.item())

    epoch_test_loss = test_running_loss / len(testloader.dataset)
    epoch_test_acc = test_correct / (test_total * 256 * 256)#预测正确的值除以总共的像素点

    print('epoch: ', epoch,
          'loss: ', round(epoch_loss, 3),
          'accuracy:', round(epoch_acc, 3),
          'IOU:', round(np.mean(epoch_iou), 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy:', round(epoch_test_acc, 3),
          'test_iou:', round(np.mean(epoch_test_iou), 3)
          )

    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc


if __name__ == "__main__":
    epochs=20
    for epoch in range(epochs):
        train_epoch(epoch,
                    model,
                    dl_train,
                    dl_test)


在这里插入图片描述
只跑了20个epoch

测试

test.py

import torch as t
import torch.nn as nn
import model
from dataset import *
import matplotlib.pyplot as plt

device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

train_data=BrainMRIdataset(train_img,train_label,train_transformer)
test_data=BrainMRIdataset(test_img,test_label,test_transformer)

dl_train=DataLoader(train_data,batch_size=4,shuffle=True)
dl_test=DataLoader(test_data,batch_size=4,shuffle=True)

model = model.Net()
img,label=next(iter(dl_train))
model=model.to('cuda')
img=img.to('cuda')
pred=model(img)
label=label.to('cuda')
loss_fn=nn.CrossEntropyLoss()
loss_fn(pred,label)
optimizer=torch.optim.Adam(model.parameters(),lr=0.0001)
def test():
    image, mask = next(iter(dl_test))
    image=image.to('cuda')
    net = model.eval()
    net.to(device)
    net.load_state_dict(t.load("./Results/weights/unet_weight/18.pth"))
    pred_mask = model(image)
    pred_mask=pred_mask
    mask=torch.squeeze(mask)
    pred_mask=pred_mask.cpu()
    num=4
    plt.figure(figsize=(10, 10))
    for i in range(num):
        plt.subplot(num, 4, i*num+1)
        plt.imshow(image[i].permute(1,2,0).cpu().numpy())
        plt.subplot(num, 4, i*num+2)
        plt.imshow(mask[i].cpu().numpy(),cmap='gray')#标签
        plt.subplot(num, 4, i*num+3)
        plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy(),cmap='gray')#预测
    plt.show()


if __name__ == "__main__":
    test()

模型分割效果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1383882.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

详谈Python的开发工具

Python作为一种流行的编程语言,在开发过程中需要使用各种工具来提高效率、简化工作流程和改善开发体验。在本文中,我们将介绍一些常用的Python开发工具,包括文本编辑器、集成开发环境(IDE)、虚拟环境管理工具、包管理器…

MCU最小系统原理图中四个问题详解——芯片中有很多电源管脚的原因(VDD/VSS/VBAT)、LC滤波、两级滤波、NC可切换元件

前言:本文对MCU最小系统原理图中的四个问题进行详解:芯片中有很多电源管脚的原因(VDD/VSS/VBAT)、LC滤波、两级滤波、NC可切换元件。本文以GD32F103C8T6最小系统原理图举例 目录: 芯片中有很多电源管脚的原因&#x…

svn spring项目增量打包工具

svn spring项目增量打包工具 前提介绍 项目使用svn ,打包方式为war包,开发工具ide 项目有时候更新功能只需要更新部分class和html文件,但是要每个都打包并不是很简单 听说idea有现成的插件可以实现这个功能,但是我没找到&…

Java内容

目录 1.命名规范 1.命名规范 2.变量

PiflowX如何快速开发flink程序

PiflowX如何快速开发flink程序 参考资料 Flink最锋利的武器:Flink SQL入门和实战 | 附完整实现代码-腾讯云开发者社区-腾讯云 (tencent.com) Flink SQL 背景 Flink SQL 是 Flink 实时计算为简化计算模型,降低用户使用实时计算门槛而设计的一套符合标…

GPT5会是什么样的?奥特曼在YC W24会上演讲要点

“YC启动活动上,Sam Altman表示:以GPT-5和AGI将在’相对不久的将来’实现的心态来构建。” 在Y Combinator的一个启动活动中,Sam Altman表示,人工通用智能(AGI)的发展即将到来,并建议在构建产品…

【Fiddler抓包】微信扫码访问链接打不开网页

又来每天进步一点点~~~ 背景:某天发版的时候,手机连接电脑抓包查看用户登录之前的sessionID,由于业务需要,是需要用户登录微信扫码跳转至某一页面的,微信(分身)扫码成功,跳转时打不…

dcm数据格式转nrrd数据格式(2维转3维)

目的 将dcm数据格式(2D)转成nrrd数据格式(3D) 将一个文件夹下的dcm数据转成一个nrrd数据 代码 1. 安装必要包 pip install SimpleITK2. 上代码 Descripttion: Result: Author: Philo Date: 2024-01-10 14:25:49 LastEditors: …

WorkPlus企业打破信息孤岛,构建统一工作平台的首选之一

在当今数字化时代,企业内部存在着繁多的工作应用和系统。要实现高效的工作协作,企业需要一个统一的工作平台来打破信息孤岛,提升协作效率。作为一家领先的企业统一工作平台,WorkPlus以其卓越的性能和专业的功能,助力企…

【算法】Java-二叉树的右视图(BFS、DFS两种解法)

题目要求: 给定一个二叉树的 根节点 root,想象自己站在它的右侧,按照从顶部到底部的顺序,返回从右侧所能看到的节点值。 示例 1: 输入: [1,2,3,null,5,null,4] 输出: [1,3,4]示例 2: 输入: [1,null,3] 输出: [1,3]示例 3: 输入…

python基础-base64编码理解

目录 1、base64是什么 2、base64有什么用 3、base64如何用 4、理解base64 5、扩展 1、base64是什么 base64 就是包括字母a-z,A-Z,数字0-9,符号“”,“/”一共64个字符的字符集;还有一个‘’ 字符,占位补充; …

WorkPlus助力企业高效协作的企业级内网即时通讯解决方案

在企业内部,高效沟通和协作是推动工作顺利进行的关键。而企业级内网即时通讯成为了提升内部沟通效率的重要工具。作为一家领先的企业级内网即时通讯解决方案,WorkPlus以其卓越的性能和高安全性,打造了高效沟通协作的新标杆。 为什么选择WorkP…

嘘……快进来!这儿有最新版Microsoft照片程序的安装秘籍!(附安装引导程序下载)

网管小贾 / sysadm.cc 最近啊有不少小伙伴向我反馈,自个的 Windows 10 系统里边居然没有 Microsoft 照片 程序。 我觉得有点不可思议,为啥呢,因为他们的电脑是新买的! 你看哈,系统是 22H2 最新版,安装日期…

陀螺仪LSM6DSV16X与AI集成(6)----检测自由落体

陀螺仪LSM6DSV16X与AI集成.6--检测自由落体 概述视频教学样品申请源码下载生成STM32CUBEMX串口配置IIC配置CS和SA0设置串口重定向参考程序初始换管脚获取ID复位操作BDU设置 概述 本文介绍如何初始化传感器并配置其参数,以便在检测到自由落体事件时发送通知。 最近…

STM32H5 Nucleo-144 board开箱

文章目录 开发板资料下载 【目标】 点亮LD1(绿)、LD2(黄)和LD3(红)三个LED灯 【开箱过程】 博主使用的是STM32CubeMX配置生成代码,具体操作如下: 打开STM32CubeMX,File-…

620基于51单片机的密码锁设计[Proteus仿真]

620基于51单片机的密码锁设计[proteus仿真] 密码锁设计这个题目算是课 程设计和毕业设计中常见的题目了,本期是一个基于51单片机的密码锁设计 需要的源文件和程序的小伙伴可以关注公众号【阿目分享嵌入式】,赞赏任意文章 2¥,私信…

第 11 章 树结构实际应用

文章目录 11.1 堆排序11.1.1 堆排序基本介绍11.1.2 堆排序基本思想11.1.3 堆排序步骤图解说明11.1.4 堆排序代码实现 11.2 赫夫曼树11.2.1 基本介绍11.2.2 赫夫曼树几个重要概念和举例说明11.2.3 赫夫曼树创建思路图解11.2.4 赫夫曼树的代码实现 11.3 赫夫曼编码11.3.1 基本介绍…

关于java类与对象的创建

关于java类与对象的创建 我们在前面的文章中回顾了方法的定义和方法的调用,以及了解了面向对象的初步认识,我们本篇文章来了解一下类和对象的关系,还是遵循结合现实的方式去理解,不是死记硬背😀。 1、类 类是一种抽…

第 3 场 蓝桥杯小白入门赛 解题报告 | 珂学家 | 单调队列优化的DP + 三指针滑窗

前言 整体评价 T5, T6有点意思&#xff0c;这场小白入门场&#xff0c;好像没真正意义上的签到&#xff0c;整体感觉是这样。 A. 召唤神坤 思路: 前后缀拆解 #include <iostream> #include <algorithm> #include <vector> using namespace std;int main()…

03.neuvector之组的划分逻辑

neuvector之组的划分逻辑 原文链接,欢迎大家关注我的github账号 一、组的定义 NeuVector 会自动从正在运行的应用程序中创建组。这些组以前缀‘nv‘开头。您也可以使用 CRD 或 REST API 手动添加它们&#xff0c;并且可以在任何模式下创建、发现、监视或保护。网络和响应规则需…