UNet - 训练数据train

news2024/11/15 13:35:31

目录

1. train 训练数据

2. Loss 值

3. 完整代码


1. train 训练数据

训练的代码只是在之前图像分类的基础上做了一些更改,具体的可以看下面的文章

pytorch 搭建 LeNet 网络对 CIFAR-10 图片分类https://blog.csdn.net/qq_44886601/article/details/127498256

首先,导入之前定义的UNet 网络

然后,加载训练集和测试集

因为加载数据集被重写过,所以这里传入的是训练的图像,然后根据里面的replace就能找到对应的标签

这里训练的时候可以将数据打乱,测试的时候没有必要,batch_size 因为电脑硬件的问题设置成2,再大的话这里内存就会不够了

 

然后定义优化器和损失函数,这里用的是BCE加上sigmoid的损失函数

训练的时候,要将模式改为train模式,然后训练的步骤很常规

梯度清零->前向传播->计算损失函数->反向传播->更新参数

 

这里测试的时候有些区别

因为这里UNet 网络的输出是一幅图像,而之前将label改为了二值图像(归一化后是0 1)。所以这里计算准确率的时候,将预测的图像也变为二值图像,计算准确率用的是对应图像像素点的灰度值是否相等的方法

 

最后保留最好准确率的那个参数就行了

 

2. Loss 值

这是跑了20 个epoch的输出

 

3. 完整代码

from model import UNet                  # 导入Unet 网络
from dataset import Data_Loader         # 数据处理
from torch import optim
import torch.nn as nn
import torch


# 网络训练模块
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # GPU or CPU
print(device)
net = UNet(in_channels=1, num_classes=1)                                # 加载网络
net.to(device)                                                          # 将网络加载到device上

# 加载训练集
train_path = "./data/train/image"
trainset = Data_Loader(train_path)
train_loader = torch.utils.data.DataLoader(dataset=trainset,batch_size=2,shuffle=True)

# len(trainset)  样本总数:21

# 加载测试集
test_path = "./data/test/image"
testset = Data_Loader(test_path)
test_loader = torch.utils.data.DataLoader(dataset=testset,batch_size=2)

optimizer = optim.RMSprop(net.parameters(),lr = 0.000001,weight_decay=1e-8,momentum=0.9)     # 定义优化器
criterion = nn.BCEWithLogitsLoss()                                                           # 定义损失函数

save_path = './UNet.pth'        # 网络参数的保存路径
best_acc = 0.0                  # 保存最好的准确率


for epoch in range(20):

    net.train()     # 训练模式
    running_loss = 0.0

    for image,label in train_loader:                   # 读取数据和label

        optimizer.zero_grad()                          # 梯度清零
        pred = net(image.to(device))                   # 前向传播

        loss = criterion(pred, label.to(device))       # 计算损失
        loss.backward()                                # 反向传播
        optimizer.step()                               # 梯度下降

        running_loss += loss.item()                    # 计算损失和

    net.eval()  # 测试模式
    acc = 0.0   # 正确率
    total = 0
    with torch.no_grad():
        for test_image, test_label in test_loader:

            outputs = net(test_image.to(device))     # 前向传播

            outputs[outputs >= 0] = 1  # 将预测图片转为二值图片
            outputs[outputs < 0] = 0

            acc += (outputs == test_label.to(device)).sum().item() / (480*480)     # 计算预测图片与真实图片像素点一致的精度:acc = 相同的 / 总个数
            total += test_label.size(0)

    accurate = acc / total  # 计算整个test上面的正确率
    print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f %%' %
          (epoch + 1, running_loss, accurate*100))

    if accurate > best_acc:     # 保留最好的精度
        best_acc = accurate
        torch.save(net.state_dict(), save_path)     # 保存网络参数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/12964.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

EventBridge 生态实践:融合 SLS 构建一体化日志服务

