完整模型训练套路 试写 以CIFAR10分类数据集为例

news2024/9/23 21:22:48

思路步骤:  

第一步准备数据集:(训练集,测试集)

import torchvision

Train_dataset = torchvision.datasets.CIFAR10("./data",True,transform=torchvision.transforms.ToTensor())
Test_Dataset =torchvision.datasets.CIFAR10("./data",False,transform=torchvision.transforms.ToTensor())

 transform=torchvision.transforms.ToTensor() 将数据集图片转换为Tensor数据类型

第二步查询数据集长度:(len)

###第二步查询数据集长度
Train_length = len(Train_dataset)
Test_length = len(Test_Dataset)
print(f"训练集长度为{Train_length}")
print((f"测试集长度为{Test_length}"))

 

第三步加载数据集(DataLoader)

##第三步加载数据集
Train_load = DataLoader(Train_dataset,64,True)
Test_load = DataLoader(Test_Dataset,64,True)

 

第四步搭建神经网络,将神经网络放入单独的一个python文件中,测试网络正确性

class Tudui(nn.Module):

 nn.Module 是所有神经网络模块的基类,用于创建自定义网络。

通过继承nn.Module,模型可以使用PyTorch提供的各种功能和方法,如参数管理、设备移动等。 

遵循这种结构的模型可以轻松地与PyTorch的其他组件(如优化器、损失函数等)集成 

 

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui,self).__init__()

 

  • super() 主要用于调用父类(超类)的方法。

 

 

nn.Conv2d(3,32,padding=2,stride=1)

输入通道为3,输出通道为32,高宽不变padding = [(卷积核数-1 )/2]

 对应神经网络部分:

        self.model = nn.Sequential(
            nn.Conv2d(3,32,padding=2,stride=1,kernel_size=5),
            nn.MaxPool2d(2),
            nn.Conv2d(32,32,padding=2,stride=1,kernel_size=5),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,padding=2,stride=1,kernel_size=5),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10),

        )

具体搞清楚python中 self.model写法  不能用self 直接替代的原因

###第四步搭建神经网络
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,padding=2,stride=1,kernel_size=5),
            nn.MaxPool2d(2),
            nn.Conv2d(32,32,padding=2,stride=1,kernel_size=5),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,padding=2,stride=1,kernel_size=5),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10),

        )

    def forward(self,x):
        x = self.model(x)
        return  x

测试网络的正确性 

input = torch.ones(64,3,32,32)

创建一个四维张量,表示一批图像数据 ,参数batch_size,channels,H,W

if __name__ == '__main__':
    tudui = Tudui()
    input = torch.ones(64,3,32,32)
    output = tudui(input)
    print(output.shape)

结果: 

 

 

 

第五步创建网络模型

###创建网络模型
tudui = Tudui()

第六步创建损失函数:

loss_fn = nn.CrossEntropyLoss()

第七步创建优化器:

##创建优化器
learning_rate =0.01
optimize = torch.optim.SGD(tudui.parameters(),lr=learning_rate)

第八步设置训练参数,开始训练: 

起始参数设置:

train_step = 0###起始训练次数
test_step = 0###起始测试次数
epoch = 10 ####训练十轮

训练开始 :从train_load读取训练集

 

 for data in Train_load:####训练开始
        imgs,targets = data

将图片输入至模型中得到输出 ,并利用损失函数计算Loss

 output = tudui(imgs)
loss = loss_fn(output,targets)

 优化器梯度清零:

  1. 防止梯度累积:

    • PyTorch 默认累积梯度。不清理会导致梯度在多个批次间累加。
  2. 确保每批次独立:

    每个批次的梯度应该只反映当前批次的计算结果。
        optimize.zero_grad()
        loss.backward()
        optimize.step()
        train_step+=1
        print(f"训练次数为{train_step},loss值{loss}")

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1966558.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小程序记账系统2024

小程序记账系统2024,编号weixin002 下载在最后 技术栈: 前台:js 后台:java 展示: 下载地址: CSDN现在上传有问题,有兴趣的朋友先收藏.正常了贴上下载地址 备注:

十二星座男、被戴绿帽后有啥反应 。

白羊座——火爆型 金牛座——沉默型 双子座——智慧型 巨蟹座——自毁型 狮子座——暴力型 处女座——喊包型 天秤座——潇洒型 天蝎座——有仇必报型 射手座——气过就算型 摩羯座——冷傲型 双鱼座——大方伟大型 水瓶座——飘忽型

Shell脚本的进程管理

进程管理是系统管理的重要方面,通过对进程的监控、启动、停止和重启,可以保证系统的稳定运行。Shell脚本是一种强大的工具,可以对进程进行自动化管理,提高效率和准确性。 参考:shell脚本进程管理 - CSDN文库 shell脚本…

JavaScript(四)——JavaScript 语法

目录 JavaScript 语法 JavaScript 字面量 JavaScript 变量 JavaScript 操作符 JavaScript 语句 JavaScript 关键字 JavaScript 注释 JavaScript 数据类型 JavaScript 函数 JavaScript 字母大小写 JavaScript 字符集 驼峰命名法 小驼峰命名法 大驼峰命名法&#xf…

DrissionPage,一个超实用的 Python 库!

更多资料获取 📚 个人网站:ipengtao.com 大家好,今天为大家分享一个超实用的 Python 库 - DrissionPage。 Github地址:https://github.com/g1879/DrissionPage 在网页数据抓取和自动化测试中,Selenium 和 Requests 是…

LeetCode-35 - 在排序数组中查找元素的第一个和最后一个位置

