Python加载 TorchScript 格式的 ResNet18 模型分类该模型进行预测并输出预测的类别和置信度

news2024/9/23 9:29:52
  • 首先加载预训练的 ResNet18 模型。
  • 将模型设置为评估模式,以确保特定层(如 Dropout 和 BatchNorm)在评估时具有确定性的行为。
  • 创建一个形状为 (1, 3, 224, 224) 的随机张量作为示例输入。
  • 使用 torch.jit.trace 函数追踪模型在给定示例输入上的行为,将模型转换为 TorchScript 格式。
  • 保存 TorchScript 格式的模型为 resnet18_torchscript.pt 文件,并打印转换成功的消息。
import torch
import torchvision.models as models

# 加载预训练的 ResNet18 模型
model = models.resnet18(pretrained=True)

# 将模型设置为评估模式
model.eval()

# 创建示例输入张量
example_input = torch.rand(1, 3, 224, 224)

# 使用 torch.jit.trace 追踪模型
traced_model = torch.jit.trace(model, example_input)

# 保存 TorchScript 模型
traced_model.save('resnet18_torchscript.pt')

print("ResNet18 模型已成功转换为 TorchScript 格式并保存。")

定义图像处理函数 process_img

    • process_img 函数接受一个图像路径作为参数。

    • 使用 cv2.imread 读取图像,将图像从 BGR 颜色空间转换为 RGB 颜色空间(因为很多深度学习模型期望输入为 RGB 格式)。

    • 将图像的像素值归一化到 [0, 1] 范围。

    • 使用 cv2.resize 将图像调整为 (224, 224) 的尺寸,这通常是 ResNet18 模型期望的输入尺寸。

    • 使用 np.transpose 将图像的维度顺序从 HWC(Height-Width-Channel)转换为 CWH(Channel-Height-Width),以符合 PyTorch 的输入要求。

    • 使用 np.expand_dims 在批量维度上扩展图像,使其形状变为 (1, C, H, W)

    • 最后将处理后的图像转换为 PyTorch 张量,并指定数据类型为 torch.float32,然后返回该张量。

def process_img(img_path):
    img=cv2.imread(img_path)

    img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img = img / 255
    img=cv2.resize(img,dsize=(224,224))
    img=np.transpose(img,(2,0,1))#HWC-->CWH
    img=np.expand_dims(img,axis=0)
    img=torch.tensor(img,dtype=torch.float32)
    return img

图像预测部分

  1. 定义一个图像路径 img_path,并将其传入 process_img 函数,得到处理后的图像张量 img
    • 使用 torch.jit.load 加载之前保存的 TorchScript 格式的模型。
    • 将处理后的图像张量传入模型进行前向传播,得到输出张量 output
    • 使用 torch.argmax 在输出张量的维度 1 上找到具有最大值的索引,即预测的类别。
    • 最后打印出预测的类别和对应类别的置信度(输出张量中对应类别的值)。
img_path='dog.jpg'
img=process_img(img_path)
model=torch.jit.load('resnet18_torchscript.pt')#还是torchscript格式的
output=model.forward(img)
cls=torch.argmax(output,axis=1)
print('预测的类别是:',cls.item(),'置信度是',output[0][cls].item())

预测图片 

结果如下:

可以去Imgnet官网找对应的网站来查看类别 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2101359.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

渗透测试靶机----DC系列 DC-1

渗透测试靶机----DC系列 DC-1 开启靶机,依旧是登陆窗,平平无奇 扫描ip,扫描端口,服务等信息 可以看到这里存在80服务,访问看看 非常明显,这里存在一个Drupal 的cms 并且是一个登录框,思路打开 …

VMware Fusion 13.6 发布下载,新增功能概览

VMware Fusion 13.6 发布下载,新增功能概览 VMware Fusion 13.6 for Mac - 领先的免费桌面虚拟化软件 适用于基于 Intel 处理器和搭载 Apple 芯片的 Mac 的桌面虚拟化软件 请访问原文链接:https://sysin.org/blog/vmware-fusion-13/,查看最…

位运算专题——常见位运算位图的使用力扣实战应用

目录 1、常见位运算 2、算法应用【leetcode】 2.1 判断字符是否唯一【面试题 】 2.1.1 算法思想【位图】 2.1.2 算法代码 2.2 只出现一次的数字 III 2.2.1 算法思想 2.2.2 算法代码 2.3 丢失的数字 2.3.1 算法思想 2.3.2 算法代码 2.4 两整数之和 2.4.1 算法思想 2…

C语言之猜数字小游戏

哈喽,大家好!我是冰淇淋加点糖。今天我们来用前面所学的知识来开发一个猜数字的小游戏,锻炼我们的编程能力和编程思维。 猜数字小游戏功能简介 1.随机生成一个1-100的数字。 2.玩家用户开始猜数字。 > 猜大了,提醒猜大了…

【爬虫软件】采集抖音博主的主页发布作品

这是我用python开发的抖音爬虫采集软件,可自动按博主抓取已发布视频。 软件界面截图: 爬取结果截图: 几点重要说明: 软件使用演示视频: https://www.bilibili.com/video/BV1Kb42187qf 完整讲解文章: ht…

虚幻5|技能栏UI优化(3)——优化技能UI并实现显示背景UI,实现技能界面设计,实现技能栏的删除和添加

实现技能栏添加:将技能界面里的技能拖到技能栏格子 一.调整,在拖出技能的时候,还会有边框 1.打开拖拽的技能格子UI 除了技能按钮,下面的子级都放到垂直框的子级,然后删除技能按钮 2.将垂直框替换成包裹框 你会发现有…

OS_线程

