学习pytorch18 pytorch完整的模型训练流程

news2024/12/26 10:37:03

pytorch完整的模型训练流程

  • 1. 流程
    • 1. 整理训练数据 使用CIFAR10数据集
    • 2. 搭建网络结构
    • 3. 构建损失函数
    • 4. 使用优化器
    • 5. 训练模型
    • 6. 测试数据 计算模型预测正确率
    • 7. 保存模型
  • 2. 代码
    • 1. model.py
    • 2. train.py
  • 3. 结果
    • tensorboard结果
      • 以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值
      • 训练loss
      • 测试loss
      • 测试集正确率
  • 4. 需要注意的细节

1. 流程

1. 整理训练数据 使用CIFAR10数据集

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)

2. 搭建网络结构

在这里插入图片描述
model.py

3. 构建损失函数

loss_fn = nn.CrossEntropyLoss()

4. 使用优化器

learing_rate = 1e-2 # 0.01
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)

5. 训练模型

output = net(imgs)    # 数据输入模型
loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少
# 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化
optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0
loss.backward()        # 设置计算的损失值的钩子,调用损失的反向传播,计算每个参数结点的参数
optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化  

6. 测试数据 计算模型预测正确率

output = net(imags)
# 计算测试集的正确率
preds = (output.argmax(1)==targets).sum()
accuracy += preds 
rate = accuracy/len(test_data)

调用模型输出tensor 数据类型的 argmax方法, argmax或获取一行或者一列数值中最大数值的下标位置,argmax(0) 是从列的维度取一列数值的最大值的下标,argmax(1) 是从行的维度取一行数值的最大值的下标
output.argmax(1)==targets 会输出如下图最后一行 [false, ture], 对应位置相同则为true,对应位置不同则为false;
调用sum()方法,计算求和,false值为0,true值为1.
最后计算得出测试集整体正确率: rate = accuracy/len(test_data)
在这里插入图片描述

7. 保存模型

torch.save(net, './net_epoch{}.pth'.format(i))

2. 代码

1. model.py

import torch
from torch import nn

# 2. 搭建模型网络结构--神经网络
class Cifar10Net(nn.Module):
    def __init__(self):
        super(Cifar10Net, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.net(x)
        return x


if __name__ == '__main__':
    net = Cifar10Net()
    input = torch.ones((64, 3, 32, 32))
    output = net(input)
    print(output.shape)

2. train.py

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriter

from p24_model import *

# 1. 准备数据集
# 训练数据
from torch.utils.data import DataLoader

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
# 测试数据
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# 查看数据大小--size
print("训练数据集大小:", len(train_data))
print("测试数据集大小:", len(test_data))
# 利用DataLoader来加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)

# 2. 导入模型结构 创建模型
net = Cifar10Net()

# 3. 创建损失函数  分类问题--交叉熵
loss_fn = nn.CrossEntropyLoss()

# 4. 创建优化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)

# 设置训练网络的一些参数
epoch = 10   # 记录训练的轮数
total_train_step = 0  # 记录训练的次数
total_test_step = 0   # 记录测试的次数

# 利用tensorboard显示训练loss趋势
writer = SummaryWriter('./train_logs')

for i in range(epoch):
    # 训练步骤开始
    net.train()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用
    for data in train_loader:
        imgs, targets = data  # 获取数据
        output = net(imgs)    # 数据输入模型
        loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少
        # 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化
        optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0
        loss.backward()        # 设置计算的损失值,调用损失的反向传播,计算每个参数结点的参数
        optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化
        # 优化一次 认为训练了一次
        total_train_step += 1
        if total_train_step % 100 == 0:
            print('训练次数: {}   loss: {}'.format(total_train_step, loss))
        # 直接打印loss是tensor数据类型,打印loss.item()是打印的int或float真实数值, 真实数值方便做数据可视化【损失可视化】
        # print('训练次数: {}   loss: {}'.format(total_train_step, loss.item()))
        writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)

    # 利用现有模型做模型测试
    # 测试步骤开始
    total_test_loss = 0
    accuracy = 0
    net.eval()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用
    with torch.no_grad():
        for data in test_loader:
            imags, targets = data
            output = net(imags)
            loss = loss_fn(output, targets)
            total_test_loss += loss.item()
            # 计算测试集的正确率
            preds = (output.argmax(1)==targets).sum()
            accuracy += preds
    # writer.add_scalar('test-loss', total_test_loss, global_step=i+1)
    writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)
    writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)
    total_test_step += 1
    print("---------test loss: {}--------------".format(total_test_loss))
    print("---------test accuracy: {}--------------".format(accuracy))
    # 保存每一个epoch训练得到的模型
    torch.save(net, './net_epoch{}.pth'.format(i))

