神经网络 torch.nn---优化器的使用

news2024/11/28 3:07:30

torch.optim - PyTorch中文文档 (pytorch-cn.readthedocs.io)

torch.optim — PyTorch 2.3 documentation

反向传播可以求出神经网路中每个需要调节参数的梯度(grad),优化器可以根据梯度进行调整,达到降低整体误差的作用。下面我们对优化器进行介绍。

构造优化器

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
  • 选择优化器的算法optim.SGD
  • 之后在优化器中放入模型参数model.parameters(),这一步是必备
  • 还可在函数中设置一些参数,如学习速率lr=0.01(这是每个优化器中几乎都会有的参数)

调用优化器中的step方法

step()方法就是利用我们之前获得的梯度,对神经网络中的参数进行更新。

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
  • 步骤optimizer.zero_grad()是必选的。对每个参数的梯度进行清零

  • 我们的输入经过了模型,并得到了输出output

  • 之后计算输出和target之间的误差loss

  • 调用误差的反向传播loss.backwrd更新每个参数对应的梯度

  • 调用optimizer.step()对卷积核中的参数进行优化调整

  • 之后继续进入for循环,使用函数optimizer.zero_grad()对每个参数的梯度进行清零,防止上一轮循环中计算出来的梯度影响下一轮循环。

优化器的使用

优化器中算法共有的参数(其他参数因优化器的算法而异):

  • params: 传入优化器模型中的参数

  • lr: learning rate,即学习速率

关于学习速率

  • 一般来说,学习速率设置得太大,模型运行起来会不稳定

  • 学习速率设置得太小,模型训练起来会过

  • 建议在最开始训练模型的时候,选择设置一个较大的学习速率;训练到后面的时候,再选择一个较小的学习速率

代码实战:

