模型的深度优化

news2024/11/28 6:46:26

文章目录

    • 一、测试模型是否正确
    • 二、图形打印直观观察
    • 三、保存训练模型
    • 四、正确率(仅使用于分类问题)

一、测试模型是否正确

本文承接我的上一篇文章完整网络模型训练(一)
运用测试数据集(test_dataloader)进行检验

total_test_loss = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = sen(imgs)
            #这里的loss是一个tensor数据类型
            loss = loss_fn(outputs, targets)
            #加上item进行转换成整形
            total_test_loss = total_test_loss + loss.item()
    print(f"整体测试集上的Loss{total_test_loss}")

注释:
with torch.no_grad()使用 no_grad() 上下文,禁用梯度计算,以节省内存和加快计算速度(因为在测试阶段不需要反向传播)

在pycharm的运行框中按下crtl+f可以弹出搜索框
运行结果:
在这里插入图片描述
可以看到确实有整体数据集的信息,但是被很多乱七八遭的信息给掩盖了,所以可以改善一下代码

 if total_train_step % 100 == 0:
            print(f"训练次数:{total_train_step},Loss:{loss.item()}")

每当步骤为整百的时候才进行打印,这样子输出结果能够整洁一点:
在这里插入图片描述

二、图形打印直观观察

也可以用tensorboard进行画图观察,补上一下代码:

writer = SummaryWriter("./logs_train")
 if total_train_step % 100 == 0:
            print(f"训练次数:{total_train_step},Loss:{loss.item()}")
            writer.add_scalar("train_loss", loss.item(), total_train_step)
print(f"整体测试集上的Loss{total_test_loss}")
    writer.add_scalar("test_loss", total_test_loss,total_test_step)
    total_test_step = total_test_step + 1

运行结果:
在这里插入图片描述
在这里插入图片描述
可以直观的看到训练的模型的loss损失函数在不断的下降。

三、保存训练模型

还可以保存每一轮训练的模型:

torch.save(sen, f"sen_{i}.pth")
    print("模型已保存")

使用 torch.save() 将模型 ‘sen’ 保存到文件中,文件名为 “sen_{i}.pth”

四、正确率(仅使用于分类问题)

哪怕我们已经得到整体测试数据集上的Loss,也不能很好的说明数据集实际上的表现效果
在分类问题中我们可以用正确率进行表示

    total_accuracy = 0
accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy

使用 writer.add_scalar() 方法,将标量数据记录到日志文件中,用于在 TensorBoard 中可视化
“test_accuracy” 是图表的名称,用于表示测试集的准确率
total_accuracy/test_data_size 表示计算出的测试集上的准确率(总正确预测数 / 测试集数据总量)
total_test_step 表示当前测试步骤,用作 X 轴上的标记,表明该数据是在第几次测试时记录的

    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)

运行结果:
在这里插入图片描述
可以看到计算处理整体数据集上的正确率,为32%

有时我们能看到有些代码会写上以下两段代码:

在训练模型开始的时候

sen.train()

例如:

	sen.train()
    for data in train_dataloader:
        imgs, targets = data
        outputs = sen(imgs)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

在测试步骤开始的时候

sen.eval()

例如:

 	sen.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = sen(imgs)
            #这里的loss是一个tensor数据类型
            loss = loss_fn(outputs, targets)
            #加上item进行转换成整形
            total_test_loss = total_test_loss + loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy

这些只对特定模型起作用,例如dropout,不过一般都写上,反正无影响。

训练模型总体过程
准备数据->加载数据->准备模型->设置损失函数->设置优化器->开始训练->最后验证->结果聚合展示

整体代码展示:

import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

#创建网络模型
class Sen(nn.Module):
    def __init__(self):
        super(Sen, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1 ,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )
    def forward(self,x):
        x = self.model(x)
        return x
#准备数据集
#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)

#length长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")

#利用DataLoader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

#创建网络模型
sen = Sen()

#损失函数
loss_fn = nn.CrossEntropyLoss()

#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)

#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epoch= 10

#添加tensorboard
writer = SummaryWriter("./logs_train")

