pytorch之梯度累加

news2024/11/30 2:35:07

1.什么是梯度?

梯度可以理解为一个多变量函数的变化率,它告诉我们在某一点上,函数的输出如何随输入的变化而变化。更直观地说,梯度指示了最优化方向。

  • 在机器学习中的作用:在训练模型时,我们的目标是最小化损失函数,以提高模型的准确性。损失函数是衡量模型预测值与真实值之间差距的函数。梯度告诉我们如何调整模型参数,以使损失函数的值减小。

2. 模型参数的优化

考虑一个简单的线性模型:

y=wx+b

  • 其中,yy 是输出,xx 是输入,ww 是权重,bb 是偏置。
  • 为了训练模型,我们使用损失函数(例如均方误差)来衡量模型输出与真实输出之间的差距。损失函数通常定义为:

Loss=1/N∑i=1N(ypred,i−ytrue,i)^2

  • 这里 NN 是样本数,ypred是模型计算的预测值,ytrue是真实值。

3. 反向传播

反向传播是一种高效计算梯度的算法,尤其在深度学习中使用广泛。

3.1 前向传播

在前向传播中,我们将输入数据通过模型传递,计算出预测结果,并基于预测结果与真实结果计算损失。

3.2 计算梯度

反向传播通过链式法则计算梯度,以更新模型参数。通过反向传播,我们可以得到损失函数对每个参数(比如 ww 和 bb)的导数,这些导数就是梯度。

  • 链式法则:假设有一个复合函数 z=f(g(x)),则其导数为:

dz/dx=dz/dg⋅dg/dx

这个法则帮助我们逐层计算梯度。

4. 梯度的累加

4.1 为什么会累加?

在训练过程中,我们可能会处理多个训练样本进行参数更新。如果连续调用多次 loss.backward(),每次都会将计算的梯度值加到之前的梯度上。

  • 这意味着如果我们不清零梯度,梯度会随着样本数的增加而不断增加,可能导致参数更新的步幅变得非常大,影响模型的收敛。
4.2 样例代码

让我们通过具体的代码示例来更好地理解梯度的累加和为什么需要清零。

import torch
import torch.nn as nn
import torch.optim as optim

# 定义简单线性模型
model = nn.Linear(1, 1)  # 线性模型:1个输入,1个输出
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 使用 SGD 优化器

# 模拟一些数据
x = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=False)  # 输入
y = torch.tensor([[2.0], [3.0], [4.0]], requires_grad=False)  # 目标输出

# 训练循环
for epoch in range(5):  # 假设训练 5 轮
    for i in range(len(x)):  # 遍历每个训练样本
        optimizer.zero_grad()  # 清零梯度,确保只考虑当前样本

        # 前向传播
        output = model(x[i])  # 计算当前样本的输出

        # 计算损失
        loss = (output - y[i]) ** 2  # 均方误差损失

        # 反向传播
        loss.backward()  # 计算梯度,累加到 model 的参数中

        # 更新参数
        optimizer.step()  # 使用累加的梯度更新参数

        print(f"Epoch: {epoch}, Sample: {i}, Loss: {loss.item()}, W: {model.weight.data}, b: {model.bias.data}")

5. 源代码解释

  1. 清零梯度:在每次处理新的训练样本前,调用 optimizer.zero_grad() 清空梯度。这是为了确保每个训练样本只对当前的梯度产生影响。

  2. 前向传播:计算当前输入的输出。

  3. 损失计算:计算输出与真实值之间的差距。

  4. 反向传播:通过 loss.backward() 计算当前样本对模型参数的梯度并将其累加到 model.parameters() 的 grad 属性上。

  5. 参数更新:调用 optimizer.step() 进行参数更新。

在 PyTorch 中,梯度的累加是一种非常重要且实用的特性,其设计有几个原因:

6. 支持小批量(Mini-batch)训练

在实践中,由于计算资源的限制,通常使用小批量数据进行训练。这意味着我们不会一次性使用整个数据集来更新模型,而是对一小部分数据频繁进行计算。

  • 梯度累加允许我们在多个小批量上计算梯度,并在适当的时候一并更新模型参数。这种策略被称为“累积梯度”。

例如,如果我们有一个较大的数据集,可以将其分为多个小批量,然后在每个小批量上计算梯度。在所有小批量处理完成后,再进行一次参数更新。这种方法可以模拟使用更大批量数据的效果,提高模型的表现。

7. 提高训练灵活性

梯度累加允许用户在特定情况下有效地控制参数更新的频率。例如:

  • 如果处理每个样本时都立即更新权重,可能会导致训练过程不稳定。而通过在多个样本上累加梯度,可以缓解这种波动性,平滑参数的更新过程。

  • 用户可以决定什么情况下清零梯度,例如只有在处理完一个完整的训练周期(epoch)后,或在经历多个小批量后再更新一次参数。这种控制在很多情况下可以提高性能和收敛性。

