22.优化器

news2025/1/15 16:52:40

优化器

当使用损失函数时,可以调用损失函数的 backward,得到反向传播,反向传播可以求出每个需要调节的参数对应的梯度,有了梯度就可以利用优化器,优化器根据梯度对参数进行调整,以达到整体误差降低的目的。

网站 : torch.optim — PyTorch 1.10 documentation

1.如何使用优化器?

(1)构造

# Example:
# SGD为构造优化器的算法,Stochastic Gradient Descent 随机梯度下降
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  #模型参数、学习速率、特定优化器算法中需要设定的参数
optimizer = optim.Adam([var1, var2], lr=0.0001)

(2)调用优化器的step方法

利用之前得到的梯度对参数进行更新

for input, target in dataset:
    optimizer.zero_grad() #把上一步训练的每个参数的梯度清零
    output = model(input)
    loss = loss_fn(output, target)  # 输出跟真实的target计算loss
    loss.backward() #调用反向传播得到每个要更新参数的梯度
    optimizer.step() #每个参数根据上一步得到的梯度进行优化

算法

如Adadelta、Adagrad、Adam、RMSProp、SGD等等,不同算法前两个参数:params、lr 都是一致的,后面的参数不同

CLASS torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
# params为模型的参数、lr为学习速率(learning rate)
# 后续参数都是特定算法中需要设置的

学习速率不能太大(太大模型训练不稳定)也不能太小(太小模型训练慢),一般建议先采用较大学习速率,后采用较小学习速率

SGD为例

以 **SGD(随机梯度下降法)**为例进行说明:

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
 
# 加载数据集并转为tensor数据类型
dataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 加载数据集
dataloader = DataLoader(dataset,batch_size=1)
 
# 创建网络名叫Tudui
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )
 
    def forward(self,x):   # x为input,forward前向传播
        x = self.model1(x)
        return x
 
# 计算loss
loss = nn.CrossEntropyLoss()
 
# 搭建网络
tudui = Tudui()
 
# 设置优化器
optim = torch.optim.SGD(tudui.parameters(),lr=0.01)  # SGD随机梯度下降法,我们上面搭建出来的模型就叫tudui,所以parameters就是tudui.parameters()
 
for data in dataloader:
    imgs,targets = data  # imgs为输入,放入神经网络中
    outputs = tudui(imgs)  # outputs为输入通过神经网络得到的输出,targets为实际输出
    result_loss = loss(outputs,targets)
    optim.zero_grad()  # 把网络模型中每一个可以调节的参数对应梯度设置为0
    result_loss.backward()  # backward反向传播求出每一个节点的梯度,是对result_loss,而不是对loss
    optim.step()  # 对每个参数进行调优

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

可以在以下地方打断点,debug:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

tudui ——> Protected Attributes ——> _modules ——> ‘model1’ ——> Protected Attributes ——> _modules ——> ‘0’ ——> weight ——> data 或 grad

通过每次按箭头所指的按钮(点一次运行一行),观察 data 和 grad 值的变化

  • 第一行 optim.zero_grad() 是让grad清零

  • 第三行 optim.step() 会通过grad更新data

    外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

完整代码

在 data 循环外又套一层 epoch 循环,一次 data 循环相当于对数据训练一次,加了 epoch 循环相当于对数据训练 20 次

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
 
# 加载数据集并转为tensor数据类型
dataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)
 
# 创建网络名叫Tudui
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )
 
    def forward(self,x):   # x为input,forward前向传播
        x = self.model1(x)
        return x
 
# 计算loss
loss = nn.CrossEntropyLoss()
 
# 搭建网络
tudui = Tudui()
 
# 设置优化器
optim = torch.optim.SGD(tudui.parameters(),lr=0.01)  # SGD随机梯度下降法,我们上面搭建出来的模型就叫tudui,所以parameters就是tudui.parameters()
for epoch in range(20):
    running_loss = 0.0  # 在每一轮开始前将loss设置为0
    for data in dataloader:  # 该循环相当于只对数据进行了一轮学习
        imgs,targets = data  # imgs为输入,放入神经网络中
        outputs = tudui(imgs)  # outputs为输入通过神经网络得到的输出,targets为实际输出
        result_loss = loss(outputs,targets)
        optim.zero_grad()  # 把网络模型中每一个可以调节的参数对应梯度设置为0
        result_loss.backward()  # backward反向传播求出每一个节点的梯度,是对result_loss,而不是对loss
        optim.step()  # 对每个参数进行调优
        running_loss = running_loss + result_loss  # 每一轮所有loss的和
    print(running_loss)

