评估修改后的YOLOv8模型的参数量和速度

news2024/11/25 20:33:30

 YOLOv8公布了自己每个模型的速度和参数量

那么如果我们自己对YOLOv8做了一些修改,又怎么样自己写代码统计一下修改后的模型的参数量和速度呢?

其实评估这些东西,大多数情况下不需要我们从头自己写一个函数来评估

一般来说,只要是作者公布的这些分数。作者都会把这些分数的实现代码放在它的代码中。你所需要的仅仅是找到这段代码所在的位置,然后直接调用作者的代码,计算即可。

上述计算要求的diam被放在了了下面这个位置(ultralytircs8.0.40这个版本是这样,其他版本有可能会不同)

\ultralytics\yolo\utils\torch_utils.py line 165

计算参数量的函数

def model_info(model, verbose=False, imgsz=640):
    # Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
    n_p = get_num_params(model)
    n_g = get_num_gradients(model)  # number gradients
    if verbose:
        print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
        for i, (name, p) in enumerate(model.named_parameters()):
            name = name.replace('module_list.', '')
            print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
                  (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))

    flops = get_flops(model, imgsz)
    fused = ' (fused)' if model.is_fused() else ''
    fs = f', {flops:.1f} GFLOPs' if flops else ''
    m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
    LOGGER.info(f"{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")


def get_num_params(model):
    return sum(x.numel() for x in model.parameters())


def get_num_gradients(model):
    return sum(x.numel() for x in model.parameters() if x.requires_grad)


def get_flops(model, imgsz=640):
    try:
        model = de_parallel(model)
        p = next(model.parameters())
        stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32  # max stride
        im = torch.empty((1, p.shape[1], stride, stride), device=p.device)  # input image in BCHW format
        flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2  # stride GFLOPs
        imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz]  # expand if int/float
        flops = flops * imgsz[0] / stride * imgsz[1] / stride  # 640x640 GFLOPs
        return flops
    except Exception:
        return 0

计算速度的函数

line354

def profile(input, ops, n=10, device=None):
    """ YOLOv8 speed/memory/FLOPs profiler
    Usage:
        input = torch.randn(16, 3, 640, 640)
        m1 = lambda x: x * torch.sigmoid(x)
        m2 = nn.SiLU()
        profile(input, [m1, m2], n=100)  # profile over 100 iterations
    """
    results = []
    if not isinstance(device, torch.device):
        device = select_device(device)
    print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
          f"{'input':>24s}{'output':>24s}")

    for x in input if isinstance(input, list) else [input]:
        x = x.to(device)
        x.requires_grad = True
        for m in ops if isinstance(ops, list) else [ops]:
            m = m.to(device) if hasattr(m, 'to') else m  # device
            m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
            tf, tb, t = 0, 0, [0, 0, 0]  # dt forward, backward
            try:
                flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2  # GFLOPs
            except Exception:
                flops = 0

            try:
                for _ in range(n):
                    t[0] = time_sync()
                    y = m(x)
                    t[1] = time_sync()
                    try:
                        _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
                        t[2] = time_sync()
                    except Exception:  # no backward method
                        # print(e)  # for debug
                        t[2] = float('nan')
                    tf += (t[1] - t[0]) * 1000 / n  # ms per op forward
                    tb += (t[2] - t[1]) * 1000 / n  # ms per op backward
                mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0  # (GB)
                s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y))  # shapes
                p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0  # parameters
                print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
                results.append([p, flops, mem, tf, tb, s_in, s_out])
            except Exception as e:
                print(e)
                results.append(None)
            torch.cuda.empty_cache()
    return results

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/819203.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【云存储】使用OSS快速搭建个人网盘教程(阿里云)

使用OSS快速搭建个人网盘 一、基础概要1. 主要的存储类型1.1 块存储1.2 文件存储1.3 对象存储 2. 对象存储OSS2.1 存储空间2.2 地域2.3 对象2.4 读写权限2.5 访问域名(Endpoint)2.6 访问密钥2.7 常用功能(1)创建存储空间&#xff…

HCIP-datacom-831题库

考取HCIP数通证书可以胜任中到大型企业网络工程师岗位,需要掌握中到大型网络的特点和通用技术,具备使用华为数通设备进行中到大型企业网络的规划设计、部署运维、故障定位的能力,并能针对网络应用设计出较高安全性、可用性和可靠性的解决方案…

RedisJava的Java客户端

目录 1.Jedis的使用 前置工作-ssh进行端口转发 JedisAPI的使用 Jedis连接池 2.SpringDataRedis的使用 1.创建项目 2.配置文件 3.注入RedisTemplate对象 4.编写代码 3.SpringRedisTemplate 哈希结构用法 ​总结 1.Jedis的使用 Jedis:以Redis命令作为方法…

途乐证券:沪指强势拉升涨0.63%,券商等板块走强,传媒板块活跃

31日早盘,两市股指全线走高,沪指一度涨超1%收复3300点,上证50指数盘中涨逾2%;随后涨幅有所收窄;两市成交额显着放大,北向资金净买入超90亿元。 到午间收盘,沪指涨0.63%报3296.58点,深…

Python多线程与GIL锁

Python多线程与GIL锁 python多线程 Python的多线程编程可以在单个进程内创建多个线程来同时执行多个任务,从而提高程序的效率和性能。Python的多线程实现依赖于操作系统的线程调度器,并且受到全局解释器锁(GIL)的限制&#xff0c…

