【Pytorch】13.搭建完整的CIFAR10模型

news2024/10/6 22:30:52

项目源码

已上传至githubCIFAR10Model,如果有帮助可以点个star

简介

在前文【Pytorch】10.CIFAR10模型搭建我们学习了用Module来模拟搭建CIFAR10的训练流程
本节将会加入损失函数,梯度下降,TensorBoard来完整搭建一个训练的模型

基本步骤

搭建神经网络最主要的流程是

  • 导入数据集(包括训练集和测试集)
  • 创建DataLoader
  • 创建自定义的神经网络
  • 选择损失函数与梯度下降算法
  • 进行n轮训练
  • n轮训练完成后通过测试集进行验证
  • 引入TensorBoard进行可视化
  • 保存每轮训练好的模型
    接下来将逐步拆解这每一个步骤

1.导入数据集

因为我们本文是要训练CIFAR10的模型,所以我们导入CIFAR10的数据集

# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,
                                             transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,
                                            transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)

分别导入训练集与测试集,并且分别记录训练集与测试集的大小
对参数的解释可以看【Pytorch】4.torchvision.datasets的使用这篇文章

2.创建DataLoader

DataLoader主要定义了如何在数据集中取数据的规则,具体讲解可以看【Pytorch】5.DataLoder的使用

# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

3.创建自定义的神经网络

在这里插入图片描述
我们可以在网上搜到CIFAR10的网络模型,通过网络模型来搭建网络,具体可以看【Pytorch】10.CIFAR10模型搭建

import torch
from torch import nn


class CIFAR10Model(nn.Module):
    def __init__(self):
        super(CIFAR10Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5, padding=2)
        self.maxpool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 32, 5, padding=2)
        self.maxpool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(32, 64, 5, padding=2)
        self.maxpool3 = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(1024, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


if __name__ == '__main__':
    model = CIFAR10Model()
    input_test = torch.ones((64, 3, 32, 32))
    output_test = model(input_test)
    print(output_test.shape)

这里我们新创建了一个model.py用于专门存储网络结构,这样在我们的训练文件中,可以通过

from model import *

# 3.创建神经网络
model = CIFAR10Model()

来导入我们自定义的神经网络

4.选择损失函数和梯度下降的方法

我们选择了交叉熵损失函数与SGD的梯度下降算法,具体讲解可以看【Pytorch】11.损失函数与梯度下降

# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()

learn_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)

5.开始进行训练

首先将模型设置为训练模式

    model.train()

具体的训练流程分为以下几部

  • 从DataLoader中获取图片以及对应的编号
  • 将图片传入神经网络并获取输出
  • 将优化器清零
  • 计算损失函数
  • 进行梯度下降
  • 调用优化器进行更新
    for data in train_loader:
        # 训练基本流程
        inputs, labels = data
        outputs = model(inputs)
        optimizer.zero_grad()
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

在基础训练的基础上,还安排了每进行100次训练就将训练数据print出来,并且写入tensorboard

 # 第i轮训练次数加一
        pre_train_step += 1
        pre_train_loss += loss.item()
        total_train_step += 1
        # 每100次输出一下
        if pre_train_step % 100 == 0:
            end_train_time = time.time()
            print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')
            # 添加可视化
            writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)
    print(f"----------------------------第{i + 1}轮训练完成----------------------------")

6.测试集验证

首先将模型设置为测试集模式

    model.eval()

首先通过with关键字来创建一个没有梯度的上下文
验证方法与训练集类似,但是没有计算梯度与更新优化器的步骤

 with torch.no_grad():
        for data in test_loader:
            # 测试集流程
            inputs, labels = data
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)

然后通过torch.argmax用于计算所有标签的最大值

  • 参数为1时代表横向判断
  • 参数为0的代表纵向判断
    计算当前模型在训练集中的正确次数
            pre_accuracy += outputs.argmax(1).eq(labels).sum().item()

7.引入TensorBoard进行可视化

我们主要是通过Summary中的add_scalar来建立可视化函数来进行可视化的,具体可以看【Pytorch】2.TensorBoard的运用

# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')

# 在训练集中,输出每一百次训练的损失函数平均值
 		# 每100次输出一下
        if pre_train_step % 100 == 0:
            end_train_time = time.time()
            print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')
            # 添加可视化
            writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)


# 在测试集中,输出模型在测试集中的正确率
pre_accuracy += outputs.argmax(1).eq(labels).sum().item()
    writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)

8.保存模型

具体可以看【Pytorch】12.网络模型的加载、修改与保存

    # 保存每轮的训练模型
    torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')

完整代码

import time
import torch
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriter

from model import *

# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,
                                             transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,
                                            transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)

# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

# 3.创建神经网络
model = CIFAR10Model()

# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()

learn_rate = 0.0001
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)

# 训练轮数
total_train_step = 0
total_test_step = 0

# 训练轮数
epoch = 20

# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')
# 5.开始训练
for i in range(epoch):
    # 将模型设置为训练模式
    print(f"----------------------------开启第{i+1}轮训练----------------------------")
    model.train()
    # 第i轮训练的次数
    pre_train_step = 0
    # 第i轮训练的总损失
    pre_train_loss = 0
    # 第i轮训练的起始时间
    start_time = time.time()
    for data in train_loader:
        # 训练基本流程
        inputs, labels = data
        outputs = model(inputs)
        optimizer.zero_grad()
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # 第i轮训练次数加一
        pre_train_step += 1
        pre_train_loss += loss.item()
        total_train_step += 1
        # 每100次输出一下
        if pre_train_step % 100 == 0:
            end_train_time = time.time()
            print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')
            # 添加可视化
            writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)
    print(f"----------------------------第{i + 1}轮训练完成----------------------------")
    # 设置为测试模式
    model.eval()
    # 第i轮训练集的总损失
    pre_test_loss = 0
    # 第i轮训练集的总正确次数
    pre_accuracy = 0
    print(f"----------------------------开启第{i + 1}轮测试----------------------------")
    # 配置没有梯度下降的环境
    with torch.no_grad():
        for data in test_loader:
            # 测试集流程
            inputs, labels = data
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)

            # 定义参数
            pre_test_loss += loss.item()
            # 记录训练集的总正确率
            # argmax(1)代表横向判断,argmax(0)代表纵向判断
            pre_accuracy += outputs.argmax(1).eq(labels).sum().item()
    # 记录测试集运行完后的事件
    end_test_time = time.time()
    print(f'当前为第{i + 1}轮测试,已经过时间:{end_test_time - start_time},当前测试集的平均损失为:{pre_test_loss / test_size},当前在测试集的正确率为:{pre_accuracy / test_size}')
    writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)
    print(f"----------------------------第{i + 1}轮测试完成----------------------------")
    # 保存每轮的训练模型
    torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')
    print(f"----------------------------第{i + 1}轮模型保存完成----------------------------")

writer.close()

训练效果
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1690508.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

人类交互3 皮肤感觉与运动系统

皮肤感觉概述 皮肤是人体最大的器官之一,具有多种感觉功能,包括: 触觉:通过触觉,我们能感知物体的形状、质地,帮助我们与外界环境进行互动和感知周围物体的特征。 热觉:热觉使我们能感知周围环…

【笔记】Qt 按钮控件介绍(QPushButton,QCheckBox,QToolButton)

文章目录 QAbstractButton 抽象类(父类)QAbstractButton 类中的属性QAbstractButton 类中的函数QAbstractButton 类中的信号QAbstractButton 类中的槽 QPushButton 类(标准按钮)QPushButton 类中的属性QPushButton 类中的函数、槽 QCheckBox 类(复选按钮)QCheckBox 类的属性QCh…

CCF20221201——现值计算

CCF20221201——现值计算 代码如下&#xff1a; #include<bits/stdc.h> using namespace std; int main() {int n,a[1001];float i,sum0.0;scanf("%d %f",&n,&i);for(int j0;j<n1;j){scanf("%d",&a[j]);suma[j]*pow((1i),-j);}print…

Linux环境中部署docker私有仓库Registry与远程访问详细流程

目录 前言 1. 部署Docker Registry 2. 本地测试推送镜像 3. Linux 安装cpolar 4. 配置Docker Registry公网访问地址 5. 公网远程推送Docker Registry 6. 固定Docker Registry公网地址 前言 作者简介&#xff1a; 懒大王敲代码&#xff0c;计算机专业应届生 今天给大家聊…

用数据,简单点!奇点云2024 StartDT Day数智科技大会,直播见

在充满挑战的2024&#xff0c;企业如何以最小化的资源投入和试错成本&#xff0c;挖掘新的增长机会&#xff0c;实现确定性发展&#xff1f; “简单点”是当前商业环境的应对策略&#xff0c;也是奇点云2024 StartDT Day的核心理念。 5月28日&#xff0c;由奇点云主办的2024 S…

自定义全局变量3

变量删除 语法 unset var_name演示 自定义常量 介绍 就是变量设置值以后不可以修改的变量叫常量, 也叫只读变量 语法 readonly var_name演示 自定义全局变量 父子Shell环境介绍 例如: 有2个Shell脚本文件 A.sh 和 B.sh 如果 在A.sh脚本文件中执行了B.sh脚本文件, 那么A.…

重磅推荐!四信AI智能一体屏系列全网上线

近年来&#xff0c;随着物联网、云计算、人工智能等新兴技术快速发展&#xff0c;制造、能源、交通、零售、医疗等行业设备需要更高程度的自动化控制。 传统的计算机和控制设备早已无法满足如今高性能复杂任务的要求&#xff0c;越来越多主流行业的项目落地依靠工控机&#xff…

Java入门基础学习笔记43——包

什么是包&#xff1f; 包是用来分门别类的管理各种不同程序的&#xff0c;类似文件夹&#xff0c;建包有利于程序的管理和维护。 建包的语法规则&#xff1a; package cn.ensource.javabean;public class Car() {} 在自己的程序中调用其他包下的程序的注意事项&#xff1a; 1…

五分钟部署开源运维平台Spug结合内网穿透实现远程登录管理

