深度学习实战:UNet模型的训练与测试详解

news2024/9/24 23:15:59
🍑个人主页:Jupiter.
🚀 所属专栏:Linux从入门到进阶
欢迎大家点赞收藏评论😊

在这里插入图片描述

在这里插入图片描述

目录

    • 1、云实例:配置选型与启动
      • 1.1 登录注册
      • 1.2 配置 SSH 密钥对
      • 1.3 创建实例
      • 1.4 登录云实例
    • 2、云存储:数据集上传与下载
    • 3、云开发:眼底血管分割案例
      • 3.1 案例背景
      • 3.2 网络搭建
      • 3.3 网络训练
      • 3.4 模型测试


1、云实例:配置选型与启动

1.1 登录注册

首先进入登录界面注册并登录账号

在这里插入图片描述

1.2 配置 SSH 密钥对

配置 SSH 密钥对的作用是后续远程登录服务器不需要密码验证,更加方便。

首先创建本地公钥,进入本地.ssh目录输入ssh-keygen -o命令,这里文件名可以设置为id_dsa,也可以是其他任意名字
在这里插入图片描述
之后我们可以在.ssh目录看到刚刚创建的两个文件

id_dsa id_dsa.pub
其中id_dsa.pub就是需要的公钥文件

进入密钥对配置,创建密钥对,将id_dsa.pub的内容复制到这里就可以

1.3 创建实例

进入GPU 云实例,点击创建实例。如下图所示,按需选择需要的 GPU 型号和镜像
在这里插入图片描述
在这里插入图片描述

1.4 登录云实例

等待实例创建完成后,点击复制“访问链接”。

在这里插入图片描述
接着来到任意一个 SSH 连接终端进行云实例登录,我这里选择的是 VSCode,如下所示
在这里插入图片描述
成功后,输入:

nvidia-smi
torch.cuda.is_available()

简单验证一下功能即可,如下所示即为成功
在这里插入图片描述

2、云存储:数据集上传与下载

文件存储为网络共享存储,可挂载至的不同实例中。相比本地数据盘,其优势是实例间共享,可以多点读写,不受实例释放的影响;此外存储后端有多冗余副本,数据可靠性非常高;但缺陷是 IO 性能一般

考虑到以上优劣,推荐使用方式:将重要数据或代码存放于文件存储中,所有实例共享,便利的同时数据可靠性也有保障;在训练时,需要高 IO 性能的数据(如训练数据),先拷贝到实例本地数据盘,从本地盘读数据获得更好的 IO 性能。如此兼顾便利、安全和性能。

接下来,我们将训练数据上传到云实例数据盘中。使用scp工具如下

scp -rP 35740 ./DRIVE-SEG-DATA root@cn-north-b.ssh.damodel.com:/root/workspace

具体地:

35740与cn-north-b.ssh.damodel.com分别为端口号和远程地址,请参考 1.4 节替换为自己的参数
./DRIVE-SEG-DATA是本地数据集路径
/root/workspace是远程实例数据集路径

在这里插入图片描述
数据的下载也是类似的命令

scp -rP 35740 root@cn-north-b.ssh.damodel.com:/root/workspace ./DRIVE-SEG-DATA

本文提到的数据集可以在DRIVE 数据集中下载:链接:https://drive.grand-challenge.org/Download/

3、云开发:眼底血管分割案例

3.1 案例背景

眼底也称为眼球的内膜,包括黄斑、视网膜和视网膜中央动静脉等结构。在临床医学中,眼底图像是眼科医生对眼疾病患者进行诊断的重要依据。随着深度学习的发展,医学影像分割技术产生了深远的变化,尤其是卷积神经网络 AlexNet、VGGNet、GoogLeNet、ResNet 等,能够学习到更加抽象和高级的特征表示,从而实现更加精确的分割结果。深度学习模型在大规模数据上训练后,通常能够获得更好的泛化能力,即对未见过的数据也能做出相对准确的预测。对于医学影像分割来说,这意味着模型可以更好地适应不同类型和来源的医学图像数据,提高了分割结果的可靠性和稳定性。同时,深度学习技术支持端到端的学习方式,即从原始输入数据直接学习到最终的分割结果,无需手工设计复杂的特征提取和预处理流程。这简化了分割算法的开发流程,提高了效率和准确性。此外,医学影像数据常常包含多种模态,如 CT、MRI 等。深度学习技术能够更好地处理多模态数据,实现不同模态之间的信息融合,从而提高了医学影像分割的准确性和全面性