2024.07.01:操作系统线程学习笔记 第7节 线程 7.1 线程的基本概念7.1.1 单线程地址空间 vs 多线程地址空间 7.2 线程的状态与转换(照搬进程设计)7.2.1 线程的3种基本状态 7.3 线程管理中的数据结构7.3.1 线程的用户栈7.3.2 线程的内核栈 7.4 …

用户时长进入“零和”时代,App们如何借Push胜出?

作者 | 曾响铃 文 | 响铃说 激烈的市场竞争、快速变化的用户需求、层出不穷的新赛道新玩法……对如今移动互联网的App开发者而言,寻求用户增长的压力变得越来越大。 而与此同时,Push作为促进用户增长的关键手段之一,其执行过程中的诸多问题…

坑——fastjson将字符串转到带枚举的java对象

fastjson将不同的字符串转换成带枚举的java对象时&#xff0c;不同的字符串值转换成java对象的结果不同&#xff1b; 测试用fastjson版本&#xff1a; <dependency> <groupId>com.alibaba</groupId> <artifactId>fastjson</artifactId> <ver…

[有彩蛋]大模型独角兽阶跃星辰文生图模型Step-1X上线,效果具说很炸裂?快来看一手实测!

先简单介绍一下阶跃星辰吧 公司的创始人兼CEO是姜大昕博士&#xff0c;他在微软担任过全球副总裁&#xff0c;同时也是微软亚洲互联网工程研究院的副院长和首席科学家。 2024年3月&#xff0c;阶跃星辰发布了Step-2万亿参数MoE语言大模型预览版&#xff0c;这是国内初创公司首…

公司监控上网记录怎么实现?监控公司局域网内电脑的上网记录,这4个方法不妨一试!

在繁忙的办公室里&#xff0c;每位员工的手指在键盘上跳跃&#xff0c;屏幕闪烁间&#xff0c;他们究竟在忙碌什么&#xff1f;是沉浸在工作的海洋中&#xff0c;还是悄然滑向了网络的另一端&#xff1f; 为了解答这一疑问&#xff0c;确保工作效率与信息安全&#xff0c;公司纷…

webpack--处理资源

在webpack.config.js中进行配置 const path require(path) module.exports {// 入口entry: ./src/main.js,// 输出output: {// 文件的输出路径path: path.resolve(__dirname, dist),// 入口文件打包输出的文件名filename: js/main.js,// 自动清空上次打包结果 原理&#xff…

Spring中基于redis stream 的消息队列实现方法

本文主要介绍了消息队列的概念性质和应用场景&#xff0c;介绍了kafka、rabbitMq常用消息队列中间件的应用模型及消息队列的实现方式&#xff0c;并实战了在Spring中基于redis stream 的消息队列实现方法。 一、消息队列 消息队列是一种进程间通信或者同一个进程中不同线程间的…

Netlify 为静态站点部署 Waline 评论系统

目录 1 准备工作2 简介2.1 Netlify2.2 Waline2.3 Leancloud 3 开始搭建3.1 Fork 仓库3.2 设置 Leancloud3.3 部署 Netlify3.4 查看评论系统 从我建成个人网站以来&#xff0c;一个月了&#xff0c;竟然还没配置过评论系统&#xff0c;一直用的别人的 awa。 那么今天就稍微研究…

B站up主全程教学趋动云部署大模型:Meta新开源【Llama3.1-70B-Instruct】!

Llama 3.1 的指令调优版本&#xff08;8B、70B、405B&#xff09;针对多语言对话用例进行了优化&#xff0c;在比8种支持语言更广泛的语言集合上进行了训练&#xff0c;预训练模型可以适应多种自然语言生成任务。 Llama 3.1 模型集合还支持利用其模型的输出来改进其他模型&…

基于yolov8的红绿灯目标检测训练与Streamlit部署(代码+教程)

项目背景 随着智能交通系统的快速发展&#xff0c;自动驾驶技术逐渐成为研究的热点。在自动驾驶领域中&#xff0c;准确识别道路上的交通信号灯是确保车辆安全行驶的关键技术之一。近年来&#xff0c;深度学习技术的发展为交通信号灯的识别提供了强大的支持。YOLO&#xff08;…

云微客短视频矩阵系统,如何让企业赢在起跑线?

在这个信息爆炸的时代&#xff0c;传统的营销方式已经无法满足现代企业的快速发展的需求了。那么如何让企业的品牌和产品脱颖而出呢&#xff1f;云微客短视频矩阵系统&#xff0c;就是这样一个创新的解决方法。 但是很多企业认为&#xff0c;在这个短视频盛行的时代&#xff0c…

cr2怎么转换成jpg?分享这五款好用软件!

在数字摄影时代&#xff0c;CR2作为佳能相机常用的RAW格式&#xff0c;虽然能够保留更多的图像细节和色彩信息&#xff0c;但在日常分享和编辑中&#xff0c;JPG格式因其兼容性和便捷性而更受欢迎。今天&#xff0c;我们就来分享五款好用的软件&#xff0c;帮助你轻松将CR2格式…

数据中心代理IP的使用指南:提升网络体验的秘密武器

在互联网的广阔海洋中&#xff0c;数据中心代理IP是一种常见且实用的工具。无论是个人用户还是企业&#xff0c;使用数据中心代理IP都能带来诸多好处。本文将详细介绍数据中心代理IP的概念、优势以及使用技巧&#xff0c;让你在网络世界中游刃有余。 什么是数据中心代理IP&…

用自定义类级注解校验两字段不能同时为空

背景&#xff1a; 有下面这么一个类&#xff0c;要校验两字段query、image不能同时为空&#xff0c;应该怎么实现&#xff1f;已知的NotBlank等标签都只能检验单个字段。 import jakarta.validation.constraints.NotBlank; import lombok.Data; import org.springframework.h…