模型训练套路(一)

news2025/1/8 4:42:36

一、训练完整使用网络模型

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from model1 import* # 此处的引用为此文在实现过程中所解决的问题

train_data = torchvision.datasets.CIFAR10(root = "../data", train=True,
                                          transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root = "../data", train=False,
                                          transform=torchvision.transforms.ToTensor(),download=True)
# 查看数据集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
# 格式化 # 格式化注意的是,之间是.的连接
print("训练数据集的长度为: {}".format(train_data_size))
print("测试数据集的长度为: {}".format(test_data_size))

# 利用dataloade r加载数据集#加载数据集的参数设置
train_dataloader = DataLoader (train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 创建网络模型
sun = SUN()

# 损失函数 交叉熵函数的使用
loss_fn = nn.CrossEntropyLoss()

# 优化器(SGD随机梯度下降)
learning_rate = 0.01
optimizer = torch.optim.SGD(sun.parameters(), lr = learning_rate)

# 设置网络训练的参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

for i in range(epoch):
    print("-------第{}轮训练开始--------".format(i+1))
    # 训练网络模型,从训练的data中取数据
    # 训练步骤开始
    for data in train_dataloader:
        imgs, targets = data
        outputs = sun(imgs)
        # 将得到的输出与真实的target比较,得到误差
        loss =loss_fn(outputs, targets)

        # 优化器优化模型
        # 进行优化,首先是梯度清零
        optimizer.zero_grad()
        # 得到每个节点的梯度
       loss.backward()
        # 对其中的参数进行优化
       optimizer.step()

        total_test_step = total_test_step + 1
        print("训练次数:{}, Loss: {}".format(total_test_step, loss.item()))

二、调用的神经网络模型

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential


class SUN(nn.Module):
    def __init__(self):
        super(SUN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x =self.model(x)
        return x



if __name__ == '__main__':
    sun = SUN()
    input = torch.ones((64, 3, 32,32))
    output = sun(input)
    print(output.shape)

三、调用python文件

在调用的python文件时,会出现一些问题:

from model1 import*

使用该语句调用,但是model1会画红色波浪线报错,并且,引用的神经网络也会出现报错,原因就是,未正确引用py文件。

尝试的解决办法:使用.model1,这种办法不可取后;

使用标记目录仍未成功;

最终,神经网络的py文件与训练的该文件在同一目录下,将被引用的Py文件,放在需引用文件的上一级目录下。也就是说,被引用文件在需引用文件的上一级。

# 接套路一代码:

# 如何知道数据训练好了没有
# 利用现有模型进行测试
# 在测试数据集上走一遍,以测试数据集的损失,来判定模型训练好了没有
# 测试过程中不需要在对模型进行调优
# 测试步骤开始
    total_test_loss = 0
    with torch.no_grad(): # 将参数梯度调零
        for data in test_dataloader:
            imgs, targets = data
            outputs = sun(imgs)
            loss =loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()

        print("整体测试集上的Loss:{}".format(total_test_loss))

        writer.add_scalar("test_loss",total_test_loss, total_train_step)
        total_test_step +=1

                # 对模型的保存

        torch.save(sun, "sun_{}.path".format(i))
        print("模型已保存")

writer.close()

在tensorboard上显示:

经过10轮的训练,测试集与训练集的损失值变化。

输出(outputs)与最终的预测(predicts)之间的转变,使用函数Argmax,就能够求出横向的最大值所在的位置。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2112922.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

63、Python之函数高级:装饰器缓存实战,优化递归函数的性能

引言 通过前面的文章,我们已经掌握了Python中常用的装饰器的使用技巧,这篇文章中,我们通过一个装饰器的实战案例,来进一步加深对装饰器的适用场景的理解。 本文的主要内容有: 1、递归函数 2、递归实现斐波那契数列…

AWTK HTML View 控件更新

AWTK HTML View 控件基于 Lite HTML 实现,从最初的版本开始,3 年多过去了,Lite HTML 做了大量的更新,最近抽空将 AWTK HTML View 控件适配到最新版本的 Lite HTML,欢迎大家使用。 AWTK HTML View 控件。HTML View 控件…

SAP B1 基础实操 - 用户定义字段 (UDF)

目录 一、功能介绍 1. 使用场景 2. 操作逻辑 3. 常用定义部分 3.1 主数据 3.2 营销单据 4. 字段设置表单 4.1 字段基础信息 4.2 不同类详细设置 4.3 默认值/必填 二、案例 1 要求 2 操作步骤 一、功能介绍 1. 使用场景 在实施过程中,经常会碰见用户需…

Qt线程使用

嗨嗨嗨,今天又学到了新的知识——线程,这个玩意在项目中使用的频率是非常高的,毕竟电脑的主线程就那么一个,那么这也就是我们为啥要学习线程的原因。比如说,我们们的游戏,如果我们的游戏界面显示的同时我们…

【生日视频制作】奥迪A8提车交车仪式AE模板修改文字软件生成器教程特效素材【AE模板】

奥迪A8提车交车仪式AE模板制作过程软件生成器素材 AE模板套用改图文教程↓↓: 怎么如何做的【生日视频制作】奥迪A8提车交车仪式AE模板修改文字软件生成器教程特效素材【AE模板】 生日视频制作步骤: 安装AE软件 下载AE模板 把AE模板导入AE软件 修改图片…

PD快充协议方案 及应用场景

快充协议诱骗原理主要依赖于快充协议芯片与供电端(如PD充电器)之间的握手通信,以申请所需要的电压与电流,确保充电过程安全、快速且高效。这种芯片通过内置的通讯模块与供电端通信,根据设备的实际需求调整输出电压和电…

大路灯护眼灯有必要吗安全吗?性价比高落地护眼灯推荐

大路灯护眼灯有必要吗安全吗?近几年来,随着生活节奏的加快,目前青少年的近视率呈现一个直线上升的趋势,其中占比达到了70%以上,并且最令人意外的是小学生竟然也占着比较大的比重,这一系列的数据不仅表明着近…

苍穹外卖学习笔记(一)

文章目录 开发环境搭建一. 前端环境搭建二. 后端环境搭建1.进入idea项目2.提交git仓库(推送github远程仓库)3.数据库环境搭建4.前后端联调(在源代码中项目已经实现登录功能)nginx反向代理好处: 三. 完善登录功能(md5加密存储)1.首先打开pojo模块中实体类的employee,…

[STL --stack_queue详解]stack、queue,deque,priority_queue,容器适配器

stack stack介绍 1、stack是一种容器适配器,专门用在具有后进先出操作的上下文环境中,其删除只能从容器的一端进行元素的插入与提取操作。 2、stack是作为容器适配器被实现的,容器适配器即是对特定类封装作为其底层的容器,并提供…

原理图库和PCB库的命名规范及创建封装、使用封装管理器

原理图库 命名规范 原理图中元件值标注规则 注:元件值(Component Value)就是元件最主要的特征对应的值。 Component value. Most analog components have a value that must be specified by this field (e.g., 2.7 kΩ). Additional disti…

c++数据结构之队列

目录 一、队列的含义 1.队列的使用 2.队列的结构 二、顺序队列的实现 1.队列的定义 2.队列的初始化 3.清空对列 4.队列是否为空 5.获取队列的长度 6.获取头元素的值 7.入队列 8.出队列 9.遍历队列中的值 10.总代码 11.打印结果 三、链表队列的实现 1.队列的…

【Hot100】LeetCode—347. 前 K 个高频元素

目录 1- 思路自定义Node结点 哈希表实现 2- 实现⭐347. 前 K 个高频元素——题解思路 3- ACM实现 原题连接:347. 前 K 个高频元素 1- 思路 自定义Node结点 哈希表实现 ① 自定义 Node 结点: 自定义 Node 结点中有 value 和 cnt 字段,其中…

力扣接雨水

给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多少雨水。 示例 1: 输入:height [0,1,0,2,1,0,1,3,2,1,2,1] 输出:6 解释:上面是由数组 [0,1,0,2,1,0,1,3,2,1,2,1] 表…

html css网页制作

​ 大家好,我是程序员小羊! 前言: HTML 和 CSS 是制作网页的基础。HTML 用于定义网页的结构和内容,CSS 用于设计网页的样式和布局。以下是一个详细的网页制作成品教程,包括 HTML 和 CSS 的基础知识,及如何…

MySQL基础(7)- 多表查询

目录 一、笛卡尔积的错误与正确的多表查询 1.出现笛卡尔积错误 2.正确的多表查询:需要有连接条件 3.查询多个表中都存在的字段 4.SELECT和WHERE中使用表的别名 二、等值连接vs非等值连接、自连接vs非自连接 1.等值连接 vs 非等值连接 2.自连接 vs 非自连…

安卓逆向(之)真机root(红米手机)

概览: 1, 手机解锁 2, 下载官方系统包,推荐线刷包,取出镜像文件 3, magisk工具修补 官方系统包 4, adb:命令对手机刷 root 5, 完成 6, 小米手机解锁 点击 小米手机解锁OEM官方教程 记得数据线连接手机电脑 工具下载 点击 下载adb(电脑操作…

进程间通信-进程池

目录 理解​ 完整代码 完善代码 回收子进程&#xff1a;​ 不回收子进程&#xff1a; 子进程使用重定向优化 理解 #include <iostream> #include <unistd.h> #include <string> #include <vector> #include <sys/types.h>void work(int rfd) {…

Windows下使用cmake编译OpenCV

Windows下使用cmake编译OpenCV cmake下载OpenCV下载编译OpenCV cmake下载 下载地址&#xff1a;https://cmake.org/download/ 下载完成&#xff0c;点击选择路径安装即可 OpenCV下载 下载地址&#xff1a;https://github.com/opencv/opencv/releases/tag/4.8.1因为我们是编译…

2024软件测试需要具备的技能(软技能硬技能)

软件测试的必备技能 在往期的文章分享了很多的面试题&#xff0c;索性做一个转型。从零基础开始讲解&#xff0c;结合面试题来和大家一起学习交流软件测试的艺术。 第一个是专业技能&#xff0c;也叫硬技能。 第二个叫做软技能。 我们在上一篇文章中讲到了软件测试流程的5个…

ChatGPT在论文写作领域的应用:初稿设计

学境思源&#xff0c;一键生成论文初稿&#xff1a; AcademicIdeas - 学境思源AI论文写作 学术论文写作中&#xff0c;内容清晰、结构合理的初稿至关重要。通过 ChatGPT&#xff0c;写作者可以快速生成内容框架、明确研究问题&#xff0c;并优化表达方式。不仅提高了写作效率&…