for i in range(epoch):
    print(f"-------第{i+1}轮训练开始-------")

    #训练步骤开始
    sen.train()
    for data in train_dataloader:
        imgs, targets = data
        outputs = sen(imgs)
        loss = loss_fn(outputs, targets)

        #优化器模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print(f"训练次数:{total_train_step},Loss:{loss.item()}")
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    #测试步骤开始
    sen.eval()
    total_test_loss = 0
    total_accuracy = 0
    #不需要梯度进行调整或者优化
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = sen(imgs)
            #这里的loss是一个tensor数据类型
            loss = loss_fn(outputs, targets)
            #加上item进行转换成整形
            total_test_loss = total_test_loss + loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy

    print(f"整体测试集上的Loss{total_test_loss}")
    print(f"整体数据集上的正确率;{total_accuracy/test_data_size}")
    writer.add_scalar("test_loss", total_test_loss,total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
    total_test_step = total_test_step + 1


writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2186824.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第二十一章 (动态内存管理)

1. 为什么要有动态内存分配 2. malloc和free 3. calloc和realloc 4. 常⻅的动态内存的错误 5. 动态内存经典笔试题分析 6. 总结C/C中程序内存区域划分 1.为什么要有动态内存管理 我们目前已经掌握的内存开辟方式有 int main() {int num 0; //开辟4个字节int arr[10] …

全局安装cnpm并设置其使用淘宝镜像的仓库地址(地址最新版)

npm、cnpm和pnpm基本概念 首先介绍一下npm和cnpm是什么,顺便说一下pnpm。 npm npm(Node Package Manager)是Node.js的默认包管理器,用于安装、管理和分享JavaScript代码包。它是全球最大的开源库生态系统之一,提供了数…

centos环境安装JDK详细教程

centos环境安装JDK详细教程 一、前期准备二、JDK安装2.1 rpm方式安装JDK2.2 zip方式安装JDK2.3 yum方式安装JDK 本文主要说明CentOS下JDK的安装过程。JDK的安装有三种方式,用户可根据实际情况选择: 一、前期准备 查看服务器操作系统型号,执…

YOLOv7改进之主干DAMOYOLO结构,结合 CReToNeXt 结构,打造高性能检测器

一、DAMOYOLO理论部分 论文地址:2211.15444 (arxiv.org) 在本报告中,我们提出了一种快速准确的对象检测方法,称为 DAMO-YOLO,它实现了比最先进的 YOLO 系列更高的性能。DAMO-YOLO 是从 YOLO 扩展而来的,具有一些新技术,包括神经架构搜索 (NAS)、高效的重新参数化广义 …

【Linux】用虚拟机配置Ubuntu 24.04.1 LTS环境

目录 1.虚拟机安装Ubuntu系统 2.Ubuntu系统的网络配置 3.特别声明 首先我们先要下载VMware软件,大家自己去下啊! 1.虚拟机安装Ubuntu系统 我们进去之后点击创建新的虚拟机,然后选择自定义 接着点下一步 再点下一步 进入这个界面之后&…

[C++]使用纯opencv部署yolov11目标检测onnx模型

yolov11官方框架:https://github.com/ultralytics/ultralytics 【算法介绍】 在C中使用纯OpenCV部署YOLOv11进行目标检测是一项具有挑战性的任务,因为YOLOv11通常是用PyTorch等深度学习框架实现的,而OpenCV本身并不直接支持加载和运行PyTor…

Python+Matplotlib展示单射、双射、满射

import matplotlib.pyplot as pltplt.rcParams[font.sans-serif] [SimHei] # 用来正常显示中文标签 plt.rcParams[axes.unicode_minus] False # 用来正常显示负号def create_function_plot(ax, title, x, y, connections):ax.set_xlim(0, 5)ax.set_ylim(0, 5)ax.set_aspect…

UE4 材质学习笔记01(什么是着色器/PBR基础)

1.什么是shader 着色器是控制屏幕上每个像素颜色的代码,这些代码通常在图形处理器上运行。 现如今游戏引擎使用先进的基于物理的渲染和照明。而且照明模型模型大多数是被锁定的。 因此我们创建着色器可以控制颜色,法线,粗糙度,…

十一假期不停歇-学习ROS第二天

一、自动修改环境变量 source install/setup.bash 这句指令可不重复添加环境变量,直接运行就可以。 从左右图即可即可看出有所不同的。以后写完功能包发现不存在那就source一下就好了。 实际上它运行的是 python3.10那里的功能包与节点 我们在创建完功能包会出…

python-矩阵转置/将列表分割成块/和超过N的最短子数组

一:矩阵转置 题目描述 输入一个 n 行 m 列的矩阵 A,输出它的转置 AT。输入 第一行包含两个整数 n 和 m,表示矩阵 A 的行数和列数。1≤n≤100,1≤m≤100。接下来 n 行,每行 m 个整数,表示矩阵 A 的元素。相邻…

别再为开题报告发愁了,ChatGPT可以助力你高效完成!

学境思源,一键生成论文初稿: AcademicIdeas - 学境思源AI论文写作 面对开题报告时,我们常常感到迷茫和压力。如何清晰地表达研究目的、梳理文献综述、明确研究问题与假设,以及构建合理的研究方法,这些环节都对开题报告…

OpenCV-图像拼接

文章目录 一、基本原理二、步骤三、代码实现1.定义函数2.读取图像3.图像配准(1).特征点检测(2).特征匹配 4.透视变换5.图像拼接 四、图像拼接的注意事项 图像拼接是一种将多张有重叠部分的图像合并成一张无缝的全景图或高分辨率图…

蓝桥杯【物联网】零基础到国奖之路:十七. 扩展模块之单路ADC和NE555

蓝桥杯【物联网】零基础到国奖之路:十七. 扩展模块之单路ADC和NE555 第一节 硬件解读第二节 CubeMx配置第三节 代码1,脉冲部分代码2,ADC部分代码![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/57531a4ee76d46daa227ae0a52993191.png) 第一节 …

基于工业物联网的能源监控系统:边缘数据处理的应用

论文标题:《Industrial IoT-Based Energy Monitoring System: Using Data Processing at Edge》 作者信息: Akseer Ali MiraniAnshul AwasthiNiall O’MahonyJoseph Walsh 他们均来自爱尔兰的芒斯特技术大学IMaR研究中心,以及位于利默里克的…

秋招校招北森笔试测评北森笔测评常见图推题目解答

北森笔试测评常见题目解析第二弹: P1:每一行均出现圆,米,口加上内部的菱形和小圆形,行测上称之为遍历。 P2:题干也给提示了,肯定不选9,考虑封闭部分,选11. P3&#xf…

面试速通宝典——7

150. 数据库连接池的作用 数据库连接池的作用包括以下几个方面: 资源重用:连接池允许多个客户端共享有限的数据库连接,减少频繁创建和销毁连接的开销,从而提高资源的利用率。 统一的连接管理:连接池集中管理数据库连…

访问webapps下边的内容不能访问解决办法

1、看是否是带有中文路径,中文路径访问不了 2、是否在访问的时候没有带8080端口 如图,带有8080就可以访问了 原因: 当你尝试通过 http://localhost:8080/hello/hello.html 访问网页时,端口号 “8080” 指定了该请求将发送到本地主…

STM32-MPU6050+DAM库源码(江协笔记)

目录 1、MPU6050简介 2、MPU6050参数 3、MPU6050硬件电路 4、MPU6050结构 5、MPU6000和MPU6050的区别 6、MPU6050应用场景 7、MPU6050电气参数 8、MPU6050时钟源选择 9、MPU6050中断源 10、MPU6050的I2C读写操作 11、DMP库移植 1、MPU6050简介 10轴传感器&#xff1…

使用CSS实现酷炫加载

使用CSS实现酷炫加载 效果展示 整体页面布局 <div class"container"></div>使用JavaScript添加loading加载动画的元素 document.addEventListener("DOMContentLoaded", () > {let container document.querySelector(".container&q…

Unity初识+面板介绍

Unity版本使用 小版本号高&#xff0c;出现bug可能性更小&#xff1b;一台电脑可以安装多个版本的Unity&#xff0c;但是需要安装在不同路径&#xff1b;安装Unity时不能有中文路径&#xff1b;Unity项目路径也不要有中文。 Scene面板 相当于拍电影的片场&#xff0c;Unity程…