pytorch软件封装

news2025/4/16 1:01:59

封装代码,通过传入文件名,即可输出类别信息

上一章节,我们做了关于动物图像的分类,接下来我们把程序封装,然后进行预测。

单张图片的predict文件

predict.py

'''
    按着路径,导入单张图片做预测
'''
from torchvision.models import resnet18
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import cv2 as cv
import os
import numpy as np

'''
    加载图片与格式转化
'''

# 图片标准化
transform_BZ = transforms.Normalize(
    mean=[0.5062653, 0.46558657, 0.37899864],  # 取决于数据集
    std=[0.22566116, 0.20558165, 0.21950442]
)

img_size = 224
val_tf = transforms.Compose([  ##简单把图片压缩了变成Tensor模式
    transforms.ToPILImage(),  # 将numpy数组转换为PIL图像
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transform_BZ  # 标准化操作
])


def cv_imread(file_path):
    cv_img = cv.imdecode(np.fromfile(file_path, dtype=np.uint8), cv.IMREAD_COLOR)
    return cv_img


def predict(img_path):
    '''
        获取标签名字
    '''
    # # 增加类别标签
    # dir_names = []
    # for root, dirs, files in os.walk("dataset"):
    #     if dirs:
    #         dir_names = dirs
    # 将输出保存到exel中,方便后续分析
    label_names = ['cat', 'chicken', 'cow', 'dog', 'duck',
                   'goldfish', 'lion', 'pig', 'sheep',
                   'snake']

    # 指定设备
    device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"Using {device} device")

    """
        加载模型
    """
    model = resnet18(weights=None)
    num_ftrs = model.fc.in_features  # 获取全连接层的输入
    model.fc = nn.Linear(num_ftrs, 10)  # 全连接层改为不同的输出
    torch_data = torch.load('./logs_resnet18_adam/best.pth',
                            map_location=torch.device(device))
    model.load_state_dict(torch_data)
    model.to(device)

    '''
        读取图片
    '''
    img = cv_imread(img_path)
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    img_tensor = val_tf(img)

    # 增加batch_size维度
    img_tensor = Variable(torch.unsqueeze(img_tensor, dim=0).float(),
                          requires_grad=False).to(device)

    '''
        数据输入与模型输出转换
    '''
    model.eval()
    with torch.no_grad():
        output_tensor = model(img_tensor)

        # 将输出通过softmax变为概率值
        output = torch.softmax(output_tensor, dim=1)

        # 输出可能性最大的那位
        pred_value, pred_index = torch.max(output, 1)

        # 将数据从cuda转回cpu
        if torch.cuda.is_available() == False:
            pred_value = pred_value.detach().cpu().numpy()
            pred_index = pred_index.detach().cpu().numpy()

        result = "预测类别为: " + str(label_names[pred_index[0]]) + " 可能性为: " + str(pred_value[0].item() * 100)[:5] + "%"
        return result


if __name__ == "__main__":
    img_path = r'dataset/cat/10.jpg'
    result = predict(img_path)
    print(result)

这里可以看出,我们用的cat数据集中的图片,预测出来的结果却是是cat,虽然可能性不是很高。

torch_data=torch.load('./logs_resnet18_adam/best.pth',map_location=torch.device(device))

