pytorch学习笔记(十二)

news2024/12/23 15:04:49

 以下代码是以CIFAR10这个10分类的图片数据集训练过程的完整的代码。

训练部分

train.py主要包含以下几个部件:

  • 准备训练、测试数据集
  • 用DateLoader加载两个数据集,要设置好batchsize
  • 创建网络模型(具体模型在model.py中)
  • 设置损失函数
  • 设置优化器,其中要包含优化的参数和学习率
  • 初始化一些参数,如训练测试的次数、以及训练的轮数epoch
  • 以训练轮数为循环进入训练
  • 从训练数据中加载数据,将数据(模型的输出和目标(标签))送进损失函数中计算损失
  • 梯度清零,并且反向传播损失函数,用优化器进行参数更新,并累计训练步数。
  • 在保证不调优的情况下看正确率(with)
    从测试集中拿数据,一样的讨论算损失,但是要算正确率
  • 用tensorboard可是话训练的结果

关于imgs, targets =data这句代码中的targets解释

  1. imgs (Images): 这个变量通常包含一批图像数据。在计算机视觉任务中,这些图像是模型的输入,可以是任何形式的视觉数据,比如照片、视频帧、医学影像等。在训练过程中,这些图像通过神经网络进行前向传播以生成预测结果。

  2. targets (Targets): 这个变量包含与 imgs 中每个图像对应的标签或目标。标签的具体形式取决于执行的任务:

    • 分类任务中,targets 可能是类别标签,例如识别图像中的对象(猫、狗、汽车等)。
    • 对象检测任务中,targets 可能包括对象的边界框(bounding boxes)和类别。
    • 语义分割任务中,targets 可能是每个像素的类别标签。
    • 回归任务中,targets 可能是一些连续值,如在面部关键点检测中的坐标点。

在训练过程中targets用于损失函数(交叉熵损失、均方误差等),这是模型学习并优化其参数的基础。损失函数衡量了模型预测和真实目标之间的差异,训练目标是最小化差异。

关于optimizer.step()的解释

在机器学习中,这玩意是个关键操作,就是用来根性模型参数的。

优化器和梯度下降,常用的优化算法(SGD、Adam、RMSprop等)来调整网络参数(如权重和偏差),以最小化损失函数。这个过程被称为梯度下降。

训练过程中的步骤:

  • Forward Pass:输入数据进行前向计算,生成预测。
  • 计算损失函数,比较网络的预测和真实计算损失
  • 反向传播:通过反向传播损失,计算每个参数梯度 loss.backward()来完成。
  • 更新参数optimizer.step()被调用来更新网络的参数。根据计算出的梯度和定义的优化算法,它会调整参数以减小损失。

注意: 

optimizer.step()根据优化器预定义的规则和计算出的梯度来更新模型参数。在调用它之后,会执行optimizer.zero_grad(),以便下一次迭代时从干净的状态开始。

import torch.nn
import torchvision
from torch.utils.tensorboard import SummaryWriter

from model import *
from torch.utils.data import DataLoader
from torch import nn

#准备数据集
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                           download=True)
#测试数据集
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)
#length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))

#利用DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)


#创建网络模型
tudui = Tudui()

#损失函数
loss_fn = nn.CrossEntropyLoss()

#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

#设置训练网络的一些参数
#记录训练的次数
total_train_step = 0
#测试的次数
total_test_step = 0
#训练的轮数
epoch = 10

#添加 tensorboard
writer = SummaryWriter("../logs_train")

