10-pytorch-完整模型训练

news2024/11/15 19:51:22

b站小土堆pytorch教程学习笔记

一、从零开始构建自己的神经网络

1.模型构建
#准备数据集
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

from model import *
from torch.utils.data import DataLoader

train_data=torchvision.datasets.CIFAR10('dataset',train=True,
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True)
test_data=torchvision.datasets.CIFAR10('dataset',train=False,
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True)
#查看训练数据集和测试集大小
train_data_size=len(train_data)
test_data_size=len(test_data)
print('训练数据集长度为:{}'.format(train_data_size))#训练数据集长度为:50000
print('测试数据集长度为:{}'.format(test_data_size))#测试数据集长度为:10000

#利用datalo加载数据集
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)

#搭建神经网络,在model文件中搭建网络,在此文件中引用
han=Han()

#损失函数
loss_fn=nn.CrossEntropyLoss()

#优化器
# learning_rate=0.01
learning_rate=1e-2
optimizer=torch.optim.SGD(han.parameters(),lr=learning_rate)

#设置训练网络的相关参数
total_train_step = 0#记录训练的次数
total_test_step = 0#记录测试的次数
epoch=10#训练轮数

#添加tensorboard
writer=SummaryWriter('logs/train')

for i in range(10):
    print('-------第{}轮训练开始-------'.format(i+1))

    for data in train_dataloader:
        imgs,target=data
        output=han(imgs)
        loss=loss_fn(output,target)

        #优化器优化模型
        optimizer.zero_grad()#梯度清零
        loss.backward()#反向传播计算梯度
        optimizer.step()#参数优化

        total_train_step=total_train_step+1
        if total_train_step % 100==0:#逢100打印
            print('训练次数:{},loss:{}'.format(total_train_step,loss.item()))#loss.item()取出tensor类型的数字
            writer.add_scalar('train_loss',loss.item(),total_train_step)

    #每训练完一轮将在测试集上跑一遍,评估其训练效果
    total_test_loss=0
    with torch.no_grad():
        for data in test_dataloader:
            imgs,target=data
            output=han(imgs)
            loss=loss_fn(output,target)
            total_test_loss=total_test_loss+loss.item()

    print('所有测试集上的损失:{}'.format(total_test_loss))
    writer.add_scalar('test_loss',total_test_loss,total_test_step)
    total_test_step+=1

    #保存每一轮模型
    torch.save(han,'han_{}.pth'.format(i))
    print('模型已保存')
writer.close()
import torch
from torch import nn


