pytorch混合精度训练及其例子(不显著影响模型精度的前提下,提高计算效率和减少显存占用。)

news2024/12/26 6:38:04

混合精度计算(Mixed Precision Training)是一种在深度学习训练过程中同时使用单精度浮点数(FP32)和半精度浮点数(FP16)进行计算的方法。其主要目的是在不显著影响模型精度的前提下,提高计算效率和减少显存占用。

混合精度计算的优势

  1. 提高计算效率:FP16计算在现代GPU(如NVIDIA的Volta、Turing和Ampere架构)上可以显著加速,因为这些GPU对FP16运算进行了优化。
  2. 减少显存占用:FP16数据占用的显存空间是FP32的一半,这意味着可以在相同的显存中处理更大的模型或更大的批量数据。

实现混合精度计算的关键组件

  1. 自动混合精度(AMP):PyTorch提供了torch.cuda.amp模块来简化混合精度训练的实现。AMP会自动管理FP16和FP32的转换,确保计算的稳定性和精度。
  2. 梯度缩放(Gradient Scaling):由于FP16的动态范围较小,梯度值可能会变得非常小,从而导致数值下溢。梯度缩放通过在反向传播之前将梯度值放大,在更新权重之前再缩小,来避免这个问题。

在这个示例中,我们使用了torch.cuda.amp.autocast来自动进行混合精度计算,并使用torch.cuda.amp.GradScaler来缩放损失值,从而避免数值不稳定性。这样可以显著减少显存占用,同时加速训练过程。

代码如下所示

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 初始化模型、损失函数和优化器
model = SimpleModel().cuda()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 创建GradScaler对象
scaler = GradScaler()

# 模拟训练数据
inputs = torch.randn(32, 10).cuda()
targets = torch.randn(32, 1).cuda()

# 训练循环
for epoch in range(10):
    model.train()
    optimizer.zero_grad()

    # 使用autocast进行前向和后向传播
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)

    # 使用GradScaler进行反向传播和优化
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")

核心代码:

from torch.cuda.amp import autocast, GradScaler

……

# 使用autocast进行前向和后向传播
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)

    # 使用GradScaler进行反向传播和优化
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2265681.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

自动驾驶3D目标检测综述(六)

停更了好久终于回来了(其实是因为博主去备考期末了hh) 这一篇接着(五)的第七章开始讲述第八章的内容。第八章主要介绍的是三维目标检测的高效标签。 目录 第八章 三维目标检测高效标签 一、域适应 (一)…

100V宽压输入反激隔离电源,适用于N道沟MOSFET或GaN或5V栅极驱动器,无需光耦合

说明: PC4411是一个隔离的反激式控制器在宽输入电压下具有高效率范围为2.7V至100V。它直接测量初级侧反激输出电压波形,不需要光耦合器或第三方用于调节的绕组。设置输出只需要一个电阻器电压。PC4411提供5V栅极驱动驱动外部N沟道MOSFET的电压或GaN。内部补偿和软启…

1.系统学习-线性回归

系统学习-线性回归 前言线性回归介绍误差函数梯度下降梯度下降示例 回归问题常见的评价函数1. MAE, mean absolutely error2. MSE, mean squared error3. R square (决定系数或R方) 机器学习建模流程模型正则化拓展阅读作业 链接: 2.系统学习-逻辑回归 …

windows使用zip包安装MySQL

windows通过zip包安装MySQL windows通过zip包安装MySQL下载MySQL的zip安装包创建安装目录和数据目录解压zip安装包创建配置目录 etc 和 配置文件 my.ini安装MySQL进入解压后的bin目录执行命令初始化执行命令安装 验证安装查看服务已安装 启动MySQL查看服务运行情况修改密码创建…

【Postgresql】数据库忘记密码时,重置密码 + 局域网下对外开放访问设置

【Postgresql】数据库忘记密码时,重置密码 + 局域网下对外开放访问设置 问题场景数据库忘记密码时,重置密码局域网下对外开放访问设置问题场景 Postgresql可支持复杂查询,支持较多的数据类型,在生产中较为使用。但有时在局域网下,想通过外部连接使用数据库,可能会出现数…

大模型-使用Ollama+Dify在本地搭建一个专属于自己的聊天助手与知识库

大模型-使用OllamaDify在本地搭建一个专属于自己的知识库 1、本地安装Dify2、本地安装Ollama并解决跨越问题3、使用Dify搭建聊天助手4、使用Dify搭建本地知识库 1、本地安装Dify 参考往期博客:https://guoqingru.blog.csdn.net/article/details/144683767 2、本地…

UE5 崩溃问题汇总!!!

Using bundled DotNet SDK version: 6.0.302 ERROR: UnrealBuildTool.dll not found in "..\..\Engine\Binaries\DotNET\UnrealBuildTool\UnrealBuildTool.dll" 在你遇到这种极奇崩溃的BUG ,难以解决的时候。 尝试了N种方法,都不行的解决方法。…

