b站小土堆pytorch学习记录—— P27-P29 完整的模型训练套路

news2024/12/28 21:07:26

文章目录

  • 一、定义模型(放在model.py文件中)
  • 二、训练
  • 三、测试
  • 四、完整的训练和测试代码

一、定义模型(放在model.py文件中)

import torch
from torch import nn

class Guodong(nn.Module):
    def __init__(self):
        super(Guodong,self).__init__()
        self.module = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.module(x)
        return x

if __name__ == '__main__':
    guodong = Guodong()
    input = torch.ones((64, 3, 32, 32))
    output = guodong(input)
    print(output.shape)

二、训练

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from model import *

dataset_train = torchvision.datasets.CIFAR10("dataset1", train=True, transform=torchvision.transforms.ToTensor(), download=True)
dataset_test = torchvision.datasets.CIFAR10("dataset1", train=False, transform=torchvision.transforms.ToTensor(),download=False)

dataset_train_size = len(dataset_train)
dataset_test_size = len(dataset_test)
print("训练集的数据长度为{}".format(dataset_train_size))
print("测试集的数据长度为{}".format(dataset_test_size))

train_dataloader = DataLoader(dataset_train, batch_size=64)
test_dataloader = DataLoader(dataset_test, batch_size=64)

# 创建网络模型
guodong = Guodong()

# 损失函数
loss_fn = nn.CrossEntropyLoss()

# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(guodong.parameters(), learning_rate)

# 设置训练网络的一些参数
total_train_step =0
total_test_step = 0
epoch = 10

for i in range(10):
    print("------第{}次训练开始------".format(i+1))

    # 训练开始
    for data in train_dataloader:
        imgs, target = data
        output = guodong(imgs)
        loss = loss_fn(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step+1
        if total_train_step % 100 == 0:
            print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))

运行结果:(部分)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
可以看到,随着训练次数的增加,loss整体上在不断变小

三、测试

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import *

writer = SummaryWriter("train_logs")

dataset_train = torchvision.datasets.CIFAR10("dataset1", train=True, transform=torchvision.transforms.ToTensor(), download=True)
dataset_test = torchvision.datasets.CIFAR10("dataset1", train=False, transform=torchvision.transforms.ToTensor(),download=False)

dataset_train_size = len(dataset_train)
dataset_test_size = len(dataset_test)
print("训练集的数据长度为{}".format(dataset_train_size))
print("测试集的数据长度为{}".format(dataset_test_size))

train_dataloader = DataLoader(dataset_train, batch_size=64)
test_dataloader = DataLoader(dataset_test, batch_size=64)

# 创建网络模型
guodong = Guodong()

# 损失函数
loss_fn = nn.CrossEntropyLoss()

# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(guodong.parameters(), learning_rate)

# 设置训练网络的一些参数
total_train_step =0
total_test_step = 0
epoch = 10

for i in range(10):
    print("------第{}次训练开始------".format(i+1))

    # 训练开始
    for data in train_dataloader:
        imgs, targets = data
        outputs = guodong(imgs)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step+1
        if total_train_step % 100 == 0:
            # print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    total_test_loss = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = guodong(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
    print("整体测试集上的Loss:{}".format(total_test_loss))
    writer.add_scalar("test_loss", total_test_loss, total_test_step)

writer.close()

运行结果:

在这里插入图片描述
打开tensorboard后,结果如下:

在这里插入图片描述

四、完整的训练和测试代码

主要功能:
加载和准备CIFAR-10数据集,以便训练和测试深度学习模型。
创建一个自定义的深度学习模型(Guodong),并定义损失函数和优化器。
执行训练循环和测试循环,通过反向传播优化模型参数,并评估模型在测试集上的性能。
使用TensorBoard记录训练过程中的损失和准确率等信息,以便后续分析和可视化。
保存训练后的模型参数到文件中,以便后续部署和使用。

此外

在深度学习中,通常使用**.train().eval()**这两个方法来设置模型的训练模式和评估模式。这两个方法通常用于 PyTorch 或 TensorFlow 等深度学习框架。

.train(): 这个方法将模型设置为训练模式。在训练模式下,模型会启用训练相关的功能,比如启用 dropout 或 batch normalization 层的运算,以及计算梯度用于参数更新。当调用该方法后,模型会处于可以接受输入数据并进行前向传播、反向传播的状态。

.eval(): 这个方法将模型设置为评估模式。在评估模式下,模型会关闭一些训练过程中的特殊操作,如 dropout 或 batch normalization 的自适应性,以确保在推理阶段的一致性。评估模式通常用于模型在验证集或测试集上的性能评估,以保证评估结果的稳定性和一致性。

