pytorch完整模型训练套路

news2024/12/23 6:07:26

文章目录

  • CIFAR10数据集简介
  • 训练模型套路
    • 1、准备数据集
    • 2、加载数据集
    • 3、搭建神经网络
    • 4、创建网络模型、定义损失函数、优化器
    • 5、训练网络
    • 6、测试数据集
    • 7、添加tensorboard
    • 8、转化为正确率
    • 9、保存模型
  • 完整代码

本文以 CIFAR10数据集为例,介绍一个完整的模型训练套路。

CIFAR10数据集简介

CIFAR-10数据集包含60000张32x32彩色图像,分为10个类,每类6000张。有50000张训练图片和10000张测试图片。

数据集分为五个训练batches和一个测试batch,每个batch有10000张图像。测试batch包含从每个类中随机选择的1000个图像。训练batches以随机顺序包含剩余的图像,但有些训练batches可能包含一个类的图像多于另一个类的图像。在它们之间,训练batches包含来自每个类的5000张图像。

下面是数据集中的类,以及每个类的10张随机图片:
在这里插入图片描述

一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。

训练模型套路

1、准备数据集

# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="./source", train=True, transform=torchvision.transforms.ToTensor(), download=True)

test_data = torchvision.datasets.CIFAR10(root="./source", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为:{train_data_size}")
print(f"测试数据集的长度为:{test_data_size}")

在这里插入图片描述

2、加载数据集

# 利用 DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

3、搭建神经网络

我们准备搭建一个这样的网络模型结构:

在这里插入图片描述

# 搭建神经网络
class Aniu(nn.Module):
    def __init__(self):
        super(Aniu, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x
if __name__ == '__main__':
    aniu = Aniu()
    input = torch.ones((64, 3, 32, 32))
    output = aniu(input)
    print(output.shape)

我们在一个新的文件下搭建并简单测试神经网络。

4、创建网络模型、定义损失函数、优化器

# 创建网络模型
aniu = Aniu()
# 损失函数
loss_fn = nn.CrossEntropyLoss()
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(aniu.parameters(), lr=learning_rate)

5、训练网络

# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

for i in range(epoch):
    print(f"----------第{i+1}轮训练开始-----------")

    # 训练开始
    for data in train_dataloader:
        imgs, targets = data
        output = aniu(imgs)
        loss = loss_fn(output, targets)

        # 优化器优化模型
        optimizer.zero_grad() # 优化器梯度清零
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        print(f"训练次数:{total_train_step},loss:{loss.item()}") # .item()可以将tensor数据类型转化

6、测试数据集

我们可以通过with torch.mo_grad():来测试

for i in range(epoch):
    print(f"----------第{i+1}轮训练开始-----------")

    # 训练开始
    for data in train_dataloader:
        imgs, targets = data
        output = aniu(imgs)
        loss = loss_fn(output, targets)

        # 优化器优化模型
        optimizer.zero_grad() # 优化器梯度清零
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print(f"训练次数:{total_train_step},loss:{loss.item()}") # .item()可以将tensor数据类型转化

    # 测试步骤开始
    total_test_loss = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            output = aniu(imgs)
            loss = loss_fn(output, targets)
            total_test_loss = total_test_loss + loss.item()
    print(f"整体测试集上的Loss:{total_test_loss}")

在这里插入图片描述

7、添加tensorboard

我们在以上的代码基础上添加tensorboard,并通过tensorboard画图进行观察:

# 添加tensorboard
writer = SummaryWriter("./log_train")

for i in range(epoch):
    print(f"----------第{i+1}轮训练开始-----------")

    # 训练开始
    for data in train_dataloader:
        imgs, targets = data
        output = aniu(imgs)
        loss = loss_fn(output, targets)

        # 优化器优化模型
        optimizer.zero_grad() # 优化器梯度清零
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print(f"训练次数:{total_train_step},loss:{loss.item()}") # .item()可以将tensor数据类型转化
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 测试步骤开始
    total_test_loss = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            output = aniu(imgs)
            loss = loss_fn(output, targets)
            total_test_loss = total_test_loss + loss.item()
    print(f"整体测试集上的Loss:{total_test_loss}")
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    total_test_step = total_test_step + 1

writer.close()

运行并在终端输入:

tensorboard --logdir="log_train"

可以观察到图像:

在这里插入图片描述

8、转化为正确率

添加一段代码,算出测试集上的正确率:

# 整体正确的个数
total_accuracy = 0

with torch.no_grad():
    for data in test_dataloader:
        imgs, targets = data
        output = aniu(imgs)
        loss = loss_fn(output, targets)
        total_test_loss = total_test_loss + loss.item()
        accuracy = (output.argmax(1) == targets).sum()
        total_accuracy = total_accuracy + accuracy
print(f"整体测试集上的Loss:{total_test_loss}")
print(f"整体测试集上的正确率:{total_accuracy/test_data_size}")
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1

9、保存模型

每轮保存一下模型:

torch.save(aniu, f"aniu_{i}.pth")
print("模型已保存")

完整代码

train.py文件:

import torch.cuda
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import *
# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="./source", train=True,
                                          transform=torchvision.transforms.ToTensor(), download=True)

test_data = torchvision.datasets.CIFAR10(root="./source", train=False,
                                          transform=torchvision.transforms.ToTensor(), download=True)

# length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为:{train_data_size}")
print(f"测试数据集的长度为:{test_data_size}")

# 利用 DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 创建网络模型 搭建神经网络
class Aniu(nn.Module):
    def __init__(self):
        super(Aniu, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

aniu = Aniu()
# if torch.cuda.is_available():
#     aniu = aniu.cuda()


# 损失函数
loss_fn = nn.CrossEntropyLoss()
# if torch.cuda.is_available():
#     loss_fn = loss_fn.cuda()
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(aniu.parameters(), lr=learning_rate)

# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10


# 添加tensorboard
writer = SummaryWriter("./log_train")

for i in range(epoch):
    print(f"----------第{i+1}轮训练开始-----------")

    # 训练开始
    aniu.train()
    for data in train_dataloader:
        imgs, targets = data
        # if torch.cuda.is_available():
        #     imgs = imgs.cuda()
        #     targets = targets.cuda()
        output = aniu(imgs)
        loss = loss_fn(output, targets)

        # 优化器优化模型
        optimizer.zero_grad() # 优化器梯度清零
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print(f"训练次数:{total_train_step},loss:{loss.item()}") # .item()可以将tensor数据类型转化
            writer.add_scalar("train_loss", loss.item(), total_train_step)


    # 测试步骤开始
    aniu.eval()
    total_test_loss = 0
    # 整体正确的个数
    total_accuracy = 0

    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            # if torch.cuda.is_available():
            #     imgs = imgs.cuda()
            #     targets = targets.cuda()
            output = aniu(imgs)
            loss = loss_fn(output, targets)
            total_test_loss = total_test_loss + loss.item()
            accuracy = (output.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy
    print(f"整体测试集上的Loss:{total_test_loss}")
    print(f"整体测试集上的正确率:{total_accuracy/test_data_size}")
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
    total_test_step = total_test_step + 1

    # torch.save(aniu.state_dict(), f"aniu_{}.pth") 官方推荐保存方式
    torch.save(aniu, f"aniu_{i}.pth")
    print("模型已保存")

writer.close()

model.py:

import torch
from torch import nn
# 搭建神经网络
class Aniu(nn.Module):
    def __init__(self):
        super(Aniu, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

if __name__ == '__main__':
    aniu = Aniu()
    input = torch.ones((64, 3, 32, 32))
    output = aniu(input)
    print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/560490.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习-线性代数-向量、基底及向量空间

概述 文章目录 概述向量理解向量运算 基底与向量的坐标表示基底与向量的深入基底与向量选取与表示基底的特殊性张成空间 向量 理解 直观理解 行向量:把数字排成一行A [ 4 5 ] [4~ 5] [4 5]列向量:把数字排成一列A [ 4 5 ] \ \left [ \begin{matrix}…

多线性开发实例分享

一. 概述 首先,在这里有必要和大家复现一下我使用该技术的背景: 在使用若依框架的时候,由于实际开发的需要,我需要配置四个数据源,并且通过mapper轮流去查每个库的指定用户数据,从而去判断改库是否存在目标…

构建一个简易数据库-用C语言从头写一个sqlite的克隆 0.前言

英文源地址 一个数据库是如何工作的? 数据是以什么格式存储的(在内存以及在磁盘)?何时从内存中转移到此磁盘上?为什么每张表只能有一个主键?回滚一个事务是如何工作的?索引是以什么格式组织的?什么时候会发生全表扫描, 以及它是如何进行的?准备好的语句是以什么格式保…

#C2#S2.2~S2.3# 加入 factory/objection/virtual interface 机制

2.2 加入factory 机制 factory机制的实现被集成在了一个宏中:uvm_component_utils。这个宏所做的事情非常多,其中之一就是将my_driver登记在 UVM内部的一张表中,这张表是factory功能实现的基础。只要在定义一个新的类时使用这个宏&#xff0…

斐波那契数列相关简化4

看这篇文章前需要看下前面三篇文章,最起码第一第二篇是需要看一下的 斐波那契数列数列相关简化1_鱼跃鹰飞的博客-CSDN博客 斐波那契数列数列相关简化2_鱼跃鹰飞的博客-CSDN博客 算法玩的就是套路,练练就熟悉了 再来一个: 用1*2的瓷砖&am…

如何在 CentOS Linux 上安装和配置 DRBD?实现高可用性和数据冗余

DRBD(Distributed Replicated Block Device)是一种用于实现高可用性和数据冗余的开源技术。它允许在不同的服务器之间实时同步数据,以提供数据的冗余和容错能力。本文将详细介绍如何在 CentOS Linux 上安装和配置 DRBD。 1. 确认系统要求 在…

一文带你了解MySQL之InnoDB统计数据是如何收集的

前言 本文章收录在MySQL性能优化原理实战专栏,点击此处查看更多优质内容。 我们前边唠叨查询成本的时候经常用到一些统计数据,比如通过show table status可以看到关于表的统计数据,通过show index可以看到关于索引的统计数据,那…

MySQL之事务初步

0. 数据源 /*Navicat Premium Data TransferSource Server : localhost_3306Source Server Type : MySQLSource Server Version : 80016Source Host : localhost:3306Source Schema : tempdbTarget Server Type : MySQLTarget Server Version…

在线OJ常用输入规则

一、字符串输入规则 1.1 单行无空格字符串输入 输入连续字符串,cin默认空格/换行符为分割标志。 string s; //输入连续字符串,cin默认空格/换行符为分割标志。 cin >> s; 1.2 单行有空格字符串输入 getline函数接受带有空格的输入流&#xff…

C++——初识模板

文章目录 总述为什么要有模板函数模板概念函数模板使用方法函数模板的原理函数模板的实例化隐式示例化显式实例化 模板参数的匹配规则 类模板类模板的实例化 总述 本篇文章将带大家简单的了解一下c的模板方面的知识,带大家认识什么是模板,模板的作用&…

STL-常用算法(一.遍历 查找 排序)

目录 常用遍历算法: for_each和transform函数示例: 常用查找算法: find函数示例: find_if函数示例: adjacent_find示例: binary_search函数示例: count函数示例: count_if函…

训练/测试、过拟合问题

在机器学习中,我们创建模型来预测某些事件的结果,比如之前使用重量和发动机排量,预测了汽车的二氧化碳排放量 要衡量模型是否足够好,我们可以使用一种称为训练/测试的方法 训练/测试是一种测量模型准确性的方法 之所以称为训练…

springmvc升级到springboot2踩的坑

声明:删除springmvc的jar配置改成springboot的,若别的组件依赖springboot该升级就升级,该删掉就删掉,此文章只记录升级后的坑,升级springboot所需的jar请自行百度。 一.Hibernate的坑 概念:jpa和Hibernate的关系,jpa…

【JAVAEE】网络编程的简单介绍及其实现

目录 1.什么是网络编程 网络编程中的基本概念 常见的客户端服务端模型 2.Socket套接字 Socket套接字分类 举例对比TCP和UDP 3.UDP数据报套接字编程 DatagramSocket API DatagramPacket API InetSocketAddress API 4.实现一个简单的UDP回显服务器与客户端 服务端与客…

当前最新免费使用GPT-4方法汇总

目录 前言 温馨提示 Ora AI 使用方式 使用测试 Forefont chat 使用方式 使用测试 Perplexity AI 使用方式 使用测试 Poe 总结 前言 目前GPT-4的收费对于大多数人而言都还是不便宜,且付费方式复杂,使用上还有每3小时25个问题的限制&#xff…

Aspose.OCR For NET 23.5 Crack

使用几行代码将光学字符识别 (OCR) 添加到您的 .NET 应用程序。 适用于 .NET 的 Aspose.OCRAspose.OCR 文档 Aspose.OCR for .NET 是一个功能强大但易于使用且具有成本效益的光学字符识别 API。有了它,您可以用不到 5 行代码将 OCR 功能添加到您的 .NET 应用程序…

【Linux】初识优雅的Linux编辑器——Vim

❤️前言 大家好!今天给大家带来的博客内容是关于Linux操作系统下的一款多模式文本编辑器Vim。本文将和大家一起来了解Vim编辑器的一些基础知识。 正文 Vim是一个多模式的文本编辑器(一共有十二种模式),其中我们当我们初学Vim时主要了解如下三种工作模式…

Linux——多线程(线程概念|进程与线程|线程控制)

目录 地址空间和页表 如何看待地址空间和页表 虚拟地址如何转化到物理地址的 线程与进程的关系 什么叫进程? 什么叫线程? 如何看待我们之前学习进程时,对应的进程概念呢?和今天的冲突吗? windows线程与linux线…

Leetcode665. 非递减数列

Every day a Leetcode 题目来源:665. 非递减数列 解法1:贪心 本题是要维持一个非递减的数列,所以遇到递减的情况时(nums[i] > nums[i 1]),要么将前面的元素缩小,要么将后面的元素放大。 …

K8s in Action 阅读笔记——【2】First steps with Docker and Kubernetes

K8s in Action 阅读笔记——【2】First steps with Docker and Kubernetes 2.1 Creating, running, and sharing a container image 2.1.1 Installing Docker and running a Hello World container 在电脑上安装好Docker环境后,执行如下命令, $ dock…