在 PyTorch 中使用梯度检查点在GPU 上训练更大的模型

news2024/11/20 9:28:53

作为机器学习从业者,我们经常会遇到这样的情况,想要训练一个比较大的模型,而 GPU 却因为内存不足而无法训练它。当我们在出于安全原因不允许在云计算的环境中工作时,这个问题经常会出现。在这样的环境中,我们无法足够快地扩展或切换到功能强大的硬件并训练模型。并且由于梯度下降算法的性质,通常较大的批次在大多数模型中会产生更好的结果,但在大多数情况下,由于内存限制,我们必须使用适应GPU显存的批次大小。

本文将介绍解梯度检查点(Gradient Checkpointing),这是一种可以让你以增加训练时间为代价在 GPU 中训练大模型的技术。 我们将在 PyTorch 中实现它并训练分类器模型。

梯度检查点

在反向传播算法中,梯度计算从损失函数开始,计算后更新模型权重。图中每一步计算的所有导数或梯度都会被存储,直到计算出最终的更新梯度。这样做会消耗大量 GPU 内存。梯度检查点通过在需要时重新计算这些值和丢弃在进一步计算中不需要的先前值来节省内存。

让我们用下面的虚拟图来解释。

上面是一个计算图,每个叶节点上的数字相加得到最终输出。假设这个图表示反向传播期间发生的计算,那么每个节点的值都会被存储,这使得执行求和所需的总内存为7,因为有7个节点。但是我们可以用更少的内存。假设我们将1和2相加,并在下一个节点中将它们的值存储为3,然后删除这两个值。我们可以对4和5做同样的操作,将9作为加法的结果存储。3和9也可以用同样的方式操作,存储结果后删除它们。通过执行这些操作,在计算过程中所需的内存从7减少到3。

在没有梯度检查点的情况下,使用PyTorch训练分类模型

我们将使用PyTorch构建一个分类模型,并在不使用梯度检查点的情况下训练它。记录模型的不同指标,如训练所用的时间、内存消耗、准确性等。

由于我们主要关注GPU的内存消耗,所以在训练时需要检测每批的内存消耗。这里使用nvidia-ml-py3库,该库使用nvidia-smi命令来获取内存信息。

 pip install nvidia-ml-py3

为了简单起见,我们使用简单的狗和猫分类数据集的子集。

 git clone https://github.com/laxmimerit/dog-cat-full-dataset.git

执行上述命令后会在dog-cat-full-dataset的文件夹中得到完整的数据集。

导入所需的包并初始化nvdia-smi

 importtorch
 importtorch.nnasnn
 importtorch.optimasoptim
 importnumpyasnp
 fromtorchvisionimportdatasets, models, transforms
 importmatplotlib.pyplotasplt
 importtime
 importos
 importcv2
 importnvidia_smi
 importcopy
 fromPILimportImage
 fromtorch.utils.dataimportDataset,DataLoader
 importtorch.utils.checkpointascheckpoint
 fromtqdmimporttqdm
 importshutil
 fromtorch.utils.checkpointimportcheckpoint_sequential
 device="cuda"iftorch.cuda.is_available() else"cpu"
 %matplotlibinline
 importrandom
 
 nvidia_smi.nvmlInit()

导入训练和测试模型所需的所有包。我们还初始化nvidia-smi。

定义数据集和数据加载器

 #Define the dataset and the dataloader.
 train_dataset=datasets.ImageFolder(root="/content/dog-cat-full-dataset/data/train",
                             transform=transforms.Compose([
                                 transforms.RandomRotation(30),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.RandomResizedCrop(224, scale=(0.96, 1.0), ratio=(0.95, 1.05)),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                             ]))
 
 val_dataset=datasets.ImageFolder(root="/content/dog-cat-full-dataset/data/test",
                             transform=transforms.Compose([
                                 transforms.Resize([224, 224]),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                             ]))
 
 train_dataloader=DataLoader(train_dataset,
                             batch_size=64,
                             shuffle=True,
                             num_workers=2)
 
 val_dataloader=DataLoader(val_dataset,
                             batch_size=64,
                             shuffle=True,
                             num_workers=2)

