06_PyTorch 模型训练[学习率与优化器基类]

news2024/12/23 20:45:05
当数据、模型和损失函数确定,任务的数学模型就已经确定,接着就要选择一个合适
的优化器(Optimizer)对该模型进行优化。
PyTorch 中所有的优化器(如:optim.Adadelta、optim.SGD、optim.RMSprop 等)均是
Optimizer 的子类,Optimizer 中定义了一些常用的方法,有 zero_grad()、 step(closure)、state_dict()、load_state_dict(state_dict)和 add_param_group(param_group)
optimizer 对参数的管理是基于组的概念,可以为每一组参数配置特定 lr,momentum,weight_decay 等等。
参数组在 optimizer 中表现为一个 list(self.param_groups),其中每个元素是 dict,表示一个参数及其相应配置,在 dict 中包含'params'、'weight_decay'、'lr' 、 'momentum'等字段。

1.基本概念代码

import torch
import torch.optim as optim


w1 = torch.randn(2, 2)
w1.requires_grad = True

w2 = torch.randn(2, 2)
w2.requires_grad = True

w3 = torch.randn(2, 2)
w3.requires_grad = True

print("w1",w1)
print("w2",w2)
print("w3",w3)

# 一个参数组
optimizer_1 = optim.SGD([w1, w3], lr=0.1)
print('len(optimizer.param_groups): ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')

# 两个参数组
optimizer_2 = optim.SGD([{'params': w1, 'lr': 0.1},
                         {'params': w2, 'lr': 0.001}])
print('len(optimizer.param_groups): ', len(optimizer_2.param_groups))
print(optimizer_2.param_groups)

2. zero_grad()

作用:将梯度清零。由于 PyTorch 不会自动清零梯度,所以在每一次更新前会进行此操作。

代码与输出:

import torch
import torch.optim as optim

# ----------------------------------- zero_grad

w1 = torch.randn(2, 2)
w1.requires_grad = True

w2 = torch.randn(2, 2)
w2.requires_grad = True

optimizer = optim.SGD([w1, w2], lr=0.001, momentum=0.9)

print(optimizer.param_groups)
print("=======================")
print(optimizer.param_groups[0])
print("=======================")
print(optimizer.param_groups[0]['params'])
print("=======================")
print(optimizer.param_groups[0]['params'][0])  #参数w1

optimizer.param_groups[0]['params'][0].grad = torch.randn(2, 2)  

print('参数w1的梯度:')
print(optimizer.param_groups[0]['params'][0].grad, '\n')  # 参数组,第一个参数(w1)的梯度

optimizer.zero_grad()
print('执行zero_grad()之后,参数w1的梯度:')
print(optimizer.param_groups[0]['params'][0].grad)  # 参数组,第一个参数(w1)的梯度

 

3.state_dict()

作用:获取模型当前的参数,以一个有序字典形式返回。

这个有序字典中,key 是各层参数名,value 就是参数。
代码与输出:
import torch.nn as nn
import torch.nn.functional as F


# ----------------------------------- state_dict
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, 3)  #输出一个特征图,需要3个 3*3 的矩阵
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(1 * 3 * 3, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 1 * 3 * 3)
        x = F.relu(self.fc1(x))
        return x


net = Net()

# 获取网络当前参数
net_state_dict = net.state_dict()

print('net_state_dict类型:', type(net_state_dict))
print('net_state_dict管理的参数: ', net_state_dict.keys())
for key, value in net_state_dict.items():
    print('参数名: ', key, '\t大小: ',  value.shape)

4.add_param_group()  

作用:

给 optimizer 管理的参数组中增加一组参数,可为该组参数定制 lr, momentum, weight_decay 等,在 finetune 中常用。

代码与输出:

# coding: utf-8

import torch
import torch.optim as optim

# ----------------------------------- add_param_group

w1 = torch.randn(2, 2)
w1.requires_grad = True

w2 = torch.randn(2, 2)
w2.requires_grad = True

w3 = torch.randn(2, 2)
w3.requires_grad = True

# 一个参数组
optimizer_1 = optim.SGD([w1, w2], lr=0.1)
print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')

# 增加一个参数组
print('增加一组参数 w3\n')
optimizer_1.add_param_group({'params': w3, 'lr': 0.001, 'momentum': 0.8})

print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')