使用 PyTorch 加载一个保存的模型权重文件(best.pth),并将其映射到指定的设备(device

img_tensor=Variable(torch.unsqueeze(img_tensor,dim=0).float(),requires_grad=False).to(device)

将一个图像张量(img_tensor)进行处理,使其成为适合输入到神经网络模型中的格式,并将其移动到指定的设备(CPU 或 GPU)上

1. torch.unsqueeze(img_tensor, dim=0)
  • 作用:在张量的第 0 维(即最外层)添加一个维度。

  • 背景:神经网络模型通常期望输入数据是一个四维张量,形状为 [batch_size, channels, height, width]。如果 img_tensor 是一个三维张量(例如 [channels, height, width]),则需要在第 0 维添加一个维度,使其形状变为 [1, channels, height, width],其中 1 表示批量大小(batch_size)为 1。

2. .float()
  • 作用:将张量的数据类型转换为 float32

  • 背景:许多神经网络模型在训练和推理时使用 float32 数据类型。如果 img_tensor 的数据类型不是 float32,则需要显式转换。

3. Variable(..., requires_grad=False)
  • 作用:将张量封装为 Variable 对象,并设置 requires_grad 属性。

  • 背景

    • Variable 是 PyTorch 中的一个旧类,用于封装张量并支持自动求导。在较新的 PyTorch 版本中,Variable 已经与 Tensor 合并,因此这一步在现代代码中通常是多余的。

    • requires_grad=False 表示这个张量不需要计算梯度。这在推理阶段非常常见,因为输入数据不需要参与梯度计算。

转成ONNX,兼容各种设备

ONNX是什么?

ONNX(Open Neural Network Exchange)是一种开放的神经网络交换格式,它为深度学习模型提供了一种标准化的表示方式,使得模型可以在不同的深度学习框架之间进行转换和共享。

ONNX的作用是什么?

  • 模型转换:开发者可以将训练好的模型从一个框架(如PyTorch)转换为ONNX格式,然后在另一个框架(如TensorFlow)中加载和使用。这使得开发者可以在不同的框架之间灵活切换,利用不同框架的优势。

  • 模型部署:ONNX模型可以被导出到多种推理引擎,如ONNX Runtime。ONNX Runtime是一个高性能的推理引擎,支持多种硬件平台(如CPU、GPU、FPGA等),可以用于将模型部署到生产环境中。

  • 模型优化:通过ONNX,开发者可以对模型进行优化和量化等操作。例如,可以将模型从浮点数量化为整数,以提高模型的推理速度和降低存储需求。

import torch
from torch import nn
from torchvision.models import resnet18
# pip install onnx
# pip install onnxruntime

if __name__ == '__main__':

    # 指定设备
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")

    # 指定模型
    model = resnet18(pretrained=False)

    num_ftrs = model.fc.in_features    # 获取全连接层的输入
    model.fc = nn.Linear(num_ftrs, 10)  # 全连接层改为不同的输出

    # 模型加载权重
    torch_data = torch.load('logs_resnet18_pretrain/best.pth',
                            map_location=torch.device(device))

    model.load_state_dict(torch_data)
    model.to(device)

    # 创建一个示例输入
    dummy_input = torch.randn(1,3,224,224, device=device)
    # 指定输出文件路径
    onnx_file_path = "logs_resnet18_pretrain/model.onnx"

    # 导出onnx
    torch.onnx.export(model, dummy_input, onnx_file_path,
                      verbose=True,  # 屏幕中打印日志信息
                      input_names=['input'],
                      output_names=['output'])

    print("Model Exported Success")

Netron模型可视化

NETRON查看网络结构

如何下载可以看这篇文章网络可视化工具netron安装流程-CSDN博客

下载过后打开文件

ONNX单张图片预测

# -*- coding: utf-8 -*-
'''
    按着路径,导入单张图片做预测
'''
import onnxruntime as ort  # pip install onnxruntime onnx
import numpy as np
import torchvision.transforms as transforms
import cv2 as cv
import os


def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=1, keepdims=True)


def cv_imread(file_path):
    cv_img = cv.imdecode(np.fromfile(file_path, dtype=np.uint8), cv.IMREAD_COLOR)
    return cv_img