8. 节省内存

对于一些深度学习模型,特别是当模型较大,或者在训练过程中使用大量数据时,清零梯度后再进行反向传播通常需要的内存较少。没有累加的梯度能够避免内存的额外消耗,进而提高整个训练过程的效率。

9. 灵活的梯度管理

开发者可以基于需求自定义梯度累加的策略。例如,有时我们可能希望实现一些特殊的训练策略,比如调整学习率、动态更改模型的训练方式等。在这些情况下,梯度的管理就显得至关重要。

10. 应用在不同的训练模式

在一些变种训练方式中,如强化学习或一些优化器的特殊需求,可能需要在更新权重前手动控制梯度。这对开发者提供了更大的灵活性和更丰富的训练策略。

数学推导

假设我们有一个线性回归模型,其数学表达式为:

y = W \cdot x + b

均方误差(MSE)损失函数:

\mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2

梯度计算:

对于权重W

\frac{\partial \mathcal{L}}{\partial W} = \frac{2}{N} \sum_{i=1}^{N} x_i (\hat{y}_i - y_i)


对于偏置 b

\frac{\partial \mathcal{L}}{\partial b} = \frac{2}{N} \sum_{i=1}^{N} (\hat{y}_i - y_i)

假设我们将数据集分成若干小批量,每个小批量包含 m个样本。我们累加这些梯度,并在累积一定数量的小批量k后更新参数。

对于第j个小批量,计算梯度:

对于权重W

\nabla W_j = \frac{2}{m} \sum_{i=1}^{m} x_i (\hat{y}_i - y_i)

对于偏置 b

\nabla b_j = \frac{2}{m} \sum_{i=1}^{m} (\hat{y}_i - y_i)

梯度累加

\nabla W_{\text{accumulated}} = \sum_{j=1}^{k} \nabla W_j

\nabla b_{\text{accumulated}} = \sum_{j=1}^{k} \nabla b_j

g更新参数

W \leftarrow W - \eta \cdot \nabla W_{\text{accumulated}}

b \leftarrow b - \eta \cdot \nabla b_{\text{accumulated}}

import torch
import torch.nn as nn
import torch.optim as optim

# 创建数据集
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0], [10.0]])

# 简单的线性回归模型
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

# 初始化模型、损失函数和优化器
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 定义小批量大小和累积步数
batch_size = 2
accumulation_steps = 2

