【猫狗分类】Pytorch VGG16 实现猫狗分类5-预测新图片

news2024/10/6 12:25:25

背景
 

好了,现在开尝试预测新的图片,并且让vgg16模型判断是狗还是猫吧。

声明:整个数据和代码来自于b站,链接:使用pytorch框架手把手教你利用VGG16网络编写猫狗分类程序_哔哩哔哩_bilibili

预测

1、导包

from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from net import vgg16

2、设置新照片的路径

test_pth=r'.\img.png'#设置可以检测的图像
test=Image.open(test_pth)

3、处理图片:图片变成tensor

transform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
image=transform(test)
  • transforms.Compose:这是一个类,可以将多个变换操作组合在一起。当你需要对数据执行一系列变换时,就会用到它。它接受一个变换函数列表作为参数。

4、设置设备

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")#CPU与GPU的选择

5、加载网络(vgg16net)

net =vgg16()#输入网络

6、加载模型(权重模型)

model=torch.load(r".\DogandCat5.pth",map_location=device)#已训练完成的结果权重输入
net.load_state_dict(model)#模型导入

网络是网络,模型是模型!模型是训练出来的权重模型!网络是认为设定的!

7、模式选择(是训练模式还是推理模式)

net.eval()#设置为推测模式
  • 在PyTorch中,net.eval()是一个非常重要的方法调用,它用于改变模型的状态,使其从训练模式切换到推理(推测)模式。理解这一点很重要,因为模型在两种模式下的行为有所不同:
  • 训练模式 (net.train()): 在这种模式下,模型中的所有层都会处于活跃状态,包括像Dropout和Batch Normalization这样的层,它们会在每次前向传播时根据训练数据进行更新,引入随机性和依赖于批次的统计信息。这对于学习模型参数是非常必要的。

  • 推理模式 (net.eval()): 调用net.eval()后,模型会进入推理模式。这时,Dropout层将不起作用(即总是通过),而Batch Normalization层会使用在训练过程中计算得到的移动平均和方差,而不是 mini-batch 中的统计信息。这意味着模型的输出对于相同的输入将变得确定性,这对于测试和预测非常重要,因为你希望对同一输入多次运行模型时得到相同的结果。

  • 总结来说,当你准备好使用训练好的模型对新数据进行预测,而不是继续修改模型参数时,就应该调用net.eval()来确保模型以正确、一致的方式进行推理

8、传图片到网络,调整输入维度为四维张量

image=torch.reshape(image,(1,3,224,224))#四维图形,RGB三个通道

在PyTorch中,使用torch.reshape或者更常用的torch.Tensor.view方法可以改变张量的形状。对于图像数据,特别是当您准备将图像输入到深度学习模型时,将其调整为适合模型输入维度的四维张量是很常见的操作。

9、开始预测

with torch.no_grad():
    out=net(image)
out=F.softmax(out,dim=1) #softmax转为概率学问题
out=out.data.cpu().numpy()
print(out)
a=int(out.argmax(1))#输出最大值位置
  • with torch.no_grad():: 这一行代码用来指示PyTorch在接下来的代码块中不记录任何梯度信息。这对于推理(预测)阶段是非常重要的,因为它可以减少内存使用并加速计算过程,因为不需要为反向传播做准备。

  • out=net(image): 在上下文管理器torch.no_grad()内,将处理过的图像image输入到神经网络模型net中进行前向传播,得到模型的原始输出out。这个输出通常是未经处理的概率分布,对于分类任务,它通常代表每个类别的得分

  • out=F.softmax(out, dim=1): 使用F.softmax函数对模型输出out进行处理,该函数会将每一行的数据转换为概率分布,确保所有元素之和为1。这里dim=1表示沿着类别维度(通常对应于神经网络输出的最后一维)进行softmax操作,使得每个样本的预测结果可以解释为各类别的概率。

例举:假设你有一个简单的分类任务,模型需要区分猫、狗、鸟三种动物,即共有3个类别。你使用一个神经网络模型进行预测,对于一个批次内单个样本的输出可能看起来像这样(在未经过softmax处理前):

out_before_softmax = torch.tensor([2.0, 1.0, 0.5], dtype=torch.float32)

这里的输出张量out_before_softmax表示模型对于这个样本属于三个类别的原始打分或logits。注意,这些数值没有直接的概率意义,它们可以是任意实数。

应用Softmax

为了将这些原始分数转化为概率分布,你将使用F.softmax函数,并且指定dim=1,因为在这个一维张量的情况下,类别维度自然就是最后一维。执行操作后:

import torch.nn.functional as F
out_after_softmax = F.softmax(out_before_softmax, dim=1)
print(out_after_softmax)

