【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

news2025/1/12 17:40:19

【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

  • 报错详情
  • 错误产生背景
  • 原理
  • 解决方案

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

报错详情

  模型在backward时,发现如下报错:
请添加图片描述
  即RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

  其大概意思是说,当在计算梯度时,某个变量已经被操作修改了,这会导致随后的计算梯度的过程中该变量的值发生变化,从而导致计算梯度出现问题。

错误产生背景

  起因是我要复现一种层级多标签分类的网络结构:
在这里插入图片描述
  当输入序列 x x x经过一次BERT模型之后,得到当前预测的一级标签,然后拼接到输入序列 x x x上,再次输入到BERT模型里以预测二级标签。

  出错版本的模型结构如下:

def forward(self, x, label_A_emb):
        context = x[0]  # 输入的句子
        mask = x[2]  

        d1 = self.bert(context, attention_mask=mask)
        logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]
        idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]
        extra = label_A_emb[idx]

        context[:, -3:] = extra
        mask[:, -3:] = 1

        d2 = self.bert(context, attention_mask=mask)
        logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]

        return logit1, logit2

  在计算梯度时,由于contextmask的值被中间修改过一次,所以会报错。

原理

请添加图片描述
  图中 w 1 w_1 w1的梯度计算如上图,损失函数为 E t o t a l E_{total} Etotal,最终 w 1 w_1 w1的梯度里是需要用到原始输入 i 1 i_1 i1的。

  所以在上面贴的模型结构代码中,输入在经过神经网络之后,又作了一次改动,然后再经过神经网络。但是梯度计算会计算两次的梯度,可是发现输入只有改动后的值了,改动前的值已经被覆盖。

计算梯度时的版本号机制是PyTorch中用于跟踪张量操作历史的一种机制。它允许PyTorch在需要计算梯度时有效地管理和跟踪相关的操作,以便进行自动微分。每个张量都有一个版本号,记录了该张量的操作历史。当对一个张量执行就地操作(inplace operation)时,例如修改张量的值或重新排列元素的顺序,版本号会增加。这种就地操作可能导致计算梯度时出现问题,因为梯度计算依赖于操作历史。

解决方案

  把即将改动的变量深拷贝一份,最终优化的代码如下:

def forward(self, x, label_A_emb):
        context = x[0]  # 输入的句子
        mask = x[2]  

        d1 = self.bert(context, attention_mask=mask)
        logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]
        idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]
        extra = label_A_emb[idx]

        context_B = copy.deepcopy(context)
        mask_B = copy.deepcopy(mask)

        context_B[:, -3:] = extra
        mask_B[:, -3:] = 1

        d2 = self.bert_A(context_B, attention_mask=mask_B)
        logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]

        return logit1, logit2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1173132.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构与算法 | 第三章:栈与队列

本文参考网课为 数据结构与算法 1 第三章栈,主讲人 张铭 、王腾蛟 、赵海燕 、宋国杰 、邹磊 、黄群。 本文使用IDE为 Clion,开发环境 C14。 更新:2023 / 11 / 5 数据结构与算法 | 第三章:栈与队列 栈概念示例 实现顺序栈类定义…

谈谈MySQL的底层存储

这个题目启的很大,但其实只是最近在复习MySQL知识的一点心得,比较零散。 更新数据时,底层page的变化 下面这个图,我还需要解释么? 上面的绿色是b数的索引块,分别说明了101号page的最大id是7,102号page的…

ACM MM 2023 | 清华、华为联合提出MISSRec:兴趣感知的多模态序列推荐预训练

©PaperWeekly 原创 作者 | 王锦鹏 单位 | 清华大学深圳国际研究生院 研究方向 | 多模态检索、推荐系统 序列推荐是一种主流的推荐范式,目的是从用户的历史行为中推测用户偏好,并为之推荐感兴趣的物品。现有的大部分模型都是基于 ID 和类目等信息做…

相机滤镜软件Nevercenter CameraBag Photo mac中文版特点介绍

Nevercenter CameraBag Photo mac是一款相机和滤镜应用程序,它提供了一系列先进的滤镜、调整工具和预设,可以帮助用户快速地优化和编辑照片。 Nevercenter CameraBag Photo mac软件特点介绍 1. 滤镜:Nevercenter CameraBag Photo提供了超过2…

【嵌入式 – GD32开发实战指南(ARM版本)】第2部分 外设篇 - 第2章 温湿度传感器AHT10

1 理论分析 1.1 AHT10介绍 AHT10,新一代温湿度传感器在尺寸与智能方面建立了新的标准:它嵌入了适于回流焊的双列扁平无引脚SMD封装,底面4 x 5mm ,高度1.6mm。传感器输出经过标定的数字信号,标准I2C格式。 AHT10 配有一个全新设计的ASIC专用芯片、一个经过改进的MEMS半导体…

难题来了:分库分表后,查询太慢了,如何优化?

说在前面: 尼恩社群中,很多小伙伴反馈, Sharding-JDBC 分页查询的速度超级慢, 怎么处理? 反馈这个问题的小伙伴,很多很多。 而且这个问题,也是面试的核心难题。前段时间,有小伙伴…

