Pytorch完整的模型训练套路

news2025/1/17 3:11:31

Pytorch完整的模型训练套路

文章目录

  • Pytorch完整的模型训练套路
  • 以CIFAR10为例实践

  1. 数据集加载步骤

使用适当的库加载数据集,例如torchvision、TensorFlow的tf.data等。
将数据集分为训练集和测试集,并进行必要的预处理,如归一化、数据增强等。

  1. 模型创建步骤

创建机器学习模型,可以是深度神经网络、传统机器学习模型或其它模型类型。
定义模型架构,包括输入层、隐藏层和输出层的结构、激活函数、损失函数等。

  1. 损失函数和优化器定义步骤

定义适当的损失函数来计算模型预测结果于真实标签之间的差异。
选择适当的优化器算法来更新模型参数,如随机梯度下降(SGD)、Adam等。

  1. 训练循环步骤

从训练集中获取一批样本数据,并将其输入模型进行前向传播。
计算损失函数,并根据损失函数进行反向传播和参数更新。
重复以上步骤,直到达到预定的训练次数或达到收敛条件。

  1. 测试循环步骤

从测试集中获取一批样本数据,并将其输入模型进行前向传播。
计算损失函数或评估指标,用于评估模型在测试集上的性能。

  1. 训练和测试过程的记录和输出步骤

使用适当的工具或库记录训练过程中的损失值、准确率、评估指标等。

  1. 结束训练步骤

根据训练结束条件、例如达到预定的训练次数或收敛条件,结束训练。可以保存模型参数或整个模型,以便日后部署和使用。

以CIFAR10为例实践

并利用tensorboard可视化

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
'''数据集加载'''
train_data = torchvision.datasets.CIFAR10(root='dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root='dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)

# 训练数据集的长度
train_data_size = len(train_data)
print(f"训练数据集的长度为:{train_data_size}")
# 测试数据集的长度
test_data_size = len(test_data)
print(f"测试数据集的长度:{test_data_size}")
#利用DataLoader加载数据集
train_dataloader = DataLoader(test_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
Files already downloaded and verified
Files already downloaded and verified
训练数据集的长度为:50000
测试数据集的长度:10000

‘’‘创建模型’‘’

以上篇文章《Pytorch损失函数、反向传播和优化器、Sequential使用》中的BS()为例

在这里插入图片描述

'''创建模型'''
class BS(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3,
                               out_channels=32,
                               kernel_size=5,
                               stride=1,
                               padding=2),  #stride和padding计算得到
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32,
                                   out_channels=32,
                                   kernel_size=5,
                                   stride=1,
                                   padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32,
                                   out_channels=64,
                                   kernel_size=5,
                                   padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),  #in_features变为64*4*4=1024
            nn.Linear(in_features=1024, out_features=64),
            nn.Linear(in_features=64, out_features=10),
        )
    
    def forward(self,x):
        x = self.model(x)
        return x
    
bs = BS()
print(bs)
BS(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)

一般来说,会将网络单独存放在一个model.py文件当中,然后利用from model import * 进行导入

'''定义损失函数和优化器'''
# 使用交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()  
# 定义优化器
learning_rate = 1e-2  #学习率0.01
optimizer = torch.optim.SGD(bs.parameters(), lr=learning_rate)
"""
训练循环步骤
"""
# 开始设置训练神经网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10



