4. 使用预训练的PyTorch网络进行图像分类

news2024/11/19 14:41:44

4. 使用预训练的PyTorch网络进行图像分类

这篇博客将介绍如何使用PyTorch预先训练的网络执行图像分类。利用这些网络只需几行代码就可以准确地对1000个常见对象类别进行分类。这些图像分类网络是开创性的、最先进的图像分类网络,包括VGG16、VGG19、Inception、DenseNet和ResNet。
这些模型是由负责发明和提出上述新型架构的研究人员训练的。训练完成后,这些研究人员将模型权重保存到磁盘上,然后将其发布给其他研究人员、学生和开发人员,供他们学习并在自己的项目中使用。
虽然模型可以自由使用,但请确保检查了与之相关的任何条款/条件,因为有些模型在商业应用中不能自由使用(AI领域的企业家通常通过训练模型本身而不是使用原始作者提供的预训练权重来绕过这一限制)。

图像分类允许为输入图像指定一个或多个标签,然而它并没有告诉对象在图像中的位置。要确定给定对象在输入图像中的位置,需要应用对象检测。
对象检测可以检测到图像中的对象及其位置;
就像有用于图像分类的预训练网络一样,也有用于目标检测的预训练网络。

下一篇博客将介绍如何使用PyTorch使用专门的对象检测网络检测图像中的对象。

1. 效果图

第一次运行会默认下载模型文件:

densenet121-a639ec97.pth
resnet50-0676ba61.pth
vgg16-397923af.pth
inception_v3_google-0cc3c7bd.pth
vgg19-dcbb9e9d.pth

E:\mat\py-demo-22>python classify_image.py --image images/cat.jpg
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\Administrator/.cache\torch\hub\checkpoints\vgg16-397923af.pth

vgg16 效果图如下
可以看到飞机以97.45%的可能性被成功识别,该模型的第二个顶级预测——翼,飞机有翅膀该预测也很准确。
在这里插入图片描述

vgg19效果图如下:
在这里插入图片描述

inception 效果图如下
在这里插入图片描述
densenet 效果图如下
可以看到猫以39.28%的可能性被检测到。
在这里插入图片描述

resnet 效果图如下
可以看到第2,3,4也检测的全是猫的品种;
在这里插入图片描述

2. 原理

基于预训练网络的PyTorch图像分类

2.1 什么是经过预训练的图像分类网络?

图像分类:没有比ImageNet更著名的数据集/挑战了。ImageNet的目标是将输入图像精确分类为1000个计算机视觉系统日常生活中常见的对象类别。

最流行的深度学习框架,包括PyTorch、Keras、TensorFlow和fast。人工智能和其他技术包括预先训练的网络。这些是计算机视觉研究人员在ImageNet数据集上训练的高度精确、最先进的模型。
在ImageNet上训练完成后,研究人员将其模型保存到磁盘,然后免费发布,供其他研究人员、学生和开发人员学习并在自己的项目中使用。

本文将演示如何使用PyTorch使用以下最先进的分类网络对输入图像进行分类:

  • VGG16
  • VGG19
  • Inception
  • DenseNet
  • ResNet

2.2 环境配置

pip install torch torchvision
pip install opencv-contrib-python

3. 源码

# USAGE
# python classify_image.py --image images/cat.jpg
# python classify_image.py --image images/mg.jpg --model densenet
# 使用PyTorch预训练的网络识别和分类图像


import argparse

import cv2  # opencv绑定
import imutils
import numpy as np  # 数值array计算
import torch  # 使用PyTorch API
# 导入必要的包
from pyimagesearch import config
from torchvision import models  # 包含PyTorch预训练的网络


# 接收输入图像,预处理
def preprocess_image(image):
    # 转换图像色彩空间(BGR--RGB)
    # 等比例缩放,并缩放像素值为[0,1]范围
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (config.IMAGE_SIZE, config.IMAGE_SIZE))
    image = image.astype("float32") / 255.0

    # 减去ImageNet图像均值,除以ImageNet标准偏差,
    # 设置“通道优先”排序,并添加一个维度
    image -= config.MEAN
    image /= config.STD
    image = np.transpose(image, (2, 0, 1))
    image = np.expand_dims(image, 0)

    # 返回预处理后的图像
    return image


# 构建命令行参数及解析
# --image 输入图像路径
# --model PyTorch自带的模型路径
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", required=True,
                help="path to the input image")
ap.add_argument("-m", "--model", type=str, default="vgg16",
                choices=["vgg16", "vgg19", "inception", "densenet", "resnet"],
                help="name of pre-trained network to use")
args = vars(ap.parse_args())

# 定义一个模型字典,它将--model命令行参数的名称映射到对应的PyTorch函数
# 字典的键是模型的可读名称,通过--model命令行参数传入。
# 字典的值是相应的PyTorch函数,用于加载模型,并在ImageNet上预先训练权重
# 可选择:VGG16、VGG19、Inception、DenseNet、ResNet
# (如果从未下载过模型权重,则会自动下载并缓存这些权重)
MODELS = {
    "vgg16": models.vgg16(pretrained=True),  # 指定pretrained=True标志指示PyTorch不仅加载模型体系结构定义,还下载模型的预先训练的ImageNet权重。
    "vgg19": models.vgg19(pretrained=True),
    "inception": models.inception_v3(pretrained=True),
    "densenet": models.densenet121(pretrained=True),
    "resnet": models.resnet50(pretrained=True)
}

