学习pytorch15 优化器

news2024/12/30 3:44:17

优化器

  • 官网
  • 如何构造一个优化器
  • 优化器的step方法
  • code
  • running log
    • 出现下面问题如何做反向优化?

官网

https://pytorch.org/docs/stable/optim.html

在这里插入图片描述
提问:优化器是什么 要优化什么 优化能干什么 优化是为了解决什么问题
优化模型参数

如何构造一个优化器

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # momentum SGD优化算法用到的参数
optimizer = optim.Adam([var1, var2], lr=0.0001)
  1. 选择一个优化器算法,如上 SGD 或者 Adam
  2. 第一个参数 需要传入模型参数
  3. 第二个及后面的参数是优化器算法特定需要的,lr 学习率基本每个优化器算法都会用到

优化器的step方法

会利用模型的梯度,根据梯度每一轮更新参数
optimizer.zero_grad() # 必须做 把上一轮计算的梯度清零,否则模型会有问题

for input, target in dataset:
    optimizer.zero_grad()  # 必须做 把上一轮计算的梯度清零,否则模型会有问题
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

or 把模型梯度包装成方法再调用

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

code

import torch
import torchvision
from torch import nn, optim
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_set = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                        download=True)

dataloader = DataLoader(test_set, batch_size=1)

class MySeq(nn.Module):
    def __init__(self):
        super(MySeq, self).__init__()
        self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
                                 MaxPool2d(2),
                                 Flatten(),
                                 Linear(1024, 64),
                                 Linear(64, 10)
                                 )

    def forward(self, x):
        x = self.model1(x)
        return x

# 定义loss
loss = nn.CrossEntropyLoss()
# 搭建网络
myseq = MySeq()
print(myseq)
# 定义优化器
optmizer = optim.SGD(myseq.parameters(), lr=0.001, momentum=0.9)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        # print(imgs.shape)
        output = myseq(imgs)
        optmizer.zero_grad()  # 每轮训练将梯度初始化为0  上一次的梯度对本轮参数优化没有用
        result_loss = loss(output, targets)
        result_loss.backward()  # 优化器需要每个参数的梯度, 所以要在backward() 之后执行
        optmizer.step()  # 根据梯度对每个参数进行调优
        # print(result_loss)
        # print(result_loss.grad)
        # print("ok")
        running_loss += result_loss
    print(running_loss)

running log

loss由小变大最后到nan的解决办法:

  1. 降低学习率
  2. 使用正则化技术
  3. 增加训练数据
  4. 检查网络架构和激活函数

出现下面问题如何做反向优化?