部分运行结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

优化器对模型参数不断进行优化,每一轮的 loss 在不断减小

实际过程中模型在整个数据集上的训练次数(即最外层的循环)都是成百上千/万的,本例仅以 20 次为例。

print(running_loss)

部分运行结果:

[外链图片转存中...(img-lfoM0bJi-1724861918306)]

优化器对模型参数不断进行优化,每一轮的 loss 在不断减小

实际过程中模型在整个数据集上的训练次数(即最外层的循环)都是成百上千/万的,本例仅以 20 次为例。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2084199.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Cryptomator:开源云存储加密

采用最新技术标准,提供最佳保护 如果有人查看您云中的文件夹,他们无法对您的数据得出任何结论。 Cryptomator 提供开源的客户端云文件加密。 它适用于 Windows、Linux、macOS 和 iOS。 Cryptomator 可与 Dropbox、Google Drive、OneDrive、MEGA、pClo…

【QT | 开发环境搭建】Linux系统(Ubuntu 18.04) 安装 QT 5.12.12 开发环境

😁博客主页😁:🚀https://blog.csdn.net/wkd_007🚀 🤑博客内容🤑:🍭嵌入式开发、Linux、C语言、C、数据结构、音视频🍭 ⏰发布时间⏰: 2024-08-29 …

C# 委托详解(Delegate)

引言 在 C# 编程当中,委托(Delegate)是一种特殊的类型,它允许将方法作为参数传递给其他方法,或者将方法作为返回值返回,这种特性使得委托成为实现回调函数、事件处理等,所有的委托都派生自Syst…

【STM32开发笔记】使用RT-Thread的SDIO驱动和FATFS实现SD卡文件读写

【STM32开发笔记】使用RT-Thread的SDIO驱动和FATFS实现SD卡文件读写 一、准备工作1.1 准备好开发板和SD卡1.2 创建RT-Thread项目 二、配置RT-Thread2.1 打开文件系统相关配置2.2 打开SD卡相关配置2.3 打开RTC配置2.4 重新生成Keil项目文件 三、编译、烧录、运行3.1 编译项目3.2…

网站建设完成后, 营销型网站如何做seo

营销型网站的SEO优化旨在提高网站在搜索引擎中的排名,从而吸引更多潜在客户并促进销售。以下是营销型网站SEO的详细解析: 关键词研究与优化 目标受众分析:了解目标受众的搜索习惯和需求,确定适合的关键词。使用工具来发现相关关键…

RV1126的GPIO计算和使用

1、获取GPIO芯片对应的序号值 先读取下/sys/kernel/debug/gpio的值,得到每个GPIO芯片的序号范围,如GPIO芯片0就为0~31。 2、根据GPIO硬件编号计算出系统内使用的GPIO序号 根据GPIO的编号,比如说GPIO3_B0,前面GPIO3代表看GPIO3的信…

傻瓜操作:GraphRAG、Ollama 本地部署及踩坑记录

目录 一、GraphRAG 介绍1.引言2.创新点3. 算法4. 数据和实验结果5.不足和展望 二、本地部署1.为什么要本地部署2.环境准备3. GraphRAG 安装3.1 下载 GraphGAG3.2 安装依赖包3.3 创建数据目录3.4 项目初始化3.5 修改配置文件 3.6 修改.env文件3.7 修改源码 4. Indexing5. query5…

Linux关于压缩之后文件更大的解释

记录于24年八月29 使用vim命令创建了lianxi1和lianxi2并在里面填写了一些内容,发现使用gzip和zip压缩后文件反而更大 事后问了一下ai回答了我的疑惑 压缩算法开销:如前所述,压缩文件需要存储额外的元数据和文件结构信息。这种开销在处理非常…

C++ TinyWebServer项目总结(13. 多进程编程)

本章讨论Linux多进程编程的以下内容: 复制进程映像的fork系统调用和替换进程映像的exec系列系统调用。僵尸进程以及如何避免僵尸进程。进程间通信(Inter Process Communication,IPC)最简单的方式:管道。三种System V进…