# 加载网络,并闪存到当前设备,设置为评估模式
# 指示PyTorch处理特殊层,如退出和批量规范化,这与训练期间处理这些层的方式不同。在进行预测之前,将模型置于评估模式至关重要的
print("[INFO] loading {}...".format(args["model"]))
model = MODELS[args["model"]].to(config.DEVICE)
model.eval()

# 从磁盘加载图像,克隆,预处理
print("[INFO] loading image...")
image = cv2.imread(args["image"])
image = imutils.resize(image, width=500)
orig = image.copy()
image = preprocess_image(image)

# 将图像从NumPy阵列转换为PyTorch张量,传递到当前设备
image = torch.from_numpy(image)
image = image.to(config.DEVICE)

# 加载预处理的ImageNet labels
print("[INFO] loading ImageNet labels...")
imagenetLabels = dict(enumerate(open(config.IN_LABELS, 'r', encoding='utf-8')))

# 执行网络的前向传递,从而产生网络的输出
# 分类图像,提取预测结果
print("[INFO] classifying image with '{}'...".format(args["model"]))
logits = model(image)
# 通过Softmax函数来获得模型训练时可能用到的1000个类别标签的预测概率。
probabilities = torch.nn.Softmax(dim=-1)(logits)
sortedProba = torch.argsort(probabilities, dim=-1, descending=True)

# 遍历预测结果值,并显示前5个预测结果,关联结果到终端
# 使用imagenetLabels字典查找类标签的名称显示预测概率
for (i, idx) in enumerate(sortedProba[0, :5]):
    print("{}. {}: {:.2f}%".format
          (i, imagenetLabels[idx.item()].strip(),
           probabilities[0, idx.item()] * 100))

# 将最高预测结果绘制在图像上并显示
(label, prob) = (imagenetLabels[probabilities.argmax().item()],
                 probabilities.max().item())
