完整网络模型训练(一)

news2024/12/29 9:02:29

文章目录

    • 一、网络模型的搭建
    • 二、网络模型正确性检验
    • 三、创建网络函数

一、网络模型的搭建

以CIFAR10数据集作为训练例子

准备数据集:

#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)

查看数据集的长度:

train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")

运行结果:
在这里插入图片描述

利用DataLoader来加载数据集:

train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

搭建CIFAR10数据集神经网络:
在这里插入图片描述
卷积层【1】代码解释:
#第一个数字3表示inputs(可以看到图中为3),第二个数字32表示outputs(图中为32)
#第三个数字5为卷积核(图中为5),第四个数字1表示步长(stride)
#第五个数字表示padding,需要计算,计算公式:
在这里插入图片描述

nn.Conv2d(3, 32, 5, 1, 2)

最大池化代码解释:
#数字2表示kernel卷积核

nn.MaxPool2d(2)

读图
卷积层【1】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述

最大池化【1】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述
卷积层【2】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述
以此类推

展平:
在这里插入图片描述
Flatten后它会变成64*4 *4的一个结果

线性输出:
在这里插入图片描述
线性输入是64*4 *4,线性输出是64,故如下代码
nn.LInear(64 *4 *4,64)

继续线性输出
在这里插入图片描述
nn.LInear(64,10)

搭建网络完整代码:

class Sen(nn.Module):
    def __init__(self):
        super(Sen, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1 ,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )
    def forward(self,x):
        x = self.model(x)
        return x

二、网络模型正确性检验

if __name__ == '__main__':
    sen = Sen()
    input = torch.ones((64, 3, 32, 32))
    output = sen(input)
    print(output.shape)

注释:

input = torch.ones((64, 3, 32, 32))

这一行代码的含义是:创建一个大小为 (64, 3, 32, 32) 的全 1 张量,数据类型为 torch.float32。
64:这是批次大小,代表输入有 64 张图片。
3:这是图片的通道数,通常为 RGB 图像的三个通道 (红、绿、蓝)。
32, 32:这是图片的高和宽,表示每张图片的尺寸为 32x32 像素。
torch.ones 函数用于生成一个全 1 的张量,这里的张量形状适合用于输入图像分类或卷积神经网络(CNN)中常见的 CIFAR-10 或类似的 32x32 像素图像数据。

运行结果:
在这里插入图片描述
可以得到成功变成了【64, 10】的结果。

三、创建网络函数

创建网络模型:

sen = Sen()

搭建损失函数:

loss_fn = nn.CrossEntropyLoss()

优化器:

learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)

优化器注释:
使用随机梯度下降(SGD)优化器
learning_rate = 1e-2 这里的1e-2代表的是:1 x (10)^(-2) = 1/100 = 0.01

记录训练的次数:

total_train_step = 0

记录测试的次数:

total_test_step = 0

训练的轮数:

epoch= 10

进行循环训练:

for i in range(epoch):
    print(f"第{i+1}轮训练开始")

    for data in train_dataloader:
        imgs, targets = data
        outputs = sen(imgs)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        print(f"训练次数:{total_train_step},Loss:{loss.item()}")

注释:
imgs, targets = data是解包数据,imgs 是输入图像,targets 是目标标签(真实值)
outputs = sen(imgs)将输入图像传入模型 ‘sen’,得到模型的预测输出 outputs
loss = loss_fn(outputs, targets)计算损失值(Loss),loss_fn 是损失函数,它比较outputs的值与targets 是目标标签(真实值)的误差
optimizer.zero_grad()清除优化器中上一次计算的梯度,以免梯度累积
loss.backward()反向传播,计算损失相对于模型参数的梯度
optimizer.step()使用优化器更新模型的参数,以最小化损失
loss.item() 将张量转换为 Python 的数值
loss.item演示:

import torch
a = torch.tensor(5)
print(a)
print(a.item())

运行结果:
在这里插入图片描述
因此可以得到:item的作用是将tensor变成真实数字5

本章节完整代码展示:

import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader

class Sen(nn.Module):
    def __init__(self):
        super(Sen, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1 ,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )
    def forward(self,x):
        x = self.model(x)
        return x
#准备数据集
#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)

#length长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")

train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

sen = Sen()

#损失函数
loss_fn = nn.CrossEntropyLoss()

#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)

#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epoch= 10

for i in range(epoch):
    print(f"第{i+1}轮训练开始")

    for data in train_dataloader:
        imgs, targets = data
        outputs = sen(imgs)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        print(f"训练次数:{total_train_step},Loss:{loss.item()}")

运行结果:
在这里插入图片描述
可以看到训练的损失函数在一直进行修正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2183068.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《OpenCV》—— 指纹验证

用两张指纹图片中的其中一张对其验证 完整代码 import cv2def cv_show(name, img):cv2.imshow(name, img)cv2.waitKey(0)def verification(src, model):sift cv2.SIFT_create()kp1, des1 sift.detectAndCompute(src, None)kp2, des2 sift.detectAndCompute(model, None)fl…

消费电子制造企业如何使用SAP系统提升运营效率与竞争力

在当今这个日新月异的消费电子市场中,企业面临着快速变化的需求、激烈的竞争以及不断攀升的成本压力。为了在这场竞赛中脱颖而出,消费电子制造企业纷纷寻求数字化转型的突破点,其中,SAP系统作为业界领先的企业资源规划(ERP)解决方…

Python批量下载PPT模块并实现自动解压

日常工作中,我们总是找不到合适的PPT模板而烦恼。即使有免费的网站可以下载,但是一个一个地去下载,然后再批量解压进行查看也非常的麻烦,有没有更好方法呢? 今天,我们利用Python来爬取一个网站上的PPT&…

