Pytorch梯度累积实现

news2024/11/24 12:24:19

前言

主要用于解决显卡内存不足的问题。
梯度累积可以使用单卡实现增大batchsize的效果

梯度累积原理

按顺序执行Mini-Batch,同时对梯度进行累积,累积的结果在最后一个Mini-Batch计算后求平均更新模型变量。
a c c u m u l a t e d = ∑ i = 0 N g r a d i \color{green}accumulated=\sum_{i=0}^{N}grad_{i} accumulated=i=0Ngradi

梯度累积是一种训练神经网络的数据Sample样本按Batch拆分为几个小Batch的方式,然后按顺序计算。
在不更新模型变量的时候,实际上是把原来的数据Batch分成几个小的Mini-Batch,每个step中使用的样本实际上是更小的数据集。
在N个step内不更新变量,使所有Mini-Batch使用相同的模型变量来计算梯度,以确保计算出来得到相同的梯度和权重信息,算法上等价于使用原来没有切分的Batch size大小一样。即:
θ i = θ i − 1 − l r ∗ ∑ i = 0 N g r a d i \color{green}\theta _{i}=\theta _{i-1}-lr*\sum_{i=0}^{N}grad_{i} θi=θi1lri=0Ngradi
在这里插入图片描述

代码实现

不加梯度累加的代码

for i, (images, labels) in enumerate(train_data):
    # 1. forwared 前向计算
    outputs = model(images)
    loss = criterion(outputs, labels)

    # 2. backward 反向传播计算梯度
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

加了梯度累加的代码

# 梯度累加参数
accumulation_steps = 4


for i, (images, labels) in enumerate(train_data):
    # 1. forwared 前向计算
    outputs = model(imgaes)
    loss = criterion(outputs, labels)

    # 2.1 loss regularization loss正则化
    loss += loss / accumulation_steps

    # 2.2 backward propagation 反向传播计算梯度
    loss.backward()

    # 3. update parameters of net
    if ((i+1) % accumulation)==0:
        # optimizer the net
        optimizer.step()
        optimizer.zero_grad() # reset grdient

代码中设置accumulation_steps = 4,意思就是变相扩大batch_size四倍。因为代码中每隔4次迭代才清空梯度,更新参数。
loss = loss/accumulation_steps,梯度累加了四次,那就要取平均除以4。同时,因为累计了4个batch,那学习率也应该扩大4倍,让更新的步子跨大点。
参考博客:1、pytorch骚操作之梯度累加,变相增大batch size
2、如何通透理解梯度累加

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1044172.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

nat综合实验

路漫漫其修远兮,吾将上下而求索。 实验目的如图 实验思路:配置内网,再配置外网,再做nat clien1配置 clien2配置 pc3配置 lsw1配置 sysname lsw1 # vlan batch 10 20 30 # interface MEth0/0/1 # interface Eth-Trunk1port link-type trunkp…

【Linux】IO操作

IO 典型 IO 模型阻塞 IO非阻塞 IO信号驱动 IO异步 IO常见问题 多路转接模型select 模型poll 模型epoll 模型 典型 IO 模型 IO 操作指的就是数据的输入输出操作;IO 过程可以分为两个步骤:等待 IO 就绪、数据拷贝 阻塞 IO 发起 IO 操作,若当…

【面试高高手】 —— Java基础(36题)

文章目录 1. 八大基本数据类型分类2. 重写和重载的区别3. int和integer区别4. Java的关键字5. 什么是自动装箱和拆箱?6. 什么是Java的多态性?7. 接口和抽象类的区别?8. Java中如何处理异常?9. Java中的final关键字有什么作用&…

iview 的table表格组件使单元格可编辑和输入

表格的列定义中&#xff0c;在需要编辑的字段下使用render函数 template表格组件 <Table border :data"data" :columns"tableColumns" :loading"loading"></Table>data中定义table对象 table: {tableColumns: [{title: 商品序号,k…

服务断路器_Resilience4j的断路器

断路器&#xff08;CircuitBreaker&#xff09;相对于前面几个熔断机制更复杂&#xff0c;CircuitBreaker通常存在三种状态&#xff08;CLOSE、OPEN、HALF_OPEN&#xff09;&#xff0c;并通过一个时间或数量窗口来记录当前的请求成功率或慢速率&#xff0c;从而根据这些指标来…

【JVM】第三篇 JVM对象创建与内存分配机制深度剖析

目录 一. JVM对象创建过程详解1. 类加载检查2. 分配内存2.1 如何划分内存?2.2 并发问题 3. 初始化4. 设置对象头5. 执行<init>方法 二. 对象头和指针压缩详解三. JVM对象内存分配详解四.逃逸分析 & 栈上分配 & 标量替换详解1. 逃逸分析 & 栈上分配2. 标量替…

用纹理图集优化3D场景性能【Texture Atlas】