writer.close()

3. 结果

训练数据集大小: 50000
测试数据集大小: 10000
0.01
训练次数: 100   loss: 2.2905373573303223
训练次数: 200   loss: 2.2878968715667725
训练次数: 300   loss: 2.258394718170166
训练次数: 400   loss: 2.1968581676483154
训练次数: 500   loss: 2.0476632118225098
训练次数: 600   loss: 2.002145767211914
训练次数: 700   loss: 2.016021728515625
---------test loss: 316.382279753685--------------
训练次数: 800   loss: 1.8957302570343018
训练次数: 900   loss: 1.8659226894378662
训练次数: 1000   loss: 1.9004186391830444
训练次数: 1100   loss: 1.9708642959594727
......

tensorboard结果

安装tensorboard运行环境

pip install tensorboard
pip install opencv-python
pip install six
tensorboard --logdir=train_logs

以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值

训练loss

在这里插入图片描述

测试loss

在这里插入图片描述

测试集正确率

在这里插入图片描述

4. 需要注意的细节

https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module

所有网络层继承于torch.nn.Module, net.train() net.eval() 在模型训练或测试之初 可以加可以不加 只有当模型结构有 Dropout BatchNorml层才会起作用,当模型有这两个网络层的时候,两个代码需要加上。
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1310610.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DHCP--自动获取IP地址

目录 一、了解DHCP服务 1、概念 2、使用DHCP的好处 3、DHCP的分配方式 二、DHCP的租约过程 1、客户机请求IP地址 2、服务器响应 3、客户机选择IP地址 4、服务器确定租约 5、服务器租约期限到了之后续期问题 6、总结 三、部署DHCP实验 1、项目要求 2、规划设计 …

云计算与AI融合:Amazon Connect开创客户服务智能时代

授权说明:本篇文章授权活动官方亚马逊云科技文章转发、改写权,包括不限于在 亚马逊云科技开发者社区, 知乎,自媒体平台,第三方开发者媒体等亚马逊云科技官方渠道 在亚马逊云科技 re:Invent 2023 大会上,Amazon Connect…

数组相关的题目

数组相关的题目 128. 最长连续序列 128. 最长连续序列 题目:给定一个未排序的整数数组 nums ,找出数字连续的最长序列(不要求序列元素在原数组中连续)的长度。 很容易就能想到要先排序,再进行后续的处理。有一个坑&a…

CentOS7安装 Docker Compose

docker系列 CentOS7安装 Docker Compose docker系列前言1、下载 Docker Compose2、 授权执行权限3、添加软链接4、验证安装 前言 下面的操作是在centos7中完成的。这里安装的是2.23.3版本的docker-compose。 1、下载 Docker Compose 确保你具有 curl 工具,然后使用…

低代码开发:属于“美味膳食”还是“垃圾食品”

目录 引言低代码是什么?低代码的优点使用挑战未来展望最后 引言 随着数字化转型的迅猛发展,低代码开发平台逐渐成为了企业和开发者的关注焦点,尤其是前两年低代码的迅速火爆,来势汹汹,号称要让大部分程序员下岗的功能…

海外中企项目概设方案

目录 一、项目背景 二、业务需求 2.1 远程视频监控 2.2 多级视频指挥 2.3 无线集群通信 2.4 车辆实时跟踪 2.5 车辆视频调度 三、需求分析 3.1 总指挥中心-标段分指挥中心x 3.2 标段分指挥中心x-指挥车x - 语音/定位业务: 3.3 标段x-指挥车x-视频业务&am…

There appears to be trouble with your network connection. Retrying

一直在报如上错误,试了很多办法,比如删掉yarn.lock,yarn cache clean,删掉node_modules,rm proxy等等都没有用 甚至于重启电脑,然而并没有什么用 突然间想到,我用了clash for window 所以想了…

uniapp开发项目注意事项

uniapp创建项目用HBuilderX创建或者用脚手架命令创建都可以vue文件渲染h5,小程序很好nvue文件渲染原生app更好,注意nvue文件css的一些局限性,简称坑死人nvue所支持的通用样式已在本文档中全部列出,一些组件可能有自定义样式&#…

