pytorch进阶学习(八):使用训练好的神经网络模型进行图片预测

news2025/1/18 17:02:07

课程资源: 

【小学生都会的Pytorch】九、运用你的模型做预测(1)_哔哩哔哩_bilibili

 笔记:

pytorch进阶学习(四):使用不同分类模型进行数据训练(alexnet、resnet、vgg等)_好喜欢吃红柚子的博客-CSDN博客


 目录

一、原理介绍

1. 加载模型与参数

2. 读取图片

3. 图片预处理

4. 把图片转换为tensor

5. 增加batch_size的维度

6. 模型验证

6.1 模型的初步输出

 6.2 输出预测值概率最大的值和位置

 6.3 把tensor转为numpy

6.4 预测类别

二、代码

1. 对单张图片做预测

2. 对整个文件夹图片做预测


        模型在经过前面几节的训练之后,传入自己的数据进行预测,流程和训练时差不多。项目目录如下所示,pic为预测时取的照片。

一、原理介绍

1. 加载模型与参数

模型骨架使用resnet18进行训练,使用预训练好的权重文件“model_resnet18_100.pth”来进行参数的加载。

# 如果显卡可用,则用显卡进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

'''
    加载模型与参数
'''

# 加载模型
model = resnet18(pretrained=False, num_classes=5).to(device)  # 43.6%

# 加载模型参数
if device == "cpu":
    # 加载模型参数,权重文件传过来
    model.load_state_dict(torch.load("model_resnet18_100.pth", map_location=torch.device('cpu')))
else:
    model.load_state_dict(torch.load("model_resnet18_100.pth"))

2. 读取图片

我们要预测的是sunflower1这张图片。

img_path = './pic/sunflower1.jpg'

 

3. 图片预处理

Image.open打开图像,转换为RGB格式,padding_black进行图像的扩充

img = Image.open(img_path)#打开图片
img = img.convert('RGB')#转换为RGB 格式
# 扩充
img = padding_black(img)

4. 把图片转换为tensor

val_tf = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                transform_BZ    # 标准化操作
            ])

# 图片转换为tensor
img_tensor = val_tf(img)

5. 增加batch_size的维度

如果直接把图片传入模型会发生以下报错:

 原因:

模型接收的是四维输入,但是我们图片的输入只有3维,要求的4维输入的第一维为batch_size,我们训练好的模型中batch_size=64,但是一张图片没有这个维度,所以需要给这张传入的图片再增加一个通道。

  • dim=0代表在第一个维度增加维度
# 增加batch_size维度
img_tensor = Variable(torch.unsqueeze(img_tensor, dim=0).float(), requires_grad=False).to(device)

6. 模型验证

6.1 模型的初步输出

 模型进行输出后可以看到如下结果,tensor中有5个数。

model.eval()
# 不进行梯度更新
with torch.no_grad():
    output_tensor = model(img_tensor)
    print(output_tensor)

 但是都不在0-1之间,不是我们需要的对每一个类的概率值,所以我们需要使用softmax进行归一化。使用softmax进行归一化。

# 将输出通过softmax变为概率值
    output = torch.softmax(output_tensor,dim=1)
    print(output)

可以看到进行softmax运算后,出现的结果使用的是科学计数法,5个数加起来为1. 

 6.2 输出预测值概率最大的值和位置

# 输出可能性最大的那位
    pred_value, pred_index = torch.max(output, 1)
    print(pred_value)
    print(pred_index)

输出可以看到输出概率为1,即100%,位置下标为3,即第四类,sunflower类。

 6.3 把tensor转为numpy

在上一步输出时的数据为tensor格式,所以我们需要把数字先转换为numpy,再进行后续标签下标到标签类的转换。

# 将数据从cuda转回cpu
pred_value = pred_value.detach().cpu().numpy()
pred_index = pred_index.detach().cpu().numpy()
    
print(pred_value)
print(pred_index)

打印结果可以看到已经成功转换到了numpy类,没有了tensor标志 

6.4 预测类别

写出类别的中文列表,一定要与test训练集标签中的顺序对应起来。

classes = ["daisy", "dandelion", "rose", "sunflower", "tulip"]

print("预测类别为: ",classes[pred_index[0]]," 可能性为: ",pred_value[0]*100,"%")

打印输出可以看到预测正确,准确率高 。