def predict(img_path):
    '''
        获取标签名字
    '''
    # dir_names = []
    # for root, dirs, files in os.walk("dataset"):
    #     if dirs:
    #         dir_names = dirs
    # label_names = dir_names

    label_names = ['cat', 'chicken', 'cow', 'dog', 'duck',
                   'goldfish', 'lion', 'pig', 'sheep',
                   'snake']

    '''
        加载图片与格式转化
    '''

    # 图片标准化
    transform_BZ = transforms.Normalize(
        mean=[0.5062653, 0.46558657, 0.37899864],  # 取决于数据集
        std=[0.22566116, 0.20558165, 0.21950442]
    )

    img_size = 224
    val_tf = transforms.Compose([  # 简单把图片压缩了变成Tensor模式
        transforms.ToPILImage(),  # 将numpy数组转换为PIL图像
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transform_BZ  # 标准化操作
    ])

    # 读取图片
    img = cv_imread(img_path)
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    img_tensor = val_tf(img)

    # 将图片转换为ONNX运行时所需的格式
    img_numpy = img_tensor.numpy()
    img_numpy = np.expand_dims(img_numpy, axis=0)  # 增加batch_size维度

    # 加载ONNX模型
    onnx_model_path = r'logs_resnet18_pretrain/model.onnx'  # 替换为ONNX模型的路径
    ort_session = ort.InferenceSession(onnx_model_path)

    # 运行ONNX模型
    outputs = ort_session.run(None, {'input': img_numpy})
    output = outputs[0]

    # 应用softmax
    probabilities = softmax(output)

    # 获得预测结果
    pred_index = np.argmax(probabilities, axis=1)
    pred_value = probabilities[0][pred_index[0]]

    result = "预测类别为: " + str(label_names[pred_index[0]]) + " 可能性为: " + str(pred_value * 100)[:5] + "%"
    return result


if __name__ == "__main__":
    img_path = r'dataset/cat/10.jpg'
    result = predict(img_path)
    print(result)

这个没什么好讲的,就是可以直接封装成了一个onnx,可以不用安装pytorch库

PyQt5做预测模型

接下来先请大家准备一些库,看一看下面这篇文章PyCharm配置外部工具PyQtDesigner、PyUIC、Pyrcc_pycharm外部工具-CSDN博客

我把所有的文件封装了一下,大家要记得改一改路径

main_one_thread,py

# -*- coding: utf-8 -*-
from mainwindow import Ui_MainWindow
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog
import sys
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtWidgets import *
from predict封装 import predict


class UiMain(QMainWindow, Ui_MainWindow):
    def __init__(self, parent=None):
        super(UiMain, self).__init__(parent)
        self.setupUi(self)
        self.fileBtn.clicked.connect(self.loadImage)


    # 打开文件功能
    def loadImage(self):
        self.fname, _ = QFileDialog.getOpenFileName(self, '请选择图片','.','图像文件(*.jpg *.jpeg *.png)')
        if self.fname:
            print(self.fname)
            self.Infolabel.setText("文件打开成功\n"+self.fname)
            jpg = QtGui.QPixmap(self.fname).scaled(self.Imglabel.width(),
                                                   self.Imglabel.height())

            self.Imglabel.setPixmap(jpg)

            result = predict(self.fname)
            self.Infolabel.setText(result)

        else:
            # print("打开文件失败")
            self.Infolabel.setText("打开文件失败")


if __name__ == '__main__':
    app = QApplication(sys.argv)
    ui = UiMain()
    ui.show()
    sys.exit(app.exec_())

运行结果

但是这个文件如果打包,别人不一定能用

main_one_thread_onnx.py

from mainwindow import Ui_MainWindow
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog
import sys
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtWidgets import *
from predict_onnx import predict


class UiMain(QMainWindow, Ui_MainWindow):
    def __init__(self, parent=None):
        super(UiMain, self).__init__(parent)
        self.setupUi(self)
        self.fileBtn.clicked.connect(self.loadImage)


    # 打开文件功能
    def loadImage(self):
        self.fname, _ = QFileDialog.getOpenFileName(self, '请选择图片','.','图像文件(*.jpg *.jpeg *.png)')
        if self.fname:
            print(self.fname)
            self.Infolabel.setText("文件打开成功\n"+self.fname)
            jpg = QtGui.QPixmap(self.fname).scaled(self.Imglabel.width(),
                                                   self.Imglabel.height())

            self.Imglabel.setPixmap(jpg)

            result = predict(self.fname)
            self.Infolabel.setText(result)

        else:
            # print("打开文件失败")
            self.Infolabel.setText("打开文件失败")