一看就懂的java对象内存布局

前言 Java 中一切皆对象,同时对象也是 Java 编程中接触最多的概念,深入理解 Java 对象能够更帮助我们深入地掌握 Java 技术栈。在这篇文章里,我们将从内存的视角,带你深入理解 Java 对象在虚拟机中的表现形式。 学习路线图&…

2023第二届全国大学生数据分析大赛A题思路

某电商平台用户行为分析与挖掘 背景:电商是当今用户最大的交易市场之一,电商行业也逐渐成熟, 所有市场中可售卖的商品全都在平台中存在,并且在网络和疫情的影 响下,在线上的消费行为满足全年龄段用户。 用户的交易行为…

unittest 通过TextTestRunner(buffer=True)打印断言失败case的输出内容

buffer是unittest.TextTestRunner的一个参数,它决定了测试运行时是否将输出结果缓存,并在测试完成后一次性打印。 当buffer设置为True时,测试运行期间的输出结果会被缓存起来,并在测试完成后一次性打印。这对于一些输出频繁的测试…

Lamport Clock算法

Lamport Clock 是一种表达逻辑时间的逻辑时钟(logical clock),能够计算得到历史事件的时间偏序关系。 假设 P0进程是分布式集群中心节点中的监控者,用于统一管理分布式系统中事件的顺序。其他进程在发送消息之前和接受事件消息之后…

操作系统——内存映射文件(王道视频p57)

1.总体概述: 2.传统文件访问方式: 我认为,这种方式最大的劣势在于,如果要对整个文件的不同部分进行多次操作的话,这样确实开销可能会大一些,而且程序员还要指定对应的“分块”载入到内存中 3.内存映射文件…

Qt的事件

2023年11月5日,周日上午 还没写完,不定期更新 目录 事件处理函数的字体特点Qt事件处理的工作原理一些常用的事件处理函数Qt中的事件类型QEvent类的type成员函数可以用来判断事件的类型事件的类型有哪些?有多少种事件类 事件处理函数的字体特…

unittest 通过TextTestRunner(failfast=True),失败或错误时停止执行case

failfast是unittest.TextTestRunner的一个参数,它用于控制测试运行过程中遇到第一个失败或错误的测试方法后是否立即停止执行。 当failfast设置为True时,一旦发现第一个失败或错误的测试方法,测试运行就会立即停止,并输出相应的失…

插值表达式 {{}}

前言 持续学习总结输出中,今天分享的是插值表达式 {{}} Vue插值表达式是一种Vue的模板语法,我们可以在模板中动态地用插值表达式渲染出Vue提供的数据绑定到视图中。插值表达式使用双大括号{{ }}将表达式包裹起来。 1.作用: 利用表达式进行…

教你烧录Jetson Orin Nano的ubuntu20.04镜像

Jetson Orin Nano烧录镜像 视频讲解 教你烧录Jetson Orin Nano的ubuntu20.04镜像 1. 下载sdk manager https://developer.nvidia.com/sdk-manager sudo dpkg -i xxxx.deb2. 进入recovery 插上typeC后,短接J14的FORCE_RECOVERY和GND,上电 如下图&#…

【调度算法】单机调度遗传算法

问题描述 工件ABCDEFG工件编号0123456加工时间4765835到达时间3245321交货期10153024141320 目标函数 最小化交货期总延时时间 运算结果 最佳调度顺序: [6, 3, 2, 5, 0, 1, 4] 最小交货期延时时间: 47python代码 import random import numpy as np…

自动驾驶行业观察之2023上海车展-----智驾供应链(3)

智驾解决方案商发展 华为:五项重磅技术更新,重点发布华为ADS 2.0和鸿蒙OS 3.0 1)产品方案:五大解决方案都有了全面的升级,分别推出了ADS 2.0、鸿蒙OS 3.0、iDVP智能汽车数字平台、智能车云服务和华为车载光最新 产品…

linux下使用vscode对C++项目进行编译

项目的目录结构 头文件swap.h 在自定义的头文件中写函数的声明。 // 函数的声明 void swap(int a,int b);swap.cpp 导入函数的声明&#xff0c;写函数的定义 #include "swap.h" // 双引号表示自定义的头文件 #include <iostream> using namespace std;// 函…

2023年中国商业密码行业研究报告

第一章 行业概况 1.1 定义及分类 根据《密码法》相关规定&#xff0c;密码是指采用特定变换的方法对信息等进行加密保护、安全认证的技术、产品和服务。 密码产业是指为了保障信息安全&#xff0c;提供加密保护、安全认证相关技术、产品和服务的相关行业总称&#xff0c;主 要…

为机器学习算法准备数据(Machine Learning 研习之八)

本文还是同样建立在前两篇的基础之上的&#xff01; 属性组合实验 希望前面的部分能让您了解探索数据并获得洞察力的几种方法。您发现了一些数据怪癖&#xff0c;您可能希望在将数据提供给机器学习算法之前对其进行清理&#xff0c;并且发现了属性之间有趣的相关性&#xff0c…