在pytroch中使用CIFAR10完成完整的模型训练套路

news2025/1/18 1:54:05

模型训练套路:

  • 1.准备数据集
  • 2.加载数据集
  • 3.搭建神经网络
  • 4创建损失函数
  • 5.优化器
  • 6.设置训练网络的一些参数
  • 7.添加tensorboard(方便观察)
  • 8.开始训练
  • .测试
  • 9.保存神经网络

准备数据

#准备数据集
dataset_train=torchvision.datasets.CIFAR10("./dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)
dataset_test=torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)

加载数据

#加载数据集
loader_train=DataLoader(dataset_train,batch_size=16,drop_last=True,shuffle=False)
loader_test=DataLoader(dataset_test,batch_size=16,drop_last=True,shuffle=False)

搭建神经网络

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.model1=torch.nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self,x):
        return self.model1(x)

# #验证神经网络正确性
#
if __name__=="__main__":
    net = Net()
    input=torch.ones((64,3,32,32))
    output=net(input)
    print(output.shape)

创建损失函数

#创建损失函数
loss_cro=torch.nn.CrossEntropyLoss()

优化器

#创建优化器
optim=torch.optim.SGD(net.parameters(),lr=1e-3)

设置训练网络的一些参数

#设置训练网络的一些参数
total_train_stp=0#总训练次数
total_test_stp=0#总测试次数
epoch=20#训练轮次

添加tensorboard(方便观察)

#tensorboard
writer=SummaryWriter("./end")

开始训练

for i in range(epoch):
    loss=0
    print("-----------第{}轮训练开始--------------".format(i+1))
    for data in loader_train:
        imgs,targets=data
        output=net(imgs)
        loss=loss_cro(output,targets)

        #优化器优化
        optim.zero_grad()
        loss.backward()
        optim.step()

        #记录训练次数
        total_train_stp+=1
        if total_train_stp%100==0:
            print("训练次数:{},loss:{}".format(total_train_stp,loss.item()))
            writer.add_scalar("train_loss",loss.item(),total_train_stp)

测试

#测试
with torch.no_grad():
    total_current_test=0
    for data in loader_test:
        imgs,targets=data
        output=net(imgs)
        loss=loss_cro(output,targets)
        total_test_stp+=1

        writer.add_scalar("test_loss",loss,total_test_stp)
        accuracy=(output.argmax(1)==targets).sum()
        total_current_test+=accuracy
        print("整体测试的正确率为{}".format(total_current_test/(total_test_stp)*10))
        writer.add_scalar("test_accuracy",total_current_test/(total_test_stp)*10,total_test_stp)

保存模型

torch.save(net,"./model3.pth")

总代码

总代码分在两个py文件中,分别是

  • 负责训练模型的train.py
  • 神经网络所在文件network.py

network.py

import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.model1=torch.nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self,x):
        return self.model1(x)

# #验证神经网络正确性
#
if __name__=="__main__":
    net = Net()
    input=torch.ones((64,3,32,32))
    output=net(input)
    print(output.shape)

train.py

import torchvision
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from network import Net


#准备数据集
dataset_train=torchvision.datasets.CIFAR10("./dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)
dataset_test=torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)


#加载数据集
loader_train=DataLoader(dataset_train,batch_size=16,drop_last=True,shuffle=False)
loader_test=DataLoader(dataset_test,batch_size=16,drop_last=True,shuffle=False)

#搭建神经网络
net=Net()


#创建损失函数
loss_cro=torch.nn.CrossEntropyLoss()

#创建优化器
optim=torch.optim.SGD(net.parameters(),lr=1e-3)

#设置训练网络的一些参数
total_train_stp=0#总训练次数
total_test_stp=0#总测试次数
epoch=20#训练轮次

#tensorboard
writer=SummaryWriter("./end")

#开始训练
for i in range(epoch):
    loss=0
    print("-----------第{}轮训练开始--------------".format(i+1))
    for data in loader_train:
        imgs,targets=data
        output=net(imgs)
        loss=loss_cro(output,targets)

        #优化器优化
        optim.zero_grad()
        loss.backward()
        optim.step()

        #记录训练次数
        total_train_stp+=1
        if total_train_stp%100==0:
            print("训练次数:{},loss:{}".format(total_train_stp,loss.item()))
            writer.add_scalar("train_loss",loss.item(),total_train_stp)