writer = SummaryWriter(".logs") #Tensorboard可视化
for i in range(epoch):
    print("----第{}轮训练开始----".format(i))
    #bs.train() # bs.train()#有batchnorm、dropout层需要调用。官方文档见torch.nn.Module
    '''训练步骤开始'''
    for data in train_dataloader:
        imgs, targets = data
        outputs = bs(imgs)
        loss = loss_fn(outputs, targets)
        
        optimizer.zero_grad() # 首先要梯度清零
        loss.backward() #得到梯度
        optimizer.step() #进行优化

        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print("训练次数:{}, loss:{}".format(total_train_step,loss.item()))
            
            writer.add_scalar("train_loss", loss.item(),total_train_step)
            
    '''测试步骤开始'''
    #bs.eval() # bs.train()#有batchnorm、dropout层需要调用。官方文档见torch.nn.Module
    total_test_loss = 0
    #total_accuracy
    total_accuracy = 0
    with torch.no_grad():#torch.no_grad()是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。
        for imgs, targets in test_dataloader:
            outputs = bs(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item() #.item()取出数字
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy
    """测试过程的记录和输出"""
    print("整体测试集上损失函数loss:{}".format(total_test_loss))
    print("整体测试集上正确率:{}".format(total_accuracy/test_data_size))
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar('test_accuracy',total_accuracy/test_data_size)
    total_test_step = total_test_step + 1
    torch.save(bs, "test_{}.pth".format(i))
    print("模型已保存")
"""
结束训练步骤
"""
writer.close()

利用tensoraboard显示:

tensorboar --logdir logs

在这里插入图片描述

补充.item()

  1. .item()
import torch
a = torch.tensor(5)
print(a)
print(a.item())
tensor(5)
5
  1. model.train()和model.eval()
    官方网址见:torch.nn.Module(*args, **kwargs)
    在这里插入图片描述
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1233107.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PP-PicoDet算法训练行人检测模型

PP-PicoDet算法训练行人检测模型 1,效果图2,PP-PicoDet介绍3,使用飞浆框架训练模型1,准备好图片和对应的标注文件2,划分训练集和验证集3,vi label_list.txt4,目录结构5,修改配置文件…

Ubuntu文件系统损坏:The root filesystem on /dev/sda5 requires a manual fsck

前言 Ubuntu在启动过程中,经常会遇到一些开故障,导致设备无法正常开机,例如文件系统损坏等。 故障描述 Ubuntu系统启动过程中,出现以下文件系统损坏错误: 产生原因 该故障是由磁盘检测不能通过导致,可能是因…

WPS或Excel查找A列中有B列没有的值

就这一行代码: 在C列输入: IF(COUNTIF(B:B,A1)>0,"该行A列中值B列有","该行A列中值B列没有")

达梦数据库安装--注册服务类型错误

最近在学习达梦数据库,安装过程中遇到一点问题,做一下记录。 达梦数据库使用命令行的方式安装,最后一步为了用户管理及控制,需要把数据库服务注册为系统服务,在注册时出现以下错误: 在这我其实犯了一个自以…

Camtasia2024免费版mac电脑录屏软件

作为一个互联网人,没少在录屏软件这个坑里摸爬滚打。培训、学习、游戏、影视解说……都得用它。这时候没个拿得出手的私藏软件,还怎么混?说实话,录屏软件这两年也用了不少,基本功能是有但总觉得缺点什么,直…

CRM系统的客户细分有什么作用?

我们常常说,企业想要开展有针对性的营销活动,就需要进行客户细分。通过特定条件,将客户分为几类,从而对不同类型的客户提供不同的产品和服务。下面我们就针对这里来详细说说,CRM中客户细分是什么?如何细分客…

小程序开发平台源码系统 各行各业都可使用 功能强大 附带完整的搭建教程

当前,数字化转型已经成为各行各业的重要趋势,而小程序作为数字化转型的重要工具之一,具有广泛的应用前景。因此,我们开发了这个源码系统,以帮助各行各业快速开发出符合需求的小程序。 以下是部分代码示例:…

莫斯卡托·达斯蒂葡萄酒是庆祝活动的绝佳饮品首选

在阿斯蒂的山坡上种植莫斯卡托非常艰难,它们需要很长的生长期,在此期间葡萄非常容易受到虫害和疾病的影响,如灰腐病、霉变或浆果蛾。即使他们能在葡萄含糖量达到最佳水平的9月份到达收获季节,他们的产量也往往很低,因此…

Vue框架学习笔记——创建Vue实例、实例与容器对应关系

文章目录 创建Vue实例容器和Vue实例绑定容器中标签体的数据和实例中的数据动态绑定容器和实例一一对应 创建Vue实例 HTML文件中写下述代码&#xff0c;可以消除生产提示&#xff0c;创建Vue实例 <script type"text/javascript">Vue.config.productionTip fal…

【2021集创赛】IEEE杯一等奖:一种28GHz高能效Outphasing PA设计

本作品参与极术社区组织的有奖征集|秀出你的集创赛作品风采,免费电子产品等你拿~活动。 团队介绍 参赛单位&#xff1a;电子科技大学 队伍名称&#xff1a;PA调得队 指导老师&#xff1a;王政 参赛队员&#xff1a;倪梦虎、杨茂旋、张振翼 总决赛奖项&#xff1a;一等奖 1.项…

Dirac‘s BRA and KET notation

from kets to bras expansions the operater matrix elements adjoint of a linear operator Hermitian and Uniraty Operators Hermitian operator defination:

华为防火墙 Radius认证

实现的功能&#xff1a;本地内网用户上网时必须要进行Radius验证&#xff0c;通过后才能上网 前置工作请按这个配置&#xff1a;华为防火墙 DMZ 设置-CSDN博客 Windows 服务器安装 Radius 实现上网认证 拓扑图如下&#xff1a; 一、服务器配置 WinRadius 1、安装WinRadius …

Lightsail VPS 实例在哪些方面胜过 EC2 实例?

文章作者&#xff1a;Libai 引言 Lightsail VPS 实例和 EC2 实例是云计算领域中两种受欢迎的技术。虽然两者都提供虚拟服务器解决方案&#xff0c;但了解 Lightsail VPS 实例在哪些方面胜过 EC2 实例非常重要。在本文中&#xff0c;我们将探讨这两种技术之间的关键区别&#x…

【前端】前端监控⊆埋点

文章目录 前端监控分为三个方面前端监控流程异常监控常见的错误捕获方法主要是 try / catch 、window.onerror 和window.addEventListener 等。Promise 错误Vue 错误React 错误 性能监控用户行为监控常见的埋点方案来源 前端监控分为三个方面 异常监控&#xff08;监控前端页面…

如何选择示波器?

简介 对于很多工程师来讲&#xff0c;从市场中上百款不同价格和规格的各种型号的示波器中&#xff0c;选择一台新示波器是一件很挠首的事情。本文就旨在指引你拨开迷雾&#xff0c;希望能帮助你避免付出昂贵的代价。 重中之重 选择示波器的第一步不是要看那些示波器的广告和规…

Lombok注解式简化开发

Lombok&#xff08;发音为"lombk"&#xff09;是一种Java库&#xff0c;它通过注解的方式来简化Java代码的编写。它提供了一组注解&#xff0c;用于在编译时生成代码&#xff0c;减少了开发人员需要手动编写的样板代码&#xff0c;提高了代码的简洁性和可读性。 Lom…

【三种加载自定义控制器的方式 Objective-C语言】

一、关于这个手动创建Window呢,给大家说完了 1.但是呢,要给大家补充一个东西, 有时候,有的框架,可能会用到什么东西呢,我写到下面: [UIApplication sharedApplication] 什么东西,是不是应用程序对象, 然后呢,keyWindow 是不是拿到它的主窗口, 然后呢,add什么东西…

2013年12月2日 Go生态洞察:Go 1.2的测试覆盖率工具

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

Find My自行车|苹果Find My技术与自行车结合,智能防丢,全球定位

自行车&#xff0c;这项古老而简单的交通工具&#xff0c;近年来在中国经历了一场令人瞩目的复兴。从城市的街头巷尾到乡村的田园小路&#xff0c;自行车成了一种新的生活方式&#xff0c;一个绿色出行的选择。中国的自行车保有量超过两亿辆&#xff0c;但是自行车丢失事件还是…

java算法学习索引之数组矩阵问题

一 将正方形矩阵顺时针转动90 给定一个NN的矩阵matrix&#xff0c;把这个矩阵调整成顺时针转动90后的形式。 顺时针转动90后为&#xff1a; 【要求】额外空间复杂度为O&#xff08;1&#xff09;。 public void rotate(int[][] matrix) {int tR 0; // 左上角行坐标int tC 0;…