PyTorch之计算模型推理时间

news2024/11/19 8:48:36

一、参考资料

如何测试模型的推理速度
Pytorch 测试模型的推理速度

二、计算PyTorch模型推理时间

1. 计算CPU推理时间

import torch
import torchvision
import time
import tqdm
from torchsummary import summary


def calcCPUTime():
    model = torchvision.models.resnet18()
    model.eval()
    # summary(model, input_size=(3, 224, 224), device="cpu")
    dummy_input = torch.randn(1, 3, 224, 224)

    num_iterations = 1000  # 迭代次数
    # 预热, GPU 平时可能为了节能而处于休眠状态, 因此需要预热
    print('warm up ...\n')
    with torch.no_grad():
        for _ in range(100):
            _ = model(dummy_input)

    print('testing ...\n')
    total_forward_time = 0.0  # 使用time来测试
    # 记录开始时间
    start_event = time.time()
    with torch.no_grad():
        for _ in tqdm.tqdm(range(num_iterations)):
            start_forward_time = time.time()
            _ = model(dummy_input)
            end_forward_time = time.time()
            forward_time = end_forward_time - start_forward_time
            total_forward_time += forward_time * 1000  # 转换为毫秒

    # 记录结束时间
    end_event = time.time()

    elapsed_time = (end_event - start_event)  # 转换为秒
    fps = num_iterations / elapsed_time

    elapsed_time_ms = elapsed_time / (num_iterations * dummy_input.shape[0])

    avg_forward_time = total_forward_time / (num_iterations * dummy_input.shape[0])

    print(f"FPS: {fps}")
    print("elapsed_time_ms:", elapsed_time_ms * 1000)
    print(f"Avg Forward Time per Image: {avg_forward_time} ms")


if __name__ == "__main__":
    calcCPUTime()

输出结果

warm up ...

testing ...

100%|██████████| 1000/1000 [00:09<00:00, 102.13it/s]
FPS: 102.11109490533485
elapsed_time_ms: 9.793255090713501
Avg Forward Time per Image: 9.777164697647095 ms

CPU资源占用情况

在这里插入图片描述

2. 计算GPU推理时间

方法一

import torch
import torchvision
import time
import tqdm
from torchsummary import summary


def calcGPUTime():
    model = torchvision.models.resnet18()
    model.cuda()
    model.eval()
    # summary(model, input_size=(3, 224, 224), device="cuda")
    dummy_input = torch.randn(1, 3, 224, 224).cuda()

    num_iterations = 1000  # 迭代次数
    # 预热, GPU 平时可能为了节能而处于休眠状态, 因此需要预热
    print('warm up ...\n')
    with torch.no_grad():
        for _ in range(100):
            _ = model(dummy_input)

    print('testing ...\n')
    total_forward_time = 0.0  # 使用time来测试
    # 记录开始时间
    start_event = time.time() * 1000
    with torch.no_grad():
        for _ in tqdm.tqdm(range(num_iterations)):
            start_forward_time = time.time()
            _ = model(dummy_input)
            end_forward_time = time.time()
            forward_time = end_forward_time - start_forward_time
            total_forward_time += forward_time * 1000  # 转换为毫秒

    # 记录结束时间
    end_event = time.time() * 1000

    elapsed_time = (end_event - start_event) / 1000.0  # 转换为秒
    fps = num_iterations / elapsed_time

    elapsed_time_ms = elapsed_time / (num_iterations * dummy_input.shape[0])

    avg_forward_time = total_forward_time / (num_iterations * dummy_input.shape[0])

    print(f"FPS: {fps}")
    print("elapsed_time_ms:", elapsed_time_ms * 1000)
    print(f"Avg Forward Time per Image: {avg_forward_time} ms")


if __name__ == "__main__":
    calcGPUTime()

输出结果

warm up ...

testing ...

100%|██████████| 1000/1000 [00:01<00:00, 727.79it/s]
FPS: 727.1527832145586
elapsed_time_ms: 1.375226806640625
Avg Forward Time per Image: 1.3709843158721924 ms

GPU资源占用情况

在这里插入图片描述

方法二

import torch
import torchvision
import numpy as np
import tqdm