在这里插入图片描述
本次实践,我们采用 UNet 进行眼底血管医学图像分割任务。UNet 是一种被广泛应用于语义分割任务的网络结构,其编码器-解码器结构以及跳跃连接的设计,使其能够有效地捕获图像中不同尺度的特征信息,从而在眼底血管分割任务中取得较好的效果。同时,在推理阶段,UNet 采用全卷积网络结构,能够快速对新的眼底图像进行血管分割,为临床应用提供了实时性支持。

3.2 网络搭建

选用 U-Net 网络结构作为基础分割模型的原因在于其通过编解码器架构,有效地结合局部信息和全局信息,提高分割准确性;同时,U-Net 的跳跃连接结构有助于保留和恢复图像中的细节和边缘信息,且在小样本情况下表现优异,能够充分利用有限数据进行有效训练,广泛应用于医学图像分割任务中。网络架构如下

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

3.3 网络训练

基于 PyTorch 的神经网络训练流程可以分为以下步骤(不考虑前期数据准备和模型结构):

定义损失函数 根据任务类型选择合适的损失函数(loss function),如分类任务常用的交叉熵损失(Cross-Entropy Loss)或回归任务中的均方误差(Mean Square Error)。

选择优化器 选择合适的优化器(optimizer),如随机梯度下降(SGD)、Adam 或 RMSprop,并设置初始学习率及其它优化参数。

训练模型 在训练过程中,通过迭代训练数据集来调整模型参数。每个迭代周期称为一个 epoch。对于每个 epoch,数据会被分成多个 batch,每个 batch 被输入到模型中进行前向传播、计算损失、反向传播更新梯度,并最终优化模型参数。

保存模型 当满足需求时,可以将训练好的模型保存下来,以便后续部署和使用。

根据这个步骤编写以下代码

def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
    dataset = Dateset_Loader(data_path)
    per_epoch_num = len(dataset) / batch_size
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    optimizer = optim.Adam(net.parameters(),lr=lr,betas=(0.9, 0.999),eps=1e-08, weight_decay=1e-08,amsgrad=False)
    criterion = nn.BCEWithLogitsLoss()
    best_loss = float('inf')
    loss_record = []
    with tqdm(total=epochs*per_epoch_num) as pbar:
        for epoch in range(epochs):
            net.train()
            for image, label in train_loader:
                optimizer.zero_grad()
                image = image.to(device=device, dtype=torch.float32)
                label = label.to(device=device, dtype=torch.float32)
                pred = net(image)
                loss = criterion(pred, label)
                pbar.set_description("Processing Epoch: {} Loss: {}".format(epoch+1, loss))
                if loss < best_loss:
                    best_loss = loss
                    torch.save(net.state_dict(), 'best_model.pth')
                loss.backward()
                optimizer.step()
                pbar.update(1)
            loss_record.append(loss.item())

    plt.figure()
    plt.plot([i+1 for i in range(0, len(loss_record))], loss_record)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.savefig('/root/shared-storage/results/training_loss.png')

运行这个脚本,可以在终端看到进度

在这里插入图片描述
训练损失函数如下,可以看到已经收敛

在这里插入图片描述

3.4 模型测试

测试逻辑如下所示,主要是计算 IoU 指标

def cal_miou(test_dir="/root/workspace/DRIVE-SEG-DATA/Test_Images",
             pred_dir="/root/workspace/DRIVE-SEG-DATA/results", gt_dir="/root/workspace/DRIVE-SEG-DATA/Test_Labels",
             model_path='best_model_drive.pth'):
    name_classes = ["background", "vein"]
    num_classes = len(name_classes)
    if not os.path.exists(pred_dir):
        os.makedirs(pred_dir)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = UNet(n_channels=1, n_classes=1)
    net.to(device=device)
    net.load_state_dict(torch.load(model_path, map_location=device))
    net.eval()

    img_names = os.listdir(test_dir)
    image_ids = [image_name.split(".")[0] for image_name in img_names]

    time.sleep(1)
    for image_id in tqdm(image_ids):
        image_path = os.path.join(test_dir, image_id + ".png")
        img = cv2.imread(image_path)
        origin_shape = img.shape
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        img = cv2.resize(img, (512, 512))
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
        img_tensor = torch.from_numpy(img)
        img_tensor = img_tensor.to(device=device, dtype=torch.float32)
        pred = net(img_tensor)
        pred = np.array(pred.data.cpu()[0])[0]
        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0
        pred = cv2.resize(pred, (origin_shape[1], origin_shape[0]), interpolation=cv2.INTER_NEAREST)
        cv2.imwrite(os.path.join(pred_dir, image_id + ".png"), pred)

    hist, IoUs, PA_Recall, Precision = compute_mIoU_gray(gt_dir, pred_dir, image_ids, num_classes, name_classes)
    miou_out_path = "/root/shared-storage/results/"
    show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)

