【深度学习】pytorch快速得到mobilenet_v2 pth 和onnx

news2024/12/30 1:30:35

在linux执行这个程序:

import torch
import torch.onnx
from torchvision import transforms, models
from PIL import Image
import os

# Load MobileNetV2 model
model = models.mobilenet_v2(pretrained=True)
model.eval()

# Download an example image from the PyTorch website
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try:
    os.system(f"wget {url} -O {filename}")
except Exception as e:
    print(f"Error downloading image: {e}")

# Preprocess the input image
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_image = Image.open(filename)
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension

# Perform inference on CPU
with torch.no_grad():
    output = model(input_tensor)

# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
print(output[0])

# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)

# Download ImageNet labels using wget
os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")

# Read the categories
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

# Save the PyTorch model
torch.save(model.state_dict(), "mobilenet_v2.pth")

# Convert the PyTorch model to ONNX with specified input and output names
dummy_input = torch.randn(1, 3, 224, 224)
onnx_path = "mobilenet_v2.onnx"
input_names = ['input']
output_names = ['output']
torch.onnx.export(model, dummy_input, onnx_path, input_names=input_names, output_names=output_names)

print(f"PyTorch model saved to 'mobilenet_v2.pth'")
print(f"ONNX model saved to '{onnx_path}'")

# Load the ONNX model
import onnx
import onnxruntime

onnx_model = onnx.load(onnx_path)
onnx_session = onnxruntime.InferenceSession(onnx_path)

# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()

# Perform inference on ONNX with the correct input name
onnx_output = onnx_session.run(['output'], {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)

# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("\nTop categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):
    print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())


得到:

在这里插入图片描述
用本地pth推理:

import torch
from torchvision import transforms, models
from PIL import Image

# Load MobileNetV2 model
model = models.mobilenet_v2()
model.load_state_dict(torch.load("mobilenet_v2.pth", map_location=torch.device('cpu')))
model.eval()

# Preprocess the input image
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension

# Perform inference on CPU
with torch.no_grad():
    output = model(input_tensor)

# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
# print(output[0])

# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)

# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

用onnx推理:

import torch
import onnxruntime
from torchvision import transforms
from PIL import Image

# Load the ONNX model
onnx_path = "mobilenet_v2.onnx"
onnx_session = onnxruntime.InferenceSession(onnx_path)

# Preprocess the input image
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension

# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()

# Perform inference on ONNX
onnx_output = onnx_session.run(None, {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)

# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("Top categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):
    print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1221296.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

牛客——OR36 链表的回文结构(C语言,配图,快慢指针)

本题是没有对C的支持的,但因为Cpp支持C,所以这里就用C写了,可以面向更多用户 链表的回文结构_牛客题霸_牛客网 (nowcoder.com) 思路一:链表翻转 简单的想想整形我们怎么比较,就是将整形A 依次取尾,放到整形…

C语言之深入指针(三)(详细教程)

C语言之深入指针 在学习这篇博客之前建议先看看这篇博客C语言之深入指针(二) 里面详细介绍了指针的 传值调用和传址调用数组名的理解使用指针访问数组⼀维数组传参的本质 文章目录 C语言之深入指针1 二级指针1.1 二级指针的介绍1.2 二级指针的使用 2 指…

Redis持久化策略之RDB与AOF

文章目录 1.RDB1)基本介绍2)自动触发3)手动触发4)RDB文件5)优点缺点 2.AOF1)基本介绍2)使用方式3)工作流程4)重写机制5)AOF文件6)优点缺点 3.RDB AOF 我们都知道,redis 是一个基于内存的数据库。基于内存的好处是访问速度快,缺点是“不持久”——当数据…

响应数据web

get请求 package com.example.demo.controller.poio;import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController;import java.util.*;//字符串格式返回 RestController public class ResponseBody {Req…

使用FFmpeg合并多个ts视频文件转为mp4格式

前言 爬取完视频发现都是ts文件,而且都是几百KB的视频片段,.ts 全名叫:MPEG Transport Stream,它是一个万能的多媒体容器,可以装下音频、视频、字幕。有时我们需要将.ts文件转换为其他更加广泛被支持的格式&#xff0…

基于Vue+SpringBoot的海南旅游景点推荐系统 开源项目

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户端2.2 管理员端 三、系统展示四、核心代码4.1 随机景点推荐4.2 景点评价4.3 协同推荐算法4.4 网站登录4.5 查询景点美食 五、免责说明 一、摘要 1.1 项目介绍 基于VueSpringBootMySQL的海南旅游推荐系统&#xff…

pytorch单精度、半精度、混合精度、单卡、多卡(DP / DDP)、FSDP、DeepSpeed模型训练

