DLA 神经网络的极限训练方法:gradient checkpointing

news2024/9/22 21:23:29

gradient checkpointing

        一般来说,训练的过程需要保存中间结果(不管是GPU还是CPU)。前向传播根据输入(bottom_data)计算输出(top_data),后向传播由top_diff计算bottom_diff(如果某个变量打开梯度进行训练的话)。top和bottom是包含数据和梯度的两个结构体,整个网络的每层top和bottom在训练的过程中都会保存,这消耗了大量的内存。

        如果不保存这些变量,每次传播时重新分配和计算,会大大减少内存的使用量,但是也会使得网络的训练时间无限延长。为了平衡这两个矛盾,论文Training Deep Nets with Sublinear Memory Cost 使用亚线性内存成本训练深度网络:我们提出了一种系统方法来减少深度的内存消耗 神经网络训练。具体来说,我们设计了一种成本高昂的算法 O(sqrt(n)) 内存来训练 n 层网络,只需计算成本 每个小批量的额外前向传递。每隔 sqrt(n)保留一个检查点的feature map。

CODE

  • https://pytorch.org/docs/stable/checkpoint.html
// https://discuss.pytorch.org/t/trying-to-understand-torch-utils-checkpoint/95224
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm.notebook import tqdm

from torch import optim
import torchvision.models as models
from torch import nn

CHECKPOINT = True
BATCH_SIZE = 32
dev = "cuda:0"

class ImageDataset(Dataset):
    def __init__(self,length = 100000,size = 244):
        self.length = length
        self.size = 244
    def __len__(self):
        return self.length
    def __getitem__(self,idx,display = False):
        return torch.from_numpy(np.random.randn(2,3,self.size,self.size))
train = ImageDataset()
trainloader = DataLoader(
    train,
    batch_size = BATCH_SIZE,
    num_workers = 24,
    pin_memory = True
)

resnet = models.resnet50(pretrained = False)

class MODEL(nn.Module):
    def __init__(self,model):
        super(MODEL,self).__init__()
        self.model = model
        self.LR = nn.Linear(1000,1000)
    def forward(self,x):
        if CHECKPOINT == False:
            o1 = self.model(x[:,0])
            o2 = self.model(x[:,1])
        else:
            o1 = torch.utils.checkpoint.checkpoint(self.model,x[:,0])
            o2 = torch.utils.checkpoint.checkpoint(self.model,x[:,1])
        
        return torch.mean((self.LR(o1)-o2)**2)
    
resnet = MODEL(resnet).to(dev)

optimizer = optim.SGD(resnet.parameters(),lr = .001)

for T in tqdm(trainloader):
    out = torch.mean(resnet(T.float().to(dev)))
    optimizer.zero_grad()
    out.backward()
    optimizer.step()

CG

在这里插入图片描述

  • https://github.com/merrymercy/dtr-prototype

ZeRO-Offload

  • https://arxiv.org/pdf/2101.06840.pdf 大规模模型训练一直是少数人的比赛场地 需要复杂的模型重构和访问昂贵的 GPU 集群。ZeRO-Offload 通过使 几乎每个人都可以访问大型模型训练。它可以训练模型 单个 GPU 上超过 13 亿个参数,与 GPU 相比,大小增加了 10 倍 流行的框架,如PyTorch,它不需要任何模型就可以做到这一点。 从数据科学家改变或牺牲计算效率。 ZeRO-卸载通过卸载数据和计算来实现大型模型训练 中央处理器。为了保持计算效率,它旨在最大限度地减少数据 移入/移出 GPU,减少 CPU 计算时间,同时最大化内存 节省 GPU 成本。因此,ZeRO-Offload可以在单个上实现40 TFlops / GPU。 NVIDIA V100 GPU 用于 10B 参数模型,与单独使用 PyTorch 的 30TF 相比 对于 1.4B 参数模型,可以训练而不会耗尽的最大参数模型 的记忆。ZeRO-Offload 还设计为在以下情况下在多个 GPU 上进行扩展 可用,可在多达 128 个 GPU 上提供近乎线性的加速。此外,它可以 与模型并行性协同工作,训练超过 70 亿的模型 单个 DGX-2 盒子上的参数,与模型尺寸相比增加了 4.5 倍 单独使用模型并行性。通过将计算和内存效率与 易于使用,ZeRO-Offload 使大规模模型训练民主化,使其成为 即使是数据科学家也可以访问,只需访问一个 GPU。

