Onnx使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上

news2024/11/15 23:55:37

目录

一、整体功能概述

二、函数分析

2.1 resnet() 函数:

2.2 pre_process(img_path) 函数:

2.3 loadOnnx(img_path) 函数:

三、代码执行流程


一、整体功能概述


这段代码实现了一个图像分类系统,使用预训练的 ResNet18 模型对输入图像进行分类,并将分类结果显示在图像上。它包括以下主要步骤:
读取一个包含类别名称和对应编号的文本文件,并将其存储在字典中。
定义了几个函数,包括模型导出函数 resnet()、图像预处理函数 pre_process() 和加载 ONNX 模型进行分类的函数 loadOnnx()。
在主程序中,指定输入图像路径,调用 loadOnnx() 函数对图像进行分类并显示结果。


二、函数分析


2.1 resnet() 函数:


使用 torchvision 中的预训练 ResNet18 模型,并设置为评估模式。
生成一个随机输入张量 x,并将模型导出为 ONNX 格式,保存为 models/resnet18.onnx 文件。

def resnet():
    model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.eval()
    x=torch.randn(1,3,224,224)
    torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])


2.2 pre_process(img_path) 函数:


读取输入图像 img_path。
调整图像大小为 224x224。
将图像颜色通道从 BGR 转换为 RGB。
对图像像素值进行归一化处理。
交换图像维度顺序,并增加一个维度。
返回预处理后的图像张量。

def pre_process(img_path):
    #h w c--->224,224,3
    #归一化
    #换轴
    #增加维度
    img=cv2.imread(img_path)
    scale_image=cv2.resize(img,dsize=(224,224))
    rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)
    rgb_img=rgb_img/255
    rgb_img=np.transpose(rgb_img,(2,0,1))
    rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)
    return rgb_img


2.3 loadOnnx(img_path) 函数:


创建一个 ONNX 推理会话,加载预导出的 ResNet18 ONNX 模型。

调用 pre_process() 函数对输入图像进行预处理。
准备输入数据并进行推理。
获取推理结果中概率最大的类别编号。
根据类别编号从字典中获取对应的类别名称,并进行翻译。
在输入图像上显示分类结果,并展示图像。

def loadOnnx(img_path):
    session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])
    img=pre_process(img_path)
    img_back=cv2.imread(img_path)
    intput_feed={'input':img}
    session_out=session.run(None,intput_feed)[0]
    out=np.argmax(session_out,axis=1)[0]
    res=str(out)
    # print(dict[res])
    ans=dict[res].split(',')[1].split(']')[0].strip()
    ans = translator.translate(ans)
    cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)
    cv2.imshow('win',img_back)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    print(ans)

完整代码如下

import cv2
import numpy as np
import torch
from torchvision import models
from torchvision.models import ResNet18_Weights
import onnxruntime as ort
from translate import Translator
translator=Translator(to_lang='Chinese')#翻译成中文
dict={}
with open('类别.txt','r',encoding='utf-8') as f:
    lines=f.readlines()
    for line in lines:
        name=line.split('\t')[0]
        value=line.split('\t')[1]
        dict[name]=value
# print(dict)
def resnet():
    model=models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.eval()
    x=torch.randn(1,3,224,224)
    torch.onnx.export(model,x,'models/resnet18.onnx',input_names=['input'],output_names=['output'])
def pre_process(img_path):
    #h w c--->224,224,3
    #归一化
    #换轴
    #增加维度
    img=cv2.imread(img_path)
    scale_image=cv2.resize(img,dsize=(224,224))
    rgb_img=cv2.cvtColor(scale_image,cv2.COLOR_BGR2RGB)
    rgb_img=rgb_img/255
    rgb_img=np.transpose(rgb_img,(2,0,1))
    rgb_img=np.expand_dims(rgb_img,0).astype(np.float32)
    return rgb_img
    #RGB