模型保存的时候保存到共享存储路径/root/shared-storage,其他实例可以直接从共享存储中获取训练后的模型
在这里插入图片描述
在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2161743.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaScript --json格式字符串和对象的转化

json字符串解析成对象 &#xff1a; var obj JSON.parse(str) 对象转化成字符串&#xff1a;var str1 JSON.stringify(obj1) <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Com…

第五篇:Linux进程的相关知识总结(1)

目录 第四章&#xff1a;进程 4.1进程管理 4.1.1进程管理需要的学习目标 4.1.1.1了解进程的相关信息 4.1.1.2僵尸进程的概念和处理方法&#xff1a; 4.1.1.3PID、PPID的概念以及特性&#xff1a; 4.1.1.4进程状态 4.1.2进程管理PS 4.1.2.1静态查看进程 4.1.2.1.1自定义…

搭建EMQX MQTT服务器并接入Home Assistant和.NET程序

本文主要介绍如何使用Docker搭建EMQX MQTT服务器&#xff0c;并将其接入到Home Assistant中&#xff0c;最后演示如何使用.NET接入MQTT。 1. 背景 在智能家居系统中&#xff0c;MQTT&#xff08;消息队列遥测传输协议&#xff09;是一种轻量级的消息传输协议&#xff0c;特别适…

《深度学习》—— 神经网络中的数据增强

文章目录 一、为什么要进行数据增强&#xff1f;二、常见的数据增强方法1. 几何变换2. 颜色变换3. 尺寸变换4. 填充5. 噪声添加6. 组合变换 三、代码实现四、注意事项五、总结 一、为什么要进行数据增强&#xff1f; 神经网络中的数据增强是一种通过增加训练数据的多样性和数量…

动态规划11,完全背包模板

NC309 完全背包 问题一&#xff1a;求这个背包至多能装多大价值的物品&#xff1f; 状态表示&#xff1a;经验题目要求 dp[i][j] 表示 从前i个物品中挑选&#xff0c;总体积不超过j&#xff0c;所有选法中&#xff0c;能选出来的最大价值。 状态转移方程 根据最后一步的状态&a…

vue2 搜索高亮关键字

界面&#xff1a; 搜索 “成功” 附上代码&#xff08;开箱即用&#xff09; <template><div class"box"><input class"input-box" v-model"searchKeyword" placeholder"输入搜索关键字" /><div class"r…

【深度】边缘计算神器之数据网关

分布式计算、云边协同、互联互通是边缘计算设备的三项重要特征。 边缘计算设备通过分布式计算模式&#xff0c;将数据处理和分析任务从中心化的云平台下放到设备网关&#xff0c;即更接近数据源的地方&#xff0c;从而显著降低了数据传输的延迟&#xff0c;提高了响应速度和处…

OpenCV normalize() 函数详解及用法示例

OpenCV的normalize函数用于对数组&#xff08;图像&#xff09;进行归一化处理&#xff0c;即将数组中的元素缩放到一个指定的范围或具有一个特定的标准&#xff08;如均值和标准差&#xff09;。它有两个原型函数, 如下: Normalize()规范化数组的范数或值范围。当normTypeNORM…

拾色器的取色的演示

前言 今天&#xff0c;有一个新新的程序员问我&#xff0c;如何确定图片中我们需要选定的颜色范围。一开始&#xff0c;我感到对这个问题很不屑。后来&#xff0c;想了想&#xff0c;还是对她说&#xff0c;你可以参考一下“拾色器”。 后来&#xff0c;我想关于拾色器&#…

C++ std::any升级为SafeAny