通过在训练和评估阶段分别调用.train()和.eval()方法,可以确保模型在不同阶段有正确的行为表现,从而提高训练和评估的效果和可靠性。

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import Guodong  # 导入自定义的模型类

# 创建TensorBoard的SummaryWriter,用于记录训练过程中的损失和准确率等信息
writer = SummaryWriter("train_logs")

# 加载CIFAR-10数据集
dataset_train = torchvision.datasets.CIFAR10("dataset1", train=True, transform=torchvision.transforms.ToTensor(), download=True)
dataset_test = torchvision.datasets.CIFAR10("dataset1", train=False, transform=torchvision.transforms.ToTensor(), download=False)

dataset_train_size = len(dataset_train)
dataset_test_size = len(dataset_test)
print("训练集的数据长度为{}".format(dataset_train_size))
print("测试集的数据长度为{}".format(dataset_test_size))

# 创建训练和测试数据加载器
train_dataloader = DataLoader(dataset_train, batch_size=64)
test_dataloader = DataLoader(dataset_test, batch_size=64)

# 创建网络模型实例
guodong = Guodong()

# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.SGD(guodong.parameters(), learning_rate)

# 设置训练网络的一些参数
total_train_step = 0
total_test_step = 0
epoch = 10

for i in range(10):
    print("------第{}次训练开始------".format(i + 1))

    guodong.train()

    # 训练开始
    for data in train_dataloader:
        imgs, targets = data
        outputs = guodong(imgs)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 测试开始
    guodong.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = guodong(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy

    print("整体测试集上的Loss:{}".format(total_test_loss))
    print("整体测试集上的正确率:{}".format(total_accuracy / dataset_test_size))
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy / dataset_test_size, total_test_step)

    # 保存模型
    torch.save(guodong.state_dict(), "guodong_{}.pth".format(i))
    print("模型已保存")

    total_test_step += 1

writer.close()

代码运行结果:

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1501872.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决:ModuleNotFoundError: No module named ‘paddle‘

错误显示: 原因: 环境中没有‘paddle’的python模块,但是您在尝试导入 解决方法: 1.普通方式安装: pip install paddlepaddle #安装命令 2.镜像源安装 pip install paddlepaddle -i https://pypi.tuna.tsinghua.e…

黑马java-JavaSE进阶-java高级技术

1.单元测试 就是针对最小的功能单元方法,编写测试代码对其进行正确性测试 2.Junit单元测试框架 可以用来对方法进行测试,它是第三方公司开源出来的 优点: 可以灵活的编写测试代码,可以针对某个方法执行测试,也支持一键…

Javaweb之Maven高级分模块设计与开发的详细解析

1. 分模块设计与开发 1.1 介绍 所谓分模块设计,顾名思义指的就是我们在设计一个 Java 项目的时候,将一个 Java 项目拆分成多个模块进行开发。 1). 未分模块设计的问题 如果项目不分模块,也就意味着所有的业务代码是不是都写在这一个 Java 项…

基于AI软件平台 HEGERLS智能托盘四向车机器人物流仓储解决方案持续升级

随着各大中小型企业对仓储需求的日趋复杂,柔性、离散的物流子系统也不断涌现,各种多类型的智能移动机器人、自动化仓储装备大量陆续的应用于物流行业中,但仅仅依靠传统的物流技术和单点的智能化设备,已经无法更有效的应对这些挑战…

神经网络的矢量化,训练与激活函数

我们现在再回到我们的神经元部分,来看我们如何用python进行正向传递。 单层的正向传递: 我们回到我们的线性回归的函数。我们每个神经元通过上述的方法,就可以得到我们的激发值,从而可以继续进行下一层。 我们用这个方法就可以得…

【论文阅读】Segment Anything论文梳理

Abstract 我们介绍了Segment Anything(SA)项目:新的图像分割任务、模型和数据集。高效的数据循环采集,使我们建立了迄今为止最大的分割数据集,在1100万张图像中,共超过10亿个掩码。 该模型被设计和训练为可…

一文学会搭建 cli 脚手架工具

文章目录 设置工具命令package.json bin 字段注释:#!/usr/bin/env node设置环境变量 接收命令选项参数process 实现commander 命令行交互:inquirer下载项目模板:download-git-repo执行额外命令:自动安装依赖child_processexeca 体…

在Anaconda3的conda中创建虚拟环境下载opencv

opencv下载全流程 一、下载Anaconda 记得从官方网格站进行下载,会有一些慢 下载后进行配置 b站讲解视频(非本人(平台大神讲解)) 二、打开conda控制台 这里的两个都可以进行下载 通常我们受用anaconda prompt 三、…

