【PyTorch】图像多分类项目部署

news2025/1/15 23:22:37

【PyTorch】图像多分类项目

【PyTorch】图像多分类项目部署

如果需要在独立于训练脚本的新脚本中部署模型,这种情况模型和权重在内存中不存在,因此需要构造一个模型类的对象,然后将存储的权重加载到模型中。

加载模型参数,验证模型的性能,并在测试数据集上部署模型

from torch import nn
from torchvision import models

# 定义一个resnet18模型,不使用预训练参数
model_resnet18 = models.resnet18(pretrained=False)
# 获取模型的全连接层的输入特征数
num_ftrs = model_resnet18.fc.in_features
# 定义分类的类别数
num_classes=10
# 将全连接层的输出特征数改为分类的类别数
model_resnet18.fc = nn.Linear(num_ftrs, num_classes)

import torch 
path2weights="./models/resnet18_pretrained.pt"
# 加载预训练的ResNet18模型权重
model_resnet18.load_state_dict(torch.load(path2weights))
# 将ResNet-18模型设置为评估模式
model_resnet18.eval();
# 检查CUDA是否可用
if torch.cuda.is_available():
    # 如果可用,将设备设置为CUDA
    device = torch.device("cuda")
    # 将模型移动到CUDA设备上
    model_resnet18=model_resnet18.to(device)

def deploy_model(model,dataset,device, num_classes=10,sanity_check=False):

    # 获取数据集的长度
    len_data=len(dataset)
    
    # 初始化输出张量
    y_out=torch.zeros(len_data,num_classes)
    
    # 初始化真实标签张量
    y_gt=np.zeros((len_data),dtype="uint8")
    
    # 将模型移动到指定设备
    model=model.to(device)
    
    # 初始化时间列表
    elapsed_times=[]
    with torch.no_grad():
        for i in range(len_data):
            # 获取数据集中的一个样本
            x,y=dataset[i]
            # 将真实标签存入张量
            y_gt[i]=y
            # 记录开始时间
            start=time.time()    
            # 将输入数据传入模型进行预测
            yy=model(x.unsqueeze(0).to(device))
            # 将预测结果存入张量
            y_out[i]=torch.softmax(yy,dim=1)
            # 计算预测时间
            elapsed=time.time()-start
            # 将预测时间存入列表
            elapsed_times.append(elapsed)

            # 如果进行完整性检查,则跳出循环
            if sanity_check is True:
                break

    # 计算平均预测时间
    inference_time=np.mean(elapsed_times)*1000
    # 打印平均预测时间
    print("average inference time per image on %s: %.2f ms " %(device,inference_time))
    # 返回预测结果和真实标签
    return y_out.numpy(),y_gt
from torchvision import datasets
import torchvision.transforms as transforms

# 数据转换
data_transformer = transforms.Compose([transforms.ToTensor()])

path2data="./data"

# 加载数据
test0_ds=datasets.STL10(path2data, split='test', download=True,transform=data_transformer)
print(test0_ds.data.shape)

from sklearn.model_selection import StratifiedShuffleSplit

# 创建StratifiedShuffleSplit对象,设置分割次数为1,测试集大小为0.2,随机种子为0
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)

# 获取test0_ds的索引
indices=list(range(len(test0_ds)))

# 获取test0_ds的标签
y_test0=[y for _,y in test0_ds]

# 对索引和标签进行分割
for test_index, val_index in sss.split(indices, y_test0):
    # 打印测试集和验证集的索引
    print("test:", test_index, "val:", val_index)
    # 打印测试集和验证集的大小
    print(len(val_index),len(test_index))

from torch.utils.data import Subset

# 从test0_ds中选取val_index索引的子集,赋值给val_ds
val_ds=Subset(test0_ds,val_index)
# 从test0_ds中选取test_index索引的子集,赋值给test_ds
test_ds=Subset(test0_ds,test_index)
# 定义均值
mean=[0.4467106, 0.43980986, 0.40664646]
# 定义标准差
std=[0.22414584,0.22148906,0.22389975]
# 定义一个名为test0_transformer的变量,用于将一系列的图像变换操作组合在一起
test0_transformer = transforms.Compose([
    # 将图像转换为Tensor类型
    transforms.ToTensor(),
    # 对图像进行归一化操作,使用mean和std作为均值和标准差
    transforms.Normalize(mean, std),
    ])   
# 将test0_transformer赋值给test0_ds的transform属性
test0_ds.transform=test0_transformer
import time
import numpy as np

# 调用deploy_model函数,传入model_resnet18,val_ds,device和sanity_check参数,返回y_out和y_gt
y_out,y_gt=deploy_model(model_resnet18,val_ds,device=device,sanity_check=False)
# 打印y_out和y_gt的形状
print(y_out.shape,y_gt.shape)