pytorch单精度、半精度、混合精度、单卡、多卡(DP / DDP)、FSDP、DeepSpeed(环境没搞起来)模型训练代码,并对比不同方法的训练速度以及GPU内存的使用 代码:pytorch_model_train FairScale(你真…

前端调取摄像头并实现拍照功能

前言 最近接到的一个需求十分有意思,设计整体实现了前端仿 微信扫一扫 的功能。整理了一下思路,做一个分享。 tips: 如果想要实现完整扫一扫的功能,你需要掌握一些前置知识,这次我们先讲如何实现拍照并且保存的功能。 一. windo…

Diffusion Models CLIP

Introduction to Diffusion Models 生成模型 主要指的是无监督学习中的生成模型,在无监督学习中的主要任务是让机器学习给定的样本,然后生成一些新的东西出来。比如:给机器看一些图片,能够生成一些新的图片出来,给机器…

element el-date-picker报错Prop being mutated:“placement“快速解决方式

报错信息 Avoid mutating a prop directly since the value will be overwritten whenever the parent component re-renders. Instead, use a data or computed property based on the prop’s value. Prop being mutated: “placement” 报错版本 element-ui 2.15.6 和 2.15…

干扰项目成本估算精准度的5大因素

干扰项目成本估算精准度的因素有很多,这些因素可能导致成本估算的不准确性,增加成本偏差和额外的成本投入,从而对项目的进度和预算产生影响。因此,在进行项目成本估算时,需要充分考虑这些因素,并采取相应的…

NFS共享

目录 三种存储类型 作用: FTP文本传输协议 原理 FTP服务状态码 用户认证 常见FTP相关软件 vsftpd 软件介绍 用户和其共享目录 基础操作 安装服务端 客户端连接服务端 登录成功 匿名用户登录 1.服务端配置 2.客户端配置 3.服务端查看 匿名用户下载 删除…

Taro.navigateTo 使用URL传参数和目标页面参数获取

文章目录 1. Taro.navigateTo 简介2. 通过 URL 传递参数3. 目标页面参数获取4. 拓展与分析4.1 拓展4.2 URL参数的类型4.3 页面间通信 5. 总结 🎉欢迎来到Java学习路线专栏~Taro.navigateTo 使用URL传参数和目标页面参数获取 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒&#x…

共聚焦显微镜的应用特点

共聚焦显微镜具有高分辨率和高灵敏度的特点,适用于多种不同样品的成像和分析,能够产生结果和图像清晰,易于分析。这些特性使共聚焦显微镜成为现代科学研究中不可或缺的重要工具,同时为人们解析微观世界提供了一种强大的手段。 作…

python递归求数字各个位数相加_和

python递归求数字的各项和&#xff0c;例如数字一千零二十四&#xff1a;“1024”&#xff0c;输出结果为“10247” 第一种方法&#xff1a; def sum(a): #求一个数字各项和&#xff0c;第一种递归方法if 0<a<9: #从前到最后一个&#xff0c;出循环…

无重复最长字符串(最长无重复子字符串),剑指offer,力扣

目录 原题&#xff1a; 力扣地址&#xff1a; 我们直接看题解吧&#xff1a; 解题方法&#xff1a; 难度分析&#xff1a; 难度算中下吧&#xff0c;这个总体不算很难&#xff0c;而且滑动窗口&#xff0c;以及哈希都比较常见 审题目事例提示&#xff1a; 解题思路&#xff08;…

windows下安装make工具

我使用的windows的gitbash命令行终端。未安装时&#xff0c;发现系统没有make命令。 make -h bash: make: command not found使用windows的包管理工具winget安装&#xff1a; winget.exe install gnuwin32.make2. 将安装的make的bin目录添加到环境变量&#xff1a; setx PATH …

HTTP Error 500.31 - Failed to load ASP.NET Core runtime

在winserver服务器上部署net6应用后&#xff0c;访问接口得到以下提示&#xff1a; 原因是因为没有安装net6的运行时和环境&#xff0c;我们可以在windows自带的 “事件查看器” 查看原因。 可以直接根据给出的地址去官网下载sdk环境&#xff0c;安装即可 下载对应的net版本…

前端反卷计划-组件库-03-组件样式

Hi, 大家好&#xff01;我是程序员库里。 今天开始分享如何从0搭建UI组件库。这也是前端反卷计划中的一项。 在接下来的日子&#xff0c;我会持续分享前端反卷计划中的每个知识点。 以下是前端反卷计划的内容&#xff1a; 目前这些内容持续更新到了我的 学习文档 中。感兴趣…

C++中静态成员变量和普通成员变量、私有成员变量和公有成员变量的区别

本文主要介绍和记录C中静态成员变量和普通成员变量、私有成员变量和公有成员变量的区别&#xff0c;并给出相关示例程序&#xff0c;最后结合相关工程应用中编译报错给出报错原因及介绍思路 一、静态成员变量和普通成员变量 C中&#xff0c;静态成员变量和普通成员变量有一些重…