已解决YOLOv5训练报错:RuntimeError: Expected all tensors to be on the same device......

news2024/11/27 4:31:42

这是发生在集成一个yolov5中没有的检测头head的情况下发生的错误,出现的时候是已经训练起来了,在训练结束时发生的报错,下面是我的解决办法。

1、问题出现及分析排查

改yolov5的网络进行训练时出的报错:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

百思不得其解,经过反复调试最终解决了该问题,其实细心一点估计早就解决了。
在这里插入图片描述

具体问题报错如下:

Optimizer stripped from work_yolox/xs_decoupledhead/deviceerror_test9/weights/last.pt, 18.2MB
Optimizer stripped from work_yolox/xs_decoupledhead/deviceerror_test9/weights/best.pt, 18.2MB

Validating work_yolox/xs_decoupledhead/deviceerror_test9/weights/best.pt...
Fusing layers... 
YOLOv5s_yolox_s summary: 374 layers, 8942326 parameters, 0 gradients
                 Class     Images  Instances          P          R      mAP50   mAP50-95:   0%|          | 0/4 00:00
Traceback (most recent call last):
  File "train.py", line 634, in <module>
    main(opt)
  File "train.py", line 528, in main
    train(opt.hyp, opt, device, callbacks)
  File "train.py", line 411, in train
    results, _, _ = validate.run(
  File "/home/luban/miniconda3/envs/CLDet/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/nfs/private/windpaper_yolo/val.py", line 210, in run
    preds, train_out = model(im) if compute_loss else (model(im, augment=augment), None)
  File "/home/luban/miniconda3/envs/CLDet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/nfs/private/windpaper_yolo/models/yolo.py", line 304, in forward
    return self._forward_once(x, profile, visualize)  # single-scale inference, train
  File "/nfs/private/windpaper_yolo/models/yolo.py", line 197, in _forward_once
    x = m(x)  # run
  File "/home/luban/miniconda3/envs/CLDet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/nfs/private/windpaper_yolo/models/yolo.py", line 137, in forward
    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

乍一看,就是变量类别不一致问题,有的在cuda设备上,有的在cpu设备上,导致计算的时候报错。这里也可以看到,问题最终是出现在

File "/nfs/private/windpaper_yolo/models/yolo.py", line 137, in forward
    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

因为这是训练过程出现的问题,且只出现在最后阶段,无论前面训练多少个epochs都没错,所以一开始查错有一点曲折。由于训练已经完成在调用验证代码出现的问题,先找到了val.y,在下面这行代码的位置反复调试,试图找出在cpu和gpu的变量,把他们统一整合到gpu上
在这里插入图片描述
试了半天,发现都是gpu变量,包括im、model、targets等,甚至直接找模型的参数如model.parameters()、model.state_dict()等来看所在的位置,结果要么是在gpu上,要么就是看不到,后面也几乎放弃了。

最后经过仔细分析,在yolo.py中反复看,最后经过尝试,终于排查出了问题,修改后经过验证就解决了问题。

2、问题解决方法

在yolov5中的yolo.py中有这么一段代码:

def _apply(self, fn):
        # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
        self = super()._apply(fn)
        m = self.model[-1]  # Detect()
        if isinstance(m, (Detect, Segment)):
            m.stride = fn(m.stride)
            m.grid = list(map(fn, m.grid))
            if isinstance(m.anchor_grid, list):
                m.anchor_grid = list(map(fn, m.anchor_grid))
        return self

这是 class BaseModel的一个私有函数,没有找到调用的位置,具体执行方法还不太清楚。结合这个函数的提示,以及错误中涉及 self.gridself.stride,我关注到了这个函数。函数中对yolov5构建head做了处理,如Detect、Segment,我尝试把我改的head模块和这几个放一起加到代码中进行调试。我先在val.py中打断点验证了stride的类型,发现:

ipdb> model.stride.device
device(type='cpu')

可见,确实存在cpu类型的变量。然后我在yolo.py中的class BaseModel的_apply(self, fn)上调试发现:

ipdb> fn
<function Module.to.<locals>.convert at 0x7fc1a25ed280>
ipdb> m.stride
tensor([ 8., 16., 32.])
ipdb> type(m.stride)
<class 'torch.Tensor'>
ipdb> m.grid
[tensor([]), tensor([]), tensor([])]
ipdb> m.anchor_grid
[tensor([]), tensor([]), tensor([])]

把我的检测头head加到代码中调试,即把你yaml配置文件中的head的模块名称加到YOUR_HEAD_MODULE,再跑代码,就不会再报错,修改如下:

def _apply(self, fn):
        # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
        self = super()._apply(fn)
        m = self.model[-1]  # Detect()
        if isinstance(m, (Detect, Segment, YOUR_HEAD_MODULE)):
            m.stride = fn(m.stride)
            m.grid = list(map(fn, m.grid))
            if isinstance(m.anchor_grid, list):
                m.anchor_grid = list(map(fn, m.anchor_grid))
        return self

修改前后变量的变化如下调试过程:


> /nfs/private/windpaper_yolo/models/yolo.py(236)_apply()
    235         if isinstance(m, (Detect, DetectDcoupleHead, Segment)):
--> 236             m.stride = fn(m.stride)
    237             m.grid = list(map(fn, m.grid))

ipdb> m.stride
tensor([ 8., 16., 32.])
ipdb> n
> /nfs/private/windpaper_yolo/models/yolo.py(237)_apply()
    236             m.stride = fn(m.stride)
--> 237             m.grid = list(map(fn, m.grid))
    238             if isinstance(m.anchor_grid, list):

ipdb> m.stride
tensor([ 8., 16., 32.], device='cuda:0')
ipdb> m.grid
[tensor([]), tensor([]), tensor([])]
ipdb> n
> /nfs/private/windpaper_yolo/models/yolo.py(238)_apply()
    237             m.grid = list(map(fn, m.grid))
--> 238             if isinstance(m.anchor_grid, list):
    239                 m.anchor_grid = list(map(fn, m.anchor_grid))

ipdb> m.grid
[tensor([], device='cuda:0'), tensor([], device='cuda:0'), tensor([], device='cuda:0')]
ipdb> n
> /nfs/private/windpaper_yolo/models/yolo.py(239)_apply()
    238             if isinstance(m.anchor_grid, list):
--> 239                 m.anchor_grid = list(map(fn, m.anchor_grid))
    240         return self

ipdb> m.anchor_grid
[tensor([]), tensor([]), tensor([])]
ipdb> n
> /nfs/private/windpaper_yolo/models/yolo.py(240)_apply()
    239                 m.anchor_grid = list(map(fn, m.anchor_grid))
--> 240         return self
    241 

ipdb> m.anchor_grid
[tensor([], device='cuda:0'), tensor([], device='cuda:0'), tensor([], device='cuda:0')]
ipdb> n
--Return--
DetectionMode... )
    )
  )
)


