pytorch 测量模型运行时间,GPU时间和CPU时间,model.eval()介绍

news2024/11/29 16:41:12

文章目录

    • 1. 测量时间的方式
    • 2. model.eval(), model.train(), torch.no_grad()方法介绍
      • 2.1 model.train()和model.eval()
      • 2.2 model.eval()和torch.no_grad()
    • 3. 模型推理时间方式
    • 4. 一个完整的测试模型推理时间的代码
    • 5. 参考:

1. 测量时间的方式

time.time()
time.perf_counter()
time.process_time()

time.time() 和time.perf_counter() 包括sleep()time 可以用作一般的时间测量,time.perf_counter()精度更高一些
time.process_time()当前进程的系统和用户CPU时间总和的值

测试代码:

def show_time():
    print('我是time()方法:{}'.format(time.time()))
    print('我是perf_counter()方法:{}'.format(time.perf_counter()))
    print('我是process_time()方法:{}'.format(time.process_time()))
    t0 = time.time()
    c0 = time.perf_counter()
    p0 = time.process_time()
    r = 0
    for i in range(10000000):
        r += i
    time.sleep(2)
    print(r)
    t1 = time.time()
    c1 = time.perf_counter()
    p1 = time.process_time()
    spend1 = t1 - t0
    spend2 = c1 - c0
    spend3 = p1 - p0
    print("time()方法用时:{}s".format(spend1))
    print("perf_counter()用时:{}s".format(spend2))
    print("process_time()用时:{}s".format(spend3))
    print("测试完毕")

测试结果:
在这里插入图片描述
更详细解释参考
Python3.7中time模块的time()、perf_counter()和process_time()的区别

2. model.eval(), model.train(), torch.no_grad()方法介绍

2.1 model.train()和model.eval()

我们知道,在pytorch中,模型有两种模式可以设置,一个是train模式、另一个是eval模式。

model.train()的作用是启用 Batch Normalization 和 Dropout。在train模式,Dropout层会按照设定的参数p设置保留激活单元的概率,如keep_prob=0.8,Batch Normalization层会继续计算数据的mean和var并进行更新。

model.eval()的作用是不启用 Batch Normalization 和 Dropout。在eval模式下,Dropout层会让所有的激活单元都通过,而Batch Normalization层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。

在使用model.eval()时就是将模型切换到测试模式,在这里,模型就不会像在训练模式下一样去更新权重。但是需要注意的是model.eval()不会影响各层的梯度计算行为,即会和训练模式一样进行梯度计算和存储,只是不进行反向传播。

2.2 model.eval()和torch.no_grad()

在讲model.eval()时,其实还会提到torch.no_grad()。

torch.no_grad()用于停止autograd的计算,能起到加速和节省显存的作用,但是不会影响Dropout层和Batch Normalization层的行为。

如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation的结果;而with torch.zero_grad()则是更进一步加速和节省gpu空间。因为不用计算和存储梯度,从而可以计算得更快,也可以使用更大的batch来运行模型。

3. 模型推理时间方式

在测量时间的时候,与一般测试不同,比如下面的代码不正确:

start = time.time()
result = model(input)
end = time.time()

而应该采用:

torch.cuda.synchronize()
start = time.time()
result = model(input)
torch.cuda.synchronize()
end = time.time()

因为在pytorch里面,程序的执行都是异步的。
如果采用代码1,测试的时间会很短,因为执行完end=time.time()程序就退出了,后台的cu也因为python的退出退出了。
如果采用代码2,代码会同步cu的操作,等待gpu上的操作都完成了再继续成形end = time.time()

4. 一个完整的测试模型推理时间的代码

一般是首先model.eval()不启用 Batch Normalization 和 Dropout, 不启用梯度更新
然后利用mode创建模型,和初始化输入数据(单张图像)

def measure_inference_speed(model, data, max_iter=200, log_interval=50):
    model.eval()

    # the first several iterations may be very slow so skip them
    num_warmup = 5
    pure_inf_time = 0
    fps = 0

    # benchmark with 2000 image and take the average
    for i in range(max_iter):

        torch.cuda.synchronize()
        start_time = time.perf_counter()

        with torch.no_grad():
            model(*data)

        torch.cuda.synchronize()
        elapsed = time.perf_counter() - start_time

        if i >= num_warmup:
            pure_inf_time += elapsed
            if (i + 1) % log_interval == 0:
                fps = (i + 1 - num_warmup) / pure_inf_time
                print(
                    f'Done image [{i + 1:<3}/ {max_iter}], '
                    f'fps: {fps:.1f} img / s, '
                    f'times per image: {1000 / fps:.1f} ms / img',
                    flush=True)

        if (i + 1) == max_iter:
            fps = (i + 1 - num_warmup) / pure_inf_time
            print(
                f'Overall fps: {fps:.1f} img / s, '
                f'times per image: {1000 / fps:.1f} ms / img',
                flush=True)
            break
    return fps
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = 'cpu'
    print(device)
    img_channel = 3
    width = 32

    enc_blks = [2, 2, 4, 8]
    middle_blk_num = 12
    dec_blks = [2, 2, 2, 2]

    width = 16
    enc_blks = [1, 1, 1]
    middle_blk_num = 1
    dec_blks = [1, 1, 1]

    net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
                 enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
    net = net.to(device)

    data = [torch.rand(1, 3, 256, 256).to(device)]
    fps = measure_inference_speed(net, data)
    print('fps:', fps)