力扣35题 题目描述:在排序数组中查找元素的第一个和最后一个位置 给你一个按照非递减顺序排列的整数数组 nums,和一个目标值 target。请你找出给定目标值在数组中的开始位置和结束位置。 如果数组中不存在目标值 target,返回 [-1, -1]。 你必…

N-way K-shot Few shot learning

首先需要明确的是少样本领域的数据划分和大规模监督学习方法的数据划分不一样。在大规模监督学习方法中,训练集和测试集是混合后按比例随机切分,训练集和测试集的数据分布一致。以分类问题为例,切分后训练集中的类别和测试集中的类别相同&…

二进制部署k8s单master集群

一、常见的K8S部署方式 1、Minikube Minikube是一个工具,可以在本地快速运行一个单节点微型K8S,仅用于学习、预览K8S的一些特性使用。 https://kubernetes.io/docs/setup/minikube 2、Kubeadmin Kubeadmin也是一个工具,提供kubeadm init…

大模型之技术概述

本文作为大模型综述第一篇,介绍大模型技术基本情况。 目录: 1.大模型技术的发展历程 2.大模型技术的生态发展 3.大模型技术的风险与挑战 1.大模型技术的发展历程 2006 年 Geoffrey Hinton 提出通过逐层无监督预训练的方式来缓解由于梯度消失而导致的…

HTML连接样式CSS和表格,表单

HTML连接样式CSS <!DOCTYPE html> <html> <head> <meta charset"utf-8"> <title>菜鸟教程(runoob.com)</title> </head> <-设置背景都是红色-> <body style"background-color:red;"> <-设置…

从赛场到云端:视频监控技术与赛事直播的技术融合与革新

在当今信息化高速发展的时代&#xff0c;视频监控技术和赛事直播作为两个重要的应用领域&#xff0c;正在以前所未有的速度融合&#xff0c;共同推动着传媒与安防领域的进步。本文将探讨视频监控技术在赛事直播中的应用及其带来的革新。 一、视频监控技术的演进 视频监控技术…

如何对CXL Port做Link Disable和Hot Reset

✨前言&#xff1a; 在CXL的验证测试中&#xff0c;对CXL Port做Link Disable和Hot Reset对比PCie的Port做相同的操作略有不同 ✨1.CXL Extensions DVSEC for Ports 协议里我们可以找到CXL协议里的第八张内容里的CXL Extensions DVSEC for Ports里的Port Control Extensions …

什么情况下你能接受 996

在当下的职场环境中&#xff0c;996 工作制一直是一个备受争议的话题。 “996”是一种工作制度的代称&#xff0c;指的是工作日早上 9 点上班&#xff0c;晚上 9 点下班&#xff0c;中午和傍晚休息 1 小时&#xff08;或不到&#xff09;&#xff0c;总计工作 10 小时以上&…

XSP04 PD诱骗芯片Type-C受电端用电5V9V10V11V12V15V20V,多协议PD+QC+AFC+FCP+SCP+VOOC使用体验

Type-C受电端控制芯片&#xff0c;顾名思义就是应用在用电端&#xff0c;例如3C数码产品、小家电、锂电池快充、小型发热产品等&#xff0c;一般产品使用Type-C接口&#xff0c;需要充电器的快充&#xff08;如9V以上&#xff09;供电&#xff0c;就可以使用XSP04 Type-C控制芯…

DevExpress WPF中文教程:如何将GridControl的更改发布到数据库?

DevExpress WPF拥有120个控件和库&#xff0c;将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpress WPF能创建有着强大互动功能的XAML基础应用程序&#xff0c;这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。 无论是Office办公软件…

【Linux】全志Tina使用swupdate命令进行ab区分区升级操作

一、代码 swupdate -v -i /ota.swu -e stable,now_A_next_B 二、介绍 首先需具备swupdate命令&#xff0c;没有此命令需使用make menuconfig开启。 其次需指定swu文件的路径&#xff0c;代码中的路径是“/ota.swu”。 之后需要知道当前是分区A还是分区B。 --从A升B&#x…

零基础入门AI:一键本地运行各种开源大语言模型 - Ollama

什么是 Ollama&#xff1f; Ollama 是一个可以在本地部署和管理开源大语言模型的框架&#xff0c;由于它极大的简化了开源大语言模型的安装和配置细节&#xff0c;一经推出就广受好评&#xff0c;目前已在github上获得了46k star。 不管是著名的羊驼系列&#xff0c;还是最新…

程序员转行大模型:从代码到无限可能

在技术日新月异的时代背景下&#xff0c;许多程序员开始思考自己的职业发展路径。面对着人工智能与机器学习领域的迅速崛起&#xff0c;越来越多的技术人员将目光投向了更为广阔的天地——转行成为大模型研究者或开发者。这一转变不仅要求个人技能的迭代升级&#xff0c;更是一…

解锁开发新纪元:GPT-4o mini的实战探索与效率革命

&#x1f308;所属专栏&#xff1a;【其它】✨作者主页&#xff1a; Mr.Zwq✔️个人简介&#xff1a;一个正在努力学技术的Python领域创作者&#xff0c;擅长爬虫&#xff0c;逆向&#xff0c;全栈方向&#xff0c;专注基础和实战分享&#xff0c;欢迎咨询&#xff01; 您的点…

容器操作基础命令

文章目录 一、启动容器启动容器用法 二、查看容器状态三、容器相关操作删除容器容器的开启和停止进入容器attachexec 暴露容器的端口查看容器的日志传递运行命令容器内部的hosts文件指定容器的DNS容器内和宿主机之间复制文件 一、启动容器 容器的生命周期 docker run可以启动…