【Pytorch】深度学习之优化器

news2024/10/5 19:10:51

文章目录

  • Pytorch提供的优化器
    • 所有优化器的基类`Optimizer`
  • 实际操作
  • 实验
  • 参考资料

优化器
根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值,使得模型输出更加接近真实标签的工具
学习目标
image.png

Pytorch提供的优化器

优化器的库torch.optim
优化器举例:
image.png

所有优化器的基类Optimizer

optimizer定义

class Optimizer(object):
	def __init__(self, params, defaults):
		self.defaults = defaults
		self.state = defaultdict(dict)
		self.param_groups = []

Optimizer属性
defaults:存储优化器的超参数,举个例子

# 使用的超参数包括:学习率lr,动量momentum,阻尼动量抑制项dampening,权重衰减weight_decay,nesterov——bool值,决定是否使用Nesterov动量方法
{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}

state: 参数缓存

# defaultdict类型的参数缓存,存储的是一个tensor键值对,key值为一个需要计算梯度的模型参数,value值为一个momentum_buffer的键值对存储动量缓冲张量
defaultdict(<class 'dict'>, {tensor([[ 0.3864, -0.0131],
        [-0.1911, -0.4511]], requires_grad=True):{'momentum_buffer': tensor([[0.0052, 0.0052],
        [0.0052, 0.0052]])}})

param_groups: 参数组,一个list,每个元素是一个字典,字典的key值顺序是params,lr,momentum,dampening,weight_decay,nesterov