这里我们用torchvision数据集的ImageFolder类定义数据集。还在数据集上定义了某些转换,如RandomRotation, RandomHorizontalFlip等。最后对图片进行归一化,并且设置batch_size=64

定义训练和测试函数

 deftrain_model(model,loss_func,optimizer,train_dataloader,val_dataloader,epochs=10):
 
     model.train()
     #Training loop.
     forepochinrange(epochs):
       model.train()
       forimages, targetintqdm(train_dataloader):
           images, target=images.to(device), target.to(device)
           images.requires_grad=True
           optimizer.zero_grad()
           output=model(images)
           loss=loss_func(output, target)
           loss.backward()
           optimizer.step()
       ifos.path.exists('grad_checkpoints/') isFalse:
         os.mkdir('grad_checkpoints')
       torch.save(model.state_dict(), 'grad_checkpoints/epoch_'+str(epoch)+'.pt')
 
 
       #Test the model on validation data.
       train_acc,train_loss=test_model(model,train_dataloader)
       val_acc,val_loss=test_model(model,val_dataloader)
 
       #Check memory usage.
       handle=nvidia_smi.nvmlDeviceGetHandleByIndex(0)
       info=nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
       memory_used=info.used
       memory_used=(memory_used/1024)/1024
 
       print(f"Epoch={epoch} Train Accuracy={train_acc} Train loss={train_loss} Validation accuracy={val_acc} Validation loss={val_loss} Memory used={memory_used} MB")
 
 
 
 deftest_model(model,val_dataloader):
   model.eval()
   test_loss=0
   correct=0
   withtorch.no_grad():
       forimages, targetinval_dataloader:
           images, target=images.to(device), target.to(device)
           output=model(images)
           test_loss+=loss_func(output, target).data.item()
           _, predicted=torch.max(output, 1)
           correct+= (predicted==target).sum().item()
   
   test_loss/=len(val_dataloader.dataset)
 
   returnint(correct/len(val_dataloader.dataset) *100),test_loss

上面创建了一个简单的训练和测试循环来训练模型。最后还通过调用nvidia-smi计算内存使用。

训练

 torch.manual_seed(0)
 
 #Learning rate.
 lr=0.003
 
 #Defining the VGG16 sequential model.
 vgg16=models.vgg16()
 vgg_layers_list=list(vgg16.children())[:-1]
 vgg_layers_list.append(nn.Flatten())
 vgg_layers_list.append(nn.Linear(25088,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,2))
 model=nn.Sequential(*vgg_layers_list)
 model=model.to(device)
 
 
 
 #Num of epochs to train
 num_epochs=10
 
 #Loss
 loss_func=nn.CrossEntropyLoss()
 
 # Optimizer 
 # optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
 optimizer=optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)
 
 
 #Training the model.
 model=train_model(model, loss_func, optimizer,
                        train_dataloader,val_dataloader,num_epochs)

我们使用VGG16模型进行分类。下面是模型的训练日志。

可以从上面的日志中看到,在没有检查点的情况下,训练64个批大小的模型大约需要5分钟,占用内存为14222.125 mb。

使用带有梯度检查点的PyTorch训练分类模型