from sklearn.metrics import accuracy_score

# 将y_out中的最大值索引赋值给y_pred
y_pred = np.argmax(y_out,axis=1)
# 打印y_pred和y_gt的形状
print(y_pred.shape,y_gt.shape)

# 计算并打印y_pred和y_gt的准确率
acc=accuracy_score(y_pred,y_gt)
print("accuracy: %.2f" %acc)

 

# 部署模型,得到预测结果和真实标签
y_out,y_gt=deploy_model(model_resnet18,test_ds,device=device)

# 取出预测结果中概率最大的类别
y_pred = np.argmax(y_out,axis=1)

# 计算准确率
acc=accuracy_score(y_pred,y_gt)

# 打印准确率
print(acc)

from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
np.random.seed(1)

# 定义一个函数,用于显示图像
def imshow(inp, title=None):
    # 定义图像的均值和标准差
    mean=[0.4467106, 0.43980986, 0.40664646]
    std=[0.22414584,0.22148906,0.22389975]
    # 将图像从tensor转换为numpy数组,并转置
    inp = inp.numpy().transpose((1, 2, 0))
    # 将均值和标准差转换为numpy数组
    mean = np.array(mean)
    std = np.array(std)
    # 将图像的像素值进行归一化
    inp = std * inp + mean
    # 将像素值限制在0和1之间
    inp = np.clip(inp, 0, 1)
    # 显示图像
    plt.imshow(inp)
    # 如果有标题,则显示标题
    if title is not None:
        plt.title(title)
    # 暂停0.001秒
    plt.pause(0.001) 

# 定义网格大小
grid_size=16
# 随机生成4个索引
rnd_inds=np.random.randint(1,len(test_ds),grid_size)
# 打印随机生成的索引
print("image indices:",rnd_inds)

# 根据索引获取对应的图像和标签
x_grid_test=[test_ds[i][0] for i in rnd_inds]
y_grid_test=[(y_pred[i],y_gt[i]) for i in rnd_inds]

# 将图像转换为网格
x_grid_test=utils.make_grid(x_grid_test, nrow=4, padding=2)
# 打印网格的形状
print(x_grid_test.shape)