if __name__ == '__main__':
    app = QApplication(sys.argv)
    ui = UiMain()
    ui.show()
    sys.exit(app.exec_())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2335616.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【多线程-第四天-自己模拟SDWebImage的下载图片功能-看SDWebImage的Demo Objective-C语言】

一、我们打开之前我们写的异步下载网络图片的项目,把刚刚我们写好的分类拖进来 1.我们这个分类包含哪些文件: 1)HMDownloaderOperation类, 2)HMDownloaderOperationManager类, 3)NSString+Sandbox分类, 4)UIImageView+WebCache分类, 这四个文件吧,把它们拖过来…

电脑提示“找不到mfc140u.dll“的完整解决方案:从原因分析到彻底修复

当你启动某个软件或游戏时,突然遭遇"无法启动程序,因为计算机中丢失mfc140u.dll"的错误提示,这确实令人沮丧。mfc140u.dll是Microsoft Foundation Classes(MFC)库的重要组成部分,属于Visual C Re…

图像变换方式区别对比(Opencv)

1. 变换示例 import cv2 import matplotlib.pyplot as plotimg cv2.imread(url) img_cut img[100:200, 200:300] img_rsize cv2.resize(img, (50, 50)) (hight,width) img.shape[:2] rotate_matrix cv2.getRotationMatrix2D((hight//2, width//2), 50, 1) img_wa cv2.wa…

图像颜色空间对比(Opencv)

1. 颜色转换 import cv2 import matplotlib.pyplot as plotimg cv2.imread("tmp.jpg") img_r cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_g cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img_h cv2.cvtColor(img, cv2.COLOR_BGR2HSV) img_l cv2.cvtColor(img, cv2.C…

每天学一个 Linux 命令(15):man

可访问网站查看,视觉品味拉满:http://www.616vip.cn/15/index.html 每天学一个 Linux 命令(15):man 命令简介 man(Manual)是 Linux 中最核心的命令之一,用于查看命令、系统调用、库函数等的手册文档。它是用户和开发者获取帮助的核心工具,几乎覆盖了系统中的所有功…

必刷算法100题之计算右侧小于当前元素的个数

题目链接 315. 计算右侧小于当前元素的个数 - 力扣(LeetCode) 题目解析 计算数组里面所有元素右侧比它小的数的个数, 并且组成一个数组,进行返回 算法原理 归并解法(分治) 当前元素的后面, 有多少个比我小(降序) 我们要找到第一比左边小的元素, 这样…

Python依赖注入完全指南:高效解耦、技术深析与实践落地

Python依赖注入完全指南:高效解耦、技术深析与实践落地 摘要 依赖注入(DI)不仅是一种设计技术,更是一种解耦的艺术。它通过削减模块间的强耦合性,为系统提供了更高的灵活性和可测试性,特别是在 FastAPI 等…

深度学习ResNet模型提取影响特征

大家好,我是带我去滑雪! 影像组学作为近年来医学影像分析领域的重要研究方向,致力于通过从医学图像中高通量提取大量定量特征,以辅助疾病诊断、分型、预后评估及治疗反应预测。这些影像特征涵盖了形状、纹理、灰度统计及波形变换等…

【Qt】Qt Creator开发基础:项目创建、界面解析与核心概念入门

🍑个人主页:Jupiter. 🚀 所属专栏:QT 欢迎大家点赞收藏评论😊 目录 Qt Creator 新建项⽬认识 Qt Creator 界⾯项⽬⽂件解析Qt 编程注意事项认识对象模型(对象树)Qt 窗⼝坐标体系 Qt Creator 新…

制造业项目管理如何做才能更高效?制造企业如何选择适配的数字化项目管理系统工具?

一、制造企业项目管理过程中面临的痛点有哪些? 制造企业在项目管理过程中面临的痛点通常涉及跨部门协作、资源调配、数据整合、风险控制等多个维度,且与行业特性(如离散制造vs流程制造)紧密相关。 进度失控多项目资源冲突信息孤…

Python批量处理PDF图片详解(插入、压缩、提取、替换、分页、旋转、删除)

目录 一、概述 二、 使用工具 三、Python 在 PDF 中插入图片 3.1 插入图片到现有PDF 3.2 插入图片到新建PDF 3.3 批量插入多张图片到PDF 四、Python 提取 PDF 图片及其元数据 五、Python 替换 PDF 图片 5.1 使用图片替换图片 5.2 使用文字替换图片 六、Python 实现 …

七种驱动器综合对比——《器件手册--驱动器》

九、驱动器 名称 功能与作用 工作原理 优势 应用 隔离式栅极驱动器 隔离式栅极驱动器用于控制功率晶体管(如MOSFET、IGBT、SiC或GaN等)的开关,其核心功能是将控制信号从低压侧传输到高压侧的功率器件栅极,同时在输入和输出之…

redis系列--1.redis是什么

国际惯例,想了解一个东西,首先就要看看官方提供了什么。redis的官网是https://redis.io 。以下这段话就是redis的简介了: Redis is an open source (BSD licensed), in-memory data structure store, used as a database, cache, and message…

CSS 过渡与变形:让交互更丝滑

在网页设计中,动效能让用户交互更自然、流畅,提升使用体验。本文将通过 CSS 的 transition(过渡)和 transform(变形)属性,带你入门基础动效设计,结合案例演示如何实现颜色渐变、元素…

MecAgent Copilot:机械设计师的AI助手,开启“氛围建模”新时代

MecAgent Copilot作为机械设计师的AI助手,正通过多项核心技术推动机械设计进入“氛围建模”新时代。以下从功能特性、技术支撑和应用场景三方面解析其创新价值: 一、核心功能特性 ​​智能草图生成与参数化建模​​ 支持自然语言输入生成设计草图和3D模型,如输入“剖面透视…

【prometheus+Grafana篇】Prometheus与Grafana:深入了解监控架构与数据可视化分析平台

💫《博主主页》:奈斯DB-CSDN博客 🔥《擅长领域》:擅长阿里云AnalyticDB for MySQL(分布式数据仓库)、Oracle、MySQL、Linux、prometheus监控;并对SQLserver、NoSQL(MongoDB)有了解 💖如果觉得文章对你有所帮…

【后端开发】初识Spring IoC与SpringDI、图书管理系统

文章目录 图书管理系统用户登录需求分析接口定义前端页面代码服务器代码 图书列表展示需求分析接口定义前端页面部分代码服务器代码Controller层service层Dao层modle层 Spring IoC定义传统程序开发解决方案IoC优势 Spring DIIoC &DI使用主要注解 Spring IoC详解bean的存储五…

git在IDEA中使用技巧

git在IDEA中使用技巧 merge和rebase 参考:IDEA小技巧-Git的使用 git回滚、强推、代码找回 参考:https://www.bilibili.com/video/BV1Wa411a7Ek?spm_id_from333.788.videopod.sections&vd_source2f73252e51731cad48853e9c70337d8e cherry pick …

榕壹云无人共享系统:基于SpringBoot+MySQL+UniApp的物联网共享解决方案

无人共享经济下的技术革新 随着无人值守经济模式的快速发展,传统共享设备面临管理成本高、效率低下等问题。榕壹云无人共享系统依托SpringBootMySQLUniApp技术栈,结合物联网与移动互联网技术,为商家提供低成本、高可用的无人化运营解决方案。…

ARCGIS PRO DSK 利用两期地表DEM数据计算工程土方量

利用两期地表DEM数据计算工程土方量需要准许以下数据: 当前地图有3个图层,两个栅格图层和一个矢量图层 两个栅格图层:beforeDem为工程施工前的地表DEM模型 afterDem为工程施工后的地表DEM模型 一个矢量图层&#xf…