6-2 pytorch中训练模型的3种方法

news2024/12/25 9:04:06

Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异。(养成自己的习惯)
有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环。
下面以minist数据集的多分类模型的训练为例,演示这3种训练模型的风格。
其中类形式训练循环我们同时演示torchkeras.KerasModel和torchkeras.LightModel两种示范。

准备数据

transform = transforms.Compose([transforms.ToTensor()])

ds_train = torchvision.datasets.MNIST(root="./data/mnist/",train=True,download=True,transform=transform)
ds_val = torchvision.datasets.MNIST(root="./data/mnist/",train=False,download=True,transform=transform)

dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=4)
dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=4)

print(len(ds_train))
print(len(ds_val))

image.png

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

#查看部分样本
from matplotlib import pyplot as plt 

plt.figure(figsize=(8,8)) 
for i in range(9):
    img,label = ds_train[i] 
    img = torch.squeeze(img) # 删除为1的维度
    ax=plt.subplot(3,3,i+1)
    ax.imshow(img.numpy())
    ax.set_title("label = %d"%label)
    ax.set_xticks([])
    ax.set_yticks([]) 
plt.show()

image.png

一、脚本风格

脚本风格的训练循环非常常见。

net = nn.Sequential()
net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3))
net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("dropout",nn.Dropout2d(p = 0.1))
net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
net.add_module("flatten",nn.Flatten())
net.add_module("linear1",nn.Linear(64,32))
net.add_module("relu",nn.ReLU())
net.add_module("linear2",nn.Linear(32,10))

print(net)

image.png
代码量较多,可以查看最下方链接对应的notebook。

二、函数风格

该风格在脚本形式上做了进一步的函数封装。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Dropout2d(p = 0.1),
            nn.AdaptiveMaxPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,10)]
        )
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x
net = Net()
print(net)

image.png
代码量较多,可以查看最下方链接对应的notebook。

三、类风格

此处使用**torchkeras.KerasModel(其源码其实就是脚本风格中的代码)**高层次API接口中的fit方法训练模型。
使用该形式训练模型非常简洁明了。
先构建模型,同一二。

from torchkeras import KerasModel 

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Dropout2d(p = 0.1),
            nn.AdaptiveMaxPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,10)]
        )
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x
    
net = Net() 

print(net)

使用kerasModel:

from torchmetrics import Accuracy

model = KerasModel(net,
                   loss_fn=nn.CrossEntropyLoss(),
                   metrics_dict = {"acc":Accuracy(task='multiclass',num_classes=10)},
                   optimizer = torch.optim.Adam(net.parameters(),lr = 0.01)  )

model.fit(
    train_data = dl_train,
    val_data= dl_val,
    epochs=10,
    patience=3,
    monitor="val_acc", 
    mode="max",
    plot=True,
    cpu=True
)

训练过程:
image.png
其实编码训练代码按照自己的习惯即可,不必要按照以上三种方式。

参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1020864.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Git --- 基础介绍

Git --- 基础介绍 git 是什么git --- 工作区, 暂存区, 资源库git --- 文件状态git --- branch 和 HEADgit --- 一次正常的git提交流程 git 是什么 Git是一款分布式源代码管理工具(版本控制工具)Git和其他传统版本控制系统比较: 传统的版本控制系统(例如 SVN)是基于差异的版本控…

家政小程序开发制作,家政保洁上门维修小程序搭建

家政小程序开发制作,现如今家政上门服务,也越来越普及到我们的生活中,比如家电清洗,水电维修,家政保洁,上门护理等等方面。那么一个合格的家政小程序,需要满足哪些功能呢?今天就带大…

视频图像处理算法opencv模块硬件设计图像颜色识别模块

1、Opencv简介 OpenCV是一个基于Apache2.0许可(开源)发行的跨平台计算机视觉和机器学习软件库,可以运行在Linux、Windows、Android和Mac OS操作系统上 它轻量级而且高效——由一系列 C 函数和少量 C 类构成,同时提供了Python、Rub…

Stable Diffusion AI绘图使用记录

1、下载安装使用 官方网站https://github.com/AUTOMATIC1111/stable-diffusion-webui 跟着一步步安装就行(英文版的) 2、真人转二次元 下载控制插件Contro lnetGitHub - Mikubill/sd-webui-controlnet: WebUI extension for ControlNet 按照官方的安…

爬虫 — 反爬

目录 一、UA 反爬二、Cookie 验证与反爬1、Cookie 简介2、使用 Cookie 原因3、Cookie 作用3.1、模拟登录3.2、反反爬 三、Referer 反爬 一、UA 反爬 UA(User Agent):用户代理,是一个特殊字符串头,使得服务器能够识别客…

深入解读什么是期权的内在价值和时间价值?

期权品种越来越丰富,对于大家套利对冲都有很多的选择。而有些初学者对时间价值一直不理解,今天呢,就给大家讲一讲深入解读什么是期权的内在价值和时间价值?本文来自:期权酱 01在期权交易过程中,想必大家都会…