作者&#xff1a; 昶风 引言 阿里云日志服务 SLS 是一款优秀的日志服务产品&#xff0c;提供一站式地数据采集、加工、查询与分析、可视化、告警、消费与投递等服务。对于使用 SLS 的用户业务而言&#xff0c;SLS 上存储的日志信息反映着业务的运行状态&#xff0c;通过适当地…

2021年认证杯SPSSPRO杯数学建模D题(第一阶段)停车的策略全过程文档及程序

2021年认证杯SPSSPRO杯数学建模 D题 停车的策略 原题再现&#xff1a; 开车前往人流集中的目的地时&#xff0c;决定在何处停车经常是一个难题。是停在距离目的地较远的地方&#xff0c;因为那里的空余车位可能较多&#xff0c;然后再走很远的路&#xff1f;或者是否应该乐观…

【C语言】程序的翻译环境和执行环境

&#x1f6a9;write in front&#x1f6a9; &#x1f50e;大家好&#xff0c;我是謓泽&#xff0c;希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流&#x1f50e; &#x1f3c5;2021年度博客之星物联网与嵌入式开发TOP5&#xff5…

智慧博物馆解决方案-最新全套文件

智慧博物馆解决方案-最新全套文件一、建设背景二、思路架构三、解决方案建成5个方面1、集约化2、物联网接入3、大数据可视化分析4、室内室外地图集成5、可视化信息多元交互四、获取 - 智慧博物馆全套最新解决方案合集一、建设背景 博物馆是征集、典藏、陈列和研究代表自然和人…

【FME实战教程】002:FME完美实现CAD数据转shp案例教程(以三调土地利用现状数据为例)

FME完美实现CAD数据转shp案例教程&#xff08;以三调土地利用数据为例&#xff09; 文章目录1. cad数据预览2. 转换过程3. shp数据预览1. cad数据预览 2. 转换过程 &#xff08;1&#xff09;打开FME Desktop2020中文软件&#xff0c;点击【新建】。 &#xff08;2&#xff09…

【Spring】——2、使用@ComponentScan自动扫描组件并指定扫描规则

&#x1f4eb;作者简介&#xff1a;zhz小白 公众号&#xff1a;小白的Java进阶之路 专业技能&#xff1a; 1、Java基础&#xff0c;并精通多线程的开发&#xff0c;熟悉JVM原理 2、熟悉Java基础&#xff0c;并精通多线程的开发&#xff0c;熟悉JVM原理&#xff0c;具备⼀定的线…

微信小程序开发(九):使用扩展组件库

前端开发中离不开各种组件库&#xff0c;我最先接触的组件库还是Bootstrap&#xff0c;后来工作中又陆续使用了inoic、ng-zorro等各种不同的库。 在微信小程序开发中也有多种组件库&#xff0c;这里记录其中几种不同组件库的使用方法。 WeUI 这是微信官方推出的一款和微信原…

使用Python,Open3D对点云散点投影到面上并可视化,使用3种方法计算面的法向量及与平均法向量的夹角

使用Python&#xff0c;Open3D对点云散点投影到面上并可视化&#xff0c;使用3种方法计算面的法向量及与平均法向量的夹角 写这篇博客源于博友的提问&#xff0c;他坚定了我继续坚持学习的心&#xff0c;带给了我充实与快乐。 将介绍以下5部分&#xff1a; 随机生成点云点投影…

LaTeX学习笔记

LaTeX学习笔记 文章目录LaTeX学习笔记1. 开始的尝试2.文档类与宏包3.标题与章节4.标注5.列表6.对齐7.插入代码块8.绘制表格9.插入图片10.数学公式10.1.基础公式10.2.复杂公式10.3 常用符号11.参考文献冲鸭&#xff01;&#xff01;&#xff01; 1. 开始的尝试 先开始试一下一个…

MySQL数据库索引和事务详解

目录 前言&#xff1a; 索引 查看索引 创建索引 删除索引 索引使用 底层数据结构分析 事务 事务引出 MySQL设计事务 事务四大特性 小结&#xff1a; 前言&#xff1a; 数据库索引和事务的存在&#xff0c;对于数据库的一些性能有了显著提升。我们需掌握其底层的实现…

