解决使用copy.deepcopy()拷贝Tensor或model时报错只支持用户显式创建的Tensor问题

news2024/11/16 7:32:59

模型训练过程中常需边训练边做validation或在训练完的模型需要做测试,通常的做法当然是先创建model实例然后掉用load_state_dict()装载训练出来的权重到model里再调用model.eval()把模型转为测试模式,这样写对于训练完专门做测试时当然是比较合适的,但是对于边训练边做validation使用这种方式就需要写一堆代码,如果能使用copy.deepcopy()直接深度拷贝训练中的model用来做validation显然是比较简洁的写法,但是由于copy.deepcopy()的限制,写model里代码时如果没注意,调用copy.deepcopy(model)时可能就会遇到这个错误:Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment,详细错误信息如下:

 File "/usr/local/lib/python3.6/site-packages/prc/framework/model/validation.py", line 147, in init_val_model
    val_model = copy.deepcopy(model)
  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib64/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib64/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/usr/lib64/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib64/python3.6/copy.py", line 306, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib64/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib64/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/usr/lib64/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib64/python3.6/copy.py", line 306, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib64/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib64/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/usr/lib64/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib64/python3.6/copy.py", line 161, in deepcopy
    y = copier(memo)
  File "/root/.local/lib/python3.6/site-packages/torch/_tensor.py", line 55, in __deepcopy__
    raise RuntimeError("Only Tensors created explicitly by the user "
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

这个错误简单地说就是copy.deepcopy()不支持拷贝requires_grad=True的Tensor(在网络中一般是非叶子结点Tensor, grad_fn不为None),开始以为真的哪个地方Tensor的requires_grad没有按要求设置,熬了几个夜去检查调试网络代码没发现什么线索很郁闷,后来想既然是copy.deepcopy()里报错的,源码也有那就去它里面debug看是拷贝网络的那部分时抛出的Exception吧,折腾了一阵发现里面这个地方加breakpoint比较合适:

   if dictiter is not None:
        if deep:
            for key, value in dictiter:
                key = deepcopy(key, memo)
                value = deepcopy(value, memo)
                y[key] = value
        else:
            for key, value in dictiter:
                y[key] = value

我这个网络的结构是使用的python dict方式定义的,运行时使用注册机制动态创建出来的,既然是dict,这里的key和value就是对应配置文件里的定义网络每层结构的dict的key和value,在这里加bp可以比较清楚地跟踪看到是在哪个地方导致的抛出Exception,结果发现原因是因为有个实现分割功能的head类的内部有个成员变量保存了这层的输出结果Tensor用于后面计算loss,模型每层的输出数据Tensor自然是requires_grad=True,把这个成员变量去掉,改成forward()输出结果,然后在网络的主类里接收它并传入计算Loss的函数,然后deepcopy(model)就不报上面的错了!

另外,显式创建一个Tensor时指定requires_grad=True(默认是False)并不会导致copy.deepcopy()报错,不管这个Tensor是在cpu上还是gpu上,关键是用户自己创建的Tensor是叶子结点Tensor,它的grad_fn是None,在这个Tensor上做切片或者加载到gpu上等操作得到的新的Tensor就不是叶子结点了,pytorch认为requires_grad=Trued的Tensor经过运算得到新的Tensor是需要求导的会自动加上grad_fn而不管这个Tensor是不是网络的一部分,这时再使用copy.deepcopy()深度拷贝新的Tensor时会抛出上面的错误,看完下面的示例就知道了:

>>> t = torch.tensor([1,2,3.5],dtype=torch.float32, requires_grad=True, device='cuda:0')
>>> t
tensor([1.0000, 2.0000, 3.5000], device='cuda:0', requires_grad=True)
>>> x = copy.deepcopy(t)
>>> x
tensor([1.0000, 2.0000, 3.5000], device='cuda:0', requires_grad=True)
>>> t1 = t[:2]
>>> t1
tensor([1., 2.], device='cuda:0', grad_fn=<SliceBackward0>)
>>> x = copy.deepcopy(t1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/python3.8/lib/python3.8/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/root/.local/lib/python3.8/site-packages/torch/_tensor.py", line 85, in __deepcopy__
    raise RuntimeError("Only Tensors created explicitly by the user "
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

>>> t = torch.tensor([1,2,3.5],dtype=torch.float32, requires_grad=True)
>>> t1 = t.cuda()
>>> t1
tensor([1.0000, 2.0000, 3.5000], device='cuda:0', grad_fn=<ToCopyBackward0>)
>>> x = copy.deepcopy(t1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/python3.8/lib/python3.8/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/root/.local/lib/python3.8/site-packages/torch/_tensor.py", line 85, in __deepcopy__
    raise RuntimeError("Only Tensors created explicitly by the user "
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

>>> t = torch.tensor([1,2,3.5],dtype=torch.float32, requires_grad=False)
>>> t
tensor([1.0000, 2.0000, 3.5000])
>>> x = copy.deepcopy(t)
>>> x
tensor([1.0000, 2.0000, 3.5000])
>>> t1 = t[:2]  
>>> t1
tensor([1., 2.])
>>> x = copy.deepcopy(t1)

为何deepcopy()不直接支持有梯度的Tensor,按理要支持复制一个当时的瞬间值应该也没问题,看到https://discuss.pytorch.org/t/copy-deepcopy-vs-clone/55022/10这里这个经常回答问题的胡子哥给了个猜测:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/184820.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ssm学生心理健康测评网的规划与设计

摘 要 1 Abstract 1 1 绪论 1 1.1 课题背景 1 1.2 课题研究现状 1 1.3 初步设计方法与实施方案 2 1.4 本文研究内容 2 2 系统开发环境 4 2.1 JSP技术介绍 4 2.2 B/S模式 4 2.3 MySQL环境配置 5 3 系统分析 6 3.1 系统可行性分析 6 3.1.…

【模糊神经网络】基于模糊神经网络的倒立摆轨迹跟踪控制

临近春节没啥事做&#xff0c;突然想起前两年未完成的模糊神经网络&#xff0c;当时是学了一段时间&#xff0c;但是到最后矩阵求偏导那块始终不对&#xff0c;最后也不了了之了&#xff0c;趁最近有空&#xff0c;想重新回顾回顾&#xff0c;看看会不会产生新的想法。经过不断…

elasticsearch基本操作

elasticsearch基本操作基础两种模式:ik分词器词库拓展索引库操作mapping映射属性typeindexanalyzerproperties索引库的CRUD创建修改查询删除文档操作创建查询修改删除基础 本教程使用es8.6.0与kibana作为测试环境 打开开发工具 ## 1.查看节点信息 GET /_cat/nodes?v ## 2.查…

【JUC并发编程】深入浅出Java并发基石——AQS

【JUC并发编程】深入浅出Java并发基石——AQS 参考资料&#xff1a; RedSpider社区——第十一章 AQS 深入剖析并发之AQS独占锁 1.5w字&#xff0c;30图带你彻底掌握 AQS&#xff01; 深入浅出AbstractQueuedSynchronizer 我画了35张图就是为了让你深入 AQS 动画演示AQS的核心原…

遍历 “可变参数模板” 的模板参数

类模板和函数模板&#xff0c;只能包含固定数量的模板参数&#xff0c;C11支持模板参数可变&#xff0c;那么在不知道模板参数有多少个的情况下&#xff0c;如何遍历模板参数&#xff1f; 目录 一、可变参数模板的声明 二、可变参数模板的遍历 1、递归遍历 2、非递归遍历 …

idea使用DataBase连接数据库 Free MyBatis Tool自动生成 实体类工具使用

DataBase DataBase连接数据库 设置DataSources Host 》IP地址Port 》端口号User 》用户名Password 》密码Database 》连接的数据库 设置驱动 Drives tables 文件夹中即所连接数据库中表 Free MyBatis Tool自动生成 实体类&#xff0c;Mapper &#xff0c;以及mapper.xml 选…

CleanMyMac4.12.3最新版本Mac系统清理工具

CleanMyMac可以为Mac腾出空间&#xff0c;软件已经更新到CleanMyMac X 支持最新版Macos 10.14系统。CleanMyMac具有一系列巧妙的新功能&#xff0c;可让您安全&#xff0c;智能地扫描和清理整个系统&#xff0c;删除大量未使用的文件&#xff0c;减小iPhoto图库的大小&#xff…

非类型模板参数/模板的特化/模板的分离编译

上一篇文章中&#xff0c;我们对模板有了初步的认识&#xff0c;接下来我们便对模板进一步地学习&#xff01; 1.非类型模板参数 模板参数分为类型形参与非类型形参&#xff1a; ①类型形参即&#xff1a;出现在模板参数列表中&#xff0c;跟在class或者typename之类的参数类…

20 个杀手级的 JavaScript 单行代码,可以节省你的编码时间

使用这些基本的单行代码将您的 JavaScript 技能提升到一个新的水平&#xff0c;这也将节省您的编码时间 &#x1f680;1) 查找数组中的最大值&#xff1a;Math.max(...array)2&#xff09;从数组中删除重复项&#xff1a;[...newSet(array)]3&#xff09;生成一个1到100之间的随…

Java开发环境搭建实践

前言 刚刚弄完python的环境搭建&#xff0c;今年打算也要好好学习Java&#xff0c;所以把Java的环境弄起来 搭建过程 jdk下载和安装 下载 官网&#xff1a;Oracle 甲骨文中国 | 云应用和云平台 打开官网 点击产品后下拉找到Java点进去。 下载Java 我就下载最新的jdk把…

Spring Boot学习之任务学习【异步、定时、邮件】

文章目录一 异步任务1.1 创建spring Boot项目&#xff0c;选择Spring Web1.2 创建AsyncService类1.3 编写controller类1.4 在启动类上开启异步功能1.5 测试结果二 定时任务2.1 基础知识2.2 项目创建2.3 创建一个ScheduledService2.4 在主程序上增加EnableScheduling 开启定时任…

Hbuilder打包成苹果IOS-App的详解

本文相关主要记录一下使用Hbuilder打包成苹果IOS-App的详细步骤。介绍一下个人开发者账号&#xff1a;再说下什么是免费的苹果开发者账号&#xff0c;就是你没交688年费的就是免费账号&#xff0c;如果你想变成付费开发者账号&#xff0c;提交申请付费就行&#xff0c;账号都是…

【C++】priority_queue使用模拟实现

priority_queue使用 http://www.cplusplus.com/reference/queue/priority_queue/ 文档介绍 优先级队列是一种容器适配器,根据严格的弱排序标准,它的第一个元素总是它所包含的元素中最大的(大堆为例) 在堆中可以随时插入元素,并且只能检索最大堆元素(优先队列中位于顶部的元素)…

应用系统基于OAuth2实现单点登录的解决方案

1、OAuth2单点认证原理 基于OAuth2的认证方式包含四种&#xff0c;其中单点登录最常用的是授权码模式&#xff0c;其基本的认证过程如下&#xff1a; 用户访问业务应用&#xff0c;业务应用进行登录检查&#xff1b;业务应用重定向到OAuth2认证服务器&#xff0c;调用获取授权…

米哈伊年终奖是32万,我的年终奖是彩虹屁!

数据来源沉默王二 | 数据报表小熊绘制 年都过完了&#xff0c;年终奖结果也都出来了&#xff0c;我这个年没有过好&#xff0c;每次想到就难受&#xff0c;在看王二整理出来的年终奖&#xff0c;整个人都不好了。 本次统计基于49条数据的不准确统计&#xff0c;仅抽取部分公司部…

Lesson 4.4 随机梯度下降与小批量梯度下降

文章目录一、损失函数理论基础二、随机梯度下降&#xff08;Stochastic Gradient Descent&#xff09;1. 随机梯度下降计算流程2. 随机梯度下降的算法特性3. 随机梯度下降求解线性回归4. 随机梯度下降算法评价三、小批量梯度下降&#xff08;Mini-batch Gradient Descent&#…

SpringMVC执行流程和原理

1、用户发送出请求到前端控制器DispatcherServlet。 2、DispatcherServlet收到请求调用HandlerMapping&#xff08;处理器映射器&#xff09;。 3、HandlerMapping找到具体的处理器(可查找xml配置或注解配置)&#xff0c;生成处理器对象及处理器拦截器 (如果有)&#xff0c;再…

51单片机学习笔记-3模块化编程

3 模块化编程 [toc] 注&#xff1a;笔记主要参考B站江科大自化协教学视频“51单片机入门教程-2020版 程序全程纯手打 从零开始入门”。 3.1 模块化编程 传统方式编程&#xff1a;所有的函数均放在main.c里&#xff0c;若使用的模块比较多&#xff0c;则一个文件内会有很多的…

1604_linux环境下使用命令行把网页转换成pdf

全部学习汇总&#xff1a; GreyZhang/toolbox: 常用的工具使用查询&#xff0c;非教程&#xff0c;仅作为自我参考&#xff01; (github.com) 使用的工具很容易在彼此之间产生隔离性障碍&#xff0c;比如我最近使用的墨水屏阅读的最合适的文件格式我觉得是pdf&#xff0c;但是我…

路由工具之路由策略router-policy、acl列表与ip-prefix前缀列表的区别、过滤列表filter-policy

3.0.0 路由工具之路由策略router-policy、acl列表与ip-prefix前缀列表的区别、过滤列表filter-policy 目录IP-Prefix前缀列表前缀列表与ACLrouter-policy路由策略应用路由策略过滤路由1、环境介绍2、配置OSPF3、过滤路由&#xff08;1&#xff09;ACL匹配路由方式过滤&#xff…