在这里插入图片描述

5. 参考:

https://blog.csdn.net/weixin_44317740/article/details/104651434
https://zhuanlan.zhihu.com/p/547033884
https://deci.ai/blog/measure-inference-time-deep-neural-networks/
https://github.com/xinntao/BasicSR

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/528186.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用qt creator编译zlib

zlib被设计为一个免费的&#xff0c;通用的&#xff0c;法律上不受限制的-即不受任何专利保护的无损数据压缩库&#xff0c;几乎可以在任何计算机硬件和操作系统上使用。 官网&#xff1a;http://www.zlib.net/ 下载zlib源码:http://www.zlib.net/zlib1213.zip 备用地址&#x…

关于使用API接口获取商品数据的那些事

随着电商行业的不断发展&#xff0c;越来越多的企业和个人需要获取各大电商平台上的商品数据。而最常用的方法是使用API接口获取商品数据。本文将为您介绍使用API接口获取商品数据的步骤和注意事项。 一、选择API接口 首先需要了解各大电商平台提供的API接口&#xff0c;目前…

由浅入深理解java集合(一)——集合框架 Collection、Map

Java 提供了一套完整的集合类&#xff08;也可以叫做容器类&#xff09;来管理一组长度可变的对象&#xff08;也就是集合的元素&#xff09;&#xff0c;其中常见的类型包括 List、Set、Queue 和 Map。从我个人的编程经验来看&#xff0c;List 的实现类 ArrayList 和 Map 的实…

华为OD机试 - 查找树中元素(Python)

题目描述 已知树形结构的所有节点信息,现要求根据输入坐标(x,y)找到该节点保存的内容值,其中x表示节点所在的层数,根节点位于第0层,根节点的子节点位于第1层,依次类推;y表示节点在该层内的相对偏移,从左至右,第一个节点偏移0,第二个节点偏移1,依次类推; 举例:…

手把手教你用代码画架构图 | 京东云技术团队

作者&#xff1a;京东物流 覃玉杰 1. 前言 本文将给大家介绍一种简洁明了软件架构可视化模型——C4模型&#xff0c;并手把手教大家如何使用代码绘制出精美的C4架构图。 阅读本文之后&#xff0c;读者画的架构图将会是这样的&#xff1a; 注&#xff1a;该图例仅作绘图示例使…

【入土级】详解C++类对象(中篇)

目录 前言&#xff1a;类的6个默认成员函数 一&#xff0c; 构造函数 1. 概念 2. 特性 二&#xff0c; 析构函数 2.1 概念 2.2 特性 2.3 牛刀小试 三&#xff0c; 拷贝构造函数 3.1概念 3. 2 特点 四&#xff0c; 赋值运算符重载 4. 1 运算符重载 五&#xff0…

网站测试的主要方法

网站测试的主要方法 网站测试是保证网站质量的重要手段&#xff0c;通过对网站进行测试可以及时发现问题并修复&#xff0c;提高用户体验和网站的可靠性。本文将介绍网站测试的主要方法。 1.功能测试&#xff1a;测试网站的所有功能是否正常。通过模拟用户的操作&#xff0c;确…

最新水果FLStudio21中文版下载及快捷键操作教程

任何一款软件&#xff0c;其快捷键永远扮演着至关重要的作用&#xff0c;熟练运用快捷键不仅能够节省时间&#xff0c;提高工作效率&#xff0c;还有助于熟练掌握所使用的软件。作为一款功能强大的音乐编曲软件&#xff0c;FL Studio有着大量的快捷键&#xff0c;这些快捷键在一…

【ArcGIS Pro二次开发】(28):用地图斑导出用地用海汇总表

本工具的作用是将现状用地或规划用地导出Excel格式的用地用海汇总表。 实现这个功能的Arcpy脚本工具我之前已经做过&#xff0c;详见&#xff1a;ArcGisPro脚本工具【8】——用地图斑导出用地用海汇总表 这次试着在ArcGIS Pro SDK中来实现同样的功能。 一、要实现的功能 如上…