pytorch CV入门3-预训练模型与迁移学习.md

专栏链接:https://blog.csdn.net/qq_33345365/category_12578430.html 初次编辑:2024/3/7;最后编辑:2024/3/8 参考网站-微软教程:https://learn.microsoft.com/en-us/training/modules/intro-computer-vision-pytorc…

mysql主从复制(同步阿里云的RDS至自建数据库)

从库同步阿里云的RDS 阿里云默认开启了binglog,所以我们无需对主库进行配置 查询主库的server_id,从库配置不要重复就行 show variables like %server_id%;编辑从库的my.cnf文件 在文件中增加如下配置 server-id 123456789 …

【微信小程序】传参存储

目录 一、本地数据存储 wx.setStorage wx.setStorageSync 1.1、异步缓存 存取数据 1.2、同步缓存 存取数据 二、使用url跳转路径携带参数 2.1、 wx.redirectTo({}) 2.2、 wx.navigateTo({}) 2.3、 wx.switchTab({}) 2.4 、wx.reLaunch({}) 2.5、组件跳转 三、…

spring boot 2.4.x 之前版本(对应spring-cloud-openfeign 3.0.0之前版本)feign请求异常逻辑

目录 feign SynchronousMethodHandler 第一部分 第二部分 第三部分 spring-cloud-openfeign LoadBalancerFeignClient ribbon AbstractLoadBalancerAwareClient 在之前写的文章配置基础上 https://blog.csdn.net/zlpzlpzyd/article/details/136060312 因为从 spring …

GPT-4 等大语言模型(LLM)如何彻底改变客户服务

GPT-4革命:如何用AI技术重新定义SEO策略 在当今快节奏的数字时代,客户服务不再局限于传统的电话线或电子邮件支持。 得益于人工智能 (AI) 和自然语言模型 (NLM)(例如 GPT-4)的进步,客户服务正在经历革命性的转变。 在这…

【棘手问题】Spring JPA一级缓存导致获取不到数据库表中的最新数据,对象地址不发生改变

【棘手问题】Spring JPA一级缓存导致获取不到数据库表中的最新数据,对象地址不发生改变 一、问题背景二、解决步骤2.1 debug2.2 原因分析2.2.1 数据步骤2.2.2 大模型解释2.2.3 解释举例2.2.4 关键函数 2.3 解决方案 三、Spring JPA一级缓存 一、问题背景 项目的数据…

在ubuntu上使用vscode+gcc-arm-none-eabi+openocd工具开发STM32

文章目录 所需工具安装调试搭建过程中遇到的问题 写在前面 老大上周让我用vscode开发STM32,我爽快的答应了,心想大学四年装了这么多环境了这不简简单单,更何况vscode这两年还用过,然而现实总是令人不快的——我竟然花了差不多两周…

Java SE入门及基础(29)

第三节 访问修饰符 1. 概念 访问修饰符就是控制访问权限的修饰符号 2. 类的访问修饰符 类的访问修饰符只有两种:public 修饰符和默认修饰符(不写修饰符就是默认) public 修饰符修饰类表示类可以公开访问。默认修饰符修饰类表示该类只能…

flutter逆向 ACTF native ap

言 算了一下好长时间没打过CTF了,前两天看到ACTF逆向有道flutter逆向题就过来玩玩啦,花了一个下午做完了.说来也巧,我给DASCTF十月赛出的逆向题其中一道也是flutter,不过那题我难度降的相当之低啦,不知道有多少人做出来了呢~ 还原函数名 flutter逆向的一大难点就是不知道lib…

lvs集群中NAT模式

群集的含义 由多台主机构成,但对外表现为一个整体,只提供一个访问入口,相当于一台大型的计算机。 横向发展:放更多的服务器,有调度分配的问题。 垂直发展:升级单机的硬件设备,提高单个服务器自身功能。 …

论文阅读:Scalable Diffusion Models with Transformers

Scalable Diffusion Models with Transformers 论文链接 介绍 传统的扩散模型基于一个U-Net骨架,这篇文章提出了一种新的扩散模型结构,将U-Net替换为一个transformer,并将这种结构称为Diffusion Transformers (DiTs)。他们还发现&#xff…

Codesys.运动控制电子齿轮

文章目录 一. 电子齿轮概念应用 二. 电子齿轮耦合功能块 2.1. MC_GearIn 2.2. MC_GearInPos 2.3. MC_GearOut 三. 电子齿轮案例 3.1. 样例介绍 3.2. 引入虚轴 3.3. 程序框架 3.4. 程序编写 3.5. 程序监控 一. 电子齿轮概念应用 在很多应用场景中有多个牵引轴每个牵引…