NUMA那些事儿

NUMA——Non Uniform Memory Access&#xff0c;中文为非统一内存访问&#xff0c;在NUMA出现之前&#xff0c;内存的控制器是包含在北桥芯片中的&#xff0c;所有内存由北桥统一管理&#xff0c;因此可以保证访问内存的一致性。随着CPU架构的不断迭代和演进&#xff0c;核数越…

Elasticsearch与Kibana安装

现有环境 windows docker ubuntu Elasticsearch安装 安装包下载 ES不同平台、版本下载路径&#xff1a;Download Elasticsearch | Elastic 本文演示用linux # 启动ubuntu环境&#xff0c;开放端口9200、9300、5601 docker run -name es -p 9200:9200 -p 9300:9300 -p 5…

指夹式血氧饱和检测仪方案分析

指夹式心率血氧饱和度方案的测量原理是根据血红蛋白(Hb)和氧合血红蛋白 (HbO2)在红光和近红光区域的吸收光谱特性为依据&#xff0c;运用Lambert Beer定律建立数据处理经验公式&#xff0c;采用光电血氧检测技术结合光电容积脉搏波描记&#xff08;PPG&#xff09;技术&#xf…

化工制造行业数字化升级案例—基于HK-Domo商业智能分析工具

案例背景导读 世伟洛克&#xff08;Swagelok&#xff09;是全球领先的流体系统解决方案的开发商和制造商&#xff0c;为包括科研、仪表、制药、油气、电力、石化、代用燃料和半导体等在内的各个行业提供产品、组装和服务。世伟洛克通过独立的销售和服务中心网站进行运营&#x…

使用 Typescript 封装 Axios

对 axios 二次封装,更加的可配置化、扩展性更加强大灵活 通过 class 类实现&#xff0c;class 具备更强封装性(封装、继承、多态)&#xff0c;通过实例化类传入自定义的配置 创建 class 严格要求实例化时传入的配置&#xff0c;拥有更好的代码提示 /*** param {AxiosInstance…

C语言习题练习8--二进制操作符

IO型--从main函数开始写&#xff0c;要写输入、计算、输出 接口型--不需要写主函数&#xff0c;默认主函数是存在的&#xff0c;你只需要完成函数就行 一、二进制中1的个数 (12条消息) C语言丨关键字signed和unsigned 的使用与区别详解_Emily-C的博客-CSDN博客_signed unsi…

【笔记】samba shell 脚本 离线安装 - Ubuntu 20.04

前言 按照官网调试代码、网上各种步骤来走&#xff08;还收费&#xff09;都不行 结果发现是防火墙问题 公司服务器安装的ufw使用失效&#xff0c;导致端口号放行添加失败 换用firewall-cmd成功 现在免费放下代码&#xff0c;气死他们收费的 目录 ├── home│ ├── k…

linux备份mysql8.0数据库脚本

文章目录环境要求步骤1、创建一个.sh文件编写shell脚本2、添加定时任务环境要求 linux系统&#xff0c;安装了mysql8.0 步骤 1、创建一个.sh文件编写shell脚本 创建文件的命令&#xff1a; vim ***.shshell文件文件参考自文章 链接 export LANGen_US.UTF-8 #注意&#xf…

测试开发技术:Python测试框架Pytest的基础入门

Pytest简介 Pytest is a mature full-featured Python testing tool that helps you write better programs.The pytest framework makes it easy to write small tests, yet scales to support complex functional testing for applications and libraries. 通过官方网站介绍…

十五、Lua 协同程序(coroutine)的学习

Lua 协同程序(coroutine) 什么是协同(coroutine)&#xff1f; Lua 协同程序(coroutine)与线程比较类似&#xff1a;拥有独立的堆栈&#xff0c;独立的局部变量&#xff0c;独立的指令指针&#xff0c;同时又与其它协同程序共享全局变量和其它大部分东西。 协同是非常强大的功…