# 训练过程
for epoch in range(5):
    optimizer.zero_grad()  # 清零梯度
    for i in range(0, len(x_train), batch_size):
        # 获取小批量数据
        x_batch = x_train[i:i + batch_size]
        y_batch = y_train[i:i + batch_size]

        # 前向传播
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)

        # 反向传播,累加梯度
        loss.backward()

        # 每处理完指定的累积步骤后,更新参数并清零梯度
        if (i // batch_size + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    # 打印损失
    print(f'Epoch [{epoch + 1}/5], Loss: {loss.item():.4f}')

# 打印模型参数
print(f'Final Parameters: W: {model.linear.weight.item():.4f}, b: {model.linear.bias.item():.4f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2188918.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

day2网络编程项目的框架

基于终端的 UDP云聊天系统 开发环境 Linux 系统GCCUDPmakefilesqlite3 功能描述 通过 UDP 网络使服务器与客户端进行通信吗,从而实现云聊天。 Sqlite数据库 用户在加入聊天室前,需要先进行用户登录或注册操作,并将注册的用户信息&#xf…

P4、P4D、HelixSwarm 各种技术问题咨询

多年大型项目P4仓库运维经验,为你解决各种部署以及标准工业化流程问题。 Perforce 官网SDPHelixCore GuideHelixSwarm GuideHelixSwarm Download

SpringBoot基础(三):Logback日志

SpringBoot基础系列文章 SpringBoot基础(一):快速入门 SpringBoot基础(二):配置文件详解 SpringBoot基础(三):Logback日志 目录 一、日志依赖二、日志格式1、记录日志2、默认输出格式3、springboot默认日志配置 三、日志级别1、基础设置2、…

家长们,你们认为孩子沉迷游戏严重还是沉迷Linux严重呢

matrix禁食 ​ 计算机技术与软件专业技术资格证持证人 ​ 关注 谢邀 Hieronymus no-sh 218 人赞同了该回答 十年前,你还能得到一个自己能控制的计算机系统,现在,窗口期早走过了。普通人不懂软件,但因该懂人心啊,人心一…

使用Apifox创建接口文档,部署第一个简单的基于Vue+Axios的前端项目

前言 在当今软件开发的过程中,接口文档的创建至关重要,它不仅能够帮助开发人员更好地理解系统架构,还能确保前后端开发的有效协同。Apifox作为一款集API文档管理、接口调试、Mock数据模拟为一体的工具,能够大幅度提高开发效率。在…

武汉自闭症儿童寄宿学校:开启学习与成长的新篇章

武汉与广州的自闭症教育之光:星贝育园开启学习与成长新篇章 在自闭症儿童教育的广阔领域,寄宿学校以其独特的教育模式和全方位的关怀,为这些特殊孩子提供了学习、成长与融入社会的宝贵机会。虽然本文标题提及了武汉自闭症儿童寄宿学校&#…

【HTML+CSS】仿电子美学打造响应式留言板

创建一个响应式的留言板 在这篇文章中,我们将学习如何创建一个简单而美观的留言板,它将包括基本的样式和动画效果,以及响应式设计,确保在不同设备上都能良好显示。 HTML 结构 首先,我们创建基本的HTML结构。留言板由…

8646 基数排序

### 思路 基数排序是一种非比较型排序算法,通过逐位(从最低位到最高位)对数字进行排序。每次分配和收集后输出当前排序结果。 ### 伪代码 1. 读取输入的待排序关键字个数n。 2. 读取n个待排序关键字并存储在数组中。 3. 对数组进行基数排序&…

MinIO 在windows环境下载和安装

目录 1.MinIO(windows)下载链接: 2. 启动MinIO (1)直接启动MinIo (2)指定端口号启动MinIo 3.通过创建.bat文件帮助启动MinIO 1.MinIO(windows)下载链接:…

国外电商系统开发-运维系统批量添加服务器

您可以把您准备的txt文件,安装要求的格式,复制粘贴到里面就可以了。注意格式! 如果是“#” 开头的,则表示注释!

Python数据可视化--Matplotlib--入门

我生性自由散漫,不喜欢拘束。我谁也不爱,谁也不恨。我没有欺骗这个,追求那个;没有把这个取笑,那个玩弄。我有自己的消遣。 -- 塞万提斯 《堂吉诃德》 Matplotlib介绍 1. Matplotlib 是 Python 中常用的 2D 绘图库&a…

ArkTS语法

一、声明 格式:关键字 变量/常量名 : 类型注释 = 值 变量声明 let count : number = 0; count = 40; 常量声明 const MAX_COUNT : number = 100; 二、数据类型 基本数据类型:string、number、boolean等 引用数据类型:Object、Array、自定义类等 …

【笔记】选择题笔记+数据结构笔记

文章目录 2014 41方法一先序遍历方法二 连通分量是极大连通子图 一个连通图的生成树是一个极小连通子图 无向图的邻接表中,第i个顶点的度为第i个链表中的结点数 邻接表和邻接矩阵对不同的操作各有优势。 最短路径算法: 单源最短路径 已知图G(V,E),我们…

深入理解Linux内核网络(二):内核与用户进程的协作

内核在协议栈接收处理完输入包以后,要能通知到用户进程,让用户进程能够收到并处理这些数据。进程和内核配合有很多种方案,第一种是同步阻塞的方案,第二种是多路复用方案。本文以epoll为例 部分内容来源于 《深入理解Linux网络》、…

认知杂谈72《别让梦想只是梦!7步跃过现实高墙的终极攻略!》

内容摘要:         梦想的实现是一场与现实的较量,需要坚持和突破。学习路线图对于掌握技能至关重要,如学编程应从基础语法开始,逐步深入。 面对难题,积极搜索、提问和实践是关键。坚持和专注是成功的核心&#…

《Windows PE》4.1.3 IAT函数地址表

IAT(Import Address Table)表又称为函数地址表,是Windows可执行文件中的一个重要数据结构,用于存储导入函数的实际入口地址。 在可执行文件中,当一个模块需要调用另一个模块中的函数时,通常会使用导入函数…

十、敌人锁定

方法:通过寻找最近的敌人,使玩家的面朝向始终朝向敌人,进行攻击 1、代码 在这个方法中使用的是局部变量,作为临时声明和引用 public void SetActorAttackRotation() {Enemys GameObject.FindGameObjectsWithTag("Enemy&qu…

机器学习-树模型算法

机器学习-树模型算法 一、Bagging1.1 RF1.2 ET 二、Boosting2.1 GBDT2.2 XGB2.3 LGBM 仅个人笔记使用,感谢点赞关注 一、Bagging 1.1 RF 1.2 ET 二、Boosting 2.1 GBDT 2.2 XGB 2.3 LGBM LightGBM(Light Gradient Boosting Machine) 基本算法原理…

2024企业网盘排行榜,十大企业网盘深度评测【part 2】

在当今数字化时代,企业网盘已成为提升工作效率、保障数据安全的重要工具。从Box到腾讯企业网盘,再到Egnyte、Amazon Drive、金山文档(WPS)和Huddle,每款产品都有其独特的功能和应用场景。然而,在众多选择中…