文章目录 前言1. Docker安装Spug2 . 本地访问测试3. Linux 安装cpolar4. 配置Spug公网访问地址5. 公网远程访问Spug管理界面6. 固定Spug公网地址 前言 Spug 面向中小型企业设计的轻量级无 Agent 的自动化运维平台&#xff0c;整合了主机管理、主机批量执行、主机在线终端、文件…

Web应用防火墙的重要性

网络安全是一个永恒的话题&#xff0c;尤其是在未知威胁不断涌现的情况下。企业网络安全是保障业务稳定运行的基础&#xff0c;Web应用防火墙(WAF)是企业网络安全的重要屏障&#xff0c;其性能直接影响到网络服务的质量和安全。 Web应用防火墙是什么&#xff1f; Web应用防火墙…

java的unsafe

在Java中&#xff0c;sun.misc.Unsafe 是一个强大且危险的类&#xff0c;它提供了一些直接操作内存、对象和线程的底层功能。这个类通常不鼓励普通开发者使用&#xff0c;因为它绕过了Java语言的一些安全性和内存管理机制&#xff0c;可能会导致难以追踪的错误和安全漏洞。 Un…

[Algorithm][动态规划][路径问题][下降路径最小和][最小路径和][地下城游戏]详细讲解

目录 1.下降路径最小和1.题目链接2.算法原理详解3.代码实现 2.最小路径和1.题目链接2.算法原理详解3.代码实现 3.地下城游戏1.题目链接2.算法原理详解3.代码实现 1.下降路径最小和 1.题目链接 下降路径最小和 2.算法原理详解 思路&#xff1a; 确定状态表示 -> dp[i][j]的…

uniapp高校二手书交易商城回收系统 微信小程序python+java+node.js+php

每年因为有大量的学生在接受教育&#xff0c;每到大学毕业季的时候&#xff0c;所使用的大量书籍对他们自己来说&#xff0c;很多是没有用&#xff0c;同时由于书籍多和不方便携带&#xff0c;导致很多大学生在毕业时将教材直接丢弃是在校大学生处理已用教材的一种主要方式。然…

【调试笔记-20240525-Windows-配置 QEMU/x86_64 运行 OpenWrt-23.05 发行版并搭建 WordPress 博客网站】

调试笔记-系列文章目录 调试笔记-20240525-Windows-配置 QEMU/x86_64 运行 OpenWrt-23.05 发行版并搭建 WordPress 博客网站 文章目录 调试笔记-系列文章目录调试笔记-20240525-Windows-配置 QEMU/x86_64 运行 OpenWrt-23.05 发行版并搭建 WordPress 博客网站 前言一、调试环境…

线性规划库PuLP使用教程

Python求解线性规划——PuLP使用教程 简洁是智慧的灵魂&#xff0c;冗长是肤浅的藻饰。——莎士比亚《哈姆雷特》 文章目录 一、说明二、安装 PuLP 库三、线性规划简介3.1 线性规划3.1.1 高考题目描述3.1.2 基本概念 3.2 整数规划3.2.1 题目描述[3]3.2.2 解题思路 四、求解过程…

c++ vector实现出现的一些问题

目录 前言&#xff1a; 浅拷贝问题: typename指定类型&#xff1a; 前言&#xff1a; 最近学习了c vector的使用&#xff0c;然后也自己实现了一下vector的部分重要的功能。然后在其中出现了一些问题&#xff0c;在这就主要记录一下我解决哪些bug。 浅拷贝问题: 在实现res…

8个实用网站和软件,收藏起来一定不后悔~

整理了8个日常生活中经常能用得到的网站和软件&#xff0c;收藏起来一定不会后悔~ 1.ZLibrary zh.zlibrary-be.se/这个网站收录了超千万的书籍和文章资源&#xff0c;国内外的各种电子书资源都可以在这里搜索&#xff0c;98%以上都可以在网站内找到&#xff0c;并且支持免费下…

Py之llama-parse:llama-parse(高效解析和表示文件)的简介、安装和使用方法、案例应用之详细攻略

Py之llama-parse&#xff1a;llama-parse(高效解析和表示文件)的简介、安装和使用方法、案例应用之详细攻略 目录 llama-parse的简介 llama-parse的安装和使用方法 1、安装 2、使用方法 第一步&#xff0c;获取API 密钥 第二步&#xff0c;安装LlamaIndex、LlamaParse L…

React开发环境配置详细讲解-04

React环境 前端随着规范化&#xff0c;可以说规范和环境插件配置满天飞&#xff0c;笔者最早接触的是jquery&#xff0c;那个开发非常简单&#xff0c;只要引入jquery就可以了&#xff0c;当时还写了一套UI框架&#xff0c;至今在做小型项目中还在使用&#xff0c;show一张效果…

Java进阶学习笔记3——static修饰成员方法

成员方法的分类&#xff1a; 类方法&#xff1a;有static修饰的成员方法&#xff0c;属于类&#xff1a; 成员方法&#xff1a;无static修饰的成员方法&#xff0c;属于对象。 Student类&#xff1a; package cn.ensource.d2_staticmethod;public class Student {double scor…