# 'params'参数对应的是一个存储待优化参数的list
[{'params': [tensor([[-0.1022, -1.6890],[-1.5116, -1.7846]], requires_grad=True)], 'lr': 1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]

Optimizer方法
zero_grad()方法、step()方法、add_param_group()方法、load_state_dict()方法、state_dict()方法
zero_grad(): 清空所管理参数的梯度,由于Tensor的梯度不会自动清零,因而每次backward时均需要清空梯度

def zero_grad(self, set_to_none: bool = False)for group in self.param_groups: # 遍历optimizer的参数组,不同参数组往往有着不同的超参数
        for p in group['params']:  # 遍历参数组中的tensor参数
            if p.grad is not None:  #梯度不为空,即需要优化的参数
                if set_to_none: 
                    p.grad = None # 将参数的梯度设置为None,表示在backward过程中不再跟踪这个梯度的计算图
                else:
                    if p.grad.grad_fn is not None: # 判断参数的梯度是否有梯度函数
                        p.grad.detach_() # 有梯度函数的参数梯度,即其是通过某个操作计算得到,使用`detach_`方法将其从计算图中分离
                    else:
                        p.grad.requires_grad_(False) # 没有梯度函数的参数梯度,通过`requires_grad`方法设置其不需要梯度
                    p.grad.zero_()# 梯度设置为0 

step():执行一步梯度更新,参数更新

def step(self, closure): 
    raise NotImplementedError # 在Optimizer基类中,step()函数被定义为抛出`NotImplementedError`异常,表明继承Optimizer的优化器类必须实现自己的step方法

add_param_group():添加参数组

def add_param_group(self, param_group):
	# 参数类型检查:检查传入的`param_group`是否为字典类型,如果不是则抛出异常
    assert isinstance(param_group, dict), "param group must be a dict"
    # 参数整理:获取传入参数组中的`params`字段,将其整理为list形式
	# 检查类型是否为tensor
    params = param_group['params']
    if isinstance(params, torch.Tensor):
        param_group['params'] = [params]
    elif isinstance(params, set):
	    # 如果参数为set类型,抛出异常
        raise TypeError('optimizer parameters need to be organized in ordered collections, but the ordering of tensors in sets will change between runs. Please use a list instead.')
    else:
        param_group['params'] = list(params)

	# 参数检查:对每个参数进行检查,确保是 leaf Tensor 类型
    for param in param_group['params']:
        if not isinstance(param, torch.Tensor):
            raise TypeError("optimizer can only optimize Tensors, but one of the params is " + torch.typename(param))
        if not param.is_leaf:
            raise ValueError("can't optimize a non-leaf Tensor")

    # 超参数设置检查:检查在该优化器defaults中要求的超参数是否被提供,若未提供则抛出异常
    for name, default in self.defaults.items():
        if default is required and name not in param_group:
            raise ValueError("parameter group didn't specify a value of required optimization parameter " + name)
        else:
            param_group.setdefault(name, default)

	# 检查当前提供的参数组中是否有重复参数,若有则抛出warning
    params = param_group['params']
    if len(params) != len(set(params)):
        warnings.warn("optimizer contains a parameter group with duplicate parameters; in future, this will cause an error; see github.com/PyTorch/PyTorch/issues/40967 for more information", stacklevel=3)
# 上面好像都在进行一些类的检测,报Warning和Error
	# 参数集合检查:检查当前参数组的参数是否与之前已有的参数组参数集合没有交集
    param_set = set()
    for group in self.param_groups:
        param_set.update(set(group['params']))

    if not param_set.isdisjoint(set(param_group['params'])):
        raise ValueError("some parameters appear in more than one parameter group")
# 添加参数
    self.param_groups.append(param_group)

load_state_dict():加载状态参数字典,可以实现模型的断点续训练

def load_state_dict(self, state_dict):
    r"""Loads the optimizer state.
    Arguments:
        state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`.
    """
    # deepcopy, to be consistent with module API
    state_dict = deepcopy(state_dict)
    # Validate the state_dict: 检验状态字典的参数组是否和当前优化器的参数组一致
    groups = self.param_groups
    saved_groups = state_dict['param_groups']
	# 检验参数组长度和参数组下tensor参数长度
    if len(groups) != len(saved_groups):
        raise ValueError("loaded state dict has a different number of parameter groups")
    param_lens = (len(g['params']) for g in groups)
    saved_lens = (len(g['params']) for g in saved_groups)
    if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
        raise ValueError("loaded state dict contains a parameter group that doesn't match the size of optimizer's group")

    # Update the state
    # 创建id映射,将状态字典中的参数与当前优化器中的参数对应起来
    id_map = {old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups)), chain.from_iterable((g['params'] for g in groups)))}

    def cast(param, value):
        r"""Make a deep copy of value, casting all tensors to device of param."""
   		.....

    # Copy state assigned to params (and cast tensors to appropriate types).
    # State that is not assigned to params is copied as is (needed for backward compatibility).
    # 转换并更新状态:将状态字典中的状态转换并更新到当前优化器中。
    state = defaultdict(dict)
    for k, v in state_dict['state'].items():
        if k in id_map:
            param = id_map[k]
            state[param] = cast(param, v)
        else:
            state[k] = v

    # Update parameter groups, setting their 'params' value
    def update_group(group, new_group):
       ...
    # 更新参数组
    param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
    # 调用`__setstate__`方法,实现将更新后的状态设置到当前优化器中
    self.__setstate__({'state': state, 'param_groups': param_groups})

state_dict():获取优化器当前状态信息字典

def state_dict(self):
    r"""Returns the state of the optimizer as a :class:`dict`.
    It contains two entries:
    * state - a dict holding current optimization state. Its content differs between optimizer classes.
    * param_groups - a dict containing all parameter groups
    """
    # Save order indices instead of Tensors
    param_mappings = {}
    start_index = 0
    
	# 将Optimizer类的状态字典进行打包操作
    def pack_group(group):
		......
	# 使用pack_group函数对param_groups中所有参数组进行打包
    param_groups = [pack_group(g) for g in self.param_groups]
    # Remap state to use order indices as keys
    # 遍历当前优化器状态字典(`self.state`)中的每一项,将键映射到参数组的顺序索引
    packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items()}
    return {
        'state': packed_state,
        'param_groups': param_groups,
    }

实际操作

代码示例

import os
import torch