输出解释

执行上述代码后,你可能会看到类似以下的输出(具体数值可能因四舍五入略有不同):

tensor([0.5561, 0.2476, 0.1963])

现在,out_after_softmax中的每个元素代表样本属于对应类别的概率,且所有概率之和为1(或接近1,由于浮点运算的精度限制)。例如,这里模型认为该样本有大约55.61%的概率是猫,24.76%的概率是狗,以及19.63%的概率是鸟。

总结

通过指定dim=1,你告诉softmax函数沿张量的最后一维进行操作,这在多分类任务中至关重要,因为它确保了每个样本的预测能够被合理地解释为各类别的概率分布。

  • out=out.data.cpu().numpy(): 将张量out从GPU(如果有的话)复制到CPU上,并转换为numpy数组,以便于进一步的处理和显示。这样做是因为后续的操作可能涉及到非PyTorch的库,如matplotlib用于绘图。

  • a=int(out.argmax(1)): 找出概率最大的类别索引,即预测的类别。argmax(1)沿着第1维度(类别维度)找到最大值的索引。argmax函数是用来找出数组或张量中最大值所在的位置(索引)。

10、显示图像

plt.figure()
list=['Cat','Dog']
plt.suptitle("Classes:{}:{:.1%}".format(list[a],out[0,a]))
plt.imshow(test)
plt.show()

  • list=['Cat','Dog']: 定义了一个类别标签列表,这里简化为猫和狗两类。
  • plt.suptitle(...): 设置图表的主标题,显示预测的类别名称及最高概率的百分比。
  • plt.imshow(test): 显示原始测试图像。
  • plt.show(): 显示整个图表,包括图像和标题。

给一张柴犬的照片,预测下:

                        

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1829400.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DataWhale - 吃瓜教程学习笔记(一)

学习视频:第1章-绪论_哔哩哔哩_bilibili 西瓜书对应章节: 第一章 机器学习三观 What:什么是机器学习? 关键词:“学习算法” Why: 为什么要学机器学习? #### 1. 机器学习理论研究#### 2. 机器学习系统开…

.net8 blazor auto模式很爽(五)读取sqlite并显示(2)

在BlazorApp1增加文件夹data&#xff0c;里面增加类dbcont using SharedLibrary.Models; using System.Collections.Generic; using Microsoft.EntityFrameworkCore;namespace BlazorApp1.data {public class dbcont : DbContext{public dbcont(DbContextOptions<dbcont>…

Servlet基础(续集2)

HttpServletResponse web服务器接收到客户端的http的请求&#xff0c;针对这个请求&#xff0c;分别创建一个代表请求的HttpServletRequest对象&#xff0c;代表响应的一个HttpServletResponse 如果要获取客户端请求过来的参数&#xff1a;找HttpServletRequest如果要给客户端…

梦想编织者Luna:COZE从童话绘本到乐章的奇妙转化

前言 Coze是什么&#xff1f; Coze扣子是字节跳动发布的一款AI聊天机器人构建平台&#xff0c;能够快速创建、调试和优化AI聊天机器人的应用程序。只要你有想法&#xff0c;无需有编程经验&#xff0c;都可以用扣子快速、低门槛搭建专属于你的 Chatbot&#xff0c;并一键发布…

Web前端项目-交互式3D魔方【附源码】

交互式3D魔方 ​ 3D魔方游戏是一款基于网页技术的三维魔方游戏。它利用HTML、CSS和JavaScript前端技术来实现3D效果&#xff0c;并在网页上呈现出逼真的魔方操作体验。 运行效果&#xff1a; 一&#xff1a;index.html <!DOCTYPE html> <html><head><…

独辟蹊径:我是如何用Java自创一套工作流引擎的(上)

作者&#xff1a;后端小肥肠 创作不易&#xff0c;未经允许严谨转载。 目录 1. 前言 2. 我为什么要自创一套工作流引擎 3. 表结构设计及关系讲解 3.1. 流程类别business_approval_workflow 3.1.1. 表结构 3.1.2. 表关系说明 3.2. 流程定义business_approval_workflow_de…

Oracle--存储结构

总览 一、逻辑存储结构 二、物理存储结构 1.数据文件 2.控制文件 3.日志文件 4.服务器参数文件 5.密码文件 总览 一、逻辑存储结构 数据块是Oracle逻辑存储结构中的最小的逻辑单位&#xff0c;一个数据库块对应一个或者多个物理块&#xff0c;大小由参数DB_BLOCK_SIZE决…

【详细介绍下PostgreSQL】

&#x1f308;个人主页: 程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

matlab-1-函数图像的绘制