Layui快速入门之第十四节 分页

目录 一:基本用法 API 渲染 属性 二:自定义主题 三:自定义文本 四:自定义排版 五:完整显示 一:基本用法 分页组件 laypage 提供了前端的分页逻辑,使得我们可以很灵活处理不同量级的数…

SAP MM学习笔记32 - 购买依赖的承认(采购申请的审批)

多数公司都不会随便让采购员买东西的,而是要设置一个审批,甚至是层层审批,之后才能购买。 一般流程是 购买依赖(采购申请) > 购买发注(采购),这个承认(审批&#xf…

PyCharm使用技巧小记

目录 1 汉化2 主题3 关联滚轮和字体大小4 快捷键4.1 常用快捷键4.2 查看软件快捷键5 上下文菜单取消 注:本文基于PyCharm 2023.2.1 版本进行截图演示。不同版本有所差异,注意区分。 1 汉化 1、点击右上角‘齿轮’图标 2、单击插件 3、在上方搜索框中…

python selenium如何带cookie访问网站

python selenium如何带cookie访问网站 要使用Python的Selenium库带有cookie访问网站,你可以按照以下步骤进行操作: 一、流程介绍 安装Selenium库(如果尚未安装): pip install selenium导入Selenium库并启动一个浏览…

DragGAN使用记录

效果图 调整人物动作 调整裙子长度 调整动物的动作 DragGAN介绍 DragGAN是一种基于人工智能的图像编辑工具,它可以根据用户的输入生成逼真的图像。与传统的图像编辑工具只能扭曲或裁剪现有的像素不同,DragGAN可以创建与用户意图匹配的新内容。 Drag…

Kubernetes网络揭秘:看完你就懂了

一、Master集群网络 master集群的网络比较简单,和通常的负载均衡集群一样。多个节点的apiserver的ip与端口(6443)使用负载均衡的ip与端口。在master/node节点join时均使用此负载均衡的ip与端口,这样就是master节点的集群网络。 master 节点之间的网络&a…

QT:使用多窗口做一个登录注册小项目(登录窗口、登录结果窗口、注册窗口)

widget.h(登录窗口) #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QCheckBox> #include <QLabel> #include <QLineEdit> #include <QPushButton> #include <Qmap> //模板类class Widget : public QWidget …

7年阿里测试经验之谈 —— 用UI自动化测试实现元素定位

随着IT行业的发展&#xff0c;产品愈渐复杂&#xff0c;web端业务及流程更加繁琐&#xff0c;目前UI测试仅是针对单一页面&#xff0c;操作量大。为了满足多页面功能及流程的需求及节省工时&#xff0c;设计了这款UI 自动化测试程序。旨在提供接口&#xff0c;集成到蜗牛自动化…

软件测试:什么是敏捷测试?

1. 什么是敏捷测试 敏捷测试是一种在敏捷开发环境中进行软件测试的方法&#xff0c;不同于传统瀑布模型中的测试阶段&#xff0c;敏捷测试强调持续测试、快速反馈和合作开发。 敏捷测试与敏捷开发相辅相成&#xff0c;通过频繁的迭代和增量开发来提高软件的交付速度和质量。 …

C++QT day9

完善登录框 点击登录按钮后&#xff0c;判断账号&#xff08;admin&#xff09;和密码&#xff08;123456&#xff09;是否一致&#xff0c;如果匹配失败&#xff0c;则弹出错误对话框&#xff0c;文本内容“账号密码不匹配&#xff0c;是否重新登录”&#xff0c;给定两个按钮…

基于GPIO子系统编写LED灯驱动

驱动程序 #include <linux/init.h> #include <linux/module.h> #include <linux/of.h> #include <linux/of_gpio.h> #include <linux/gpio.h> #include <linux/fs.h> #include <linux/io.h> #include <linux/device.h> #incl…

【Spring】IOC基本用法

&#x1f388;博客主页&#xff1a;&#x1f308;我的主页&#x1f308; &#x1f388;欢迎点赞 &#x1f44d; 收藏 &#x1f31f;留言 &#x1f4dd; 欢迎讨论&#xff01;&#x1f44f; &#x1f388;本文由 【泠青沼~】 原创&#xff0c;首发于 CSDN&#x1f6a9;&#x1f…

js自带的字体图标

let body document.querySelector(body)body.width 100%for (let i 1; i < 10000; i) {let str &#i;let sm str.big()let st document.createElement(span)st.innerHTML smst.style.fontSize 30pxbody.appendChild(st)} 结果如下

Vue2+Vue3基础入门到实战项目全套教程的学习笔记

内容的视频链接点击此处可进入 这套笔记是按照视频和视频笔记总结的笔记&#xff0c;主要是方便vue的学习或温习&#xff0c;基本抛弃css样式的添加&#xff0c;专注于vue的使用。 第一天 Vue 快速上手 Vue的概念 Vue 是一个用于 构建用户界面 的 渐进式 框架 创建实例 …