# 设置权重
weight = torch.randn((2,2), requires_grad=True)
# 设置梯度
weight.grad = torch.ones((2,2))
# 输出当前权重和梯度
print("The data of weight before step:\n{}".format(weight.data))
print("The grad of weight before step:\n{}".format(weight.grad))
# 实例化优化器
optimizer = torch.optim.SGD([weight], lr=0.1, momentum=0.9)
# 一步优化操作
optimizer.step()
# 查看一步更新后的参数结果
print("The data of weight before step:\n{}".format(weight.data))
print("The grad of weight before step:\n{}".format(weight.grad))
# 权重清零
optimizer.zero_grad()
# 检验清零是否成功
print("The grad of weight after optimizer.zero_grad():\n{}".format(weight.grad))
# 输出优化器参数
print("optimizer.params_group is \n{}".format(optimizer.param_groups))
# 查看参数位置 --optimizer和weight的位置一样
print("weight in optimizer:{}\nweight in weight:{}\n".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))
# 添加参数
weight2 = torch.randn((3, 3), requires_grad=True)
optimizer.add_param_group({'params':weight2, 'lr': 0.0001, 'nesterov':True})
# 查看现有参数
print("optimizer.param_groups is\n{}".format(optimizer.param_groups))
# 查看当前优化器的状态信息
opt_state_dict = optimizer.state_dict()
print("state_dict before step:\n", opt_state_dict)
# 进行5次step操作
for _ in range(50):
    optimizer.step()
# 输出现有状态信息
print("state_dict after step:\n", optimizer.state_dict())
# 保存参数信息 --路径自行更换
torch.save(optimizer.state_dict(), os.path.join(r"D:\pythonProject\Attention_Unet", "optimizer_state_dict.pkl"))
print("Done!")
# 加载参数信息
state_dict = torch.load(r"D:\pythonProject\Attention_Unet\optimizer_state_dict.pkl") # 需要修改为你自己的路径
optimizer.load_state_dict(state_dict)
print("load state_dict successfully\n{}".format(state_dict))
# 输出属性信息
print("\n{}".format(optimizer.defaults))
print("\n{}".format(optimizer.state))
print("\n{}".format(optimizer.param_groups))

注意事项

  1. 每个优化器都是一个类,只有其经过实例化之后才能使用
class Net(nn.Module):
	...
net = Net()
optim = torch.optim.SGD(net.parameters(), lr = lr)
optim.step
  1. optimizer的操作分为两步:梯度置零,梯度更新
optimizer = torch.optim.SGD(net.parameters(), lr=1e-5)
for epoch in range(EPOCH):
	...
	optimizer.zero_grad()
	loss = ...
	loss.backward()
	optimizer.step()
  1. 以层为单位,设置每个优化器更新的参数权重
from torch import optim
from torchvision.models import resnet18

net = resnet18

optimizer = optim.SGD([
	{'params': net.fc.parameters()},
	{'params': net.layer4[0].conv1.parameters(), 'lr': 1e-2}
], lr=1e-5)

实验

数据生成

a = torch.linspace(-1, 1, 1000)
# 利用unsqueeze进行升维操作
x = torch.unsqueeze(a, dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(x.size()))

# 数据可视化
import matplotlib.pyplot as plt
plt.scatter(x,y)
plt.title('Generated Data') 
plt.xlabel('X-axis') 
plt.ylabel('Y-axis') 
plt.show()

网络结构

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.hidden = nn.Linear(1, 20)
        self.predict = nn.Linear(20, 1)

    def forward(self, x):
        x = self.hidden(x)
        x = F.relu(x)
        x = self.predict(x)
        return x

利用不同的优化器对该网络结构的权重参数进行优化,并绘制loss随着step变化的图示,得到收敛速度

import torch.optim as optim
import torch.nn.functional as F

# 定义模型
model1 = Net()
model2 = Net()

# 定义损失函数
criterion = nn.MSELoss()

# 定义两个不同的优化器
optimizer1 = optim.SGD(model1.parameters(), lr=0.01)
optimizer2 = optim.Adam(model2.parameters(), lr=0.01)

# 训练模型
num_epochs = 1000
losses1, losses2 = [], []