梯度累积

        训练时大的batch一般能得到更稳定的训练效果,梯度累积训练方法是一种用于训练深度神经网络的技术,旨在减少显存需求并提高训练效果。在传统的训练方法中,模型的参数是通过单个批次(batch)的数据计算得到的梯度平均值进行更新。但在梯度累积训练中,模型的参数更新是通过多个批次的梯度累积得到的。

以下是梯度累积训练的基本步骤:

  1. 设置梯度累积步数(accumulation steps),它决定了要累积多少个批次的梯度。

  2. 初始化模型的参数。

  3. 对于每个训练批次(batch):

    • 使用当前批次的数据进行前向传播计算损失。
    • 对损失进行反向传播计算梯度。
    • 累积当前批次的梯度到之前的梯度值上。
  4. 当累积达到设置的步数时,将累积的梯度应用于模型参数的更新:

    • 通过将累积的梯度平均化来获得参数的更新值。
    • 使用更新值来更新模型的参数。
  5. 重复步骤3和4,直到完成所有的训练批次。

梯度累积训练的主要优势在于能够降低每个批次所需的显存量,允许在具有有限显存的硬件上训练更大的模型。此外,梯度累积还可以改善模型的收敛性,提高模型的性能和泛化能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/839930.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AcWing 4310:树的DFS ← vector、auto、邻接表

【题目来源】https://www.acwing.com/problem/content/description/4313/【题目描述】 给定一棵 n 个节点的树。 节点的编号为 1∼n,其中 1 号节点为根节点,每个节点的编号都大于其父节点的编号。 现在,你需要回答 q 个询问。 每个询问给定两…

Godot 4 源码分析 - Path2D与PathFollow2D

学习演示项目dodge_the_creeps,发现里面多了一个Path2D与PathFollow2D 研究GDScript代码发现,它主要用于随机生成Mob var mob_spawn_location get_node(^"MobPath/MobSpawnLocation")mob_spawn_location.progress randi()# Set the mobs dir…

面试热题(x的平方根)

给你一个非负整数 x ,计算并返回 x 的 算术平方根 。 由于返回类型是整数,结果只保留 整数部分 ,小数部分将被 舍去 。 注意:不允许使用任何内置指数函数和算符,例如 pow(x, 0.5) 或者 x ** 0.5 。 这道题虽然是简单题…

Laravel 5 报错信息存在严重漏洞

靠自己生活,灵魂都是安宁的。 简介 Laravel是一套简洁、优雅的PHPweb开发程序框架,并且具有简洁的表达,是一个比较容易理解且强大的,它提供了强大的工具用以开发大型网站的应用。 漏洞复现 使用Whoops错误库来显示\Whoops\Han…

mysql报错:name ‘_mysql‘ is not defined

原因是: Mysqldb 不兼容 python3.5 以后的版本 解决办法: 使用pymysql代替MySQLdb 在项目应用下的__init__.py 添加上去 import pymysqlpymysql.version_info (1, 4, 13, "final", 0) pymysql.install_as_MySQLdb()

找不到org.apache.http.annotation.NotThreadSafe的类文件

问题现象 最近在做一个调用包含请求体的GET请求功能的时候,引用了 httpclient-4.5.2.jar 和 httpcore-4.4.5.jar。 在工程编译环节报错如下: 原因分析 httpcore 和 httpclient 的版本不匹配。 解决方案 将 httpcore 的版本由 4.4.5 改为 4.4.4&am…

常用工具的常用操作1

写在前面 记录可能用到的各种工具常见技巧。 1:sublime 1.1:操作多列 首先选中要操作的列所在的行: 然后点击selection,spit lines: 接下来移动左右键就可以操作了,删除或者批量添加内容: 1…

客户流失分析预测案例 -- 机器学习项目基础篇(7)

客户流失 它是指现有的客户、用户、订阅者或任何类型的回头客停止与公司开展业务或结束与公司的关系。 客户流失的类型 合同客户流失:当客户签订了服务合同并决定取消服务时,例如有线电视,SaaS。自愿流失:当用户自愿取消服务时…

QtWebApp同时开启http服务和https服务,接受来自客户端的不同请求并进行相应的处理