为了用梯度检查点训练模型,只需要编辑train_model函数。

 deftrain_with_grad_checkpointing(model,loss_func,optimizer,train_dataloader,val_dataloader,epochs=10):
 
 
     #Training loop.
     forepochinrange(epochs):
       model.train()
       forimages, targetintqdm(train_dataloader):
           images, target=images.to(device), target.to(device)
           images.requires_grad=True
           optimizer.zero_grad()
           #Applying gradient checkpointing
           segments=2
 
           # get the modules in the model. These modules should be in the order
           # the model should be executed
           modules= [modulefork, moduleinmodel._modules.items()]
 
           # now call the checkpoint API and get the output
           output=checkpoint_sequential(modules, segments, images)
           loss=loss_func(output, target)
           loss.backward()
           optimizer.step()
       ifos.path.exists('checkpoints/') isFalse:
         os.mkdir('checkpoints')
       torch.save(model.state_dict(), 'checkpoints/epoch_'+str(epoch)+'.pt')
 
 
       #Test the model on validation data.
       train_acc,train_loss=test_model(model,train_dataloader)
       val_acc,val_loss=test_model(model,val_dataloader)
 
       #Check memory.
       handle=nvidia_smi.nvmlDeviceGetHandleByIndex(0)
       info=nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
       memory_used=info.used
       memory_used=(memory_used/1024)/1024
 
       print(f"Epoch={epoch} Train Accuracy={train_acc} Train loss={train_loss} Validation accuracy={val_acc} Validation loss={val_loss} Memory used={memory_used} MB")
 
 deftest_model(model,val_dataloader):
   model.eval()
   test_loss=0
   correct=0
   withtorch.no_grad():
       forimages, targetinval_dataloader:
           images, target=images.to(device), target.to(device)
           output=model(images)
           test_loss+=loss_func(output, target).data.item()
           _, predicted=torch.max(output, 1)
           correct+= (predicted==target).sum().item()
   
   test_loss/=len(val_dataloader.dataset)
 
   returnint(correct/len(val_dataloader.dataset) *100),test_lossdeftest_model(model,val_dataloader)

我们将函数名修改为train_with_grad_checkpointing。也就是不通过模型(图)运行训练,而是使用checkpoint_sequential函数进行训练,该函数有三个输入:modules, segments, input。modules是神经网络层的列表,按它们执行的顺序排列。segments是在序列中创建的段的个数,使用梯度检查点进行训练以段为单位将输出用于重新计算反向传播期间的梯度。本文设置segments=2。input是模型的输入,在我们的例子中是图像。这里的checkpoint_sequential仅用于顺序模型,对于其他一些模型将产生错误。

使用梯度检查点进行训练,如果你在notebook上执行所有的代码。建议重新启动,因为nvidia-smi可能会获得以前代码中的内存消耗。

 torch.manual_seed(0)
 
 lr=0.003
 
 # model = models.resnet50()
 # model=model.to(device)
 
 vgg16=models.vgg16()
 vgg_layers_list=list(vgg16.children())[:-1]
 vgg_layers_list.append(nn.Flatten())
 vgg_layers_list.append(nn.Linear(25088,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,4096))
 vgg_layers_list.append(nn.ReLU())
 vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
 vgg_layers_list.append(nn.Linear(4096,2))
 model=nn.Sequential(*vgg_layers_list)
 model=model.to(device)
 
 
 
 
 num_epochs=10
 
 #Loss
 loss_func=nn.CrossEntropyLoss()
 
 # Optimizer 
 # optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
 optimizer=optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)
 
 
 #Fitting the model.
 model=train_with_grad_checkpointing(model, loss_func, optimizer,
                        train_dataloader,val_dataloader,num_epochs)

输出如下:

从上面的输出可以看到,每个epoch的训练大约需要6分45秒。但只需要10550.125 mb的内存,也就是说我们用时间换取了空间,并且这两种情况下的精度都是79,因为在梯度检查点的情况下模型的精度没有损失。

总结

梯度检查点是一个非常好的技术,它可以帮助在小显存的情况下完整模型的训练。经过我们的测试,一般情况下梯度检查点会将训练时间延长20%左右,但是时间长点总比不能用要好,对吧。

本文的源代码:

https://avoid.overfit.cn/post/a13e29c312c741ac94d4a5079fb9f8af

作者:Vikas Kumar Ojha

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/191314.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

万向节锁问题