print('可以看到,参数组是一个list,一个元素是一个dict,每个dict中都有lr, momentum等参数,这些都是可单独管理,单独设定,十分灵活!')

 5.load_state_dict(state_dict)

作用:

将 state_dict 中的参数加载到当前网络,常用于 finetune。

代码与输出:

import torch
import torch.nn as nn
import torch.nn.functional as F


# ----------------------------------- load_state_dict

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(1 * 3 * 3, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 1 * 3 * 3)
        x = F.relu(self.fc1(x))
        return x

    def zero_param(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.constant_(m.weight.data, 0)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.constant_(m.weight.data, 0)
                m.bias.data.zero_()
net = Net()

# 保存,并加载模型参数(仅保存模型参数)
torch.save(net.state_dict(), 'net_params.pkl')   # 假设训练好了一个模型net
pretrained_dict = torch.load('net_params.pkl')

# 将net的参数全部置0,方便对比
net.zero_param()
net_state_dict = net.state_dict()
print('conv1层的权值为:\n', net_state_dict['conv1.weight'], '\n')

# 通过load_state_dict 加载参数
net.load_state_dict(pretrained_dict)
print('加载之后,conv1层的权值变为:\n', net_state_dict['conv1.weight'])

 6.step(closure)

作用:执行一步权值更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/194904.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

STM32串口收发、串口中断、串口波特率的理解、普通IO模拟串口

STM32串口收发、串口中断一 、串口中断二、使用DMA三、串口波特率的理解开发环境:stm32cubuMax Keil5一 、串口中断 1.当收到消息的时候,立即进入控制程序,实现通过串口控制硬件; 2.在stm32cubeMax中配置串口 配置全局中断 2.在main函数中…

Django项目搭建_修改目录结构

1.安装环境 使用conda下载Django项目需要的依赖 pip install django2.2.6 -i https://pypi.douban.com/simple/pip install djangorestframework -i https://pypi.douban.com/simple/pip install PymySQL -i https://pypi.douban.com/simple/pip install Pillow -i https://p…

CSDN为什么会发展社区?看看官方怎么说

文章目录🌟 课前小差🌟 23年可兼收名利?🌟 博客之星🌟 红包活动🌟 相聚线下🌟 妙笔生花🌟 原力计划🌟 个人定位🌟 为什么要发展社区?&#x1f31f…

100种思维模型之决策树思维模型-004

选择决定了现状和未来,在生活中有很多选择的机会,但是真的选择对了吗?在该读书的年纪,却想着长大真好。在该工作奋斗的年纪,却后悔自己年轻时没好好读书,而悔恨。其实不是我们没有选择的权利,而…

2023年情人节浪漫表白神器(附源码下载)

2023年情人节将要来临,作为一名程序猿也不要落后了,快来用代码展示你的爱吧!下面给大家准备了6款情人节表白神器,把说不出口的话,放到代码里送给你爱的人吧!祝大家表白成功,俘获美人心&#xff…

Spring Boot集成 Swagger2 展现在线接口文档

一:swagger是什么? Swagger是一款RESTFUL接口的文档在线自动生成功能测试功能软件。Swagger是一个规范和完整的框架,用于生成、描述、调用和可视化RESTful风格的Web服务。目标是使客户端和文件系统作为服务器以同样的速度来更新文件的方法,参数和模型紧…

浅析云边端协同框架下的“AI+”视频融合能力以及场景应用

随着边缘侧与终端侧业务的规模化落地部署,很多新的业务场景已经逐渐不满足于中心化的云端计算模式。尤其是在AI人工智能技术进一步落地应用的趋势下,基于云边端深度融合与协同的“AI”模式,在满足用户对视频服务的智能识别需求上,…

centos环境docker安装nexus3搭建maven私有仓库

拉取最新nexus3镜像docker pull nexus3创建宿主机上的映射文件目录,并授权[root1-0002 ~]# mkdir -p /mnt/lckj/nexus/data [root1-0002 ~]# chmod -R 777 /mnt/lckj/nexus/data运行redis,生成相应容器-d 后台启动[root1-0002 ~]# docker run -d --name …

Go并发读取string的Panic问题

上问题,先看下panic的函数栈信息,说现实strings.Count()发生了panic,来看下函数 第一个参数是字符串s,再结合函数栈信息的十六进制,0x0、0x9表示字符串s的地址和长度 这里来看一下string的底层数据结构:…

Spring Security OAuth2.0认证授权

目录 1.基本概念 1.1什么是认证 1.2什么是会话? 1.2什么是授权 1.3授权的数据模型 1.4.1基于角色的访问控制 1.4.2基于资源的访问控制 2.基于Session的认证方式 2.1认证流程 分布式系统认证方案 什么事分布式系统? 分布式认证需求 分布式认证…

行业安全解决方案 | 能源行业如何在新时期建设新安全?

伴随5G、人工智能、大数据、云计算等新技术的蓬勃发展,数智化成为传统电力能源转型发展的重要方向。与此同时,伴随着能源行业数字技术与电力技术、业务生产的愈发深度的融合,新时期的能源行业网络安全形势有了新变化,网络边界威胁…

DPDK实现的用户态协议栈(UDP)

DPDK实现的用户态协议栈背景NIC与DPDK的比较环境配置Windowe下配置静态IP表代码实现总结背景 DPDK接管NIC之后,接收到的数据都是原始数据,要实现一个协议栈就必须解析协议包和打包协议包,DPDK提供了丰富的API可以使用。 以UDP协议为例&#…

redis分布式集群

文章目录一、redis持久化1.1.RDB持久化1.1.1.执行时机1.1.2.RDB原理1.1.3.小结1.2.AOF持久化1.2.1.AOF原理1.2.2.AOF配置1.2.3.AOF文件重写1.2.4.小结1.3.RDB与AOF对比二、Redis主从集群2.1.集群结构2.2.准备实例和配置2.3.启动2.4.开启主从关系2.5.测试2.6.主从数据同步原理2.…

MMLAB学习笔记-DAY1

一、机器学习 1.机器学习的典型范式 监督学习:数据是由人工标注的,数据之间存在某种映射关系,目的是让机器学习到数据和标签之间的关系无监督学习:数据是没有标签的,通过对数据分析,运用聚类等方法探索出…

六、循环语句

一、while循环 1.语法 while 条件:条件成⽴重复执⾏的代码1条件成⽴重复执⾏的代码2.....2.应用 #偶数累加 i 1 resualt 0while i<100:if i % 2 0:resualt ii1print(resualt)3.break和continue 说明&#xff1a; 举例&#xff1a;⼀共吃5个苹果&#xff0c;吃完第⼀个&…

如何又快又好实现 Catalog 系统搜索能力?火山引擎 DataLeap 这样做

摘要 DataLeap 是火山引擎数智平台 VeDI 旗下的大数据研发治理套件产品&#xff0c;帮助用户快速完成数据集成、开发、运维、治理、资产、安全等全套数据中台建设&#xff0c;降低工作成本和数据维护成本、挖掘数据价值、为企业决策提供数据支撑。 火山引擎 DataLeap 的 Data…

Spring Boot + WebSocket 实时监控异常

本文已经收录到Github仓库&#xff0c;该仓库包含计算机基础、Java基础、多线程、JVM、数据库、Redis、Spring、Mybatis、SpringMVC、SpringBoot、分布式、微服务、设计模式、架构、校招社招分享等核心知识点&#xff0c;欢迎star~ Github地址&#xff1a;https://github.com/…

【笔记】容器基础-隔离与限制

Docker 项目的核心原理&#xff1a;为待创建的用户进程 1.启用 Linux Namespace 配置&#xff1a;修改进程视图 2.设置指定的 Cgroups 参数&#xff1a;为进程设置资源限制 3.切换进程的根目录&#xff08;Change Root&#xff09;&#xff1a; 容器的隔离与限制 1.启用 Linux…

MySQL性能优化四 MySQL索引优化实战一

一 查询案例 示例表 CREATE TABLE employees (id int(11) NOT NULL AUTO_INCREMENT,name varchar(24) NOT NULL DEFAULT COMMENT 姓名,age int(11) NOT NULL DEFAULT 0 COMMENT 年龄,position varchar(20) NOT NULL DEFAULT COMMENT 职位,hire_time timestamp NOT NULL DEF…

王凤英,能治好何小鹏的技术“自恋”吗?

1月30日&#xff0c;小鹏官宣一手打造长城汽车(601633)SUV战略转型的前二号人物——王凤英&#xff0c;加盟小鹏出任CEO一职。 虽然这则消息已风传多日&#xff0c;但正式公布的一刻还是在汽车圈内炸开了锅&#xff0c;主要原因有两点&#xff1a;一是王凤英刚刚加入小鹏就被委…