js点击按钮上传文件

vue语法 <template><div style"width: 152px;"><input id"file" ref"file" class"filepath" change"changepic()" type"file" /><el-button size"small" type"primary&quo…

Docker | 发布镜像到镜像仓库

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏:Docker系列 ✨特色专栏: MySQL学习 🥭本文内容:Docker | 发布镜像到镜像仓库 📚个人知识库: [Leo知识库]https://gaoziman.gitee.io/bl…

操作系统笔记——概论、进程、线程(王道408)

文章目录 前言计算机系统概述OS的基本概念OS的发展历程OS的运行机制OS体系结构OS引导虚拟机 进程和线程进程和线程基础进程进程状态进程控制进程通信线程线程实现 CPU调度调度的层次进程调度细节调度算法评价指标批处理调度算法交互式调度方法 同步与互斥基本概念互斥互斥软件实…

2023.12.6 关于 Spring Boot 事务的基本概念

目录 事务基本概念 前置准备 Spring Boot 事务使用 编程式事务 声明式事务 Transactional 注解参数说明 Transational 对异常的处理 解决方案一 解决方案二 Transactional 的工作原理 面试题 Spring Boot 事务失效的场景有那些&#xff1f; 事务基本概念 事务指一…

10.CSS浮动

CSS浮动 1.介绍 在最初&#xff0c;浮动是用来实现文字环绕图片效果的&#xff0c;现在浮动是主流的页面布局方式之一 2.作用 让元素脱离标准流&#xff0c;同一级的浮动的元素可以并排在一排显示 3.元素浮动后的特点 脱离文档流不管浮动前是什么元素&#xff0c;浮动后&…

【为什么POI的SXSSFWorkbook占用内存更小?】

&#x1f513;为什么POI的SXSSFWorkbook占用内存更小&#xff1f; &#x1f3c6;POI的SXSSFWorkbook&#x1f3c6;POI的SXSSFWorkbook占用内存&#x1f3c6;扩展配置行缓存限制 &#x1f3c6;POI的SXSSFWorkbook SXSSFWorkbook类是Apache POI库的一部分&#xff0c;它是一个流…

产品入门第二讲:Axure产品元件库的使用

&#x1f4da;&#x1f4da; &#x1f3c5;我是默&#xff0c;一个在CSDN分享笔记的博主。&#x1f4da;&#x1f4da; ​​​​ &#x1f31f;在这里&#xff0c;我要推荐给大家我的专栏《Axure》。&#x1f3af;&#x1f3af; &#x1f680;无论你是编程小白&#xff0c;还是…

python下使用Open3D

1.切记不要安装最新的python否则无法使用open3D &#xff0c;官网显示只支持python3.8-3.11 这是我安装的python版本 2.由于访问github很慢&#xff0c;所以我手动下载ply文件 https://github.com/isl-org/open3d_downloads/releases/download/20220201-data/fragment.ply 3…

手写进度条,鼠标移入显示悬浮框

效果 <template><div class"box"><div class"mid-box"><div class"mid-contant"><!-- 提示框 --><divv-if"hover"class"tooltip":style"{top: hovertop,}"><div>{{ ho…

c语言堆排序(详解)

堆排序 堆排序是一种基于二叉堆数据结构的排序算法&#xff0c;它的基本概念包括&#xff1a; 建立堆&#xff1a;将待排序的列表构建成一个二叉堆&#xff0c;即满足堆的性质的完全二叉树&#xff0c;可以是最大堆或最小堆。最大堆要求父节点的值大于等于其子节点的值&#x…

Linux(21):软件安装 RPM,SRPM 与 YUM

软件管理员简介 以原始码的方式来安装软件&#xff0c;是利用厂商释出的Tarball来进行软件的安装。 不过&#xff0c;你每次安装软件都需要侦测操作系统与环境、设定编译参数、实际的编译、最后还要依据个人喜好的方式来安装软件到定位。这过程是真的很麻烦的。 如果厂商先在他…

FastAPI之表单数据

FastAPI 表单数据处理教程 FastAPI 是一个现代、快速&#xff08;高性能&#xff09;的 Web 框架&#xff0c;用于构建 API&#xff0c;它用 Python 3.6类型提示的特性旨在方便和快速地设计和构建 APIs&#xff0c;并且减少代码的冗余与错误。下面将介绍如何在 FastAPI 中处理…