可以发现,上述变量m.stride、m.grid和m.anchor_grid经过这个函数后就都加了device='cuda:0’的身份,这就会进入cuda成为cuda变量存在gpu中,从而是参数变量类型一致,我修改后训练正常,自此就完成了该bug的修改。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/764382.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

步进电机控制

步进电机控制 #include "./stepper/bsp_stepper_init.h" #include "./delay/core_delay.h" #include "stm32f4xx.h"void TIM_SetTIMxCompare(TIM_TypeDef *TIMx,uint32_t channel,uint32_t compare); void TIM_SetPWM_period(TIM_TypeDef* TI…

Python爬虫学习笔记(五)————JsonPath解析

目录 1.JSONPath —— xpath在json的应用 2.JSONPath 表达式 3.jsonpath的安装及使用方式 4.jsonpath的使用 5.JSONPath语法元素和对应XPath元素的对比 6.实例 &#xff08;1&#xff09;商店案例 &#xff08;2&#xff09; 解析淘票票的“城市选择”数据 1.JSONPath…

Java8实战-总结3

Java8实战-总结3 基础知识流多线程并非易事 默认方法 基础知识 流 几乎每个Java应用都会制造和处理集合。但集合用起来并不总是那么理想。比方说&#xff0c;从一个列表中筛选金额较高的交易&#xff0c;然后按货币分组。需要写一大堆套路化的代码来实现这个数据处理命令&…

cocos creator Richtext点击事件

组件如图 添加ts自定义脚本&#xff0c;定义onClickFunc点击方法&#xff1a; import { Component, _decorator} from "cc";const { ccclass } _decorator; ccclass(RichTextComponent) export class RichTextComponent extends Component{public onClickFunc(even…

reggie优化02-SpringCache

1、SpringCache介绍 2、SpringCache常用注解 package com.itheima.controller;import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.itheima.entity.User; import com.itheima.service.UserService; import lombok.extern.slf4j.Slf4j; imp…

Mybatis:传参+提交事务(自动or手动)+sql多表关联查询(两种方法)

目录 一、参数两种类型&#xff1a; 二、传参的几种方法&#xff1a; 三、提交事务 四、sql多表关联查询(两种方法) 一、参数两种类型&#xff1a; 1.#{参数}&#xff1a;预编译方式&#xff0c;更安全&#xff0c;只用于向sql中传值&#xff1b; select * from admin w…

getattr, __getattr__, __getattribute__和__get__区别

一、getattr() 和另外三个方法都是魔法函数不同的是&#xff0c;getattr()是python内置的一个函数&#xff0c;它可以用来获取对象的属性和方法。例子如下: class A():a 5def __init__(self, x):self.x xdef hello(self):return hello funca A(10)print(getattr(a, x)) #…

2023 双非本科三个月互联网找实习心路历程