选择CIFAR10

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Linear, Sequential, Flatten
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)
dataloader = DataLoader(dataset, batch_size=1)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model = Sequential(
            Conv2d(3, 32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(32, 32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(32, 64, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Flatten(),
            Linear(64 * 4 * 4, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        # print(outputs)
        # print(targets)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        # print(result_loss)
        running_loss = running_loss + result_loss

    print(running_loss)

输出:

总结使用优化器训练的训练套路):

  • 设置损失函数loss function

  • 定义优化器optim

  • 从使用循环dataloader中的数据:for data in dataloder

    • 取出图片imgs,标签targets:imgs,targets=data

    • 将图片放入神经网络,并得到一个输出:output=model(imgs)

    • 计算误差:loss_result=loss(output,targets)

    • 使用优化器,初始化参数的梯度为0:optim.zero_grad()

    • 使用反向传播求出梯度:loss_result.backward()

    • 根据梯度,对每一个参数进行更新:optim.step()

  • 进入下一个循环,直到完成训练所需的循环次数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1797745.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

通过龙讯旷腾PWmat发《The Journal of Chemical Physics》 :基于第一性原理分子动力学热力学积分的离子溶剂化自由能计算

背景导读 离子溶解是电化学中一个重要的过程。电化学反应中许多重要的参数,例如电化学还原电位、无限稀释活度系数、亨利定律溶解常数和离子溶解度等,都与离子的溶剂化能有关。然而,由于测量技术和数据处理的困难,离子溶剂化能的…

LabVIEW与Arm控制器之间的通讯

LabVIEW是一个强大的图形化编程环境,广泛应用于自动化控制、数据采集和测试测量等领域。而Arm控制器则是嵌入式系统中常用的处理器架构,广泛用于各种控制和计算任务。将LabVIEW与Arm控制器进行通讯控制,可以结合二者的优势,实现高…

笔记96:前馈控制 + 航向误差

1. 回顾 对于一个 系统而言,结构可以画作: 如果采用 这样的控制策略,结构可以画作:(这就是LQR控制) 使用LQR控制器,可以通过公式 和 构建一个完美的负反馈系统; a a 但是有上…

学习笔记——网络参考模型——TCP/IP模型(网络层)

三、TCP/IP模型-网络层 1、IPV4报头 (1)IPV4报文格式 IP Packet(IP数据包),其包头主要内容如下∶ Version版本∶4 bit,4∶表示为IPv4; 6∶表示为IPv6。 Header Length首部长度∶4 bit,代表IP报头的长度(首部长度),如果不带Opt…

安卓手机平板使用JuiceSSH无公网IP远程连接本地服务器详细流程

文章目录 前言1. Linux安装cpolar2. 创建公网SSH连接地址3. JuiceSSH公网远程连接4. 固定连接SSH公网地址5. SSH固定地址连接测试 前言 处于内网的虚拟机如何被外网访问呢?如何手机就能访问虚拟机呢? 本文就和大家分享一下如何使用 cpolarJuiceSSH 实现手机端远程连接Linux…

conntrack如何限制您的k8s网关

1.1 conntrack 介绍 对于那些不熟悉的人来说,conntrack简单来说是Linux内核的一个子系统,它跟踪所有进入、出去或通过系统的网络连接,允许它监控和管理每个连接的状态,这对于诸如NAT(网络地址转换)、防火墙和保持会话连续性等任务至关重要。它作为Netfilter的一部分运行,…

Ubuntu的启动过程

尽管通常情况下Ubuntu的启动并不需要用户过多地参与,但是Ubuntu系统的启动本身是一个非常复杂的过程。在这个过程中,有硬件的检测、系统内核的准备以及各种系统服务的启动等。作为系统管理员,需要深入了解其中所经历的阶段,才能在…

Cesium开发环境搭建(二)

由于win7搭建很费事,重新安装了OS,win10的。 记录一下,搭建步骤: 1.下载node.js。 百度搜索即可下载对应的版本。下载cesium。 2.安装node.js。 安装后,输入node -v,显示版本信息,表示安装…

【面试干货】索引的作用

【面试干货】索引的作用 1、索引的作用 💖The Begin💖点点关注,收藏不迷路💖 1、索引的作用 索引 可以协助 快速查询、更新数据库表中数据。 通过使用索引,数据库系统能够快速定位到符合查询条件的数据,提…

1.6T模块与DSP技术的演进

近日,光通信行业市场机构LightCounting在市场报告中指出,去年的模块供应商已经展示了首批1.6T光学模块的风采,而今年,DSP供应商更是着眼于第二代1.6T模块设计的未来。这些前沿技术的突破,不仅代表了数据传输速度的飞跃…

【set】集合总结

一、Set Set集合是Collection的子接口,代表一种集合,此种集合是元素不重复. 有两个常用实现类 HashSet 是元素不重复,无序,主要是指遍历顺序和插入顺序不一致 TreeSet 是元素不重复,排序 LinkedHashSet不常用 二、HashSet 1.1 介绍 HashSet是Set的实现类 底层是由哈希表实…

Python数据分析I

目录 注:简单起见,下文中"df"均写为"表名","函数"均写为"HS","属性"均写为"SX","范围"均写为"FW"。 1.数据分析常用开源库 注释…

物联网-高性能时序数据库QuestDB

高性能时序数据库QuestDB 开源地址:https://github.com/questdb/questdb 官网:https://questdb.io/ 当前 13.9k start 自带免费可视化管理界面 支持各种语言客户端 C & C .NET Go Java Node.js Python Rust 上手容易可兼容 Postgresql InfluxDB …

浅浅写一个Word、PowerPoint、Excel文档转PDF工具

前言 最近在搞知识库,需要把各种 Word、PowerPoint、Excel 文件转换成 PDF 文件,不然 Word 中的表格中的文字提取会出现一些问题;使用 Office 或者 WPS 将大量文件转换成 PDF 需要频繁重复打开文件,点击保存为PDF,然后…

Qt_C++ RFID网络读卡器Socket Udp通讯示例源码

本示例使用的设备&#xff1a; WIFI/TCP/UDP/HTTP协议RFID液显网络读卡器可二次开发语音播报POE-淘宝网 (taobao.com) #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow> #include <QHostInfo> #include <QNetworkInterface> #include <…

【机器学习】训练GNN图神经网络模型进行节点分类

1. 引言 1.1 图神经网络GNN概述 图神经网络&#xff08;Graph Neural Network&#xff0c;GNN&#xff09;是一种专门用于处理图结构数据的神经网络方法。它起源于2005年&#xff0c;当时Gori等人首次提出了GNN的概念&#xff0c;用于学习图中的节点特征以及它们之间的关系。…

Doris 少数SQL在Datagrip无法执行,而在DorisUI或程序调用可以执行的问题

问题&#xff1a;Doris 少数SQL在Datagrip无法执行&#xff0c;而在DorisUI或程序调用可以执行 解决&#xff1a;Datagrip 执行SQL切分异常&#xff0c;设置默认执行语句方式&#xff0c;将分句改为整句执行 但是 支持多SQL批量分开执行更好用

英伟达的数字孪生地球是什么

1 英伟达的数字孪生地球 Earth-2是一个全栈式开放平台&#xff0c;包含&#xff1a;ICON 和 IFS 等数值模型的物理模拟&#xff1b;多种机器学习模型&#xff0c;例如 FourCastNet、GraphCast 和通过 NVIDIA Modulus 实现的深度学习天气预测 (DLWP)&#xff1b;以及通过 NVIDI…

大学电工基础与电子设计试题及答案,分享几个实用搜题和学习工具 #其他#经验分享

学习和考试是大学生生活中不可避免的一部分&#xff0c;而在这个信息爆炸的时代&#xff0c;如何快速有效地获取学习资源和解答问题成为了大学生们共同面临的难题。为了解决这个问题&#xff0c;搜题和学习软件应运而生。今天&#xff0c;我将为大家介绍几款备受大学生青睐的搜…

Vue进阶之Vue无代码可视化项目(一)

Vue无代码可视化项目 项目搭建初始步骤拓展:工程项目从0-1项目规范化package.jsoncpell.jsoncustom-words.txtts-eslint规则.eslintrc.cjsgit钩子检查有没有问题type-checkspellchecklint:stylehusky操作安装pre-commitpnpm的commit规范package.json:commitlint.config.cjs安装…