for i in range(epoch):
    print("-------------第{}轮训练开始-------------".format(i+1))

    #训练步骤开始
    #并不需要把网络设置成训练状态才能进行训练
    tudui.train()
    for data in train_dataloader:
        imgs, targets =data
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)
        #梯度清零
        #优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_step = total_train_step + 1
        #避免无用信息覆盖
        if total_train_step % 100 == 0:
            print("训练次数: {},loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)
    #测试步骤
    #也不是需要把网络设置成eval状态才能进行网络的一个测试
    tudui.eval()
    total_test_loss = 0
    #看正确率
    total_accuracy = 0
    #在with里面的代码没有了梯度,保证不会进行调优
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets =data
            outputs = tudui(imgs)
            #一部分数据在网络模型上的损失
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy
    print("整体测试集上的Loss:{}".format(total_test_loss))
    print("整体测试的正确率:{}".format(total_accuracy/test_data_size))
    writer.add_scalar("train_loss", loss.item(), total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
    #测试的步骤+1否则图画不出来
    total_train_step = total_test_step + 1

    torch.save(tudui, "tudui_{}.pth".format(i))
    print("模型已保存")
writer.close()

上面是一个训练过程,下面介绍一下训练准确率怎么得来的。

假设有一个2分类的模型

Model(2分类)

#下面是得分

Outputs = [[0.2,0.3],[0.1,0.4]]

#通过Argmax 变成

Preds = [1]

                [1]

Inputs target=[0][1]

Preds==inputs target

#上面的这个式子返回的就是T or F

#加起来就是分类正确的个数了。

[false,true].sum()=1

                                       

这边注意一下output.argmax(x)的方向,x是0或是1,0的方向是竖着来的,1的方向是横着来的。

import torch
outputs = torch.tensor([[0.1,0.2],
                        [0.3,0.4]])
print(outputs.argmax(1))
preds = outputs.argmax(1)
targets = torch.tensor([0,1])
print((preds == targets).sum())

-----------------------------------------------------未完待续1------------------------------------------------------------- 

 训练的一些细节:

如果有Dropout和BatchNorm等一些特殊层,需要

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1422327.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深入理解G0和G1指令:C++中的实现与激光雕刻应用

系列文章 ⭐深入理解G0和G1指令:C中的实现与激光雕刻应用⭐基于二值化图像转GCode的单向扫描实现⭐基于二值化图像转GCode的双向扫描实现⭐基于二值化图像转GCode的斜向扫描实现基于二值化图像转GCode的螺旋扫描实现基于OpenCV灰度图像转GCode的单向扫描实现基于Op…

Apple Pencil如何连接iPad?这里提供详细步骤

如果你刚拿起一支Apple Pencil,想和iPad一起使用,你需要先连接设备。将Apple Pencil与iPad配对的方法因你拥有的铅笔而异。 一旦你将Apple Pencil连接到iPad,你就可以利用这些方便的功能。你可以记下手写笔记,使用Scribble功能&a…

H5 嵌套iframe并设置全屏

H5 嵌套iframe并设置全屏 上图上代码 <template><view class"mp-large-screen-box"><view class"mp-large-screen-count"><view class"mp-mini-btn-color mp-box-count" hover-class"mp-mini-btn-hover" clic…

QEMU - e1000全虚拟化前端与TAP/TUN后端流程简析

目录 1. Host -> Guest 2.Guest ->Host 3. 如何修改以支持TUN设备的后端&#xff1f; 4. 相关 QEMU 源码 5. 实验 1. Host -> Guest 2.Guest ->Host 3. 如何修改以支持TUN设备的后端&#xff1f; 1. 简单通过后端网卡名字来判断是TUN还是TAP。 2. 需要前端全…

gdp调试—Linux

目录 介绍 使用 介绍 代码分为debug模式和release模式 如果一份代码要被调试&#xff0c;这份代码必须是debug Linux下编译代码默认是是release模式 如果你想代码是debug模式 必须加上 - g 小提&#xff1a; vim默认&#xff1a;命令模式 gcc默认&#xff1a;releas…

华为数通方向HCIP-DataCom H12-831题库(简答题01-27)

第01题 第02题 第03题 第04题 第05题 IS-IS是链路状态路由协议,使用SPF算法进行路由计算。某园区同时部署了IPV4和IPv6并运行IS-IS实现网络的互联与通。如图所示,该网络IPV4和IPV6开销相同,R1和R4只支持IPV4缺省情况下,计算形成的IPV6最短路径树中,R2访问R6的下一跳设备是…

【C++初阶】C++入门(2)

&#x1f525;博客主页&#xff1a; 小羊失眠啦. &#x1f3a5;系列专栏&#xff1a;《C语言》 《数据结构》 《C》 《Linux》 《Cpolar》 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 文章目录 一、函数重载1.1 函数重载的概念1.2 函数重载的种类1.3 C支持函数重载的原理 二…

最全前端 HTML 面试知识点

一、HTML 1.1 HTML 1.1.1 定义 超文本标记语言&#xff08;英语&#xff1a;HyperTextMarkupLanguage&#xff0c;简称&#xff1a;HTML&#xff09;是一种用于创建网页的标准标记语言 HTML元素是构建网站的基石 标记语言&#xff08;markup language &#xff09; 由无数个…

橱窗宝石 - 华为OD统一考试

OD统一考试&#xff08;C卷&#xff09; 分值&#xff1a; 100分 题解&#xff1a; Java / Python / C 题目描述 橱窗里有一排宝石&#xff0c;不同的宝石对应不同的价格&#xff0c;宝石的价格标记为 gems[i],0<i<n, n gems.length 宝石可同时出售0个或多个&#xff…

Mysql+MybatisPlus+Vue实现基础增删改查CRUD

数据库 设计数据库 设计几个字段&#xff0c;主键id自动增长且不可为空 create table if not exists user (id bigint(20) primary key auto_increment comment 主键id,username varchar(255) not null comment 用户名,sex char(1) not null comment 性…

十一、常用API——练习

常用API——练习 练习1 键盘录入&#xff1a;练习2 算法水题&#xff1a;练习3 算法水题&#xff1a;练习4 算法水题&#xff1a;练习5 算法水题&#xff1a; 练习1 键盘录入&#xff1a; 键盘录入一些1~100之间的整数&#xff0c;并添加到集合中。 直到集合中所有数据和超过2…

【新课】安装部署系列Ⅲ—Oracle 19c Data Guard部署之两节点RAC部署实战

本课程由云贝教育-刘峰老师出品&#xff0c;感谢关注 课程介绍 Oracle Real Application Clusters (RAC) 是一种跨多个节点分布数据库的企业级解决方案。它使组织能够通过实现容错和负载平衡来提高可用性和可扩展性&#xff0c;同时提高性能。本课程基于当前主流版本Oracle 1…

017 JavaDoc生成文档

什么是JavaDoc 示例 package se;/*** 用于学习Java* author Admin* version 1.0* since 1.8*/ public class Test20240119 {/*** 主方法* param args*/public static void main(String[] args) {// 你好&#xff0c;世界System.out.println("Hello world");} } 写一…

故障诊断 | 一文解决,GRU门控循环单元故障诊断(Matlab)

文章目录 效果一览文章概述专栏介绍模型描述源码设计参考资料效果一览 文章概述 故障诊断 | 一文解决,GRU门控循环单元故障诊断(Matlab) 专栏介绍 订阅【故障诊断】专栏,不定期更新机器学习和深度学习在故障诊断中的应用;订阅

opencv#40 图像细化

图像细化原理 作用&#xff1a;图像细化是将图像的线条从多像素宽度减少到单位像素宽度的过程&#xff0c;又被称为“骨架化”&#xff0c;删除像素点的标准&#xff1a; 通常情况下&#xff0c;我们使用二值化图像&#xff0c;我们在判断是否要删除某些像素点时&#xff0c;要…

遍历删除空文件夹

文章目录 遍历删除空文件夹概述笔记END 遍历删除空文件夹 概述 在手工整理openssl3.2编译完的源码工程中的文档, 其中有好多空文件夹. 做了一个小工具, 将空文件夹都遍历删除掉. 这样找文档方便一些. 删除后比对了一下, 空文件夹还真不少. 笔记 // EmptyDirRemove.cpp : 此…

音视频数字化(数字与模拟-音频广播)

在互联网飞速发展的今天,每晚能坐在电视机前面的人越来越少,但是每天收听广播仍旧是很多人的习惯。 从1906年美国费森登在实验室首次进行无线电广播算起,“广播”系统已经陪伴人们115年了。1916年,收音机开始上市,收音机核心是“矿石”。1920年开始“调幅”广播,1941年开…

Uniapp小程序端打包优化实践

背景描述&#xff1a; 在我们最近开发的一款基于uniapp的小程序项目中&#xff0c;随着功能的不断丰富和完善&#xff0c;发现小程序包体积逐渐增大&#xff0c;加载速度也受到了明显影响。为了提升用户体验&#xff0c;团队决定对小程序进行一系列打包优化。 项目优化点&…

Optimism的挑战期

1. 引言 前序博客&#xff1a; Optimism的Fault proof 用户将资产从OP主网转移到以太坊主网时需要等待一周的时间。这段时间称为挑战期&#xff0c;有助于保护 OP 主网上存储的资产。 而OP测试网的挑战期仅为60秒&#xff0c;以简化开发过程。 2. OP与L1数据交互 L1&#xf…

无人机除冰保障电网稳定运行

无人机除冰保障电网稳定运行 近日&#xff0c;受低温雨雪冰冻天气影响&#xff0c;福鼎市多条输配电线路出现不同程度覆冰。 为保障福鼎电网安全可靠运行&#xff0c;供电所员工运用无人机飞行技术&#xff0c;通过在无人机下方悬挂器具&#xff0c;将无人机飞到10千伏青坑线…