def loadOnnx(img_path):
    session=ort.InferenceSession(r'models\resnet18.onnx',providers=['CPUExecutionProvider'])
    img=pre_process(img_path)
    img_back=cv2.imread(img_path)
    intput_feed={'input':img}
    session_out=session.run(None,intput_feed)[0]
    out=np.argmax(session_out,axis=1)[0]
    res=str(out)
    # print(dict[res])
    ans=dict[res].split(',')[1].split(']')[0].strip()
    ans = translator.translate(ans)
    cv2.putText(img_back,ans,(100,100),fontFace=1,fontScale=2.0,color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)
    cv2.imshow('win',img_back)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    print(ans)
    pass
if __name__ == '__main__':
    img_path='dog.png'
    # resnet()#导出模型
    loadOnnx(img_path)


三、代码执行流程


在 if __name__ == '__main__': 部分:
定义输入图像路径 img_path。
可以选择调用 resnet() 函数导出模型(注释状态,通常只在第一次运行或模型更新时使用)。
调用 loadOnnx(img_path) 函数对输入图像进行分类和显示结果。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2078768.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

error C2375: “WSAAsyncGetHostByName”: 重定义;不同的链接

error C2375: “WSAAsyncGetHostByName”: 重定义;不同的链接 win11 vs2015 背景:当项目中使用到了开源库,开源库使用WinSock2.h,同时windows项目又有包含Windows.h, 编译时常常会出现一堆编译错误,方法重定义等等。 问题原因: 默认windows.h头文件会包含winsock.h //…

【Linux】快速入门(第一篇)