浏览器插件利器--allWebPluginV2.0.0.18-alpha版发布

allWebPlugin简介 allWebPlugin中间件是一款为用户提供安全、可靠、便捷的浏览器插件服务的中间件产品,致力于将浏览器插件重新应用到所有浏览器。它将现有ActiveX控件直接嵌入浏览器,实现插件加载、界面显示、接口调用、事件回调等。支持Chrome、Firefo…

[MRCTF2020]Unravel!!

使用zsteg查看图片有隐藏文件,没有头绪,先放弃 使用zsteg和010editor查看都发现一个png图片 把JM.png拷贝到kali,使用binwalk分离,得到一个aes.png 使用010editor查看wav,发现尾部有可疑的字符串,拷贝出来备…

记一次应急响应之网站暗链排查

目录 前言 1. 从暗链而起的开端 1.1 暗链的介绍 1.2 暗链的分类 2. 在没有日志的情况下如何分析入侵 2.1 寻找指纹 2.2 搜索引擎搜索fofa资产搜索 2.2.1 fofa资产搜索 2.2.2 bing搜索引擎搜索 3.通过搭建系统并进行漏洞复现 4. 应急响应报告编写 前言 免责声明 博文…

二叉树(binary tree)遍历详解

一、简介 二叉树常见的遍历方式包括前序遍历、中序遍历、后序遍历和层序遍历等。我将以下述二叉树来讲解这几种遍历算法。 1、创建二叉树代码实现 class TreeNode:def __init__(self,data):self.datadataself.leftNoneself.rightNonedef createTree():treeRootTreeNode(F)N…

大模型提示词工程技术3-提示词输入与输出的优化的技巧详细介绍

大模型提示词工程技术3-提示词输入与输出的优化的技巧详细介绍。《大模型提示词工程技术》的作者:微学AI,这是一本专注于提升人工智能大模型性能的著作,它深入浅出地讲解了如何通过优化输入提示词来引导大模型生成高质量、准确的输出。书中不…

腾讯地图三维模型加载GLTF,播放模型动画

腾讯地图三维模型加载,播放模型动画 关键代码 const clock new THREE.Clock();console.log(gltf)// 确保gltf对象包含scene和animations属性if (gltf && gltf.scene && gltf.animations) {// 创建AnimationMixer实例,传入模型的scenec…

【51单片机】2-3-1 【I/O口】【电动车防盗报警项目】震动传感器实验1—震动点灯

1.硬件 51单片机最小系统LED灯模块震动传感器模块 2.软件 main.c程序 #include "reg52.h"sbit led1 P3^7;//根据原理图(电路图),设备变量led1指向P3组IO口的第7口 sbit vibrate P3^3;//Do接到了P3.3口void Delay2000ms() //…

力扣刷题--2185. 统计包含给定前缀的字符串【简单】

题目描述 给你一个字符串数组 words 和一个字符串 pref 。 返回 words 中以 pref 作为 前缀 的字符串的数目。 字符串 s 的 前缀 就是 s 的任一前导连续字符串。 示例 1: 输入:words [“pay”,“attention”,“practice”,“attend”], pref “at…

用 Higress AI 网关降低 AI 调用成本 - 阿里云天池云原生编程挑战赛参赛攻略

作者介绍:杨贝宁,爱丁堡大学博士在读,研究方向为向量数据库 《Higress AI 网关挑战赛》正在火热进行中,Higress 社区邀请了目前位于排行榜 top5 的选手杨贝宁同学分享他的心得。下面是他整理的参赛攻略: 背景 我们…

Jmeter(十四)Jmeter分布式部署测试

单个接口测试,我们使用谷歌的插件postman 多个接口测试,我们使用Jmeter进行测试 一、使用工具测试 1、使用Jmeter对接口测试 首先我们说一下为什么用Posman测试后我们还要用Jmeter做接口测试,在用posman测试时候会发现的是一个接口一个接…

存储实验:基于华为存储实现存储双活(HyperMetro特性)

目录 什么是存储双活仲裁机制 实验需求实验拓扑实验环境实验步骤1. 双活存储存储初始化(OceanStor v3 模拟器)1.1开机,设置密码1.2登录DM,修改设备名、系统时间和导入License1.3 设置接口IP 2. 仲裁服务器配置(Centos7…