PyTorch学习笔记(十五)——完整的模型训练套路

news2025/1/23 4:40:58

以 CIFAR10 数据集为例,分类问题(10分类)

 

model.py

import torch
from torch import nn

# 搭建神经网络
class MyNN(nn.Module):
    def __init__(self):
        super(MyNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

if __name__ == '__main__':
    # 验证网络的正确性
    mynn = MyNN()
    input = torch.ones(64,3,32,32)
    output = mynn(input)
    print(output)

运行结果:torch.Size([64,10]) 

返回64行数据,每一行数据有10个数据,代表每一张图片在10个类别中的概率

train.py

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# model.py必须和train.py在同一个文件夹下
from model import *

# 准备数据集(CIFAR10 数据集是PIL Image,要转换为tensor数据类型)
train_data = torchvision.datasets.CIFAR10(root="../datasets",train=True,transform=torchvision.transforms.ToTensor(),download=False)
test_data = torchvision.datasets.CIFAR10(root="../datasets",train=False,transform=torchvision.transforms.ToTensor(),download=False)

# 获得数据集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))

# 利用dataloader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

# 创建网络模型
mynn = MyNN()
# 损失函数
loss_function = nn.CrossEntropyLoss()
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(mynn.parameters(), lr=learning_rate) # SGD 随机梯度下降

# 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 # 训练的轮数

# 添加tensorboard
writer = SummaryWriter("../logs_train")

for i in range(epoch):
    print("----------第{}轮训练开始----------".format(i+1))
    # 训练步骤开始
    mynn.train()
    for data in train_dataloader:
        imgs,targets = data
        outputs = mynn(imgs)
        loss = loss_function(outputs, targets)
        # 优化器优化模型
        optimizer.zero_grad() # 首先要梯度清零
        loss.backward() # 反向传播得到每一个参数节点的梯度
        optimizer.step() # 对参数进行优化

        total_train_step += 1
        # 训练步骤逢百才打印记录
        if total_train_step % 100 == 0:
            print("训练次数:{},loss:{}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss",loss.item(),total_train_step)

    # 测试步骤开始
    mynn.eval()
    total_test_loss = 0
    total_accuracy = 0
    # 无梯度,不进行调优
    with torch.no_grad():
        for data in test_dataloader:
            imgs,targets = data
            outputs = mynn(imgs)
            loss = loss_function(outputs, targets)
            total_test_loss += loss
            # 即便得到整体测试集上的 loss,也不能很好说明在测试集上的表现效果
            # 在分类问题中可以用正确率表示
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy
    print("整体测试集上的loss:{}".format(total_test_loss))
    print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))
    writer.add_scalar("test_loss",total_test_loss,total_test_step)
    writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)
    total_test_step += 1

    # 保存每一轮训练的模型
    torch.save(mynn,"mynn_{}.pth".format(i))
    # torch.save(mynn.state_dict(),"mynn_{}.pth".format(i))
    print("模型已保存")

writer.close()

 

 关于正确率的计算:

方式1:

import torch

outputs = torch.tensor([[0.1,0.2],
                       [0.3,0.4]])
target = torch.tensor([0,1])

predict = outputs.argmax(1)
print(predict)

print(predict == target)
print((predict == target).sum())

 方式2:

import torch

outputs = torch.tensor([[0.1,0.2],
                       [0.3,0.4]])
target = torch.tensor([0,1])

predict = torch.max(outputs, dim=1)[1]
print(predict)

print(torch.eq(predict,target))
print(torch.eq(predict,target).sum())
print(torch.eq(predict,target).sum().item())

关于mynn.train()和mynn.eval():

这两句不写网络依然可以运行,它们的作用是:

 

 

这个案例没有 Dropout 层或 BatchNorm 层,所以有没有这两行都无所谓。但如果有这些特殊层,一定要调用。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/899290.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言:深度学习知识储备

目录 数据类型 每种类型的大小是多少呢? 变量 变量的命名: 变量的分类: 变量的作用域和生命周期 作用域: 生命周期: 常量 字符串转义字符注释 字符串: 转义字符 操作符: 算术操作符…

nginx反向代理、负载均衡

