深度学习:模型训练过程中Trying to backward through the graph a second time解决方案

news2024/12/28 20:55:44

1 问题描述

在训练lstm网络过程中出现如下错误:

Traceback (most recent call last):
  File "D:\code\lstm_emotion_analyse\text_analyse.py", line 82, in <module>
    loss.backward()
  File "C:\Users\lishu\anaconda3\envs\pt2\lib\site-packages\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "C:\Users\lishu\anaconda3\envs\pt2\lib\site-packages\torch\autograd\__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

2 问题分析

按照错误提示查阅相关资料了解到,实际上在大多数情况下retain_graph都应采用默认的False,除了几种特殊情况:

  • 一个网络有两个output分别执行backward进行回传的时候: output1.backward(), output2.backward().
  • 一个网络有两个loss需要分别执行backward进行回传的时候: loss1.backward(), loss2.backward().

但本项目的LSTM训练模型不属于以上情况,再次查找资料,在在pytorch的官方论坛上找到了真正的原因:

如截图中的描述,只要我们对变量进行运算了,就会加进计算图中。所以本项目的问题在于在for循环梯度反向传播中,使用了循环外部的变量h,如下所示:

epochs = 128
    step = 0
    model.train()  # 开启训练模式
    for epoch in range(epochs):
        h = model.init_hidden(batch_size)  # 初始化第一个Hidden_state

        for data in tqdm(train_loader):
            x_train, y_train = data
            x_train, y_train = x_train.to(device), y_train.to(device)
            step += 1  # 训练次数+1

            x_input = x_train.to(device)
            model.zero_grad()

            output, h = model(x_input, h)

            # 计算损失
            loss = criterion(output, y_train.float().view(-1))
            loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step()

            if step % 10 == 0:
                print("Epoch: {}/{}...".format(epoch + 1, epochs),
                      "Step: {}...".format(step),
                      "Loss: {:.6f}...".format(loss.item()))

3 问题解决

代码修改如下:

epochs = 128
    step = 0
    model.train()  # 开启训练模式
    for epoch in range(epochs):
        h = model.init_hidden(batch_size)  # 初始化第一个Hidden_state

        for data in tqdm(train_loader):
            x_train, y_train = data
            x_train, y_train = x_train.to(device), y_train.to(device)
            step += 1  # 训练次数+1

            x_input = x_train.to(device)
            model.zero_grad()

            h = tuple([e.data for e in h])
            output, h = model(x_input, h)

            # 计算损失
            loss = criterion(output, y_train.float().view(-1))
            loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step()

            if step % 10 == 0:
                print("Epoch: {}/{}...".format(epoch + 1, epochs),
                      "Step: {}...".format(step),
                      "Loss: {:.6f}...".format(loss.item()))

增加for循环内部变量,对外部变量进行复制,内部变量参与梯度传播,问题解决。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1044136.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Unity】LODGroup 计算公式

Unity 在配置 LodGroup 时&#xff0c;其分级切换的计算方法是按照物体在相机视野中占据的比例计算的。在运行时&#xff0c;如果相机视野范围&#xff08;Field of View&#xff09;没有改变&#xff0c;那么这个值可以直接换算成物体距离相机的距离。这里就讨论下如何计算得到…

ubuntu下用pycharm专业版连接AI服务器及其docker环境

一&#xff1a;用pycharm专业版连接AI服务器 1、首先在自己电脑上新建一个文件夹&#xff0c;后续用于映射服务器上自己所要用的项目文件 2、用pycharm专业版打开该文件夹&#xff0c;作为一个项目打开 3、然后在工具->部署->配置 4、配置中形式如下&#xff1a; 点击左…

Chatbot UI集成LocalAI实现自托管的ChatGPT

本文比惯例提前了一天发&#xff0c;因为明天一早&#xff0c;老苏就踏上回乡的路了&#xff0c;三年没回老家&#xff0c;这次专门请了 2 天的假 难得回家&#xff0c;打算多陪陪家人&#xff0c;和多年不见的朋友聚聚&#xff0c;当然如果有网络条件&#xff0c;还是会正常发…

英语单词记忆学习打卡系统 微信小程序

本单词记忆系统使用了计算机语言Java和存放数据的仓库MySQL&#xff0c;采用了微信小程序模式来实现。本系统使用了框架SSM和Uni-weixin实现了单词记忆系统应有的功能&#xff0c;系统主要角色包括管理员和用户。 关键词&#xff1a;Java&#xff1b;MySQL&#xff1b;SSM  在…

Unity实现设计模式——命令模式

Unity实现设计模式——命令模式 推荐一个Unity学习设计模式很好的GitHub地址&#xff1a;https://github.com/QianMo/Unity-Design-Pattern 有非常多的Star 一、介绍 命令模式使得请求的发送者与请求的执行者之间消除耦合&#xff0c;让对象之间的调用关系更加灵活。在命令模…

聊聊零拷贝技术原理和应用

文章目录 0. 引言1. 什么是零拷贝技术 1. 零拷贝技术在不同领域的应用2.传统拷贝技术的缺点3. 零拷贝技术的原理与实现1. sendfile系统调用2. 内核缓冲区与用户缓冲区3. DMA&#xff08;Direct Memory Access&#xff09;技术4. 文件描述符传递与共享5. Direct I/O&#xff08;…

Apache shiro RegExPatternMatcher 权限绕过漏洞 (CVE-2022-32532)

漏洞描述 2022年6月29日&#xff0c;Apache 官方披露 Apache Shiro &#xff08;CVE-2022-32532&#xff09;权限绕过漏洞。 当Apache Shiro中使用RegexRequestMatcher进行权限配置&#xff0c;且正则表达式中携带"."时&#xff0c;未经授权的远程攻击者可通过构造恶…