label = str(label).split(":")[1]
cv2.putText(orig, "Label: {}, {:.2f}%".format(label.strip(), prob * 100),
            (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
cv2.imshow("Classification " + args["model"], orig)
cv2.waitKey(0)
cv2.destroyAllWindows()

参考

  • https://pyimagesearch.com/2021/07/26/pytorch-image-classification-with-pre-trained-networks/
  • human readable ImageNet Labels

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/163327.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows下 pytorch的安装(gpu版本以及cpu版本)

一. 查看是否有gpu 打开cmd 输入nvidia-smi 是以下这种情况的就是有gpu 没有gpu的话就会报错 下载安装cuda以及cudnn(安装cpu版本可以跳过此步骤直接进行pytorch的安装) 下载cuda 看清楚两个箭头指的地方 一个是11.3.0 一个是日期 后面下载cudnn的时…

ProEssentials Pro 9.8.0.32 Crack

ProEssentials .Net图表组件用于对您的科学、工程和金融图表进行评估和选择! Winforms 图表, WPF 图表, C/MFC/VCL 图表. Gigasoft拥有20多年帮助企业开发大型客户端和嵌入式图表项目的经验 为何选择ProEssentials? 我们真诚地希望您能针对您的具体实施…

day03 链表 | 203、移除链表元素 707、设计链表 206、反转链表

题目 203、移除链表元素 删除链表中等于给定值 val 的所有节点。 示例 1: 输入:head [1,2,6,3,4,5,6], val 6 输出:[1,2,3,4,5] 示例 2: 输入:head [], val 1 输出:[] 示例 3: 输入&am…

Pytorch Kaggle实战:House Prices - Advanced Regression Techniques

通过Kaggle比赛,将所学知识付诸实践 目录 1、下载和缓存数据集 2、访问和读取数据集 3、数据预处理 3、训练 4、K折交叉验证 5、模型选择 6、提交Kaggle预测 1、下载和缓存数据集 建立字典DATA_HUB,它可以将数据集名称的字符串映射到数据集相关的二元组上&am…

网络抓包-抓包工具tcpdump的使用与数据分析

1.测试背景 本次测试选用两台不同的服务器,ip分别为.233和.246,233服务器为客户端,246服务器为服务端。利用tcp协议就行socket通信。socket网络编程部分示例代码为基本的通信代码,需要了解tcp网络通讯的基本协议与过程。服务器上采用tcpdump…

【学习笔记】【Pytorch】八、池化层

【学习笔记】【Pytorch】八、池化层学习地址主要内容一、最大池化操作示例二、nn.MaxPool2d类的使用1.使用说明2.代码实现三、池化公式学习地址 PyTorch深度学习快速入门教程【小土堆】. 主要内容 一、最大池化操作示例 二、nn.MaxPool2d类的使用 作用:对于输入信…

Min_25筛

概述 Min_25是日本一个ACM选手的ID,这个筛法是他发明的,所以称之为Min_25筛。它能在亚线性复杂度求出一类积性函数的 fff 的前缀和,前提 是这个积性函数在质数和质数的幂位置的函数值比较好求。借助埃拉托色尼筛的思想 将原问题转化成与质因…

华为PIM-SM 动态RP实验配置

目录 建立PIM SM邻居 配置DR 配置动态RP 组成员端DR上配置IGMP 配置PIM安全 配置SPT切换 配置Anycast RP 配置接口的IP地址,并配置路由协议使得全网互通 建立PIM SM邻居 AR5操作 multicast routing-enable 开启组播路由转发功能 int g0/0/0 pim sm …

MacOS对文件夹加密的方法

背景 MacOS没有那种类似于windows那种对文件夹加解密的软件,MacOS自带有一种加解密,但是其实使用体验上跟windows那种很不一样。 win上的加解密都很快,就好像仅仅对文件夹进行加解密(我估计是安全性较低的,因为加密过…

【JavaSE】异常的初步认识

目录 1、初步认识异常 1、算数异常 2、空指针异常 3、数组越界异常 2、异常的结构体系 3、异常的分类 1、编译时异常/受查异常 2、运行时异常/非受查异常 4、异常的处理 1、处理异常的编程方式(防御式编程) 1、事前防御性(LBYL&a…

【软件测试】软件测试基础知识

1. 什么是软件测试 软件测试就是验证软件产品特性是否满足用户的需求 2. 调试与测试的区别 目的不同 调试:发现并解决软件中的缺陷测试:发现软件中的缺陷 参与角色不同 调试:开发人员测试:测试人员,开发人员等&a…

软件测试复习04:动态测试——黑盒测试

作者:非妃是公主 专栏:《软件测试》 个性签:顺境不惰,逆境不馁,以心制境,万事可成。——曾国藩 文章目录等价划分法边值分析法错误推测法因果图法示例习题等价划分法 等价类:一个几何&#xf…

如何快速搭建自己的阿里云服务器(宝塔)并且部署springboot+vue项目(全网最全)

📢欢迎点赞👍收藏⭐留言📝如有错误敬请指正! 文章目录📢欢迎点赞👍收藏⭐留言📝如有错误敬请指正!一、前言二、准备工作1、新手申请2、安全组设置3、修改实例4.这里可以 直接用阿里云…

【图像处理OpenCV(C++版)】——4.2 对比度增强之线性变换

前言: 😊😊😊欢迎来到本博客😊😊😊 🌟🌟🌟 本专栏主要结合OpenCV和C来实现一些基本的图像处理算法并详细解释各参数含义,适用于平时学习、工作快…

【数据结构】5.7 哈夫曼树及其应用

文章目录前言5.7.1 哈夫曼树的基本概念哈夫曼树的特点5.7.2 哈夫曼树的构造算法哈夫曼树的构造过程哈夫曼算法的实现算法思路算法实现5.7.3 哈夫曼编码哈夫曼编码思想前缀编码哈夫曼编码哈夫曼编码的性质哈夫曼编码的算法实现文件的编码和解码前言 编程:将学生的百…

【精品】k8s(Kubernetes)由基础到实战学法指南

轻松快速学会k8s四招 图1 k8s四招 学完本篇,您会获得什么惊喜? 从初学k8s,到帮助别人学会的过程中,发现朋友们和我,并非不努力,而是没有掌握更好的方法。有方法可让我们学的更快更轻松,这篇文章,以一个networkpolicy的题目,来逐步讲解,帮助大家建立一种,自己可以根…

深入了解延迟队列 DelayQueue

1. 前言 前面我们了解了基于数组,链表实现的阻塞队列,以及优先级队列。今天我们来了解下基于优先级队列的延迟队列,而且今天的内容很核心哦。 大家快搬好小板凳做好,听我慢慢分析 2. 简单实例 Task 类 public class Task implem…

数据结构(字符串)

字符串简称串,由零个或多个字符组成的有限序列,一般记为s=“a0 a1a2…an-1”,(n≥0)。其中s称作串名,用双引号括起来的字符序列是串的值。字符ai(0≤i≤n-1)可以是字母、数字或其它字…

开发第三天(Day 03)

首先对ipl.nas进行修改: ; haribote-ipl ; TAB4ORG 0x7c00 ; 这个程序被读入哪里; 以下是标准FAT12格式软盘的描述JMP entryDB 0x90DB "HARIBOTE" ; 可以自由地写引导扇区的名字 (8字节)DW 512 ; 1扇区…

【动态内存管理】-关于动态内存你只知道四个函数是不够的,这里还有题目教你怎么正确使用函数,还不进来看看??

🎇作者:小树苗渴望变成参天大树 💦作者宣言:认真写好每一篇博客 💢 作者gitee:link 如 果 你 喜 欢 作 者 的 文 章 ,就 给 作 者 点 点 关 注 吧! 🎊动态内存管理&…