class Han(nn.Module):
    def __init__(self):
        super(Han, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

if __name__ == '__main__':
    han=Han()
    input=torch.ones(64,3,32,32)
    output=han(input)
    print(output.shape)#torch.Size([64, 10])10表示十个类别输出概率

结果如下:
在这里插入图片描述

2.使用argmax计算整体正确率
#每训练完一轮将在测试集上跑一遍,评估其训练效果
    total_test_loss=0
    total_acc=0
    with torch.no_grad():
        for data in test_dataloader:
            imgs,target=data
            output=han(imgs)
            loss=loss_fn(output,target)
            total_test_loss=total_test_loss+loss.item()

            acc=(output.argmax(1)==target).sum()#(1)横着看
            total_acc+=acc
    print('所有测试集上的损失:{}'.format(total_test_loss))
    print('整体测试集上的正确率:{}'.format(total_acc/test_data_size))
    writer.add_scalar('test_loss',total_test_loss,total_test_step)
    writer.add_scalar('test_acc', total_acc/test_data_size, total_test_step)
    total_test_step+=1

整体测试集上的正确率:0.27480000257492065

3.当训练或测试时存在dropout层或batch normal层,则需要在训练训练和测试前加入:
#训练前
han.train()
#测试前
han.eval()

二、使用GPU

网络模型、数据(输入、标注)、损失函数调用cuda()

1.方式1
#模型
if torch.cuda.is_available():
    han=han.cuda()
#损失函数
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.cuda()
imgs,target=data
imgs=imgs.cuda()
target=target.cuda()
2.方式2
#定义训练设备
device=torch.device('cuda')
han=han.to(device)
imgs = imgs.to(device)
target = target.to(device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1468996.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【深度学习】Pytorch 教程(十一):PyTorch数据结构:4、张量操作(2):索引和切片操作

文章目录 一、前言二、实验环境三、PyTorch数据结构1、Tensor(张量)1. 维度(Dimensions)2. 数据类型(Data Types)3. GPU加速(GPU Acceleration) 2、张量的数学运算1. 向量运算2. 矩阵…

企业如何定制化“可靠的”系统,实现数字化转型?

二十大提出高质量发展是首要任务,为顺应数字经济时代的发展,数字化转型正不断赋能各行各业。越来越多的企业管理者也意识到数字化转型是帮助企业提升内部运营效率,提升业务开展效率,减低企业成本的有效手段。 那么如何推动企业数字…

linux前端部署

安装jdk 配置环境变量 刷新配置文件 source profile source /etc/profile tomcat 解压文件 进去文件启动tomcat 开放tomcat的端口号 访问 curl localhsot:8080 改配置文件 改IP,改数据库名字,密码, 安装数据库 将war包拖进去 访问http:…

wpf 3d 后台加载模型和调整参数

下载了一个代码,加载obj模型;它的参数在xaml里,模型加载出来刚好; 然后加载另一个obj模型;加载出来之后大,偏到很高和左的位置; 它之前的摄像机位置, Position"9.94759830064…

橘子学es原理01之准备工作

es本身是具备很好的使用特性的,我指的是他的部署方面的,至于后期的使用和运维那还是很一眼难尽的。 我们从这一篇开始就着重于es的一些原理性的的一些探讨,当然我们也会有一些操作性的,业务性的会分为多个栏目来写。比如前面我写的…

AutoSAR(基础入门篇)10.8-实验:模式管理

目录 一、配置BswM 二、配置唤醒源 三、配置ComM Users 四、配置BswM的通信控制 五、Service Mapping 首先备份上一次的工程,养成好习惯(最好还是用Git,这次最后再安利一下Git这个神器)。今天的实验异常的简单,基…

SpringMVC 学习(三)之 @RequestMapping 注解

目录 1 RequestMapping 注解介绍 2 RequestMapping 注解的位置 3 RequestMapping 注解的 value 属性 4 RequestMapping 注解的 method 属性 5 RequestMapping 注解的 params 属性(了解) 6 RequestMapping 注解的 headers 属性(了解&…

CSS三大定位方式(浮动、定位、弹性盒)详细解析

CSS三大定位方式 前言:作为一名前端开发,已经工作2年了。由于自己是半路出家,从嵌入式方向转到前端开发,都是边百度边开发,很多基础都不了解,只要解决问题就好,但是近来为了让自己知识体系化&a…

【数据结构(顺序表)】

一、什么是数据结构? 数据结构是由“数据”和“结构”两词组合而来。 什么是数据?常见的数值1、2、3、4.....、教务系统里保存的用户信息(姓名、性别、年龄、学历等等)、网页里肉眼可以看到的信息(文字、图片、视频等等&#xff…

Yolo v9 “Silence”模块结构及作用!

论文链接:👿 YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information 代码链接:👿 https://github.com/WongKinYiu/yolov9/tree/main Silence代码 class Silence(nn.Module):def __init__(self):supe…

Mysql运维篇(四) MHA

大佬博文 https://www.cnblogs.com/gomysql/p/3675429.html MySQL 高可用(MHA) - 知乎 一、MHA简介: MHA(Master High Availability)目前在MySQL高可用方面是一个相对成熟的解决方案,它由日本DeNA公司y…

【Linux进程】进程状态---进程僵尸与孤儿

📙 作者简介 :RO-BERRY 📗 学习方向:致力于C、C、数据结构、TCP/IP、数据库等等一系列知识 📒 日后方向 : 偏向于CPP开发以及大数据方向,欢迎各位关注,谢谢各位的支持 目录 1.进程排队2.进程状态…

uni-app nvue vue3 setup中实现加载webview,解决nvue中获取不到webview实例的问题

注意下面的方法只能在app端使用, let wv plus.webview.create("","custom-webview",{plusrequire:"none", uni-app: none, width: 300,height:400,top:uni.getSystemInfoSync().statusBarHeight44 }) wv.loadURL("https://ww…

Sentinel微服务流量治理组件实战上

目录 分布式系统遇到的问题 解决方案 Sentinel 是什么? Sentinel 工作原理 Sentinel 功能和设计理念 流量控制 熔断降级 Sentinel工作主流程 Sentinel快速开始 Sentinel资源保护的方式 基于API实现 SentinelResource注解实现 Spring Cloud Alibaba整合…

《隐私计算简易速速上手小册》第7章:隐私计算与云计算/边缘计算(2024 最新版)

文章目录 7.1 云计算中的隐私保护7.1.1 基础知识7.1.2 主要案例:使用 Python 实现云数据的安全上传和访问7.1.3 拓展案例 1:实现基于角色的访问控制7.1.4 拓展案例 2:使用 Python 保护 API 安全7.2 边缘计算的隐私问题7.2.1 基础知识7.2.2 主要案例:使用 Python 实现边缘设…

halcon中的一维测量

一维测量 像点到点的距离,边缘对的距离等沿着一维方向的测量都属于1D测量范畴。Halocn的一维测量首先构建矩形或者扇形的ROI测量对象,然后在ROI内画出等距离的、长度与ROI宽度一致的、垂直于ROI的轮廓线(profile line)的等距线。…

Jenkins解决Host key verification failed (2)

Jenkins解决Host key verification failed 分析原因情况 一、用OpenSSH的人都知ssh会把你每个你访问过计算机的公钥(public key)都记录在~/.ssh/known_hosts。当下次访问相同计算机时,OpenSSH会核对公钥。如果公钥不同,OpenSSH会发出警告,避免…

Java 学习和实践笔记(20):static的含义和使用

static的本义是静止的。在计算机里就表示静态变量。 在Java中,从内存分析图上可以看到,它与类、常量池放在一个区里: 从图可以看到,普通的方法和对象属性,都在heep里,而static则在方法区里。 static声明的…

Linux第65步_学习“Makefie”

学习“Makefie”,为后期学习linux驱动开发做铺垫。 1、在“/home/zgq/linux/atk-mp1”创建一个“Test_MakeFile”目录用于学习“Makefie”。 打开终端 输入“cd /home/zgq/linux/回车”,切换到“/home/zgq/linux/”目录 输入“mkdir Linux_Drivers回…

【AUTOSAR】--02 AUTOSAR网络管理相关参数

这是AUTOSAR网络管理梳理的第二篇文章,主要讲解AUTOSAR网络管理的相关参数。第一篇链接【01 AUTOSAR网络管理基础】。​ 相关参数有很多,我挑了一些相对重要的参数,分三部分进行讲解: 第一部分:比较常用&#xff0c…