SSM整合:图书管理系统

图书管理系统 一.环境 1.数据库环境 CREATE DATABASE ssmbuild;USE ssmbuild;DROP TABLE IF EXISTS books;CREATE TABLE books (bookID INT(10) NOT NULL AUTO_INCREMENT COMMENT 书id,bookName VARCHAR(100) NOT NULL COMMENT 书名,bookCounts INT(11) NOT NULL COMMENT 数量…

Leecode热题100-48.旋转图像

给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像,这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。 示例 1: 输入:matrix = [[1,2,3],[4,5,6],[7,8,9]] 输出:[[7,4,1],[8,5,2],[9,6,3]]示…

QML使用Qt自带软键盘例子

//注意:一定要保证Qt有安装VirtualKeyboard插件 import QtQuick 2.10 import QtQuick.Window 2.3 import QtQuick.Controls 2.3 import QtQuick.VirtualKeyboard 2.1 import QtQuick.VirtualKeyboard.Settings 2.1 Window { id: root visible: true w…

109.WEB渗透测试-信息收集-FOFA语法(9)

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 内容参考于: 易锦网校会员专享课 上一个内容:108.WEB渗透测试-信息收集-FOFA语法(8) 未授权burp: …

净利润暴跌,撤了,募投资金大比例购置不动产,突击申请专利

开科唯识终止原因如下:首先,报告期内,开科唯识收入规模较小,2023年上半年净利润更是出现暴跌的情况,其2023年可能难以满足创业板上市新规。此外,开科唯识研发费用率始终低于同行业可比公司,仅有…

线性代数书中求解齐次线性方程组、非齐次线性方程组方法的特点和缺陷(附实例讲解)

目录 一、克拉默法则 1. 方法概述 2. 例16(1) P45 3. 特点 (1) 只适用于系数矩阵是方阵 (2) 只适用于行列式非零 (3) 只适用于唯一解的情况 (4) 只适用于非齐次线性方程组 二、逆矩阵 1. 方法概述 2. 例16(2) P45 3. 特点 (1) 只适用于系数矩阵必须是方阵且可逆 …

每日读则推(二)

n.免疫疗法 n.策略,行动计划,战略 n.一代 v.设计(engineer n.工程师,设计师 v.设计,建造) A novel immunotherapy strategy using in vivo generation of engineered CAR T cells can n.(长篇)小说 a.新颖的,珍奇的 …

WebGIS包括哪些技术栈?怎么学习?

WebGIS,其实是利用Web开发技术结合地理信息系统(GIS)的产物,它是一种通过Internet实现GIS交互操作和服务的最佳途径。 WebGIS通过图形化界面直观地呈现地理信息和特定数据,具有可扩展性和跨平台性。 它提供交互性&am…

CSP-J二轮模拟赛----张浩轩补题报告

1.题目报告 1.交替出场2.翻翻转转3.方格取数4.圆圆中的方方AC0分--文件读写0分20分--骗分 2.赛中概况 第一题比较顺利,五六分钟就开始敲代码,暴力AC。 第二题耗了30分钟左右才有思路,写的时候也不大顺利,用得递归。文件读写错了…

HTML+CSS 基础第三季课堂笔记

一、CSS基础概念 CSS有两个重要的概念,分别是样式和布局 CSS的样式分为两种,一种是文字的样式,一种是盒模型的样式 CSS的另一个重要的特性就是辅助页面布局,完成HTML不能完成的功能,比如并排显示,比如精…

Flowable之任务撤回(支持主流程、子流程相互撤回)

撤回任务:主流程 > 主流程 处室主管【送科长审核】 处室主管【撤回科长审核】 流程日志 撤回任务:子流程 > 子流程 会办接收岗【送处室主管】 会办接收岗【撤回处室主管】 会办接收岗【同意】 撤回任务:子流程 > 主流程 处室主管…

秋招内推--招联金融2025

【投递方式】 直接扫下方二维码,或点击内推官网https://wecruit.hotjob.cn/SU61025e262f9d247b98e0a2c2/mc/position/campus,使用内推码 igcefb 投递) 【招聘岗位】 后台开发 前端开发 数据开发 数据运营 算法开发 技术运维 软件测试 产品策…

每日读则推(一)

v.避免,回避,撤销 n.黑花园蚁 Several animals are known to alter their behavior to avoid infections. But black garden v.改变,改动 n.传染病,感染,污染ants are th…

低空经济时代来临,挑战和机遇详细分析

低空经济作为一种新兴的经济形态,正逐步成为推动国民经济发展的新增长点。它依托于低空空域,涵盖通用航空、无人机应用、航空运动、低空旅游等多个领域,展现了广阔的发展前景和巨大的发展潜力。本文旨在详细分析低空经济时代来临所带来的挑战…

C语言数组和指针笔试题(三)

目录 字符数组四例题1例题2例题3例题4例题5例题6例题7 结果字符数组五例题1例题2例题3例题4例题5例题6例题7结果字符数组六例题1例题2例题3例题4例题5例题6例题7 结果 感谢各位大佬对我的支持,如果我的文章对你有用,欢迎点击以下链接 🐒🐒🐒个…

C语言、Eazy_x——井字棋

#include<graphics.h>char board_data[3][3] { { -,-,-},{ -,-,-},{ -,-,-}, };char current_piece o;//检测指定棋子玩家是否获胜 bool CheckWin(char c) {if (board_data[0][0] c && board_data[0][1] c && board_data[0][2] c)return true;if (…

html+css+js实现step进度条效果

实现效果 代码实现 HTML部分 <div class"box"><ul class"step"><li class"circle actives ">1</li><li class"circle">2</li><li class"circle">3</li><li class&quo…