以前一直听说过万向节锁当时觉得问题太难就没去认真分析最近在B站找了一些视频看懂了。简单来说旋转是有顺序的,比如transform面板有三个旋转分量,你先调整y,再调整x,最后调整z按照正常思路来说,调整x轴是在y轴旋转的基…

想成为数据分析师,看这里,数据分析必备的43个Excel函数

目录 前言 函数分类: 关联匹配类清洗处理类逻辑运算类计算统计类时间序列类 前言 Excel是我们工作中经常使用的一种工具,对于数据分析来说,这也是处理数据最基础的工具。 很多传统行业的数据分析师甚至只要掌握Excel和SQL即可。 对于初学者…

【DataX】datax | datax-web | win搭建datax-web环境

一、环境准备 1、jdk8 2、maven 3、mysql7 4、python3 5、window10 6、idea 7、2345解压(win支持tar.gz解压) 8、git 二、操作步骤 1、datax操作步骤 1)下载datax http://datax-opensource.oss-cn-hangzhou.aliyuncs.com/datax.tar.gz 2&am…

ES6 环境下 Openlayers 集成使用 ol-ext 以及在线示例

ES6 环境下 Openlayers 集成使用 ol-ext 以及在线示例ol-ext 简介版本说明打包后体积集成方式在线示例最近打算重新封装一下 Openlayers,方便前端人员使用,基础功能没什么可说的,毕竟 Openlayers 的示例和 API 已经非常友好了。 想增加一些地…

2023-01-31 CSDN问答中如何防止和惩罚 “偷代码操作“

CSDN问答中如何防止和惩罚 "偷代码操作"前言一. 代码隐藏保护(CSDN官方回复目前无此功能)二. 先占位后抄袭的处理三. 编辑记录是照妖镜总结前言 随着问答的参与时间累积, 逐渐的碰到了一些问题, 常在河边走, 怎能不湿鞋, 原先看到抄代码结果原创没被采纳, 只能报以同…

AcWing 10. 有依赖的背包问题(分组背包问题 + 树形DP)

AcWing 10. 有依赖的背包问题(分组背包问题 树形DP)一、问题二、分析1、整体分析2、状态表示3、状态转移4、循环设计5、初末状态三、代码一、问题 二、分析 1、整体分析 这道题其实就是作者之前讲解过的一道题:AcWing 487. 金明的预算方案…

【双向链表】数据结构双向链表的实现

前言: 前一期我们已经学习过单链表了,今天我们来学习链表中的双向链表! 目录1.概念以及结构2.双向链表结点结构体3.接口实现3.1动态申请一个结点3.2初始化链表3.3打印链表3.4双向链表尾插3.5 双向链表尾删3.6双向链表头插3.7双向链表头删3.8双…

Linux常用命令——pvscan命令