二、代码

1. 对单张图片做预测

'''
    功能:按着路径,导入单张图片做预测
    作者: Leo在这

'''
from torchvision.models import resnet18
import torch
from PIL import Image
import torchvision.transforms as transforms
from torch.autograd import Variable

# 如果显卡可用,则用显卡进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

'''
    加载模型与参数
'''

# 加载模型
model = resnet18(weights=False, num_classes=5).to(device)  # 43.6%

# 加载模型参数
if device == "cpu":
    # 加载模型参数,权重文件传过来
    model.load_state_dict(torch.load("model_resnet18_100.pth", map_location=torch.device('cpu')))
else:
    model.load_state_dict(torch.load("model_resnet18_100.pth"))

'''
    加载图片与格式转化
'''
img_path = './pic/sunflower1.jpg'

'''
    图片进行预处理
'''
# 图片标准化
transform_BZ= transforms.Normalize(
    mean=[0.5, 0.5, 0.5],# 取决于数据集
    std=[0.5, 0.5, 0.5]
)

val_tf = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                transform_BZ    # 标准化操作
            ])


def padding_black(img):  # 如果尺寸太小可以扩充
    w, h = img.size
    scale = 224. / max(w, h)
    img_fg = img.resize([int(x) for x in [w * scale, h * scale]])
    size_fg = img_fg.size
    size_bg = 224
    img_bg = Image.new("RGB", (size_bg, size_bg))
    img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
                          (size_bg - size_fg[1]) // 2))
    img = img_bg
    return img

# 打开图片,转换为RGB

img = Image.open(img_path)#打开图片
img = img.convert('RGB')#转换为RGB 格式
# 扩充
img = padding_black(img)
# print(type(img))

# 图片转换为tensor
img_tensor = val_tf(img)
# print(type(img_tensor))

# 增加batch_size维度
img_tensor = Variable(torch.unsqueeze(img_tensor, dim=0).float(), requires_grad=False).to(device)


'''
    数据输入与模型输出转换
'''
model.eval()
# 不进行梯度更新
with torch.no_grad():
    output_tensor = model(img_tensor)
    print(output_tensor)
    #
    # 将输出通过softmax变为概率值
    output = torch.softmax(output_tensor,dim=1)
    print(output)

    # 输出可能性最大的那位
    pred_value, pred_index = torch.max(output, 1)
    print(pred_value)
    print(pred_index)

    # 将数据从cuda转回cpu
    pred_value = pred_value.detach().cpu().numpy()
    pred_index = pred_index.detach().cpu().numpy()
    print(pred_value)
    print(pred_index)
    # #
    # 增加类别标签
    classes = ["daisy", "dandelion", "rose", "sunflower", "tulip"]
    print("预测类别为: ",classes[pred_index[0]]," 可能性为: ",pred_value[0]*100,"%")

2. 对整个文件夹图片做预测

对根目录为pic的文件夹做图片预测,步骤和单张图片预测差不多,使用for循环遍历文件。

'''
    功能:导入文件夹做预测
    作者:Leo在这
'''

from torchvision.models import resnet18
import torch
from PIL import Image
import torchvision.transforms as transforms
from torch.autograd import Variable

import os

# 如果显卡可用,则用显卡进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

'''
    加载模型与参数
'''

# 加载模型
model = resnet18(pretrained=False, num_classes=5).to(device)  # 43.6%

if device == "cpu":
    # 加载模型参数
    model.load_state_dict(torch.load("model_resnet18_100.pth", map_location=torch.device('cpu')))
else:
    model.load_state_dict(torch.load("model_resnet18_100.pth"))
'''
    加载图片与格式转化
'''

# 图片标准化
transform_BZ= transforms.Normalize(
    mean=[0.5, 0.5, 0.5],# 取决于数据集
    std=[0.5, 0.5, 0.5]
)

val_tf = transforms.Compose([##简单把图片压缩了变成Tensor模式
                transforms.Resize(224),
                transforms.ToTensor(),
                transform_BZ#标准化操作
            ])


def padding_black(img):  # 如果尺寸太小可以扩充
    w, h = img.size
    scale = 224. / max(w, h)
    img_fg = img.resize([int(x) for x in [w * scale, h * scale]])
    size_fg = img_fg.size
    size_bg = 224
    img_bg = Image.new("RGB", (size_bg, size_bg))
    img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
                          (size_bg - size_fg[1]) // 2))
    img = img_bg
    return img