如何在 Ubuntu 22.04 下编译 StoneDB for MySQL 8.0 | StoneDB 使用教程 #1

作者:双飞(花名:小鱼) 杭州电子科技大学在读硕士 StoneDB 内核研发实习生 ❝ 大家好,我是 StoneDB 的实习生小鱼,目前正在做 StoneDB 8.0 内核升级相关的一些事情。刚开始接触数据库开发没多久&#xff0c…

第55步 深度学习图像识别:CNN特征层和卷积核可视化(TensorFlow)

基于WIN10的64位系统演示 一、写在前面 (1)CNN可视化 在理解和解释卷积神经网络(CNN)的行为方面,可视化工具起着重要的作用。以下是一些可以用于可视化的内容: (a)激活映射&…

多目标关联(分配)最近邻法

多目标关联(分配)最近邻法 最近邻数据关联 适用于两帧图片的中多目标位置关联,目标轨迹与新目标之间的关联、固定位置下的动目标跟踪关联等问题。 新目标与被跟踪目标的预测位置“最邻近”的观测点作为与航迹相关联的观测。 如有三批目标T…

活字格性能优化技巧-如何在大规模数据量的场景下提升数据访问效率

在上节内容中我们介绍了如何利用数据库主键提升访问性能,本节内容我们继续为大家介绍如何在大规模数据量的场景下提升数据访问效率。 在开始之前先做个小小的实验: 1. 准备一张数据表,内置1000W行记录。 2. 直观感受一下这个表的规模。使用…

数据结构入门指南:单链表(附源码)

目录 前言 尾删 头删 查找 位置前插入 位置后插入 位置删除 位置后删除 链表销毁 总结 前言 前边关于链表的基础如果已经理解透彻,那么接下来就是对链表各功能的实现,同时也希望大家能把这部分内容熟练于心,这部分内容对有关链表部分的…

CustomeG6-canvas

目录 简介 scss 快速上手 语雀 简介 antv/g6是一款基于JavaScript的图形可视化引擎,由阿里巴巴的AntV团队开发。 创建各种类型的图形,如流程图、关系图、树形图等。 G6采用了自己的绘图模型和渲染引擎,使其具备高性能的图形渲染能力。…

npm更新和管理已发布的包

目录 1、更改包的可见性 1.1 将公共包设为私有 ​编辑 使用网站 使用命令行 1.2 将私有包公开 使用网站 使用命令行 2、将协作者添加到用户帐户拥有的私有包 2.1 授予对Web上私有用户包的访问权限 2.2 从命令行界面授予私有包访问权限 2.3 授予对私有组织包的访问权限…

InfiniBand,到底是个啥?

对于InfiniBand,很多搞数通的同学肯定不会陌生。 进入21世纪以来,随着云计算、大数据的不断普及,数据中心获得了高速发展。而InfiniBand,就是数据中心里的一项关键技术,地位极为重要。 尤其是今年以来,以Ch…

春秋云镜 CVE-2021-32682

春秋云镜 CVE-2021-32682 elFinder RCE 靶标介绍 elFinder是一套基于Drupal平台的、开源的AJAX文件管理器。该产品提供多文件上传、图像缩放等功能;elFinder 存在安全漏洞,攻击者可利用该漏洞在托管elFinder PHP连接器的服务器上执行任意代码和命令。 启动场景 漏…

金蝶管易云 X Hologres:新一代全渠道电商ERP最佳实践

业务简介 金蝶管易云是金蝶集团旗下专注提供电商企业管理软件服务的子公司,成立于2008年,是国内最早的电商ERP服务商之一,目前已与300主流电商平台建有合作关系,以企业数据为驱动,深度融合线上线下数据,为…

pytorch学习——正则化技术——丢弃法(dropout)

一、概念介绍 在多层感知机(MLP)中,丢弃法(Dropout)是一种常用的正则化技术,旨在防止过拟合。(效果一般比前面的权重衰退好) 在丢弃法中,随机选择一部分神经元并将其输出…

HCIP中期实验

1、该拓扑为公司网络,其中包括公司总部、公司分部以及公司骨干网,不包含运营商公网部分。 2、设备名称均使用拓扑上名称改名,并且区分大小写。 3、整张拓扑均使用私网地址进行配置。 4、整张网络中,运行OSPF协议或者BGP协议的设备…

Hadoop 之 Hive 4.0.0-alpha-2 搭建(八)

Hadoop 之 Hive 搭建与使用 一.Hive 简介二.Hive 搭建1.下载2.安装1.解压并配置 HIVE2.修改 hive-site.xml3.修改 hadoop 的 core-site.xml4.启动 三.Hive 测试 一.Hive 简介 Hive 是基于 Hadoop 的数据仓库工具,可以提供类 SQL 查询能力 二.Hive 搭建 1.下载 H…

linux 安装FTP

检查是否已经安装 $] rpm -qa |grep vsftpd vsftpd-3.0.2-29.el7_9.x86_64出现 vsftpd 信息表示已经安装,无需再次安装 yum安装 $] yum -y install vsftpd此命令需要root执行或有sudo权限的账号执行 /etc/vsftpd 目录 ftpusers # 禁用账号列表 user_list # 账号列…