UNet网络训练

news2024/11/18 6:17:13

UNet网络训练

训练资源

构建好UNet网络模型后,需要进行训练。但是训练需要特别多的原始图像和标签图像,对于一般而言这一步特别繁琐,不过在网上有一些免费的数据集可以让我们省略这一步,直接进行训练测试。

VOC(Visual Object Classes)数据集是一个广泛使用的计算机视觉数据集,主要用于目标检测、图像分割和图像分类等任务。VOC数据集最初由英国牛津大学的计算机视觉小组创建,并在PASCAL VOC挑战赛中使用。

VOC数据集包含各种不同类别的标记图像,每个图像都有与之相关联的边界框(bounding box)和对象类别的标签。数据集中包括了20个常见的目标类别,例如人、汽车、猫、狗等。此外,VOC数据集还提供了用于图像分割任务的像素级标注。

VOC数据集涵盖了多个年度的发布,每个年度的数据集包含训练集、验证集和测试集。训练集用于模型的训练和参数优化,验证集用于模型的调参和性能评估,而测试集则用于最终模型的性能评估和比较。

VOC数据集:https://host.robots.ox.ac.uk/pascal/VOC/voc2007/

image-20230925130121998

image-20230925130229417

一般目标检测只需用到Annotations、ImageSets、JPEGImages这3个文件夹,剩下的可以删掉。

Annotations:存放xml格式的标注文件

JPEGImages:该文件夹存储了 VOC 数据集中的图像数据。

ImageSets:该文件夹包含了几个用于数据集划分和评估的文本文件。

SegmentationClass:包含了每个图像像素的语义类别标注信息。

训练

创建一个Dataloader对象用来加载自定义数据集。

进入训练循环,每个 epoch 遍历一遍数据集。对于每个 batch,将输入数据和真实标签拷贝到计算设备上,并进行前向推理得到输出结果。计算输出结果与真实标签之间的交叉熵损失,并计算梯度并反向传播。每隔一定的周期,打印当前的训练损失并保存模型权重。另外,在每个 batch 中,将输入、真实标签和输出结果合并成一张图像,保存到指定的路径中,用于观察训练效果。

from torch import nn,optim # 优化器
import torch 
from torch.utils.data import DataLoader # 用于加载自定义数据集类
from data import * # 导入自定义类
from net import * # 导入UNet网络模型
from torchvision.utils import save_image # 导入保存图像方法


# 如果有cuda,就用;否则就用cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 权重 用于存储和加载训练好的Unet深度学习模型的权重或参数。".pth" 是PyTorch模型文件的命名约定。
weight_path = "params/unet.pth"
# 数据集路径
data_path = r"E:\Undergraduate\School\Scientific_research\ML\Machine_Learning\lab\图形分割\Unet模型总\VOCdevkit\VOC2012"
# 训练结果图像保存路径
save_path = "train_image"

# 在主程序中运行
if __name__ == "__main__":
    # 创建数据加载器对象,MyDataset 是自定义的数据集类,用于加载训练数据,batch_size=2 表示每次训练使用的图像数量为 2。
    data_loader = DataLoader(MyDataset(data_path), batch_size=2, shuffle=True)
    # 实例化UNet网络模型 通过 .to(device) 将模型移动到指定的设备上。
    net = UNet().to(device)

    # 检查是否存在预训练的模型权重文件,如果存在则加载权重到模型中,否则输出提示信息。
    if os.path.exists(weight_path):
        net.load_state_dict(torch.load(weight_path))
        print("Successful load weight!")
    else:
        print("Not successful load weight")
    
    # 创建优化器和损失函数对象
    opt = optim.Adam(net.parameters())
    loss_fun = nn.BCELoss()

    # 设置起始训练轮数,并开始训练
    epoch = 1
    while True:
        # 遍历数据加载器中的每个批次,将图像数据和分割图像数据移动到指定设备上。
        for i, (image, segment_image) in enumerate(data_loader):
            image, segment_image = image.to(device), segment_image.to(device)

            # 前向传播计算网络输出结果,并计算训练损失。
            out_image = net(image)
            train_loss = loss_fun(out_image, segment_image)

            # 梯度清零,反向传播计算梯度。
            opt.zero_grad()
            train_loss.backward()

            # 隔一段时间进行打印信息
            if i%5 == 0:
                print(f"{epoch} {i} - train_loss ==>{train_loss.item()}")

            if i%50 == 0:
                torch.save(net.state_dict(), weight_path)
            
            # 从批次中取出第一张图像、分割图像和网络输出结果。将图像、分割图像和网络输出结果按顺序堆叠,并保存为图像文件。
            _image = image[0]
            _segment_image = segment_image[0]
            _out_image = out_image[0]
            
            img = torch.stack([_image, _segment_image, _out_image],dim = 0)
            save_image(img, f"{save_path}/{i}.png")
            
        epoch += 1

在这里插入图片描述

UNet网络测试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1039887.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Centos环境使用Docker安装Kafka

1 Kafka简介 1、kafka是什么? Kafka是一种高吞吐量的分布式发布订阅消息系统,它可以处理消费者规模的网站中的所有动作流数据,具有高性能、持久化、多副本备份、横向扩展能力。 2、kafka的工作原理[去耦合] Kafka采用的是订阅-发布的模式&am…

Android应用线上闪退问题解决

解决Android应用线上闪退问题需要仔细的监控、调试和分析。以下是一些解决Android线上闪退问题的工具和方法,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包开发公司,欢迎交流合作。 工具: 1.Google Play 控制台&…

anaconda navigator启动时一直卡在 loading applications 页面