推荐&#xff1a;用 NSDT编辑器 快速搭建可编程3D场景 在 Unity 中开发移动应用程序时&#xff0c;确保一切都得到优化始终至关重要。 最大化帧速率使我们能够专注于优化脚本、烘焙灯光、修改对象等。 当我们将移动应用程序带入虚拟现实时&#xff0c;这一点变得更加重要。 虽…

嵌入式Linux应用开发-文件 IO

嵌入式Linux应用开发-文件 IO 第四章 文件 IO4.1 文件从哪来&#xff1f;4.2 怎么访问文件&#xff1f;4.2.1 通用的 IO 模型&#xff1a;open/read/write/lseek/close4.2.2 不是通用的函数&#xff1a;ioctl/mmap 4.3 怎么知道这些函数的用法&#xff1f;4.4 系统调用函数怎么…

基于微信小程序的健身小助手打卡预约教学系统(源码+lw+部署文档+讲解等)

文章目录 前言系统主要功能&#xff1a;用户的功能设计为&#xff1a;管理员的功能设计为&#xff1a;健身房的功能设计为&#xff1a;具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获…

QFrame类学习笔记

1、QFrame的作用 QFrame类继承于QWidget类&#xff0c;被QAbstractScrollArea, QLabel, QLCDNumber, QSplitter, QStackedWidget, and QToolBox等类继承。 QFrame作为许多基础控件的基类&#xff0c;提供许多成员方法给子类&#xff0c;实现子类的框架样式的设计。框架样式主要…

Android 13 定制化开发--开启相机或麦克风时,去掉状态栏上的绿色图标

Android 12 或更高版本的设备上&#xff0c;当应用使用麦克风或相机时&#xff0c;图标会出现在状态栏中。如果应用处于沉浸模式&#xff0c;图标会出现在屏幕的右上角。用户可以打开“快捷设置”&#xff0c;并选择图标以查看哪些应用当前正在使用麦克风或摄像头。图 1 显示了…

Ubuntu 安装Kafka

在本指南中&#xff0c;我们将逐步演示如何在 Ubuntu 22.04 上安装 Apache Kafka。 在大数据中&#xff0c;数以百万计的数据源生成了大量的数据记录流&#xff0c;这些数据源包括社交媒体平台、企业系统、移动应用程序和物联网设备等。如此庞大的数据带来的主要挑战有两个方面…

【数据结构】插入排序:直接插入排序、折半插入排序、希尔排序的学习知识总结

目录 1、排序的基本概念 2、直接插入排序 2.1 算法思想 2.2 代码实现 3、折半插入排序 3.1 算法思想 3.2 代码实现 4、希尔排序 4.1 算法思想 4..2 代码实现 1、排序的基本概念 排序是将一组数据按照预定的顺序排列的过程&#xff0c;排序的基本概念包括以下内容…

自学WEB后端01-安装Express+Node.js框架完成Hello World!

一、前言&#xff0c;网站开发扫盲知识 1.网站搭建开发包括什么&#xff1f; 前端 前端开发主要涉及用户界面&#xff08;UI&#xff09;和用户体验&#xff08;UX&#xff09;&#xff0c;负责实现网站的外观和交互逻辑。前端开发使用HTML、CSS和JavaScript等技术来构建网页…

数据结构--快速排序

文章目录 快速排序的概念Hoare版本挖坑法前后指针法快速排序的优化三数取中法小区间用插入排序 非递归的快速排序 快速排序的概念 快速排序是通过二叉树的思想&#xff0c;先设定一个值&#xff0c;通过比较&#xff0c;比它大的放在它的右边&#xff0c;比它小的放在它的左边…

Python中的数据常见问题

数据可视化在Python中是一个非常重要的主题&#xff0c;它可以帮助我们更好地理解和分析数据。无论是探索数据的特征&#xff0c;还是向其他人展示数据的结果&#xff0c;数据可视化都起到了关键作用。然而&#xff0c;在进行数据可视化时可能会遇到一些常见问题。本文将为您分…

基于微信小程序的同城家政服务预约系统(源码+lw+部署文档+讲解等)

文章目录 前言系统主要功能&#xff1a;具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计…

DC电源模块关于宽电压输入和输出的范围

BOSHIDA DC电源模块关于宽电压输入和输出的范围 DC电源模块是一种电子设备&#xff0c;能够将输入的直流电源转换成所需的输出电源&#xff0c;用于供电各种电子设备。其中&#xff0c;关于宽电压输入和输出的范围&#xff0c;是DC电源模块常见的设计要求之一。本文将详细介绍…

嵌入式Linux应用开发-基础知识及GCC 编译器的使用

嵌入式Linux应用开发-基础知识及GCC 编译器的使用 第一章 HelloWorld 背后没那么简单1.1 交叉编译 hello.c1.2 请回答这几个问题1.3 演示 (...) 第二章 GCC 编译器的使用2.1 配套视频内容大纲2.1.1 GCC 编译过程(精简版)2.1.2 常用编译选项2.1.3 怎么编译多个文件2.1.4 制作、使…