零、前言 在 QtWebApp开发https服务器,完成客户端与服务器基于ssl的双向认证,纯代码操作 一文中已经用纯代码的形式完成了客户端和服务端的 https 协议交互。 不过,只是开放了https服务,更多情况下,http服务和https服…

解决 Android Studio 的 Gradle 面板上只有关于测试的 task

文章目录 问题描述解决办法 笔者出问题时的运行环境: Android Studio Flamingo | 2022.2.1 Android SDK 33 Gradle 8.0.1 JDK 17 问题描述 笔者最近发现一个奇怪的事情。笔者的 Android Studio 的 Gradle 面板上居然除了用于测试的 task 之外,其它什…

Spring Security6入门及自定义登录

一、前言 Spring Security已经更新到了6.x,通过本专栏记录以下Spring Security6学习过程,当然大家可参考Spring Security5专栏对比学习 Spring Security5专栏地址:security5 Spring Security是spring家族产品中的一个安全框架,核心功能包括…

驱动工作原理

驱动原理 在Linux操作系统中,硬件驱动程序中实现对硬件直接操作,而用户空间,通过通用的系统调用接口(open() 打开相应的驱动设备,ioctl()控制相应的功能等),实现对硬件操作,应用程序没有直接操作…

百度Apollo规划算法——OBB障碍物检测代码解析

百度Apollo规划算法——Box障碍物检测代码解析 前言代码代码分析f1f2f3f4f5f6 参考 前言 本文主要分析Apollo代码中函数bool Box::HasOverlap(const Box2d &box) const {}的数学原理。 在阅读此部分代码时,第一遍没看懂return的一堆什么意思,百度之后…

work weekly

每周汇报:围绕着项目范围及需求内容完成情况多少、人力资源情况、整体进度情况、成本情况、【范围】多少工作、【资源】投入多少人、【时间】花费多少时间、【成本】花了多少钱 【质量】一般没有特别要求的默认软件开发过程规范要求响应时间 【沟通】这里不说了 …

31 对集合中的字符串,按照长度降序排列

思路&#xff1a;使用集合的sort方法&#xff0c;新建一个Comparator接口&#xff0c;泛型是<String>&#xff0c;重写里面的compare方法。 package jiang.com; import java.util.Arrays; import java.util.Comparator; import java.util.List;public class Practice4 {…

ppt压缩文件怎么压缩最小?文件压缩技巧分享

在日常的工作和学习中&#xff0c;难免会遇到PPT太大&#xff0c;需要将其压缩变小的情况&#xff0c;但很多朋友还不知道怎么压缩PPT文件&#xff0c;下面就给大家分享几个简单的方法&#xff0c;分分钟缩小过大的PPT文件。 一、PowerPoint PowerPoint就是微软公司的演示文稿…

微信小程序的自定义TabBar及Vant的使用

一、安装Vant 1、在 资源管理器 空白位置&#xff0c;点右键打开 在外部终端窗口打开 2、初始化NPM npm init -y 3、安装命令 npm i vant/weapp1.3.3 -S --production 4、构建NPM包 在 工具 里选择构建NPM包 5、删除style:v2 在app.json里&#xff0c;删除"style"…

【中断机制】什么是中断?使用中断的原因、注意事项

目录 一、为什么需要中断 二、什么是中断 1、中断的概念 2、中断的分类 3、中断的处理流程 三、中断处理程序要少用延时的原因 一、为什么需要中断 以网卡为例&#xff0c;CPU 如果要从网卡获取数据&#xff0c;不可能时时盯着网卡啥时候会有数据。当网卡收到数据时&…

炼钢工艺流程(2)

1. 轧制单元 更换前后两个工作辊之间的轧制对象称为轧制单元&#xff0c;对应一个轧制计划。两个 支撑辊之间的轧制对象是由多个轧制单元组成&#xff0c;称为轧制单元组&#xff0c;对应多个轧制计 划。 轧制单元的结构 每个计划开始的部分板坯按照宽度非减的方向排列来加热轧…

Linux中安装Node

安装 先从 官方网站 下载安装包&#xff0c;有时 node 版本太新会导致失败&#xff0c;详见下方的常见问题第2点 cd /home // 创建目录&#xff0c;将下载好的 node 安装包上传到此目录 mkdir Download mkdir /usr/local/lib/node解压 // 解压&#xff0c;前面是文件当前路径…