dir_loc = r"./pic"
model.eval()
with torch.no_grad():
    for a,b,c in os.walk(dir_loc):
        for filei in c:
            full_path = os.path.join(a,filei)
            # print(full_path)
            # img_path = './pic/sunflower3.jpg'

            img = Image.open(full_path)#打开图片
            img = img.convert('RGB')#转换为RGB 格式
            img = padding_black(img)
            # print(type(img))

            img_tensor = val_tf(img)
            # print(type(img_tensor))

            # 增加batch_size维度
            img_tensor = Variable(torch.unsqueeze(img_tensor, dim=0).float(), requires_grad=False).to(device)


            '''
                数据输入与模型输出转换
            '''

            output_tensor = model(img_tensor)
            # 将输出通过softmax变为概率值
            output = torch.softmax(output_tensor,dim=1)

            # 输出可能性最大的那位
            pred_value, pred_index = torch.max(output, 1)

            pred_value = pred_value.detach().cpu().numpy()
            pred_index = pred_index.detach().cpu().numpy()

            # 增加类别标签
            classes = ["daisy", "dandelion", "rose", "sunflower", "tulip"]

            print("预测类别为: ",classes[pred_index[0]]," 可能性为: ",pred_value[0]*100,"%")

结果如下所示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/422751.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

免费远程桌面连接工具合集

随着科技的进步和通信技术的发展,远程办公变得越来越普遍。这种办公模式有助于提高工作效率,对于员工来说很友好的是,上班变得更加灵活了。 今天就给大家推荐几款远程桌面连接工具,不仅可以电脑连接,手机也可以直接连…

读写分离导致读不到刚插入的数据

背景 前两天在做一个功能的时候,需要先插表,如果数据重复则从数据库中查询出这条数据,这段代码在测试环境并没有什么问题,但是到生产之后就会偶现的报一些错,就是读不到已插入的数据,导致后续业务出现问题…

超详细Django+vue+vscode前后端分离搭建

文章目录一、Django后端搭建1.1 创建项目和app1.2 注册app1.3 运行项目1.4 配置mysql数据库1.5 创建数据库类1.6 使用Django后台进行数据管理2、Django rest framework配置2.1 序列化2.2 添加视图2.3 添加路由2.4 在项目根目录下的urls中加入如下代码2.5 api测试2.6 筛选和搜索…

BGP协议解析(白话版)

之前一直没搞明白BGP有啥用,加了跟没加没啥区别,专门查资料写了这篇《BGP协议解析》。 下面使用eNSP模拟器演示! IBGP与EBGP的区别 BGP分为两种:IBGP与EBGP。 两个路由器的BGP号相同,建立邻居关系叫IBGP&#xff0…

树莓派连接串口时无法开机

树莓派连接串口时无法开机我的情况我的思考我的解决过程重点参考我的情况 因为项目需要,因此需要使用树莓派控制电机,而电机是一上电就会给树莓派发送数据,而这时树莓派还正处于开机时,结果就是开机失败。当将串口断开时就又可以…

PHP快速入门05-时间日期与时区,附30个常用案例

文章目录前言一、时间日期与时区1.1 时间与日期1.2 时区二、 30个日期时间函数的用法示例2.1 获取当前的时间戳2.2 将时间戳格式化为日期时间2.3 获取当前的日期2.4 获取当前的时间2.5 获取当前年份2.6 获取当前月份2.7 获取当前日期的第几天2.8 计算两个日期之间的天数差2.9 计…

央媒报道的长与短

传媒如春雨,润物细无声,大家好,我是51媒体 胡老师。 在最近的媒体服务中,遇到一个问题,与大家讨论下,很多媒体特别是央媒,在活动报道中不会完全按照新闻稿通稿的内容去报道,有的会根…

MQ选型,kafka、RocketMQ、RabbitMQ、ActiveMQ

MQ(Message Queue),是基础数据结构中“先进先出”的一种数据结构。指把要传输的数据(消息)放在队列中,用队列机制来实现消息传递——生产者产生消息并把消息放入队列,然后由消费者去处理。消费者…

java SimpleDateFormat和Calendar日期类