数字IC前端学习笔记:脉动阵列的设计方法学(四)

相关阅读 数字IC前端https://blog.csdn.net/weixin_45791458/category_12173698.html?spm1001.2014.3001.5482 引言 脉动结构(也称为脉动阵列)表示一种有节奏地计算并通过系统传输数据的处理单元(PEs)网络。这些处理单元有规律地泵入泵出数据以保持规则…

软件工程-【软件项目管理】--期末复习题汇总

一、单项选择题 (1)赶工一个任务时,你应该关注( C ) A. 尽可能多的任务 B. 非关键任务 C. 加速执行关键路径上的任务 D. 通过成本最低化加速执行任务 (2)下列哪个不是项目管理计划的一部分&…

【Git学习】windows系统下git init后没有看到生成的.git文件夹

[问题] git init 命令后看不到.git文件夹 [原因] 文件夹设置隐藏 [解决办法] Win11 win10

《Posterior Collapse and Latent Variable Non-identifiability》

看起来像一篇很有用的paper,而且还是23年的 没看完 后边看不懂了 Abstract 现有的解释通常将后验崩塌归因于由于变分近似而使用神经网络或优化问题。 而本文认为后验崩塌是潜在变量不可识别性的问题(a problem of latent variable non-identifiability) 本文证明了…

申请腾讯混元的API Key并且使用LobeChat调用混元AI

申请腾讯混元的API Key并且使用LobeChat调用混元AI 之前星哥写了一篇文章《手把手教拥有你自己的大模型ChatGPT和Gemini等应用-开源lobe-chat》搭建的开源项目,今天这篇文章教大家如何添加腾讯云的混元模型,并且使用LobeChat调用腾讯混元AI。 申请腾讯混…

Navicat通过ssh连接mysql

navicat 通过ssh连接mysql 对搭建完的mysql连接时,通过ssh连接的方法 需要确保mysql默认端口3306没有被防火墙阻拦 第一步 第二步 35027448270)] 需要注意的是乌班图系列的默认root的ssh是禁止的,应该用别的账户登录

【NACOS插件】使用官网插件更换NACOS数据库

说明 nacos 2.3.1默认支持mysql和derby数据库,如果想要支持其他数据库,可以通过使用插件方式实现。对于该插件的使用,官方说明文档较为粗略(不过也没问题,实际上整个过程就是很简单,只是使用者想复杂了),网…

mysql基础(jdbc)

1.Java连接数据库步骤 1.注册驱动 Class<?> driverManagerClass.forName("com.mysql.cj.jdbc.Driver"); 2.获取连接 Connection conDriverManager.getConnection("jdbc:mysql://localhost:3306/studymysql","root","123456"); …

ROM修改进阶教程------修改刷机包init.rc 自启用户自定义脚本的一些基本操作 代码格式与注意事项

在很多定制化固件中。我们需要修改系统的rc文件来启动自己的一些脚本。但有时候修改会不起作用,其具体原因在于权限与代码格式的问题。博文将系统的解析代码操作编写的注意事项与各种权限分别。了解以上. 轻松编写自定义启动脚本. 通过博文了解💝💝💝 1-------💝💝…

硬件模块常使用的外部中断

对于STM32来说&#xff0c;想要获取的信号是外部驱动的很快的突发信号 例1&#xff1a;旋转编码器的输出信号&#xff1a; 可能很久都不会拧它&#xff0c;不需要STM32做任何事情但是一拧它&#xff0c;就会有很多脉冲波形需要STM32接收信号是突发的&#xff0c;STM32不知道什…

3D布展平台主要有哪些功能?有什么特点?

3D布展平台是一种利用3D技术和虚拟现实&#xff08;VR&#xff09;技术&#xff0c;为用户提供线上虚拟展览和展示服务的平台。这些平台通常允许用户创建、设计和发布3D虚拟展厅&#xff0c;从而提供沉浸式的展览体验。以下是对3D布展平台的详细介绍&#xff1a; 一、主要功能 …

大恒相机开发(2)—Python软触发调用采集图像

大恒相机开发&#xff08;2&#xff09;—Python软触发调用采集图像 完整代码详细解读和功能说明扩展学习 这段代码是一个Python程序&#xff0c;用于从大恒相机采集图像&#xff0c;通过软件触发来采集图像。 完整代码 咱们直接上python的完整代码&#xff1a; # version:…

VTK知识学习(27)- 图像基本操作(二)

1、图像类型转换 1&#xff09;vtkImageCast 图像数据类型转换在数字图像处理中会频繁用到。一些常用的图像算子(例如梯度算子)在计算时出于精度的考虑&#xff0c;会将结果存储为float或double类型&#xff0c;但在图像显示时&#xff0c;一般要求图像为 unsigned char 类型,…