ImageNet预训练图像分类模型预测单张图像

news2024/12/23 14:18:36

导入基础工具包

import os

import cv2

import pandas as pd
import numpy as np

import torch

import matplotlib.pyplot as plt
%matplotlib inline

计算设备确定

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

载入预训练模型

from torchvision import models
# 载入预训练图像分类模型

model = models.resnet18(pretrained=True) 

# model = models.resnet152(pretrained=True)
model = model.eval() #将模型设为eval
model = model.to(device)

图像预处理,比较固定的四个部分,其他分类任务也可以用。

四步:

  1. 缩放裁剪
  2. 中心获取
  3. 转为Tensor
  4. 归一化处理:更近似于正态分布,易于神经网络处理。mean、std这六个数也是通用的。
from torchvision import transforms

# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

载入图片

# img_path = 'test_img/banana1.jpg'
# img_path = 'test_img/husky1.jpeg'
img_path = 'test_img/basketball_shoe.jpeg'

# img_path = 'test_img/cat_dog.jpg'


# 用 pillow 载入
from PIL import Image
img_pil = Image.open(img_path)

执行图像分类预测:

input_img = test_transform(img_pil) # 预处理,将图片传入图片与处理的函数

 转换模型所需要的维度:

input_img = input_img.unsqueeze(0).to(device)
input_img.shape

运行后为:

torch.Size([1, 3, 224, 224]),即一张3通道224*224的图片

执行前向预测:
 

# 执行前向预测,得到所有类别的 logit 预测分数
pred_logits = model(input_img) 
pred_logits.shape

结果为:

torch.Size([1, 1000])

利用softmax对分数大小进行比较:

import torch.nn.functional as F
pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
pred_softmax.shape

预测结果分析

对softmax结果画一个柱状图:

plt.figure(figsize=(8,4))

x = range(1000)
y = pred_softmax.cpu().detach().numpy()[0]