# 设置图像的大小
plt.rcParams['figure.figsize'] = (10, 10)
# 显示网格
imshow(x_grid_test,y_grid_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1946222.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

人工智能与机器学习原理精解【6】

文章目录 数值优化基础理论凹凸性定义在国外与国内存在不同国内定义国外定义总结示例与说明注意事项 国内凹凸性二阶定义的例子凹函数例子凸函数例子 凸函数(convex function)的开口方向凸函数的二阶导数凸函数的二阶定义单变量函数的二阶定义多变量函数…

设计模式第三天|设计模式结构型:适配器模式、装饰器模式、代理模式

文章目录 设计模式的分类适配器模式概念俗话说角色具体应用(Spring MVC)图解具体步骤 装饰器模式定义核心俗话说类名表现图解具体构造代码实现简化优点缺点 代理模式(Spring AOP 面向切面)定义俗话说角色代理模式分类静态代理角色代码好处缺点实用 动态代理 AOP什么是AOP具体信…

BFF:优化前后端协作设计模式

BFF:优化前后端协作设计模式 BFF是什么 BFF即 Backends For Frontends (服务于前端的后端)。是一种介于前端和后端之间一种重要的通信设计模式。它旨在解决前端与后端协作中的复杂性问题。 背景 行业背景:传统前端应用(如Web应用、移动应…

《深入探秘Java中的枚举:掌握Enum的魔力》

目录 📝 枚举枚举的定义枚举的使用1、表示一组固定常量2、实现接口3、枚举与策略模式4、EnumSet5、EnumMap 📎 参考文章 😀 准备好了吗?让我们一起步入这座Java神奇的城堡,探寻枚举(Enum)这个强…

Ubuntu 修改源地址

注意事项:版本说明!!! Ubuntu24.04的源地址配置文件发生改变。 不再使用以前的 sources.list 文件,该文件内容变成了一行注释: # Ubuntu sources have moved to /etc/apt/sources.list.d/ubuntu.sources…

STM32-FreeRTOS快速学习

定义 FreeRTOS 满足实施系统对任务响应时间的要求。 实时操作系统、轻量级(内核小,只需要几KB的ROM和RAM)、 提供了一些内核功能,如任务管理、时间管理、内存管理和通信机制等。 和裸机的区别 裸机:无操作系统&…

产品系统的UI暗色系和浅色系模式切换是符合人体视觉工程学的设计

视觉革命:UI设计中的暗夜与黎明 UI设计如同夜空中最亮的星辰,引领着用户穿梭于信息的海洋。而今,一场视觉革命正在悄然上演,它关乎于我们的眼睛,关乎于我们的体验——那就是产品系统的UI暗色系和浅色系模式的切换。如…

【机器学习】Jupyter Notebook如何使用之基本步骤和进阶操作

引言 Jupyter Notebook 是一个交互式计算环境,它允许创建包含代码、文本和可视化内容的文档 文章目录 引言一、基本步骤1.1 启动 Jupyter Notebook1.2 使用 Jupyter Notebook 仪表板1.3 在笔记本中工作1.4 常用快捷键1.5 导出和分享笔记本 二、进阶用法2.1 组织笔…

Excel超级处理器,工作簿文件.xls/.xlsx/.csv相互批量转换

如何将.xlsx文件转成.csv文件,.xls转换成.xlsx文件,以及.xls文件转成.csv文件或.csv转换成.xlsx文件,如果是单个文件转换,那么将当前文件另存为,保存类型,选择即可。如下图所示: 如果是多个文件…

【AutoDL】AutoDL+Xftp+Xshell+VSCode配合使用教程

身边没有显卡资源或不足以训练模型时,可以租赁服务器的显卡。 1、AutoDL Step :注册账号->选择显卡->选择环境->开机启动 1.1 首先打开AutoDL官网,注册账号 1.2 租赁自己想要的显卡资源 1.3 选择基础环境。 此处,我们让其自动配置…

[网络通信原理]——TCP/IP模型—网络层

网络层 网络层概述 网络层位于OSI模型的第三层,它定义网络设备的逻辑地址,也就是我们说的IP地址,能够在不同的网段之间选择最佳数据转发路径。在网络层中有许多协议,其中主要的协议是IP协议。 IP数据包格式 IP数据报是可变长度…

Linux服务器配置Python+PyTorch+CUDA深度学习环境

参考博主Linux服务器配置PythonPyTorchCUDA深度学习环境_linux cuda环境配置-CSDN博客 https://blog.csdn.net/NSJim/article/details/115386936?ops_request_misc&request_id&biz_id102&utm_termlinux%E8%99%9A%E6%8B%9F%E7%8E%AF%E5%A2%83%E6%8C%89pytorch%20c…

微信答题小程序产品研发-需求分析与原型设计

欲知应候何时节,六月初迎大暑风。 我前面说过,我决意仿一款答题小程序,所以我做了大量的调研。 题库软件产品开发不仅仅是写代码这一环,它包含从需求调研、分析与构思、设计到开发、测试再到部署上线一系列复杂过程。 需求分析…

子数组和为k子数组和最大

题目1:子数组和为k /*给你一个整数数组 nums 和一个整数 k ,请你统计并返回 该数组中和为 k 的子数组的个数 。子数组是数组中元素的连续非空序列。示例 1:输入:nums [1,1,1], k 2 输出:2 示例 2:输入&a…

微软蓝屏事件对企业数字化转型有什么影响?

引言:从北京时间2024年7月19日(周五)下午2点多开始,全球大量Windows用户出现电脑崩溃、蓝屏死机、无法重启等情况。事发后,网络安全公司CrowdStrike称,收到大量关于Windows电脑出现蓝屏报告,公司…

make2exe:自动集成测试

模板Makefile,生成多个C/C模块的集成测试程序。

算法学习day19

一、通过删除字母匹配到字符字典中的最大值 给你一个字符串 s 和一个字符串数组 dictionary ,找出并返回 dictionary 中最长的字符串,该字符串可以通过删除 s 中的某些字符得到。 如果答案不止一个,返回长度最长且字母序最小的字符串。如果…

花几千上万学习Java,真没必要!(二十六)

1、成员内部类: package internalclass.com; //在Java中,成员内部类(也称为非静态内部类)是定义在另一个类(外部类)内部的类。 //成员内部类可以访问外部类的所有成员(包括私有成员&#xff09…

【计算机网络】网络层——IPv4地址(个人笔记)

学习日期:2024.7.24 内容摘要:IPv4地址,分类编址,子网,无分类编址 IPv4地址概述 在TCP/IP体系中,IP地址是一个最基本的概念,IPv4地址就是给因特网上的每一台主机的每一个接口分配一个在全世界…

ASP.NET Web Api 使用 EF 6,DateTime 字段如何取数据库服务器当前时间

前言 在做数据库设计时,为了方便进行数据追踪,通常会有几个字段是每个表都有的,比如创建时间、创建人、更新时间、更新人、备注等,在存储这些时间时,要么存储 WEB 服务器的时间,要么存储数据库服务器的时间…