常识 如何建一个新文件 创建新文件&#xff0c;点击新建&#xff0c;我们就可以开始写代码了 为什么要在代码开头加入clear 假如我们有2个文件&#xff0c;第一个文件里面给x赋值100&#xff0c;第二个文件为输出x 依次运行&#xff1a; 结果输出100&#xff0c;这是因为它们…

WPF/C#:异常处理

什么是异常&#xff1f; 在C#中&#xff0c;异常是在程序执行过程中发生的特殊情况&#xff0c;例如尝试除以零、访问不存在的文件、网络连接中断等。这些情况会中断程序的正常流程。 当C#程序中发生这种特殊情况时&#xff0c;会创建一个异常对象并将其抛出。这个异常对象包…

Floyd-Warshall

应用场景 要求出每两点之间的最短路。或判断两点之间的连通性&#xff08;两点之间是否有路径&#xff09;。 板子 代码&#xff08;必背!!!&#xff09; for(int k 1; k < n; k)for(int i 1; i < n; i)for(int j 1; j < n; j)d[i][j] min(d[i][j], d[i][k] …

堆的基本概念

堆 堆是一个完全二叉树 完全二叉树的要求&#xff0c;除了最后一层&#xff0c;其他层的节点个数都是满的&#xff0c;最后一层的节点都靠左排列 堆中每一个节点的值都必须大于等于(或小于等于)其子树中每个节点的值 堆中每个节点的值都大于等于(或者小于等于)其左右子节点的值…

C#(C Sharp)学习笔记_封装【十八】

什么是封装&#xff1f; 封装是面向对象思维的三大特性之一。封装是将数据和对数据进行操作的函数绑定到一起的机制。它隐藏了对象的内部状态和实现细节&#xff0c;只对外提供必要的接口&#xff0c;从而确保对象内部状态的完整性和安全性。封装的主要目的是增强安全性和简化…

登录MySQL方式

登录MySQL方式 方式一&#xff1a;通过MySQL自带的客户端 MySQL 客户端输入命令即可 方式二&#xff1a;通过window自带的客户端 从命令端&#xff08;cmd&#xff09;进入 mysql -h localhost -P 3306 -u root -p Enter password:密码登录方式&#xff1a; mysql -h 主…

【LeetCode最详尽解答】11-盛最多水的容器 Container-With-Most-Water

欢迎收藏Star我的Machine Learning Blog:https://github.com/purepisces/Wenqing-Machine_Learning_Blog。如果收藏star, 有问题可以随时与我交流, 谢谢大家&#xff01; 链接&#xff1a; 11-盛最多水的容器 直觉 这个问题可以通过可视化图表来理解和解决。 通过图形化这个…

基于51单片机万年历设计—显示温度农历

基于51单片机万年历设计 &#xff08;仿真&#xff0b;程序&#xff0b;原理图&#xff0b;设计报告&#xff09; 功能介绍 具体功能&#xff1a; 本系统采用单片机DS1302时钟芯片LCD1602液晶18b20温度传感器按键蜂鸣器设计而成。 1.可以显示年月日、时分秒、星期、温度值。…

mySql的事务(操作一下)

目录 1. 简介2. 事务操作3. 四大特性4. 并发事务问题5. 脏读6. 不可重复读7. 幻读事务隔离级别参考链接 1. 简介 事务是一组操作的集合&#xff0c;它是一个不可分割的工作单位&#xff0c;事务会把所有的操作作为一个整体一起向系统提交或撤销操作请求&#xff0c;即这些操作…

机器学习(V)--无监督学习(二)主成分分析

当数据的维度很高时&#xff0c;很多机器学习问题变得相当困难&#xff0c;这种现象被称为维度灾难&#xff08;curse of dimensionality&#xff09;。 在很多实际的问题中&#xff0c;虽然训练数据是高维的&#xff0c;但是与学习任务相关也许仅仅是其中的一个低维子空间&am…

【Java】Object、Objects、包装类、StringBuilder、StringJoiner

目录 1.API2.Object类3.Objects类4.包装类4.1包装类概述4.2包装类的其他常见操作 5.StringBuilder 可变字符串5.1概述5.2StringBuilder案例 6.StringJoiner 1.API API&#xff1a;应用程序编程接口&#xff0c;全称application programing interface&#xff0c;即Java已经写好…

分享一个 .NET Core 使用选项方式读取配置内容的详细例子

前言 在 .NET Core 中&#xff0c;可以使用选项模式&#xff08;Options Pattern&#xff09;来读取和管理应用程序的配置内容。 选项模式通过创建一个 POCO&#xff08;Plain Old CLR Object&#xff09;来表示配置选项&#xff0c;并将其注册到依赖注入容器中&#xff0c;方…