Files already downloaded and verified
MySeq(
  (model1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)
tensor(18622.4551, grad_fn=<AddBackward0>)
tensor(16121.4092, grad_fn=<AddBackward0>)
tensor(15442.6416, grad_fn=<AddBackward0>)
tensor(16387.4531, grad_fn=<AddBackward0>)
tensor(18351.6152, grad_fn=<AddBackward0>)
tensor(20915.9785, grad_fn=<AddBackward0>)
tensor(23081.5254, grad_fn=<AddBackward0>)
tensor(24841.8359, grad_fn=<AddBackward0>)
tensor(25401.1602, grad_fn=<AddBackward0>)
tensor(26187.4961, grad_fn=<AddBackward0>)
tensor(28283.8633, grad_fn=<AddBackward0>)
tensor(30156.9316, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1183303.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Proteus仿真】【51单片机】水质监测报警系统设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器&#xff0c;使用按键、LED、蜂鸣器、LCD1602、PCF8591 ADC、PH传感器、浑浊度传感器、DS18B20温度传感器、继电器模块等。 主要功能&#xff1a; 系统运行后&…

链表经典面试题之一讲

什么是链表&#xff1f; 链表是一种物理存储结构上非连续、非顺序的存储结构&#xff0c;数据元素的逻辑顺序是通过链表中的指针链接次序实现的 。 今天给大家分享一道经典的单链表面试题 力扣题目——反转链表https://leetcode.cn/problems/reverse-linked-list/ 只给了头…

Wsl2 Ubuntu在不安装Docker Desktop情况下使用Docker

目录 1. 前提条件 2.安装Distrod 3. 常见问题 3.1.docker compose 问题无法使用问题 3.1. docker-compose up报错 参考文档 1. 前提条件 win10 WSL2 Ubuntu(截止202308最新版本是20.04.xx) 有不少的博客都是建议直接安装docker desktop&#xff0c;这样无论在windows…

实体属性映射框架mapstruct

1. 框架介绍 mapstruct框架是一种实体类间的映射框架&#xff0c;能够通过JAVA注解的形式将一个实体类的属性安全的赋值给另一个实体类。通过一系列注解可以定义实体类属性之间的映射关系&#xff0c;mapstruct会在编译期间生成映射实现类&#xff0c;而非通过反射的方式进行实…

Dcoker学习笔记(一)

Dcoker学习笔记一 一、 初识Docker1.1 简介1.2 虚拟机和docker的区别1.3 Docker架构1.4 安装Docker&#xff08;Linux&#xff09; 二、 Dcoker基本操作2.1 镜像操作2.2 容器操作练习 2.3 数据卷volume&#xff08;容器数据管理&#xff09;简介数据卷语法数据卷挂载 2.4 自定义…

【Git】Git基础命令操作速记

【Git】Git基础命令操作速记 文章目录 【Git】Git基础命令操作速记1. 初始化1.1 设置用户名和邮箱1.2 初始化仓库 2. 基础命令2.1 add和commit2.2 reset2.3 查看日志2.4 删除/找回本地仓库文件2.5 找回暂存区文件2.6 diff命令(找不同) 3. 分支命令3.1 查看分支3.2 创建分支3.3 …

深度学习之基于YoloV5的火灾检测系统

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 火灾检测系统基于YoloV5的介绍 火灾检测是一项重要的安全任务&#xff0c;它旨在及时发现和报警火灾风险。基于深度…

postgresql|数据库|提升查询性能的物化视图解析

前言&#xff1a; 我们一般认为数字的世界是一个虚拟的世界&#xff0c;OK&#xff0c;但我们其实有些需求是和现实世界一模一样的&#xff0c;比如&#xff0c;数据库尤其是关系型数据库&#xff0c;希望在使用的数据库能够更快&#xff08;查询速度&#xff09;&#xff0c;…

《008.Springboot+vue之自习室选座系统》

[火]《008.Springbootvue之自习室选座系统》 项目简介 [1]本系统涉及到的技术主要如下&#xff1a; 推荐环境配置&#xff1a;DEA jdk1.8 Maven MySQL 前后端分离; 后台&#xff1a;SpringBootMybatisredis; 前台&#xff1a;vueElementUI; [2]功能模块展示&#xff1a; 前端…

云端生成式 AI – 基于 Amazon EKS 的 Stable Diffusion 图像生成方案

Stable Diffusion 是当下生成式 AI 领域最受欢迎的开源多模态语言-图像模型&#xff0c;由于其易用的接口和良好的使用体验&#xff0c;受到了开源社区和广大设计行业从业者的追捧。Stable Diffusion 模型版本正在快速迭代&#xff0c;并带动了各行各业的生产力变革。目前市场上…

vmware16.1.2安装 windows7后 VMVMware tools 灰色 需要手动安装操作详情

问题1&#xff1a; 问题描述&#xff1a; 在Windows7镜像安装完成后&#xff0c;安装"VMware Tools"时出现&#xff1a;安装程序无法继续&#xff0c;需要将操作系统更新到SP1.2 重新安装后也没办法解决。 证明问题没有出在操作系统上&#xff1b;那么&#xff0c…

RISC Zero的Bonsai证明服务

1. 引言 Bonsai为通用ZKP网络&#xff0c;其支持任意链、任意协议、以及任意应用&#xff0c;利用ZKP来扩容、隐私和互操作。Bonsai的目标是为每条链都提供无限计算的能力。 借助Bonsai&#xff0c;可仅需数天的开发&#xff0c;即可实现对以太坊、L1链、Cosmos app链、L2 ro…

Mybatis-Plus使用Wrapper自定义SQL

文章目录 准备工作Mybatis-Plus使用Wrapper自定义SQL注意事项目录结构如下所示domain层Controller层Service层ServiceImplMapper层UserMapper.xml 结果如下所示&#xff1a;单表查询条件构造器单表查询&#xff0c;Mybatis-Plus使用Wrapper自定义SQL联表查询不用&#xff0c;My…

Java进击框架:Spring-数据存取(七)

Java进击框架&#xff1a;Spring-数据存取&#xff08;七&#xff09; 前言事务管理声明式事务管理 DAO支持JDBC的数据访问使用JdbcTemplate控制数据库连接JDBC批处理操作封装 SQL 语句中的参数 使用R2DBC进行数据访问对象关系映射(ORM)数据访问HibernateJPA XML模式 前言 参考…

目标检测算法 - YOLOv1

文章目录 1. 作者简介2. 目标检测综述3. YOLOv1算法3.1 预测阶段3.2 预测阶段后处理3.3 训练阶段 YOLO的全称是you only look once&#xff0c;指只需要浏览一次就可以识别出图中的物体的类别和位置。 YOLO是目标检测模型。目标检测是计算机视觉中比较简单的任务&#xff0c;用…

10-26 maven配置

打开idea 打开setting 基于Idea创建idea项目 加载jar包&#xff1a;(一般需要自己去手动加入&#xff0c;本地仓库是没有的)

【HarmonyOS】HarmonyOS备案获取公钥和指纹

【关键字】 HarmonyOS应用、鸿蒙应用、元服务、应用备案 HarmonyOS应用在华为云等平台进行应用备案时&#xff0c;平台需要提供用公钥和签名指纹的信息&#xff0c;Android可以直接通过keystore或jks签名文件进行签名信息获取&#xff0c;HarmonyOS签名方式与Android不同&…

LangChain之关于RetrievalQA input_variables 的定义与使用

最近在使用LangChain来做一个LLMs和KBs结合的小Demo玩玩&#xff0c;也就是RAG&#xff08;Retrieval Augmented Generation&#xff09;。 这部分的内容其实在LangChain的官网已经给出了流程图。 我这里就直接偷懒了&#xff0c;准备对Webui的项目进行复刻练习&#xff0c;那么…

Spring Cloud - 手写 Gateway 源码,实现自定义局部 FilterFactory

目录 一、FilterFactory 分析 1.1、前置知识 1.2、分析源码 1.2.1、整体分析 1.2.2、源码分析 1.3、手写源码 1.3.1、基础框架 1.3.2、实现自定义局部过滤器 1.3.3、加参数的自定义局部过滤器器 一、FilterFactory 分析 1.1、前置知识 前面的学习我们知道&#xff0c…

云服务器搭建flink集群

文章目录 1.集群配置2.修改集群配置3. 访问Web UI4. 提交作业方式5.Yarn部署模式配置5.1 会话模式部署&#xff08;Session Mode&#xff09;5.2 单作业模式(Per-job Mode)5.3 应用模式部署&#xff08;推荐&#xff09;5.3.1 上传HDFS提交&#xff08;推荐&#xff09; 5.4 历…