std::any测试 #include <any>class A { public:int8_t a; };int main(int argc, char* argv[]) {std::any num((int8_t)42);auto a std::any_cast<A>(num);return 0; }异常&#xff1a; 0x00007FFA9385CD29 处(位于 test.exe 中)有未经处理的异常: Microsoft C 异…

通信工程学习:什么是NFVO网络功能虚拟化编排器

NFVO&#xff1a;网络功能虚拟化编排器 NFVO&#xff08;Network Functions Virtualization Orchestrator&#xff09;&#xff0c;即网络功能虚拟化编排器&#xff0c;是网络功能虚拟化&#xff08;NFV&#xff09;架构中的核心组件之一。NFV是一种将传统电信网络中的网络节点…

Health Check

强大的自愈能力是Kubernetes这类容器编排引擎的一个重要特性&#xff0c;自愈的默认实现方式是自动重启发生故障的容器&#xff0c;除此之外&#xff0c;用户还可以利用Liveness和Readiness探测机制设置更精细的健康检查&#xff0c;进而实现如下需求&#xff1a; 零停机部署避…

c++优先队列priority_queue(自定义比较函数)

c优先队列priority_queue&#xff08;自定义比较函数&#xff09;_c优先队列自定义比较-CSDN博客 373. 查找和最小的 K 对数字 - 力扣&#xff08;LeetCode&#xff09; 官方题解&#xff1a; class Solution { public:vector<vector<int>> kSmallestPairs(vecto…

开源UNI-SOP云统一认证平台

今天给大家分享一款开源的商用级别认证平台UNI-SOP&#xff0c;这块软件分为开源版本和专业版本&#xff0c;由于专业版涉及到一些代码授权问题&#xff0c;暂时未开源&#xff0c;不过&#xff0c;一般应用开源版本足够了。 先来看看系统管理平台界面&#xff0c;然后我们再来…

[OPEN SQL] SELECT语句

本次操作使用的数据库表为SCUSTOM&#xff0c;其字段内容如下所示 航班用户(SCUSTOM) 1.SELECT语句 SELECT语句从数据库表中读取必要的数据 1.1 读取一行数据 语法格式 SELECT SINGLE <cols>... WHERE cols&#xff1a;数据库表的字段 从数据库表中读取一条数据可使…

[数据结构]动态顺序表的实现与应用

文章目录 一、引言二、动态顺序表的基本概念三、动态顺序表的实现1、结构体定义2、初始化3、销毁4、扩容5、缩容5、打印6、增删查改 四、分析动态顺序表1、存储方式2、优点3、缺点 五、总结1、练习题2、源代码 一、引言 想象一下&#xff0c;你有一个箱子&#xff08;静态顺序…

武汉大学首个人形机器人来了!

B站&#xff1a;啥都会一点的研究生公众号&#xff1a;啥都会一点的研究生 AI圈又发生了哪些新鲜事&#xff1f; 武汉大学展示首个人形机器人“天问”&#xff1a;1.7米高&#xff0c;65公斤重&#xff0c;36个自由度 武汉大学近日展示了其首个人形机器人“天问”&#xff0…

屏幕演示工具 | 水豚鼠标助手 v1.0.7

水豚鼠标助手是一款功能强大的屏幕演示工具&#xff0c;专为Windows 10及以上系统设计。这款软件提供了多种实用功能&#xff0c;旨在增强用户的屏幕演示体验&#xff0c;特别适合教师、讲师和需要进行屏幕演示的用户。鼠标换肤&#xff1a;软件提供多种鼠标光标样式&#xff0…

国庆出行新宠:南卡Pro5骨传导耳机,让旅途不再孤单

国庆长假即将来临&#xff0c;对于热爱旅行和户外运动的朋友们来说&#xff0c;一款适合旅行使用的耳机无疑是提升旅途体验的神器。今天&#xff0c;我要向大家推荐一款特别适合国庆出行的耳机——南卡Runner Pro5骨传导耳机。作为一名热爱旅游的体验者&#xff0c;我强烈推荐南…

2024年主流前端框架的比较和选择指南

在选择前端框架时&#xff0c;开发者通常会考虑多个因素&#xff0c;包括框架的功能、性能、易用性、社区支持和学习曲线等。以下是一些主流前端框架的比较和选择指南。 1. 主流前端框架简介 React 优点: 组件化开发&#xff0c;易于复用和维护。虚拟DOM提高了性能。强大的生…