DETR模型计算量(FLOPs)参数量(Params)

news2025/2/28 4:59:23

前言

关于计算量(FLOPs)参数量(Params)的一个直观理解,便是计算量对应时间复杂度,参数量对应空间复杂度,即计算量要看网络执行时间的长短,参数量要看占用显存的量。

计算量: FLOPs,FLOP时指浮点运算次数,s是指秒,即每秒浮点运算次数的意思,考量一个网络模型的计算量的标准。

参数量: Params,是指网络模型中需要训练的参数总数。

在这里插入图片描述

了解以上概念后,接下来便是如何计算这两个值。
一个很常见的方法便是通过ptflos包来实现。

# -- coding: utf-8 --
import torchvision
from ptflops import get_model_complexity_info

model = torchvision.models.alexnet(pretrained=False)
flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
print('flops: ', flops, 'params: ', params)

这段代码可以说是即插即用。

DAB-DETR模型

博主以DAB-DETR模型为例,运行时报错,这是由于权重文件于模型配置文件不匹配导致的

权重文件与模型配置不匹配

RuntimeError: Error(s) in loading state_dict for DABDeformableDETR:
	size mismatch for input_proj.0.0.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1]).
	size mismatch for input_proj.1.0.weight: copying a param with shape torch.Size([256, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 1, 1]).
	size mismatch for input_proj.2.0.weight: copying a param with shape torch.Size([256, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
	size mismatch for input_proj.3.0.weight: copying a param with shape torch.Size([256, 2048, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).

修改num_channels的值即可,原本为【128,256,512】

  if return_interm_layers:
        # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
        self.strides = [8, 16, 32]
        self.num_channels = [512, 1024, 2048]

推理代码

推理代码如下:几乎所有的DETR类模型的推理代码都是可以通用的。

import json
import os, sys
import torch
import numpy as np

from models import build_DABDETR
from models.dab_deformable_detr import build_dab_deformable_detr
from util.slconfig import SLConfig
from datasets import build_dataset
from util.visualizer import COCOVisualizer
from util import box_ops
model_config_path = "D:/graduate/others/DAB-DETR/config.json" # change the path of the model config file
model_checkpoint_path = "D:/graduate/others/DAB-DETR/checkpoint.pth" # change the path of the model checkpoint
# See our Model Zoo section in README.md for more details about our pretrained models.

args = SLConfig.fromfile(model_config_path)
model, criterion, postprocessors = build_DABDETR(args)
checkpoint = torch.load(model_checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
_ = model.eval()
with open('util/coco_id2name.json') as f:
    id2name = json.load(f)
    id2name = {int(k): v for k, v in id2name.items()}
from PIL import Image
import datasets.transforms as T
image = Image.open("./figure/4.jpg").convert("RGB") # load image
# transform images
transform = T.Compose([
    T.RandomResize([800], max_size=1333),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image, _ = transform(image, None)
from ptflops import get_model_complexity_info
model=model.to(args.device)
flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
print('flops: ', flops, 'params: ', params)
# predict images
with torch.no_grad():
    output = model.cuda()(image[None].cuda())
  # visualize outputs
output = postprocessors['bbox'](output, torch.Tensor([[1.0, 1.0]]).cuda())[0]
thershold = 0.5  # set a thershold
vslzr = COCOVisualizer()
scores = output['scores']
print(len(scores))
labels = output['labels']
boxes = box_ops.box_xyxy_to_cxcywh(output['boxes'])
select_mask = scores > thershold

box_label = [id2name[int(item)] for item in labels[select_mask]]
pred_dict = {
      'boxes': boxes[select_mask],
      'size': torch.Tensor([image.shape[1], image.shape[2]]),
      'box_label': box_label
}

vslzr.visualize(image, pred_dict, savedir=None, dpi=120)

DN-DETR模型

DN-DETR模型推理代码与DAB-DETR模型推理代码大同小异,但问题却不尽相同。

空值问题

indicator0 = torch.zeros([num_queries * num_patterns, 1]).cuda()
TypeError: unsupported operand type(s) for *: 'int' and 'NoneType'

空值问题,给num_patterns赋值=1即可

CPU与GPU运算问题

boxes = boxes * scale_fct[:, None, :]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

数据有的在cpu上,有的在gpu上,在boxes = boxes * scale_fct[:, None, :]后面加上.cuda()

tuple转换问题

此外,还会报错tuple的转换问题

TypeError: tuple indices must be integers or slices, not str

将下面的代码

out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']

改为:

out_logits=outputs[0]['pred_logits']
out_bbox = outputs[0]['pred_boxes']

参数量计算问题

至此,DN-DETR模型推理代码修改无误,但在计算参数量时却出现问题:

File "D:\Anaconda\envs\deformable_detr\lib\site-packages\ptflops\pytorch_ops.py", line 162, in multihead_attention_counter_hook
    q, k, v = input
ValueError: not enough values to unpack (expected 3, got 2)

这里可以看到报错是参数数量出现了问题,我们找到原来的代码,将q, k, v = input改为:

q, k= input, v=k

GPU与CPU运算问题

同样的,这里也报了数据计算位置不一致的问题,如法炮制即可。

 File "E:\graduate\papers\DN-DETR\DN-DETR-main\models\DN_DAB_DETR\DABDETR.py", line 458, in forward
    boxes = boxes * scale_fct[:, None, :]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

DN-DAB-Deformable-DETR模型

参数量运算问题

由于DN-DAB-Deformable-DETR与DN-DAB-DETR共用一套代码,这里出了问题。

    q, k= input
ValueError: too many values to unpack (expected 2)

我们查看一下input的长度,共有三个值,那么原本的写法就没有问题了,改为原本写法即可。

q, k, v= input

报错batch-size问题,其实很好解决,因为我们只是推理,只有一张图片,那么只需要设置为1即可。

至此,模型推理与计算量,参数量计算解决了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/862382.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

云端剪切板,让你的数据同步无界

云端剪切板,让你的数据同步无界! 每个人都应该保护自己的数据,同时使它易于访问和共享。这就是我们的云剪切板网站诞生的原因!无论你在哪里,只要登录我们的网站,就可以随时随地使用你的剪切板数据。 你可…

从支付或退款之回调处理的设计,看一看抽象类的使用场景

一、背景 抽象类,包含抽象方法和实例方法,抽象方法待继承类去实例化,正是利用该特性,以满足不同支付渠道的差异化需求。 我们在做多渠道支付的时候,接收支付或退款的回调报文,然后去处理。这就意味着&…

对话音视频牛哥:如何设计功能齐全的跨平台低延迟RTMP播放器

开发背景 2015年,我们在做移动单兵应急指挥项目的时候,推送端采用了RTMP方案,这在当时算是介入RTMP比较早的了,RTMP推送模块做好以后,我们找了市面上VLC还有Vitamio,来测试整体延迟,实际效果真…

大数据传输的定义与大数据传输解决方案的选择

当我们需要处理大量的数据时,我们就要把数据从一个地方移动到另一个地方。这个过程就叫做大数据传输。它通常需要用到高速的网络连接、分散的存储系统和数据传输协议,以保证数据的快速、可靠和安全的移动。常用的大数据传输技术有Hadoop分布式文件系统&a…

servlet三大类HttpSevlet,HttpServletRequest,HttpServletResponse介绍

一、HttpServlet HttpServlet类是一个被继承的方法,可以看做一个专门用来响应http请求的类,这个类的所有方法都是为响应http请求服务的,要对一个某个路径谁知http响应时,需要写一个类来继承HttpServlet类,并重写里面的…

BGP基础建邻+宣告实验

实验要求及拓扑 一、实验思路 1.编写静态路由使R1、R2之间可通和使R4、R5之间可通。 2.使用OSPF使R2、R3、R4之间可通。 3.各自宣告AS区域,中间区域两两之间建邻。 4.注意建邻所使用的端口,外部BGP邻居关系和内部BGP邻居关系的区别。 二、上虚拟机操…

企业微信web登录(扫二维码登录)

记录一下企业微信web扫码登录的使用过程。 按惯例,先看登录流程: 步骤 首先, 企业微信后台开启“企业微信授权登陆功能”,“设置授权回调域名” ,授权回调域名必须与访问链接的域名完全一致。(访问链接的域名就是扫码…

【Kubernetes】Kubernetes的调度

K8S调度 一、Kubernetes 调度1. Pod 调度介绍2. Pod 启动创建过程3. Kubernetes 的调度过程3.1 调度需要考虑的问题3.2 具体调度过程 二、影响kubernetes调度的因素1. nodeName2. nodeSelector3. 亲和性3.1 三种亲和性的区别3.2 键值运算关系3.3 节点亲和性3.4 Pod 亲和性3.5 P…

高忆管理:创业板股票涨跌幅?

创业板股票涨跌幅限制大于主板商场,为何呈现这样的现象?从多个角度剖析,其中包含方针因素、商场走势、职业危险等多个方面。 首要,方针因素是导致股票涨跌幅波动的一个重要因素。在新的方针环境下,相关部门关于创业板股…

ModaHub魔搭社区——Milvus Cloud向量数据库

向量数据库:在AI时代的快速发展与应用 摘要: 随着人工智能技术的不断进步,向量数据库在处理大规模数据方面发挥着越来越重要的作用。本文介绍了向量数据库的基本概念、应用场景和技术挑战,并详细阐述了Milvus Cloud作为典型的向量数据库产品的技术特点、性能优化和应用案例…

拼多多秋招 考试内容详解和备考技巧

拼多多秋招内容简介 作为线上销售行业的知名企业之一,拼多多的销售模式也得到了越来越多的人认可,而伴随着企业规模的不断扩大,拼多多也需要能力杰出、认可自己公司文化的新员工,从目前的招聘情况来看,拼多多的岗位需…

拿下美团校招:MySQL InnoDB非聚簇索引知识点解析!

大家好,我是你们的小米,在这里欢迎大家来到《小米的技术小屋》!今天,我将和大家一起来揭开一个有趣且有深度的话题,那就是来自美团校招面试的一道问题:“MySQL中的InnoDB在什么情况下使用非聚簇索引&#x…

SpringBoot禁用Swagger3

Swagger3默认是启用的&#xff0c;即引入包就启用。 <dependency><groupId>io.springfox</groupId><artifactId>springfox-boot-starter</artifactId><version>3.0.0</version> </dependency> <dependency><groupId…

纤维素衍生物辅料行业分析-市场规模达15.67亿美元

纤维素衍生物辅料行业分析&#xff1a;2022年全球纤维素合成生物辅料市场规模达15.67亿美元 关注医药行业的纤维素衍生物辅料。药用辅料是生产药品和调配处方时所用的赋形剂和附加剂&#xff0c;是药物制剂的重要组成部分。纤维素衍生物作为天然高分子衍生材料&#xff0c;具有…

Uniapp使用腾讯地图并进行标点创建和设置保姆教程

使用Uniapp内置地图 首先我们需要创建一个uniapp项目 首先我们需要创建一个uniapp项目 我们在HBuilder左上角点击文件新建创建一个项目 然后下面这张图的话就是uniapp创建项目过程当中需要注意的一些点和具体的操作 然后我们创建完项目之后进入到项目pages文件夹下&#xff…

面试热题(二叉树的锯齿形层次遍历)

给你二叉树的根节点 root &#xff0c;返回其节点值的 锯齿形层序遍历 。&#xff08;即先从左往右&#xff0c;再从右往左进行下一层遍历&#xff0c;以此类推&#xff0c;层与层之间交替进行&#xff09; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;[[3…

Spring @Profile注解使用和源码解析

使用 带有Profile的注解的bean的不会被注册进IOC容器&#xff0c;需要为其设置环境变量激活&#xff0c;才能注册进IOC容器&#xff0c;如下通过setActiveProfiles设置了dev值&#xff0c;那么这三个值所对应的Bean会被注册进IOC容器。当然&#xff0c;我们在实际使用中&#…

Grafana技术文档--基本安装-docker安装并挂载数据卷-《十分钟搭建》-附带监控服务器

阿丹&#xff1a; Prometheus技术文档--基本安装-docker安装并挂载数据卷-《十分钟搭建》_一单成的博客-CSDN博客 在正确安装了Prometheus之后开始使用并安装Grafana作为Prometheus的仪表盘。 一、拉取镜像 搜索可拉取版本 docker search Grafana拉取镜像 docker pull gra…

cmake (更新中)

概述 关于 CMake CMake 是一个可扩展的开源系统&#xff0c;以一种与操作系统和编译器无关的方式来管理构建过程。与许多跨平台系统不同&#xff0c;CMake 被设计为与本机构建环境配合使用。在每个源代码目录中放置简单的配置文件&#xff08;称为 CMakeLists.txt 文件&#xf…

VLOOKUP函数使用

在Excel中&#xff0c;VLOOKUP函数用于在一个范围内查找某个值&#xff0c;并返回该值所在行的指定列的内容。VLOOKUP函数的基本语法如下&#xff1a; VLOOKUP(lookup_value, table_array, col_index_num, [range_lookup])参数说明&#xff1a; lookup_value&#xff1a;要查…