在线Linux命令查询工具(http://www.lzltool.com/LinuxCommand) pvscan 扫描系统中所有硬盘的物理卷列表 补充说明 pvscan命令会扫描系统中连接的所有硬盘,列出找到的物理卷列表。使用pvscan命令的-n选项可以显示硬盘中的不属于任何卷组的物理卷,这些…

OAuth2代码演示

目录 1 创建项目结构 1.1 客户 1.2 认证服务器 1.3 资源拥有者 1.4 资源服务器 client 客户 authorization-server 认证服务 resource-owner 资源所有者 resource-server 资源服务器 工作流程: 客户向资源所有者申请授权码 资源所有者下发授权码 客户拿到授权…

springboot+mongodb初体验

MongoDB 是一个基于分布式文件存储的数据库。由 C 语言编写,旨在为 WEB 应用提供可扩展的高性能数据存储解决方案。 MongoDB 是一个介于关系数据库和非关系数据库之间的产品,是非关系数据库当中功能最丰富,最像关系数据库的。 1、mongodb服务…

JavaScript 算术运算符

JavaScript 算术运算符 加减乘除以及取模&#xff08;求余数&#xff09;、、– <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" />…

代码随想录算法训练营第十六天 | 104.二叉树的最大深度、559.n叉树的最大深度,111.二叉树的最小深度,222.完全二叉树的节点个数

Day15 周日休息一、参考资料二叉树的最大深度 &#xff08;优先掌握递归&#xff09;题目链接/文章讲解/视频讲解&#xff1a; https://programmercarl.com/0104.%E4%BA%8C%E5%8F%89%E6%A0%91%E7%9A%84%E6%9C%80%E5%A4%A7%E6%B7%B1%E5%BA%A6.html 二叉树的最小深度 &#xff08…

车载网络 - BootLoader - CAN/CANFD刷写前提

刷写作为车载网络测试极其重要的一个模块一直拖到今天才开始写,之前确实没有一个太好的想法怎么介绍这一块,虽然现在也没有想出来怎么写能够更好的介绍这块的内容,不过我也尽量用通俗的语言让大家看懂。 刷写流程 刷写流程我也根据用例的设计分为3个阶段:前置条件、刷写程序…

UDP+有穷自动状态机构造网络指令系统

UDP有穷自动状态机构造网络指令系统 项目背景 某展厅的小项目&#xff0c;使用Unity制作了一个视频播放器&#xff0c;作为受控端&#xff0c;需要接收解说员手中的“PAD”或“触控屏电脑”等设备发来的控制指令。要求指令系统满足以下功能&#xff1a; 能够随意切换要播放的…

剑指Offer 第17天 Top K问题 优先级队列解决数据流中位数

目录 剑指 Offer 40. 最小的k个数 剑指 Offer 41. 数据流中的中位数 剑指 Offer 40. 最小的k个数 输入整数数组 arr &#xff0c;找出其中最小的 k 个数。例如&#xff0c;输入4、5、1、6、2、7、3、8这8个数字&#xff0c;则最小的4个数字是1、2、3、4。 示例 1&#xff1a; …

图像处理中的微分算子

摘要 微分算子在图像处理中的作用主要是用在图像的边缘检测&#xff0c;而图像边缘检测必须满足两个条件&#xff1a;一能有效的抑制噪声&#xff0c;二能必须尽量精确定位边缘位置。现在常用的微分算子主要有&#xff1a;Sobel算子&#xff0c;Robert算子&#xff0c;Prewitt…

【数据结构-JAVA】堆和优先级队列

前面介绍过队列&#xff0c;队列是一种先进先出(FIFO)的数据结构&#xff0c;但有些情况下&#xff0c;操作的数据可能带有优先级&#xff0c;一般出队 列时&#xff0c;可能需要优先级高的元素先出队列&#xff0c;该中场景下&#xff0c;使用队列显然不合适&#xff0c;比如&…

Hugo博客教程(一)

秋风阁——北溪入江流&#xff1a;https://focus-wind.com/ 秋风阁——计算机视觉实验&#xff1a;边缘提取与特征检测 文章目录Hugo博客教程&#xff08;一&#xff09;博客静态博客静态博客的优缺点常见的静态博客HexoHugo动态博客动态博客的优缺点常见的动态博客WordPressTy…

sql进阶教程

sql进阶教程第一章、神奇的sql1.1 CASE 表达式将已有编号方式转换为新的方式并统计用一条 SQL 语句进行不同条件的统计用 CHECK 约束定义多个列的条件关系在 UPDATE 语句里进行条件分支表之间的数据匹配在 CASE 表达式中使用聚合函数本节要点1.2 自连接的用法面向集合语言SQL可…

shiro(二):springboot整合shiro

1. 整合思路 2. 加入jsp相关配置方便测试 2.1 加入依赖&#xff1a; <!--引入JSP解析依赖--> <dependency><groupId>org.apache.tomcat.embed</groupId><artifactId>tomcat-embed-jasper</artifactId></dependency> <dependenc…