# TODO - 计算模型的推理时间
def calcGPUTime():

    device = 'cuda:0'
    model = torchvision.models.resnet18()
    model.to(device)
    model.eval()

    repetitions = 1000

    dummy_input = torch.rand(1, 3, 224, 224).to(device)

    # 预热, GPU 平时可能为了节能而处于休眠状态, 因此需要预热
    print('warm up ...\n')
    with torch.no_grad():
        for _ in range(100):
            _ = model(dummy_input)

    # synchronize 等待所有 GPU 任务处理完才返回 CPU 主线程
    torch.cuda.synchronize()

    # 设置用于测量时间的 cuda Event, 这是PyTorch 官方推荐的接口,理论上应该最靠谱
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    # 初始化一个时间容器
    timings = np.zeros((repetitions, 1))

    print('testing ...\n')
    with torch.no_grad():
        for rep in tqdm.tqdm(range(repetitions)):
            starter.record()
            _ = model(dummy_input)
            ender.record()
            torch.cuda.synchronize()  # 等待GPU任务完成
            curr_time = starter.elapsed_time(ender)  # 从 starter 到 ender 之间用时,单位为毫秒
            timings[rep] = curr_time

    avg = timings.sum() / repetitions
    print('\navg={}\n'.format(avg))


if __name__ == '__main__':
    calcGPUTime()

输出结果

warm up ...

testing ...

100%|██████████| 1000/1000 [00:01<00:00, 627.50it/s]

avg=1.4300348817110062

GPU资源占用情况

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1573187.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深入浅出 -- 系统架构之分布式多形态的存储型集群

一、多形态的存储型集群 在上阶段&#xff0c;我们简单聊了下集群的基本知识&#xff0c;以及快速过了一下逻辑处理型集群的内容&#xff0c;下面重点来看看存储型集群&#xff0c;毕竟这块才是重头戏&#xff0c;集群的形态在其中有着多种多样的变化。 逻辑处理型的应用&…

线程池详解并使用Go语言实现 Pool

写在前面 在线程池中存在几个概念&#xff1a;核心线程数、最大线程数、任务队列。 核心线程数指的是线程池的基本大小&#xff1b;也就是指worker的数量最大线程数指的是&#xff0c;同一时刻线程池中线程的数量最大不能超过该值&#xff1b;实际上就是指task任务的数量。任务…

9_springboot_shiro_jwt_多端认证鉴权_整合jwt

1. Shiro框架回顾 到目前为之&#xff0c;Shiro框架本身的知识点已经介绍完了。web环境下&#xff0c;整个框架从使用的角度我们需要关注的几个点&#xff1a; 要使用Shiro框架&#xff0c;就要创建核心部件securityManager 对象。 SpringBoot项目中&#xff0c;引入shiro-spr…

python接入AI 实现微信自动回复

import numpy as np # 引入numpy库&#xff0c;目的是将读取的数据转换为列表 import pandas as pd # 引入pandas库&#xff0c;用来读取csv数据 from uiautomation import WindowControl # 引入uiautomation库中的WindowControl类&#xff0c;用来进行图像识别和模拟操作 i…

go | 上传文件分析 | http协议分析 | 使用openssl 实现 https 协议 server.key、server.pem

是这样的&#xff0c;现在分析抓包数据 test.go package mainimport ("fmt""log""github.com/gin-gonic/gin" )func main() {r : gin.Default()// Upload single filer.MaxMultipartMemory 8 << 20r.POST("/upload", func(c *g…

Apache Log4j2 Jndi RCE CVE-2021-44228漏洞原理讲解

Apache Log4j2 Jndi RCE CVE-2021-44228漏洞原理讲解 一、什么是Log4j2二、环境搭建三、简单使用Log4j2四、JDNI和RMI4.1、启动一个RMI服务端4.2、启动一个RMI客户端4.3、ldap 五、漏洞复现六、Python批量检测 参考视频&#xff1a;https://www.bilibili.com/video/BV1mZ4y1D7K…

基于Socket简单的TCP网络程序

⭐小白苦学IT的博客主页 ⭐初学者必看&#xff1a;Linux操作系统入门 ⭐代码仓库&#xff1a;Linux代码仓库 ❤关注我一起讨论和学习Linux系统 TCP单例模式的多线程版本的英汉互译服务器 我们先来认识一下与udp服务器实现的不同的接口&#xff1a; TCP服务器端 socket()&…

【C++初阶】String在OJ中的使用(一):仅仅反转字母、字符串中的第一个唯一字母、字符串最后一个单词的长度、验证回文串、字符串相加

前言&#xff1a; &#x1f3af;个人博客&#xff1a;Dream_Chaser &#x1f388;博客专栏&#xff1a;C &#x1f4da;本篇内容&#xff1a;仅仅反转字母、字符串中的第一个唯一字母、字符串最后一个单词的长度、验证回文串、字符串相加 目录 917.仅仅反转字母 题目描述&am…

【stm32】软件I2C读写MPU6050