anaconda navigator启动时一直卡在 loading applications 页面 方法1 在安装目录找到D:\anaconda\Lib\site-packages\anaconda_navigator\api 然后打开conda_api.py, 在1358行找到data yaml.load(f),将其改为data yaml.safeload(f) 猜测为保证代码…

精准对接促合作:飞讯受邀参加市工信局举办的企业供需对接会

2023年9月21日,由惠州市工业和信息化局主办的惠州市工业软件企业与制造业企业供需对接会成功举办,对接会旨在促进本地工业软件企业与制造业企业的紧密合作,推动数字化转型的深入发展。此次会议在市工业和信息化局16楼会议室举行,会…

【校招VIP】产品行测之逻辑计算题

考点介绍: 数理逻辑包括对于统计学有基础的了解,有基础的数据敏感性,拥有从数据层层深挖定位到问题的能力。知道先验概率,置信度,归因方法等基础的统计学概念。作为产品经理都应该去理解这些逻辑,并且思考如…

DirectX12学习笔记-创建窗口

创建窗口就是纯的Win API,我设想的窗口是这样的: 我们调用WinMain启动窗口,然后在WinMain初始化和启动消息循环。 消息会传入OnEvent, WndProc是窗口过程函数(每个窗口都有一个WndProc函数,用于接收和处理窗口相关的…

yolov8训练自己的数据集(标注到训练)

yolov8可以用作目标检测,分割,姿态,跟踪。这里举例目标检测从标注到训练的过程。 官网连接 先把代码下载下来,这个不用说了。 然后准备数据集,创建一个文件夹dataset(自己命名),下面…

m1芯片-centos安装mysql

在m1芯片中,虚拟机centos7使用mysql官方的yum源安装mysql没问题,但是在启动mysql的时候会报错,从日志上看是硬件问题,报错信息为 Most likely, you have hit a bug, but this error can also be caused by malfunctioning hardwar…

OpenCV项目开发实战--主成分分析(PCA)的特征脸应用(附C++/Python实现源码)

什么是主成分分析? 这是理解这篇文章的先决条件。 图 1:使用蓝线和绿线显示 2D 数据的主要组成部分(红点)。 快速回顾一下,我们了解到第一个主成分是数据中最大方差的方向。第二主成分是空间中与第一主成分垂直(正交)的最大方差方向,依此类推。第一和第二主成分红点(2…

【周赛364-数组】美丽塔 I-力扣 2865

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kuan 的首页,持续学…

SpringBoot统一返回处理和全局异常处理

统一接口返回 前后端分离项目,通常后端会返回给前端统一的数据格式,一般包括code,msg,data信息。 创建返回统一实体类 package com.example.exceptionspring.domain;import lombok.Data;Data public class Result {private Integer code;private Strin…

基于微信小程序的校园二手交易平台设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言学生的主要功能有:管理员的主要功能有:具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序(小蔡coding)有保障的售后福利 代码参考源码获取 前言 💗博主介绍:✌全网粉丝10W…

别问怎么下载,金蝶云星空SaaS BI系统不用下载

国产自研的奥威软件-金蝶云星空SaaS BI,不下载不安装,从浏览器上一键注册登录即可使用:一键点击下载金蝶云星空方案,执行后,BI系统将基于金蝶云星空内的数据与方案自带的BI报表,智能计算分析指标&#xff0…

python模拟斐波那契数列输出

用户输入指定的数列范围,正确输出结果。 源代码: def fiebo(n): a 1 b 1 for i in range(n): if i 0: print(a, end" ") elif i 1: print(b, end" ") else: …

yolov8模型训练遇到的问题

训练时有一种报错:no labels found in xxx.cache 首先要确定我们的图像,标签文件夹内容无误。检查完后如果还不行,就看看训练用到的东西,比如dataset.py,部分代码如下: def get_labels(self):""…

中国社科院大学-美国杜兰大学金融管理硕士暨能源管理硕士项目2023年毕业典礼

中国社科院大学-美国杜兰大学金融管理硕士暨能源管理硕士项目2023年毕业典礼 2023年9月16日,中国社会科学院大学-美国杜兰大学金融管理硕士项目暨能源管理硕士项目2023年毕业典礼在我校望京校区成功举办。 张波副校长致辞 中国社会科学院大学副校长张波教授、杜兰大…

求职应聘找工作的同学,在线测评怎么过?

信息时代,越来越多的公司在招聘时引入了人才测评机制。企业和单位希望通过人才测评在广大的应聘者中,找到符合自己要求的人才。虽然很多应聘者能力和简历都比较出众,但却在最开始的人才测评中吃了亏。有的公司很看重人才测评结果。测评就相当…

多模态大模型微调记录

VisualGLMhttps://github.com/THUDM/VisualGLM-6Bhttps://github.com/THUDM/VisualGLM-6B 清华大学开源的多模态大模型,具有62亿参数的中英双语言模型 基本思路: 1 通过中间模块(Qformer)构建起预训练视觉和语言的桥梁 2 中英…

C++文件交互实践:职工管理系统

管理系统需求 实现一个基于多态的职工管理系统 创建管理类 管理类负责内容&#xff1a; 与用户的沟通菜单界面对职工增删改查的操作与文件的读写交互 文件交互 -- 写文件 void workerManger::save() {ofstream ofs;ofs.open(FILENAME, ios::out);for (int i 0; i < th…

TP-LINK设备在防视频监控EasyCVR平台上无法使用语音对讲功能该如何解决?

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台可拓展性强、视频能力灵活、部署轻快&#xff0c;可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等&#xff0c;以及支持厂家私有协议与SDK接入&#xff0c;包括海康Ehome、海大宇等设备的SDK等。平台既具备传统安…