双非本科三个月互联网找实习心路历程 1、实习面试准备2、面试日历&#xff08;1&#xff09;开发投递&#xff08;2&#xff09;线下宣讲&#xff08;3&#xff09;转投测试&#xff0c;机会多多 3、同窗现状4、货拉拉 offer 的故事5、我的闲言6、我的收获(1&#xff09;勇气&a…

2.5 线性表的建表

1. 顺序表建表 #include <iostream>/// <summary> /// 数组最大长度 /// </summary> const int MAX_SIZE 10;/// <summary> /// 顺序表建表 /// </summary> /// <param name"arr">数组</param> /// <param name"…

万达商管IPO:看似轻舟已过万重山,实则负重前行?

近日&#xff0c;继万达商管债券发行计划被终止、证监会质疑万达商场销售数据真实性、珠海万达商管的股权被法院冻结后又解冻&#xff0c;万达商管又遇“水逆”——惠誉发布报告下调万达商管的评级&#xff0c;并认为珠海万达商管可能无法在2023年底前完成上市。 纷至沓来的负…

什么是链路跟踪 Skywarking

什么是链路跟踪 Skywarking 链路跟踪&#xff08;Link Tracing&#xff09;是一种用于追踪分布式系统中请求路径和性能的技术。SkyWalking 是一个开源的 APM&#xff08;Application Performance Monitoring&#xff09;系统&#xff0c;它提供了链路跟踪功能。 SkyWalking 的…

ceph----应用

文章目录 一、创建 CephFS 文件系统 MDS 接口1.1 服务端操作1.2 客户端操作 二、创建 Ceph 块存储系统 RBD 接口三、OSD 故障模拟与恢复 一、创建 CephFS 文件系统 MDS 接口 1.1 服务端操作 1&#xff09;在管理节点创建 mds 服务 cd /etc/ceph ceph-deploy mds create node0…

Java编程-基本排序算法

冒泡排序 图解 &#xff08;注&#xff1a;图片来源网络&#xff09; Java代码 package suanfa_Ja;import org.apache.hadoop.security.SaslOutputStream;// 基本排序算法&#xff0c;冒泡排序 时间复杂度 O(n^2) 空间复杂度O(1) public class BubbleSort {public static v…

blender 建模马拉松

效果展示 蘑菇模型创建&#xff1a; 创建蘑菇头 shift A &#xff0c;创建立方体&#xff1b; 右下工具栏添加细分修改器&#xff08;视图层级&#xff1a;2&#xff0c;渲染&#xff1a;2&#xff09;&#xff1b;tab键进入编辑模式&#xff0c;alt z 进入透显模式&…

Python项目依赖项管理的秘诀:requirements.txt文件

一、背景 公司里面很多时候我们开发的Python项目都不只是我们一个人使用&#xff0c;而是整体团队使用。Python项目需要在别人的电脑环境中运行&#xff0c;则需要别人的电脑环境中也要安装上我们项目需要的python库。那么项目中到底用到了哪些Python库&#xff0c;每个库具体…

12.matlab数据分析——多项式的建立 (matlab程序)

1.简述 多项式及其建立 在运算中我们经常接触到的就是所谓的多项式&#xff0c;比如很常见的一个多项式&#xff1a; 这里我们就说这是一个x的多项式&#xff0c;最高次是2次&#xff0c;常数项是3&#xff0c;二次项的系数是1&#xff0c;一次项的系数是2&#xff0c;相信这些…

流程管理是什么?“流程管理”到底管什么?

流程管理&#xff08;process management&#xff09;&#xff0c;是一种以规范化的构造端到端的卓越业务流程为中心&#xff0c;以持续的提高组织业务绩效为目的的系统化方法。 任正非曾在一次访谈时说到&#xff1a; “权力要放进流程中&#xff0c;流程才有权力&#xff0c…

【Django学习】(十四)自定义action_router

之前我们的视图类可以继承GenericViewSet或者ModelViewSet&#xff0c;我们不用再自定义通用的action方法&#xff0c;但是有时候我们需要自定义action&#xff0c;我们该如何设计呢&#xff1f; 自定义action 1、手写视图逻辑 1.1、先在视图集里自定义action方法&#xff0…

LeetCode 790. 多米诺和托米诺平铺 - 二维空间的动态规划

多米诺和托米诺平铺 中等 304 相关企业 有两种形状的瓷砖&#xff1a;一种是 2 x 1 的多米诺形&#xff0c;另一种是形如 “L” 的托米诺形。两种形状都可以旋转。 给定整数 n &#xff0c;返回可以平铺 2 x n 的面板的方法的数量。返回对 109 7 取模 的值。 平铺指的是每个…

icp许可证 办理流程(icp资质申请条件)

icp许可证 办理流程(icp资质申请条件)是什么&#xff1f; ICP经营许可证是可以线上无忧办理的&#xff0c;包下证&#xff0c;流程也很简单&#xff0c;只需要你提供企业营业执照、法人身份证这些基础材料就可以。加急10-20工作日拿证&#xff0c;普通20-60工作日拿证。 在了解…