目录一、SimpleDateFormat使用二、Calendar使用一、SimpleDateFormat使用 使用Date直接输出日期时,是使用系统默认的格式输出,所以需要使用SimpleDateFormat来格式化日期。 那么SimpleDateFormat类怎么使用呢,我们需要先了解此类的格式化符号…

Codeforces Round 866 (Div. 2) 题解

目录 A. Yuras New Name(构造) 思路: 代码: B. JoJos Incredible Adventures(构造) 思路: 代码: C. Constructive Problem(思维) 思路: 代…

一、计算机的发展历史

一、计算机的发展历史 第一台现代计算机 ENIAC:世界上第一台现代通用电子数字计算机,诞生于1946年2月14日的美国宾夕法尼亚大学。研制电子计算机的想法产生于第二次世界大战进行期间。当时激战正酣,各国的武器装备还很差,占主要地…

Java垃圾收集原理

程序计数器、虚拟机栈、本地方法栈这三个区域随线程而灭,栈中栈帧的内存大小也是在确定的。这几个区域的内存分配和回收都具有确定性,因此不需要过多考虑如何回收。 Java堆和方法区这两个区域有着很显著的不确定性 一个接口的实现类需要的内存可能不一…

软考第七章 下一代互联网

下一代互联网 1.IPv6 IPv4的缺陷: 网络地址短缺路由速度慢,IPv4头部多达13个字段,路由器处理的信息量很大缺乏安全功能不支持新的业务模式 关于PIv6的研究成果都包含在1998年12月发表的RFC 2460文档中 1.1 IPv6分组格式 版本&#xff1a…

量子退火Python实战(3):投资组合优化(Portfolio) MathorCup2023特供PyQUBO教程

文章目录前言一、什么是投资组合优化?二、投资组合优化建模1. 目标函数:回报2.约束函数:风险3.最终优化目标函数三、基于PyQUBO实现1. 获取数据2. 数据处理3. 目标函数PyQUBO实现4. OpenJij实施优化总结前言 提示:包含pyQUBO用法…

硬件语言Verilog HDL牛客刷题day11 A里部分 和 Z兴部分

1.VL72 全加器 1.题目: ① 请用题目提供的半加器实现全加器电路① 半加器的参考代码如下,可在答案中添加并例化此代码。 2. 解题思路 (可以看代码) 2.1 先看 半加器 s 是加位 , C 是进位。 2.2 再看全加器 …

2023年新手如何选择云服务器配置来部署自己的网站?

现在做网站的人越来越少了,没有以前那种百万网站站长的势头。但是,不论个人站长还是企业,只要网上开展业务其实都会需要自己网站或小程序、APP等平台。如今,很少有人使用虚拟主机,但是独立服务器成本高,一般…

【2023】Kubernetes-网络原理

目录kubernetes网络模型kubernetes网络实现容器到容器之间通信Pod之间的通信Pod到Service之间的通信集群内部与外部组件之间的通信开源容器网络方案FlannelCalicokubernetes网络模型 Kubernetes网络模型设计的一个基础原则是:每个Pod都拥有一个独立的IP地址&#x…

异地远程访问本地SQL Server数据库【无公网IP内网穿透】

文章目录1.前言2.本地安装和设置SQL Server2.1 SQL Server下载2.2 SQL Server本地连接测试2.3 Cpolar内网穿透的下载和安装2.3 Cpolar内网穿透的注册3.本地网页发布3.1 Cpolar云端设置3.2 Cpolar本地设置4.公网访问测试5.结语转发自CSDN远程穿透的文章:无需公网IP&a…

哪吒探针 - Windows 和Linux端agent安装(详细注意版)

一、Windows端agent安装配置 环境准备 环境: Windows 服务器软件:哪吒探针点击下载、nssm 点击下载(探针agent和nssm都要下载准备好) 设置环境变量下载软件后,解压到任意位置,然后按 winR 打开运行窗口,输入 sysdm.cpl 打开系统属性–>高级…

基于GIS/SCADA的智慧燃气数字孪生Web3D可视化系统

在低碳经济快速发展的今天,天然气在我国能源结构的占比逐年提高,安全供气成为关乎民生福祉、经济发展和社会和谐的大事。 自我国开展燃气铺设以来,经过长期运营的家用燃气和工业燃气设备管道设施设备基础差、检维修难度大,且传统燃…