1. Linux简介 1.操作系统概念 Linux 也是众多操作系统之一,要想知道 Linux 是什么,首先得说一说什么是操作系统。 计算机是一台机器,它按照用户的要求接收信息、存储数据、处理数据,然后再将处理结果输出(文字、图片…

.ipynb文件:交互式 Jupyter Notebook

Python 接口文件(带有扩展名的文件.pyi),或称为 Python 存根文件,在使用类型提示增强 Python 代码方面发挥着至关重要的作用。 当你遇到名称以 .ipynb、.pyi、.pyc 等结尾的 Python 文件时,你是否会感到困惑&#xff…

Adobe ME软件安装win/mac下载与使用教程

目录 一、Adobe ME软件介绍 1.1 软件概述 1.2 主要功能 1.3 软件优势 二、系统要求 2.1 Windows系统要求 2.2 macOS系统要求 三、安装步骤 3.1 Windows系统安装 3.2 macOS系统安装 四、使用教程 4.1 基本界面介绍 4.2 视频编码与转码 4.3 音频和字幕处理 4.4 高…

快来领取迅雷加速器7天会员,让你的《黑神话·悟空》更新速度嗖嗖嗖!⚡️

嘿,各位《黑神话悟空》的小伙伴们!😆 最近大家肯定都在Steam上体验这款国产3A大作吧?游戏的画质、玩法是不是让你眼前一亮?😍 但是!😫 大家有没有发现,游戏加载和更新时…

谷歌的有害链接是什么?

有害链接,顾名思义,是指那些可能对你网站的Google排名产生负面影响的链接,但,真的存在会对网站造成坏影响的链接吗? 所谓的有害链接,更多是现在很多seo工具所定义出来的,事实上,自从…

豆瓣评分9.4!最适合Python入门后进阶的Python食谱!

Python是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。Python的设计具有很强的可读性,相比其他语言经常使用英文关键字,其他语言的一些标点符号,它具有比其他语言更有特色语法结构。 今天给小伙伴们分享的这份手册&#x…

Python + Playwright(23):处理 iframe (内嵌框架)「详细介绍」

Python Playwright(23):处理 iframe 内嵌框架「详细介绍」 简介1. 理解 iframe 的特性2. 处理 iframe 的方法2.1 使用 page.frames 遍历所有的 iframe2.2 通过 page.frames 访问特定 iframe2.2 通过 page.frames 的索引访问 iframe2.3 通过 …

企业如何实现多个分公司组网方案

在现代商业环境中,企业往往需要连接多个分公司以实现高效的资源共享和协同工作。以下是一个全面的多个分公司组网策略供参考。 一、确定网络架构和布局 总部作为核心数据中心:总部应配备高性能的网络设备和完善的安全防护措施,承担数据存储和…

医疗器械管理软件 符合新规 免费升级

盘谷医疗器械管理软件具有对采购、收货、验收、贮存、销售、出库、复核、退货等各经营环节进行实时质量控制的功能;具有权限管理功能,确保各类数据的录入、修改、保存等操作应当符合授权范围、管理制度和操作规程的要求,保证数据真实、准确、…

深入理解ARM64的函数调用标准与栈布局

一、引言 随着计算机技术的飞速发展,人们对计算机的性能要求越来越高,为了突破32位架构的4GB地址空间限制,并实现更好的性能提升。ARM公司推出了一种64位处理器架构,也就是我们今天所要讨论的ARM64。ARM64(也称ARMv8)面世以来,在…

leetcode 3146 两个字符串的排列差

leetcode 3146 两个字符串的排列差 正文题目描述解题思路方法1 Python 处理字符串的思路方法2 正文 题目描述 解题思路 直接 for 循环遍历第一个字符串,在第二个字符串中找出第一个字符串中的对应字符的位置,做差,再取绝对值,最…

Clickhouse集群化(三)集群化部署

1. 准备 clickhouse支持副本和分片的能力,但是自身无法实现需要借助zookeeper或者clickhouse-keeper来实现不同节点之间数据同步,同时clickhouse的数据是最终一致性 。 2. Zookeeper 副本的写入流程 没有主从概念 平等地位 互为副本 2.1. 部署zookeep…

高效能低延迟:EasyCVR平台WebRTC支持H.265在远程监控中的优势

TSINGSEE青犀视频EasyCVR视频汇聚平台在WebRTC方面确实支持H.265编码,尽管标准的WebRTC API在大多数浏览器中默认并不支持H.265(也称为HEVC,高效视频编码)编码。EasyCVR平台通过一系列创新的技术手段,实现了在WebRTC协…

区块链应用,密码学会议书籍推荐以及隐私保护知识整理

基于区块链技术的安全多方计算项目示例 1. iCube——全球首个安全多方计算区块链金融项目 iCube团队通过与美国普渡大学区块链人工智能实验室深度合作,实现了区块链的安全多方计算。iCube建立了面向信息的终极抽象基础层和基于个人工智能的算法模型层,…

互联网盲盒小程序,提高企业市场竞争力

盲盒作为一种休闲娱乐的方式,受到了大众的追捧,各大消费者争相购买,市场一时火热非凡! 随着互联网电商的出现,盲盒也开始在线上发展,当消费者距离盲盒门店较远或者没有时间下,就可以在小程序上…

Linux:Socket网络编程

目录 1. 理解源 IP 地址和目的 IP 地址 2:认识端口号 3:端口号范围划分 4:理解源端口号和目的端口号 5:理解Socket(套接字) 6:两个传输协议 (TCP/UDP) 6.1:User Datagram Prot…

重磅!尤文图斯携手Fortinet打造足球界的网络安全堡垒

近日,尤文图斯足球俱乐部与推动网络与安全融合的全球网络安全领导者 Fortinet(NASDAQ:FTNT)正式宣布建立合作伙伴关系,并签署了一项为期至2026年的赞助协议。在此框架下,Fortinet荣膺尤文图斯未来两个赛季的…

深度学习/机器学习软件教学平台

1、基本介绍 机器学习与深度学习教学系统是基于业界应用广泛的在线机器学习和深度学习建模开发框架JupyterLab开发的,面向高校数据分析、机器学习、深度学习,以及人工智能相关专业教学和实训的教学系统。 2、系统特色 系统首页 系统主界面 在线实验界面…

LLM推理端实现

LLM推理端是什么 Large Language Model,大语言模型。典型代表ChatGPT。 推理端:模型训练出来后,用于模型应用和部署的interface。 推理端实现了本地环境中部署大语言模型。可以实现LLM的基本功能,包括生成文本、自动摘要、语言…