#测试
with torch.no_grad():
    total_current_test=0
    for data in loader_test:
        imgs,targets=data
        output=net(imgs)
        loss=loss_cro(output,targets)
        total_test_stp+=1

        writer.add_scalar("test_loss",loss,total_test_stp)
        accuracy=(output.argmax(1)==targets).sum()
        total_current_test+=accuracy
        print("整体测试的正确率为{}".format(total_current_test/(total_test_stp)*10))
        writer.add_scalar("test_accuracy",total_current_test/(total_test_stp)*10,total_test_stp)
torch.save(net,"./model3.pth")
print("模型已经保存")

writer.close()

最终运行截图

image.png
可以看出再训练了20轮后可以达到93%的正确率

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1958531.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

外星人入侵_外星人

项目_外星人入侵_外星人 1创建第一个外星人1.1创建Alien类1.2创建Alien实例1.3让外星人出现在屏幕上 2创建一群外星人2.1确定一行可以容纳多少外星人2.2 创建多行外星人2.3创建外星人群2.4重构create_fleet()2.5添加行 3让外星人群移动3.1向右移动外星人3.2创建表示外星人移动方…

迷你世界魔方模型快速制作

做六个不一样颜色的顶部 --黄,绿,红,蓝,橙,白 --local ids{4000,3999, 3998,3997,3996,3995} 游戏脚本运行上一期文章 local x0,y0,z0-39,7,10--起点坐标 --框架、底面、侧面1-4、顶面 local id{682,671,681,680,66…

消息队列rabbitmq的使用

前提条件:环境安装amqp和安装rabbitmq sudo apt-get update sudo apt-get install rabbitmq-amqp-dev 1、创建CMakeLists.txt文件 # Copyright (c) Huawei Technologies Co., Ltd. 2019. All rights reserved.# CMake lowest version requirement cmake_minimum_…

tof系统标定流程之lens标定

1、lens标定详解 为什么在标定tof时需要进行lens的标定,可以说lens标定是一个必不可少的步骤,tof模组也是有镜头的,镜头的畸变会导致进入的光线出现偏差,最终照射到tof芯片表面导致深度图的分布出现畸变,通常是枕形畸变。例外一个用途在于,在计算fppn误差环节需要知道镜头…

机器学习算法与Python实战 | 两行代码即可应用 40 个机器学习模型--lazypredict 库!

本文来源公众号“机器学习算法与Python实战”,仅用于学术分享,侵权删,干货满满。 原文链接:两行代码即可应用 40 个机器学习模型 今天和大家一起学习使用 lazypredict 库,我们可以用一行代码在我们的数据集上实现许多…

【数据结构】队列(链表实现 + 力扣 + 详解 + 数组实现循环队列 )

Hi~!这里是奋斗的明志,很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~~ 🌱🌱个人主页:奋斗的明志 🌱🌱所属专栏:数据结构 📚本系列文章为个人学…

流行巨星布兰妮·斯皮尔斯发生了什么事?她现在在哪里过得怎么样

流行音乐公主布兰妮斯皮尔斯是 21 世纪初的经典偶像。她从 15 岁起就开始唱歌和表演,并创作了《Oops I Did it Again》和《Baby One More Time》等热门歌曲。她的歌曲非常出色,在 2000 年荣登榜首。她接下来的几张专辑变得更加畅销,她毫不畏惧…

学习008-02-04-04 Enable Split Layout in a List View(在列表视图中启用拆分布局 )

Enable Split Layout in a List View(在列表视图中启用拆分布局 ) This lesson explains how to enable a Split Layout in a List View. 本课介绍如何在列表视图中启用拆分布局。 The Detail View opens when you select an object from the List Vie…

G120 EPos配置方案及应用场景

EPos功能就是基本定位器功能,它可计算出轴的运行特性,使轴以时间最佳的方式移动到目标位置。EPos功能主要包括:设定值 直接给定(MDI)功能、 选择程序段功能、回参考点功能、点动功能、运行到固定挡块功能。 EPos功能通过处理给定的加速度、速度和位置值生成运行特性曲线,…

node+mysql+layui+ejs实现左侧导航栏菜单动态显示

nodemysqllayuiejs实现左侧导航菜单动态显示 实现思路效果图数据库技术栈代码实现main.html(前端首页页面)查询资源菜单方法 jsapp.js配置ejs模板 node入门到入土项目实战开始,前端篇项目适合node小白入门,因为我也是小白来学习no…

机器人笛卡尔空间阻抗控制

机器人笛卡尔空间阻抗控制是一种重要的机器人控制策略,它关注于机器人末端执行器在笛卡尔空间(即任务空间)内的动态特性,以实现与环境的柔顺交互。以下是对机器人笛卡尔空间阻抗控制的详细解释: 一、基本概念 笛卡尔空间:指机器人末端执行器(如手爪、工具等)所处的三维…

Hive之扩展函数(UDF)

Hive之扩展函数(UDF) 1、概念讲解 当所提供的函数无法解决遇到的问题时,我们通常会进行自定义函数,即:扩展函数。Hive的扩展函数可分为三种:UDF,UDTF,UDAF。 UDF:一进一出 UDTF:一进多出 UDAF&#xff1a…

YOLO v8目标检测(三)模型训练与正负样本匹配

YOLO v8目标检测 损失函数理论 在YOLO v5模型中,cls, reg, obj代表的是三个不同的预测组成部分,对应的损失函数如下: cls: 这代表类别预测(classification)。对应的损失是类别预测损失(loss_cls&#xff…

Win10出现错误代码0x80004005 一键修复指南

对于 Windows 10 用户来说,错误代码 0x80004005 就是这样一种迷雾,它可能在不经意间出现,阻碍我们顺畅地使用电脑。这个错误通常与组件或元素的缺失有关,它可能源自注册表的错误、系统文件的损坏,或者是软件的不兼容。…

listener监听

背景: 过滤器代码也可实现接口请求次数统计,但会影响过滤器本意;故在dispatcher servlet层进行监听统计 价值: 所有接口的次数统计可适用于系统全天访问量; 单个请求接口的次数统计可在企业中根据接口次数的高低,可分析出接口对应的功能受用户的喜好程度 请求通过过滤器到了s…

common-intellisense:助力TinyVue 组件书写体验更丝滑

本文由体验技术团队Kagol原创~ 前两天,common-intellisense 开源项目的作者 Simon-He95 在 VueConf 2024 群里发了一个重磅消息: common-intellisense 支持 TinyVue 组件库啦! common-intellisense 插件能够提供超级强大的智能提示功能&…

c生万物系列(职责链模式与if_else)

从处理器的角度来说,条件分支会导致指令流水线的中断,所以控制语句需要严格保存状态,因为处理器是很难直接进行逻辑判断的,有可能它会执行一段时间,发现出错后再返回,也有可能通过延时等手段完成控制流的正…

skynet 实操篇

文章目录 概述demo启动文件skynet_start配置文件main.luastart函数thread_workerskynet_context_message_dispatchskynet_mq_popdispatch_message 小结 概述 上一篇写完skynet入门篇,这一篇写点实操性质的。 demo 对于一个开源框架,大部分都有他们自己…

《Linux运维总结:基于x86_64架构CPU使用docker-compose一键离线部署zookeeper 3.8.4容器版分布式集群》

总结:整理不易,如果对你有帮助,可否点赞关注一下? 更多详细内容请参考:《Linux运维篇:Linux系统运维指南》 一、部署背景 由于业务系统的特殊性,我们需要面对不同的客户部署业务系统&#xff0…

C++客户端Qt开发——界面优化(美化登录界面)

美化登录界面 在.ui中拖入一个QFream,顶层窗口的QWidget无法设置背景图片,套上一层QFrame将背景图片设置到QFrame上即可 用布局管理器管理元素:用户名LineEdit,密码LineEdit,记住密码ComboBox,登录Button…