基于Springboot实现毕业生信息招聘平台管理系统演示【项目源码+论文说明】分享

基于Springboot实现毕业生信息招聘平台管理系统演示 摘要 随着社会的发展&#xff0c;社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 毕业生信息招聘平台&#xff0c;主要的模块包括查看管理员&#xff1b;首页、个人中心、企…

Nginx 可视化管理工具与 cpolar 配置:实现远程访问本地服务的优化

文章目录 前言1. docker 一键安装2. 本地访问3. Linux 安装cpolar4. 配置公网访问地址5. 公网远程访问6. 固定公网地址 前言 Nginx Proxy Manager 是一个开源的反向代理工具&#xff0c;不需要了解太多 Nginx 或 Letsencrypt 的相关知识&#xff0c;即可快速将你的服务暴露到外…

服务断路器_服务雪崩解决方案之服务降级

什么是服务降级 两种场景: 当下游的服务因为某种原因响应过慢&#xff0c;下游服务主动停掉一些不太重要的业务&#xff0c;释放出服务器资源&#xff0c;增加响应速度&#xff01;当下游的服务因为某种原因不可用&#xff0c;上游主动调用本地的一些降级逻辑&#xff0c;避免…

SPA移动端解决方案参考

企业在实现SAP移动化时遇到的一些挑战&#xff0c;如果我们利用自己开发团队来进行应用程序的开发&#xff0c;可能会陷入规划&#xff0c;开发&#xff0c;调试&#xff0c;测试的循环中&#xff0c;最后仍一无所获。那如果企业寻找第三方咨询公司进行开发的话&#xff0c;又担…

【高阶数据结构】哈希的应用 {位图;std::bitset;位图的应用;布隆过滤器;布隆过滤器的应用}

一、位图 1.1 位图概念 面试题 给40亿个不重复的无符号整数&#xff0c;没排过序。给一个无符号整数&#xff0c;如何快速判断一个数是否在这40亿个数中。【腾讯】 遍历查找&#xff1a;内存中无法存放40亿个整数&#xff08;约占内存15-16G&#xff09;&#xff1b;时间复杂…

项目经理工具箱

新项目经理误区 要解决的关键点 事&#xff1a;范围&#xff0c;进度&#xff0c;成本&#xff0c;质量 人&#xff1a;项目干系人&#xff0c;团队&#xff0c;外包成员&#xff1b; 干系人管理计划&#xff0c;沟通管理计划 技术和管理区别和联系 非暴力沟通 结构思考力 重…

正点原子lwIP学习笔记——NTP实时时间实验

1. NTP简介 NTP&#xff08;Network Time Protocol&#xff09;网络时间协议基于UDP&#xff0c;用于网络时间同步的协议&#xff0c;使网 络中的计算机时钟同步到UTC&#xff0c;再配合各个时区的偏移调整就能实现精准同步对时功能。 NTP 服务器&#xff08;Network Time Pr…

ERROR in docs.42140ac.js from UglifyJs webpack打包报错

ERROR in docs.42140ac.js from UglifyJs 原因是UglifyJs 针对js压缩 不支持es6语法&#xff08;或者引入的第三方插件存在es6语法&#xff09; ERROR in docs.42140ac.js from UglifyJs 使用的 uglifyjs-webpack-plugin 解决方法 降低uglifyjs-webpack-plugin的版本 “ugl…

系统化思考,从初级到高级书单推荐

用思考工具进行系统思考&#xff0c;解决复杂问题&#xff0c;成为某个领域的高手&#xff0c;下面这几本书就是补充你脑海的系统思考的工具&#xff0c;一定要保存。 《简单的逻辑学》 作者&#xff1a;麦克伦尼 一切的系统源自于逻辑&#xff0c;如果你没有逻辑分析的能力&…

[谷粒商城笔记]07、Linux环境-虚拟机网络设置

1.本机cmd,输入命令ipconfig,查看本地ip 192.168.56.1是虚拟机的ip 2.自定义虚拟机ip 修改这个文件下的 这里&#xff0c;把ip换成 192.168.56.‘10’ 引号内数字自定义 3.在本机和虚拟机命令行&#xff0c;互相ping IP 查看是否设置成功

静态NAT,动态NAT,NAPT(实验配置+原理讲解)

目录 静态NAT,动态NAT&#xff0c;NAPT 实验一&#xff1a;静态NAT地址转换 实验二&#xff1a;动态NAT配置 实验三&#xff1a;NAPT配置 静态NAT,动态NAT&#xff0c;NAPT 静态地址转换&#xff1a;只能实现一个私网与一个公网的一对一映射 动态地址转换&#xff1a;创建…

Python 编程基础 | 第一章-预备知识 | 1.5、开发工具

一、开发工具 - VSCode VSCode是一个相当优秀的IDE&#xff0c;具备开源、跨平台、模块化、插件丰富、轻量化、启动时间快、颜值高的特质。 1、下载VSCode VSCode下载地址&#xff1a;https://code.visualstudio.com/ 2、安装VSCode 载软件包&#xff0c;一步步安装即可&#x…

CSS笔记——基本语法及相关知识

CSS层叠样式表是用于定义 HTML 或 XML 文档的样式和布局的语言。它可以让开发者更加灵活地控制页面元素的样式和排版&#xff0c;从而提高页面的可读性和用户体验 一、css样式书写顺序和规范 CSS样式的书写顺序和规范是为了让代码更易读、易维护和易扩展。下面是一些常见的规…