ax = plt.bar(x, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
plt.ylim([0, 1.0]) # y轴取值范围
# plt.bar_label(ax, fmt='%.2f', fontsize=15) # 置信度数值

plt.xlabel('Class', fontsize=20)
plt.ylabel('Confidence', fontsize=20)
plt.tick_params(labelsize=16) # 坐标文字大小
plt.title(img_path, fontsize=25)

plt.show()

取置信度最大的n个结果:

n = 10
top_n = torch.topk(pred_softmax, n)
top_n

out:

torch.return_types.topk(
values=tensor([[0.5988, 0.3556, 0.0064, 0.0047, 0.0041, 0.0041, 0.0037, 0.0025, 0.0022,
         0.0022]], device='cuda:0', grad_fn=<TopkBackward0>),
indices=tensor([[430, 514, 522, 630, 502, 770, 427, 768, 805,  35]], device='cuda:0'))

解析出类别:

# 解析出类别
pred_ids = top_n[1].cpu().detach().numpy().squeeze()
pred_ids

out:

array([430, 514, 522, 630, 502, 770, 427, 768, 805,  35])

如何知道430、514是哪一类?

df = pd.read_csv('imagenet_class_index.csv')

将分类结果写在原图上:

# 用 opencv 载入原图
img_bgr = cv2.imread(img_path)


for i in range(n):
    class_name = idx_to_labels[pred_ids[i]][1] # 获取类别名称
    confidence = confs[i] * 100 # 获取置信度
    text = '{:<15} {:>.4f}'.format(class_name, confidence)
    print(text)
    
    # !图片,添加的文字,左上角坐标,字体,字号,bgr颜色,线宽
    img_bgr = cv2.putText(img_bgr, text, (25, 50 + 40 * i), cv2.FONT_HERSHEY_SIMPLEX, 1.25, (0, 0, 255), 3)

# 保存图像
cv2.imwrite('output/img_pred.jpg', img_bgr)

# 载入预测结果图像
img_pred = Image.open('output/img_pred.jpg')
img_pred

 

预测结果用表格输出:

pred_df = pd.DataFrame() # 预测结果表格
for i in range(n):
    class_name = idx_to_labels[pred_ids[i]][1] # 获取类别名称
    label_idx = int(pred_ids[i]) # 获取类别号
    wordnet = idx_to_labels[pred_ids[i]][0] # 获取 WordNet
    confidence = confs[i] * 100 # 获取置信度
    pred_df = pred_df.append({'Class':class_name, 'Class_ID':label_idx, 'Confidence(%)':confidence, 'WordNet':wordnet}, ignore_index=True) # 预测结果表格添加一行
display(pred_df) # 展示预测结果表格

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1422328.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pytorch学习笔记(十二)

以下代码是以CIFAR10这个10分类的图片数据集训练过程的完整的代码。 训练部分 train.py主要包含以下几个部件&#xff1a; 准备训练、测试数据集用DateLoader加载两个数据集&#xff0c;要设置好batchsize创建网络模型&#xff08;具体模型在model.py中&#xff09;设置损失函…

深入理解G0和G1指令:C++中的实现与激光雕刻应用

系列文章 ⭐深入理解G0和G1指令&#xff1a;C中的实现与激光雕刻应用⭐基于二值化图像转GCode的单向扫描实现⭐基于二值化图像转GCode的双向扫描实现⭐基于二值化图像转GCode的斜向扫描实现基于二值化图像转GCode的螺旋扫描实现基于OpenCV灰度图像转GCode的单向扫描实现基于Op…

Apple Pencil如何连接iPad?这里提供详细步骤

如果你刚拿起一支Apple Pencil&#xff0c;想和iPad一起使用&#xff0c;你需要先连接设备。将Apple Pencil与iPad配对的方法因你拥有的铅笔而异。 一旦你将Apple Pencil连接到iPad&#xff0c;你就可以利用这些方便的功能。你可以记下手写笔记&#xff0c;使用Scribble功能&a…

H5 嵌套iframe并设置全屏

H5 嵌套iframe并设置全屏 上图上代码 <template><view class"mp-large-screen-box"><view class"mp-large-screen-count"><view class"mp-mini-btn-color mp-box-count" hover-class"mp-mini-btn-hover" clic…

QEMU - e1000全虚拟化前端与TAP/TUN后端流程简析

目录 1. Host -> Guest 2.Guest ->Host 3. 如何修改以支持TUN设备的后端&#xff1f; 4. 相关 QEMU 源码 5. 实验 1. Host -> Guest 2.Guest ->Host 3. 如何修改以支持TUN设备的后端&#xff1f; 1. 简单通过后端网卡名字来判断是TUN还是TAP。 2. 需要前端全…

gdp调试—Linux

目录 介绍 使用 介绍 代码分为debug模式和release模式 如果一份代码要被调试&#xff0c;这份代码必须是debug Linux下编译代码默认是是release模式 如果你想代码是debug模式 必须加上 - g 小提&#xff1a; vim默认&#xff1a;命令模式 gcc默认&#xff1a;releas…

华为数通方向HCIP-DataCom H12-831题库(简答题01-27)

第01题 第02题 第03题 第04题 第05题 IS-IS是链路状态路由协议,使用SPF算法进行路由计算。某园区同时部署了IPV4和IPv6并运行IS-IS实现网络的互联与通。如图所示,该网络IPV4和IPV6开销相同,R1和R4只支持IPV4缺省情况下,计算形成的IPV6最短路径树中,R2访问R6的下一跳设备是…

【C++初阶】C++入门(2)

&#x1f525;博客主页&#xff1a; 小羊失眠啦. &#x1f3a5;系列专栏&#xff1a;《C语言》 《数据结构》 《C》 《Linux》 《Cpolar》 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 文章目录 一、函数重载1.1 函数重载的概念1.2 函数重载的种类1.3 C支持函数重载的原理 二…

最全前端 HTML 面试知识点

一、HTML 1.1 HTML 1.1.1 定义 超文本标记语言&#xff08;英语&#xff1a;HyperTextMarkupLanguage&#xff0c;简称&#xff1a;HTML&#xff09;是一种用于创建网页的标准标记语言 HTML元素是构建网站的基石 标记语言&#xff08;markup language &#xff09; 由无数个…

橱窗宝石 - 华为OD统一考试

OD统一考试&#xff08;C卷&#xff09; 分值&#xff1a; 100分 题解&#xff1a; Java / Python / C 题目描述 橱窗里有一排宝石&#xff0c;不同的宝石对应不同的价格&#xff0c;宝石的价格标记为 gems[i],0<i<n, n gems.length 宝石可同时出售0个或多个&#xff…

Mysql+MybatisPlus+Vue实现基础增删改查CRUD

数据库 设计数据库 设计几个字段&#xff0c;主键id自动增长且不可为空 create table if not exists user (id bigint(20) primary key auto_increment comment 主键id,username varchar(255) not null comment 用户名,sex char(1) not null comment 性…

十一、常用API——练习

常用API——练习 练习1 键盘录入&#xff1a;练习2 算法水题&#xff1a;练习3 算法水题&#xff1a;练习4 算法水题&#xff1a;练习5 算法水题&#xff1a; 练习1 键盘录入&#xff1a; 键盘录入一些1~100之间的整数&#xff0c;并添加到集合中。 直到集合中所有数据和超过2…

【新课】安装部署系列Ⅲ—Oracle 19c Data Guard部署之两节点RAC部署实战

本课程由云贝教育-刘峰老师出品&#xff0c;感谢关注 课程介绍 Oracle Real Application Clusters (RAC) 是一种跨多个节点分布数据库的企业级解决方案。它使组织能够通过实现容错和负载平衡来提高可用性和可扩展性&#xff0c;同时提高性能。本课程基于当前主流版本Oracle 1…

017 JavaDoc生成文档

什么是JavaDoc 示例 package se;/*** 用于学习Java* author Admin* version 1.0* since 1.8*/ public class Test20240119 {/*** 主方法* param args*/public static void main(String[] args) {// 你好&#xff0c;世界System.out.println("Hello world");} } 写一…

故障诊断 | 一文解决,GRU门控循环单元故障诊断(Matlab)

文章目录 效果一览文章概述专栏介绍模型描述源码设计参考资料效果一览 文章概述 故障诊断 | 一文解决,GRU门控循环单元故障诊断(Matlab) 专栏介绍 订阅【故障诊断】专栏,不定期更新机器学习和深度学习在故障诊断中的应用;订阅

opencv#40 图像细化

图像细化原理 作用&#xff1a;图像细化是将图像的线条从多像素宽度减少到单位像素宽度的过程&#xff0c;又被称为“骨架化”&#xff0c;删除像素点的标准&#xff1a; 通常情况下&#xff0c;我们使用二值化图像&#xff0c;我们在判断是否要删除某些像素点时&#xff0c;要…

遍历删除空文件夹

文章目录 遍历删除空文件夹概述笔记END 遍历删除空文件夹 概述 在手工整理openssl3.2编译完的源码工程中的文档, 其中有好多空文件夹. 做了一个小工具, 将空文件夹都遍历删除掉. 这样找文档方便一些. 删除后比对了一下, 空文件夹还真不少. 笔记 // EmptyDirRemove.cpp : 此…

音视频数字化(数字与模拟-音频广播)

在互联网飞速发展的今天,每晚能坐在电视机前面的人越来越少,但是每天收听广播仍旧是很多人的习惯。 从1906年美国费森登在实验室首次进行无线电广播算起,“广播”系统已经陪伴人们115年了。1916年,收音机开始上市,收音机核心是“矿石”。1920年开始“调幅”广播,1941年开…

Uniapp小程序端打包优化实践

背景描述&#xff1a; 在我们最近开发的一款基于uniapp的小程序项目中&#xff0c;随着功能的不断丰富和完善&#xff0c;发现小程序包体积逐渐增大&#xff0c;加载速度也受到了明显影响。为了提升用户体验&#xff0c;团队决定对小程序进行一系列打包优化。 项目优化点&…

Optimism的挑战期

1. 引言 前序博客&#xff1a; Optimism的Fault proof 用户将资产从OP主网转移到以太坊主网时需要等待一周的时间。这段时间称为挑战期&#xff0c;有助于保护 OP 主网上存储的资产。 而OP测试网的挑战期仅为60秒&#xff0c;以简化开发过程。 2. OP与L1数据交互 L1&#xf…