for epoch in range(num_epochs):
    # 将数据转换为 PyTorch 张量
    x_tensor = torch.FloatTensor(x).view(-1, 1)
    y_tensor = torch.FloatTensor(y).view(-1, 1)

    # 使用第一个优化器进行训练
    optimizer1.zero_grad()
    outputs1 = model1(x_tensor)
    loss1 = criterion(outputs1, y_tensor)
    loss1.backward()
    optimizer1.step()
    losses1.append(loss1.item())

    # 使用第二个优化器进行训练
    optimizer2.zero_grad()
    outputs2 = model2(x_tensor)
    loss2 = criterion(outputs2, y_tensor)
    loss2.backward()
    optimizer2.step()
    losses2.append(loss2.item())

# 绘制损失变化图
import matplotlib.pyplot as plt

plt.plot(range(num_epochs), losses1, label='SGD')
plt.plot(range(num_epochs), losses2, label='Adam')
plt.title('Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

训练结果图示:
image.png

问题说明
使用不同的optimizer对相同数据进行优化时,应该要用不同的模型,因为如果使用相同的模型,两个优化器的优化过程是相互干扰的
总结一下就是,相同输入数据,不同model实例,不同optimizer,相同criterion标准

参考资料

  1. datawhale through-pytorch repo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1092550.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

uniapp打包配置

安卓&#xff1a; 首先不管是什么打包都需要证书&#xff0c;安卓的证书一般都是公司提供或者自己去申请。然后把包名等下图框住的信息填上&#xff0c;点击打包即可。 ios&#xff1a;ios需要使用mac到苹果开发者平台去申请证书&#xff0c;流程可以参考下边的链接 参考链接…

Ceph 中的写入放大

新钛云服已累计为您分享769篇技术干货 介绍 Ceph 是一个开源的分布式存储系统&#xff0c;设计初衷是提供较好的性能、可靠性和可扩展性。 Ceph 独一无二地在一个统一的系统中同时提供了对象、块、和文件存储功能。 Ceph 消除了对系统单一中心节点的依赖&#xff0c;实现了无中…

基于单片机的感应自动门系统

目录 摘 要......................................................................................................................... 3 第一章 绪论.............................................................................................................…

论文阅读:

来源&#xff1a;公众号看到一篇文章 原文&#xff1a;https://arxiv.org/pdf/2301.04275.pdf 代码&#xff1a;GitHub - fengluodb/LENet: LENet: Lightweight And Efficient LiDAR Semantic Segmentation Using Multi-Scale Convolution Attention 0、摘要 基于LiDAR的语义…

【LeetCode刷题(数据结构)】:二叉树的前序遍历

给你二叉树的根节点root 返回它节点值的前序遍历 示例1&#xff1a; 输入&#xff1a;root [1,null,2,3] 输出&#xff1a;[1,2,3] 示例 2&#xff1a; 输入&#xff1a;root [] 输出&#xff1a;[] 示例 3&#xff1a; 输入&#xff1a;root [1] 输出&#xff1a;[1] 示例…

计算机毕业设计 基于Java的敬老院管理系统 Javaweb项目 Java实战项目 前后端分离 文档报告 代码讲解 安装调试

&#x1f34a;作者&#xff1a;计算机编程-吉哥 &#x1f34a;简介&#xff1a;专业从事JavaWeb程序开发&#xff0c;微信小程序开发&#xff0c;定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事&#xff0c;生活就是快乐的。 &#x1f34a;心愿&#xff1a;点…

基于pid控制的小功率直流电机调速

摘 要 随着电子技术的高度发展 ,直流电机测控逐步从模拟化向数字化转变。完全由硬件电路实现的直流电机测控系统 ,电路复杂 ,调整困难且可靠性不高 ,缺乏控制的灵活性。在工业控制中 ,按偏差的比例P、积分I和微分D进行控制的PID调节器现在得到广泛的应用。在小型微型 计算机用…

LoRa模块的通信范围与其他无线通信技术的比较

在物联网&#xff08;IoT&#xff09;和远程传感应用中&#xff0c;选择合适的无线通信技术至关重要。LoRa&#xff08;低功耗广域网&#xff09;模块因其低功耗、远距离通信和广覆盖范围而备受关注。本文将探讨LoRa模块的通信范围&#xff0c;并与其他无线通信技术如Wi-Fi和蓝…

Leetcode刷题详解——盛最多水的容器

1.题目链接&#xff1a;盛最多水的容器 2.题目描述&#xff1a; 给定一个长度为 n 的整数数组 height 。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线&#xff0c;使得它们与 x 轴共同构成的容器可以容纳最多的水。 返回容…

基于单片机的智能交通灯控制系统设计

目录 摘 要...................................................................................... 2 第一章 绪论........................................................................ 5 1.1 研究课题背景...................................................…

伪元素和伪类的区别和作用?

什么是伪元素和伪类 伪元素 伪元素&#xff08;pseudo-elements&#xff09;是CSS中的一种选择器&#xff0c;用于选择元素的特定部分而不是整个元素本身。伪元素允许你在已选择的元素内部创建或修改内容&#xff0c;而无需在文档结构中添加额外的HTML元素。伪元素的语法以::&…

基于局部结构特征的图像匹配

目录 第一章 绪论........................................................................ 6 1.1 研究课题背景....................................................... 6 1.2 图像匹配技术国内外发展现状........................... 8 1.3 课题研究的目的......…

FBI分享AvosLocker勒索软件的技术细节和防御建议

导语 近日&#xff0c;美国联邦调查局&#xff08;FBI&#xff09;和网络安全与基础设施安全局&#xff08;CISA&#xff09;联合发布了一份关于AvosLocker勒索软件的技术细节和防御建议的联合网络安全公告。该公告详细介绍了AvosLocker勒索软件的攻击方式和使用的工具&#xf…

MFF论文笔记

论文名称&#xff1a;Improving Pixel-based MIM by Reducing Wasted Modeling Capability_发表时间&#xff1a;ICCV2023 作者及组织&#xff1a;上海人工智能实验室&#xff0c;西门菲沙大学&#xff0c;香港中文大学 问题与贡献 MIM(Model Maksed Model)方法可以分为两部分…

WSL 配置 Linux

WSL 配置 Linux Windows 启动 Linux 子系统 控制面板 -> 程序和功能&#xff0c; 将 适用于 Linux 的 Windows 子系统 勾选。 安装 Terminal 在 Microsoft Store 市场上搜索 Terminal 安装 Windows Terminal。 安装 编译工具链 sudo apt update # 更新软件包 sudo apt i…

MyBatis自定义映射resultMap,处理一对多,多对一

1、自定义映射resultMap 复习&#xff1a;查询的标签select必须设置属性resultType或resultMap&#xff0c;用于设置实体类和数据库表的映射 关系 resultType&#xff1a;自动映射&#xff0c;用于属性名和表中字段名一致的情况 &#xff08;或设置了下划线映射为驼峰&#x…

qemu基础篇——VSCode 配置 GDB 调试

文章目录 VSCode 配置 GDB 调试安装 VSCode 插件调试文件创建调试配置配置脚本qemu 启动脚 启动调试报错情况一报错情况二报错情况三 调试界面运行 GDB 命令查看反汇编断点查看内核寄存器查看变量参考链接 VSCode 配置 GDB 调试 上一节中直接使用 GDB 命令行调试&#xff0c;本…

基于海洋捕食者优化的BP神经网络(分类应用) - 附代码

基于海洋捕食者优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码 文章目录 基于海洋捕食者优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码1.鸢尾花iris数据介绍2.数据集整理3.海洋捕食者优化BP神经网络3.1 BP神经网络参数设置3.2 海洋捕食者算法应用 4…

JVM字节码指令详解

文章目录 前言一、JVM字节码指令概述1. 什么是JVM字节码指令&#xff1a;2. 字节码指令的作用&#xff1a;3. 字节码指令的分类&#xff1a; 二、字节码指令的种类1. 加载和存储指令2. 算术指令3. 类型转换指令4. 对象和数组操作指令5. 操作数栈管理指令6. 控制转移指令7. 方法…

在emacs中,设置latex的主文档

文档&#xff1a; chapter1.tex chapter2.tex main.tex 在chapter1.tex中&#xff0c;先按下 ctrlc ctrln&#xff0c;再按下ctrlc ctrla&#xff0c;在下方的提示框中输入主文档。