修改nginx.conf的配置 upstream nginx_boot{# 30s内检查心跳发送两次包,未回复就代表该机器宕机,请求分发权重比为1:2server 192.168.87.143 weight100 max_fails2 fail_timeout30s; server 192.168.87.1 weight200 max_fails2 fail_timeout30s;# 这里的…

【流程引擎】--Camunda基础及sprringboot简单集成Camunda

目录 一、前言二、Camunda基本介绍2.1、camunda基础--符号表示2.2、camunda基础--网关表示2.3、camunda基础--事件表示 三、springboot集成Camunda四、后续 一、前言 目前市场上有常见的流程引擎:JBPM、Activiti、Camunda、Flowable、CompileFlow。它们的发展史如下…

TR 已经释放 task未释放的问题

货铺QQ群号:834508274 微信群不能扫码进了,可以加我微信SAPliumeng拉进群,申请时请提供您哪个模块顾问,否则是一律不通过的。 进群统一修改群名片,例如BJ_ABAP_森林木。群内禁止发广告及其他一切无关链接,小…

16-案例-记账单

功能需求: <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title> </head> &l…

224、仿真-基于51单片机音乐播放器流水灯控制Proteus仿真设计(程序+Proteus仿真+原理图+程序流程图+元器件清单+配套资料等)

毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、原理图 五、程序源码 资料包括&#xff1a; 需要完整的资料可以点击下面的名片加下我&#xff0c;找我要资源压缩包的百度网盘下载地址及提取码。 方案选择 单片机的选…

C++音乐播放系统

C音乐播放系统 音乐的好处c发出声音乐谱与赫兹对照把歌打到c上 学习c的同学们都知道&#xff0c;c是一个一本正经的编程语言&#xff0c;因该没有人用它来做游戏、做病毒、做…做…做音乐播放系统吧&#xff01;&#xff01; 音乐的好处 提升情绪&#xff1a;音乐能够影响我们…

【C++进阶】继承、多态的详解(多态篇)

【C进阶】继承、多态的详解&#xff08;多态篇&#xff09; 目录 【C进阶】继承、多态的详解&#xff08;多态篇&#xff09;多态的概念多态的定义及实现多态的构成条件&#xff08;重点&#xff09;虚函数虚函数的重写&#xff08;覆盖、一种接口继承&#xff09;C11 override…

解决C#报“MSB3088 未能读取状态文件*.csprojAssemblyReference.cache“问题

今天在使用vscode软件C#插件&#xff0c;编译.cs文件时&#xff0c;发现如下warning: 图(1) C#报cache没有更新 出现该warning的原因&#xff1a;当前.cs文件修改了&#xff0c;但是其缓存文件*.csprojAssemblyReference.cache没有更新&#xff0c;需要重新清理一下工程&#x…

双层优化入门(4)—基于对偶变换的双层优化求解

之前的博客介绍了双层优化的基本原理、以及如何使用KKT条件和智能优化算法求解双层优化问题&#xff0c;这篇博客将继续介绍如何通过对偶变换求解双层优化问题。 1.线性规划的对偶问题 参考资料&#xff1a; 运筹学修炼日记&#xff1a;如何优雅地写出大规模线性规划的对偶_刘…

spring boot 整合支付宝微信支付

1.目录结构 2.引入依赖 <!--引入阿里支付--><dependency><groupId>com.alipay.sdk</groupId><artifactId>alipay-sdk-java</artifactId><version>4.11.8.ALL</version></dependency><!--引入微信支付--><depe…

Redis中的淘汰策略

前言 本文主要说明在Redis面临key过期和内存不足的情况时&#xff0c;可以采用什么策略进行解决问题。 Redis中是如何应对过期数据的 正如我们知道的Redis是基于内存的、单线程的一个中间件&#xff0c;在面对过期数据的时候&#xff0c;Redis并不会去直接把它从内存中进行剔…

运用工具Postman快速导出python接口测试脚本

Postman的脚本可以导出多种语言的脚本&#xff0c;方便二次维护开发。 Python的requests库&#xff0c;支持python2和python3&#xff0c;用于发送http/https请求 使用unittest进行接口自动化测试 一、环境准备 1、安装python&#xff08;使用python2或3都可以&#xff09;…

HCIP之VLAN实验

目录 一、实验题目 二、实验思路 三、实验步骤 3.1 将接口划入vlan&#xff0c;设置trunk干道 3.2 启动DHCP服务&#xff0c;下发地址 四、测试 一、实验题目 实验要求&#xff1a; 1&#xff0c;PC1/3的接口均为access模式&#xff0c;且属于vlan2&#xff0c;处于同一…

pyltp 0.2.1安装

1. LTP及pyltp pyltp是 LTP的 Python封装&#xff0c;它里面提供了包括分词&#xff0c;词性标注&#xff0c;命名实体识别&#xff0c;句法分析等等能力。 比较坑的是我们可能无法直接通过pip install pyltp0.2.1方式来安装&#xff0c;所以本文就简单记录下如何通过源码安装…

04_15页表缓存(TLB)和巨型页

前言 linux里面每个物理内存(RAM)页的一般大小都是4kb(32位就是4kb),为了使管理虚拟地址数变少 加快从虚拟地址到物理地址的映射 建议配值并使用HugePage巨型页特性 cpu和mmu和页表缓存(TLB)和cache和ram的关系 CPU看到的都是虚拟地址&#xff0c;需要经过MMU的转化&#xf…

langchain-ChatGLM源码阅读:模型加载

文章目录 使用命令行参数初始化加载器模型实例化清空显存加载模型调用链loader.py的_load_model方法auto_factory.py的from_pretrained方法modeling_utils.py的from_pretrained方法hub.py的get_checkpoint_shard_files方法modeling_utils.py的_load_pretrained_mode方法回到loa…

电脑远程接入软件可以进行文件传输吗?快解析内网穿透

电脑远程接入软件的出现&#xff0c;让我们可以在两台电脑之间进行交互和操作。但是&#xff0c;很多人对于这些软件能否进行文件传输还存在一些疑问。下面的文章将解答这个问题。 1.电脑远程接入软件可以进行文件传输。传统上&#xff0c;我们可能会通过传输线或者移动存储设…

听GPT 讲Prometheus源代码--promql/promdb

Prometheus的promql目录包含PromQL(Prometheus Query Language)的解析和执行代码: parser.go 定义PromQL语法结构和parser,用于将PromQL查询语句进行语法解析。 semantic.go 实现PromQL的语义分析,检查查询是否语法正确且语义合理。 engine.go 定义PromQL执行引擎的接口和数据结…