【存储数据恢复】H3C存储卷中的数据恢复案例

存储数据恢复环境&故障&#xff1a; H3C FlexStorage某型号存储&#xff0c;25块磁盘组建的RAID5&#xff0c;其中包含一块热备盘。 工作人员误操作将存储设备中原先的2个卷删除&#xff0c;删除之后又使用和删除2个卷同样大小的空间重建了一个卷。用户希望恢复删除的2个卷…

springboot缓存

1. 认识缓存 一级缓存 - 缓存是一种介于数据永久存储介质与数据应用之间的数据临时存储介质 - 使用缓存可以有效的减少低速数据读取过程的次数&#xff0c;提高系统性能 Service public class BookServiceImplCache implements BookService {Autowiredprivate BookDao book…

玩转服务器之环境篇:PHP和Python环境部署指南 | 京东云技术团队

前几篇文章中讲解了如何搭建docker和Java Web环境的方法&#xff0c;本篇文章来教大家搭建一个好的PHP和Python环境&#xff0c;可以帮助开发和运行PHP和Python应用程序&#xff0c;使其更加高效和稳定。 一、 PHP环境介绍 好的开发环境无疑会大大提升编码效率&#xff0c;近…

搭载RK3588的Orange Pi 5 Plus来了!首发“亲民价”649元起!

Orange Pi 5 Plus来了。令人惊艳的Orange Pi 5 家族将又添一个重量级新成员! 作为香橙派首款搭载瑞芯微RK3588的开发板&#xff0c;Orange Pi 5 Plus将成为顶级高性能开发板的臻选。先说几个令人惊艳的亮点&#xff1a;两个PCIe扩展的2.5G以太网接口&#xff1b;eMMC 闪存插座…

黑马Redis笔记高级篇 | 分布式缓存

黑马Redis笔记高级篇 | 分布式缓存 1、Redis持久化&#xff08;解决数据丢失&#xff09;1.1 RDB持久化1.1.1 定义1.1.2 异步持久化bgsave原理 1.2 AOF持久化1.3 RDB和AOF比较 2、Redis主从&#xff08;解决并发问题&#xff09;2.1 搭建主从架构2.2 主从数据同步原理2.2.1 全量…

这样连交换机和路由器,多少网工没试过?

大家好&#xff0c;我是老杨。 交换机已经说过很多遍了&#xff0c;园区组网案例、交换机级联、光模块配件解读等等…… 有群友说想要看更多的配置案例&#xff0c;这次就给你说说&#xff0c;交换机和路由器对接上网配置。 还搞不明白交换机和路由器的&#xff0c;复习一下…

webpack开发服务器配置

&#x1f482; 个人网站:[【紫陌】【笔记分享网】](http://zimo.aizhaiyu.com/) &#x1f485; 想寻找共同学习交流、共同成长的伙伴&#xff0c;[请点击【前端学习交流群】](http://zimo.aizhaiyu.com/wechat/wechat.html) 文章最后有作者l联系方式&#xff08;备注进群&am…

写给程序员Android Framework 开发,

前言 在 Android 开发者技能中&#xff0c;如果想进大厂&#xff0c;一般拥有较好的学历可能有优势一些。但是如果你靠硬实力也是有机会的&#xff0c;例如死磕Framework。Framework 知识广泛应用在Android各个领域中&#xff0c;重要性显而易见。 成为一名Android Framework…

【2023程序员必看】大数据行业分析

1、政策重点扶持&#xff0c;市场前景广阔 2014年&#xff0c;大数据首次写入政府工作报告&#xff0c;大数据逐渐成为各级政府关注的热点。 2015年9月&#xff0c;国务院发布《促进大数据发展的行动纲要》&#xff0c;大数据正式上升至国家战略层面&#xff0c;十九大报告提…

网络突发环路,原来可以这么解决啊

大家好&#xff0c;我是老杨。 我相信&#xff0c;任何一个网工都遇到过网络环路&#xff0c;遇到这个情况&#xff0c;你的应对方法是什么&#xff1f; 我了解到大部分的初阶网工&#xff0c;最开始都只能用拔插网线和重启观测法来排除回路。 简单来说&#xff0c;就是先给…

下一个超级生态节点openGauss ——鲲鹏开发者峰会2023 openGauss技术专题回顾

摘要 2023年5月6日&#xff0c;一场面向计算产业开发者的技术盛会“ 鲲鹏开发者峰会2023 ”在东莞松山湖正式拉开帷幕&#xff0c;在其中的“openGauss技术专场”上&#xff0c;openGauss相关专家和伙伴围绕openGauss社区进展、openGauss5.0版本技术创新&#xff0c;基于op…