软件I2C读写MPU6050(文章最后附上源码) 编码 概况 首先建立通信层的.c和.h模块 在通信层里写好I2C底层的GPIO初始化 以及6个时序基本单元 起始、终值、发送一个字节、接收一个字节、发送应答、接收应答 写好I2C通信层之后&#xff0c;再建立MPU6050的.c和.h模块 基于I2C通…

软考116-上午题-【计算机网络】-LINUX命令

一、真题 真题1&#xff1a; 真题2&#xff1a; 权限通常分为三类&#xff1a; 读&#xff08;r&#xff09;&#xff1a;允许读取文件内容或列出目录内容。写&#xff08;w&#xff09;&#xff1a;允许修改文件内容或在目录中创建/删除文件。执行&#xff08;x&#xff09;&…

stm32开发之threadx使用记录(主逻辑分析)

前言 threadx的相关参考资料 论坛资料、微软官网本次使用的开发板为普中科技–麒麟&#xff0c;核心芯片为 stm32f497zgt6开发工具选择的是stm32cubemx(代码生成工具)clion(代码编写工具)编译构建环境选择的是arm-none-gcc编译 本次项目结构 CMakeList对应的配置 set(CMAKE_…

SD-WAN国际网络专线:高效、合规且可靠的跨境连接解决方案

在数字化时代&#xff0c;企业对跨境网络连接的需求日益增长。SD-WAN技术作为一种新兴的解决方案&#xff0c;正逐渐成为构建跨境网络连接的首选。本文将探讨SD-WAN国际网络专线的发展现状、合规性要求以及选择时需要考虑的关键因素。 SD-WAN技术&#xff1a;跨境网络连接的新…

如何在没有备份的情况下从 iPad 恢复照片?

有很多操作都可能导致iPad照片丢失&#xff0c;包括误删除、出厂设置、iPad的iOS更新等。如果没有备份&#xff0c;似乎没有办法找回它们。然而&#xff0c;即使您将备份保留在 iCloud 或iTunes上&#xff0c;这些方式也需要您的 iPad 首先重置&#xff0c;从而用备份内容覆盖当…

堆排序解读

在算法世界中&#xff0c;排序算法一直是一个热门话题。推排序&#xff08;Heap Sort&#xff09;作为一种基于堆这种数据结构的有效排序方法&#xff0c;因其时间复杂度稳定且空间复杂度低而备受青睐。本文将深入探讨推排序的原理、实现方式&#xff0c;以及它在实际应用中的价…

lua学习笔记5(分支结构和循环的学习)

print("*****************分支结构和循环的学习******************") print("*****************if else语句******************") --if 条件 then end a660 b670 --单分支 if a<b thenprint(a) end --双分支 if a>b thenprint("满足条件")…

机器学习模型——逻辑回归

https://blog.csdn.net/qq_41682922/article/details/85013008 https://blog.csdn.net/guoziqing506/article/details/81328402 https://www.cnblogs.com/cymx66688/p/11363163.html 参数详解 逻辑回归的引出&#xff1a; 数据线性可分可以使用线性分类器&#xff0c;如果…

c# wpf LiveCharts 简单试验

1.概要 1.1 说明 1.2 环境准备 NuGet 添加插件安装 2.代码 <Window x:Class"WpfApp3.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"…

WindowsPowerShell安装配置Vim的折腾记录

说明 vim一直以来都被称为编辑器之神一样的存在。但用不用vim完全取决于你自己&#xff0c;但是作为一个学计算机的同学来说&#xff0c;免不了会和Linux打交道&#xff0c;而大部分的Linux操作系统都预装了vim作为编辑器&#xff0c;如果是简单的任务&#xff0c;其实vim只要会…

电商技术揭秘八:搜索引擎中的SEO内部链接建设与外部推广策略

文章目录 引言一、 内部链接结构优化1.1 清晰的导航链接1. 简洁明了的菜单项2. 逻辑性的布局3. 避免深层次的目录结构4. 使用文本链接5. 突出当前位置6. 移动设备兼容性 1.2 面包屑导航1. 显示当前页面位置2. 可点击的链接3. 简洁性4. 适当的分隔符5. 响应式设计6. 避免重复主页…

图像分割-RSPrompter

文章目录 前言1. 自动化提示器1.1 多尺度特征增强器1.2 RSPrompterAnchor-based PrompterQuery-based Prompter 2. SAM的扩展3. 结果WHU数据集NWPU数据集SSDD数据集 前言 《RSPrompter: Learning to prompt for remote sensing instance segmentation based on visual foundati…