利用Pytorch预训练模型进行图像分类

news2024/12/24 20:12:30

Use Pre-trained models for Image Classification.

# This post is rectified on the base of https://learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/

# And we have re-orginaized the code script.

预训练模型(Pre-trained models)是在ImageNet等大型基准数据集上训练的神经网络模型。深度学习社区从这些开源模型中受益匪浅。此外,预训练模型也是计算机视觉研究取得快速进展的一个重要因素。其他研究人员和从业人员可以使用这些最先进的模型,而不是从头开始重新训练。

# Here are some examples of classic pre-trained models.

在这里插入图片描述

在详细介绍如何使用预训练模型进行图像分类之前,我们先来看看有哪些预训练模型。我们将以 AlexNet 和 ResNet101 为例进行讨论。这两个网络都在 ImageNet 数据集上训练过。

ImageNet 数据集拥有超过 1400 万张由斯坦福大学维护的图像。它被广泛用于各种与图像相关的深度学习项目。这些图像属于不同的类别或标签。预训练模型(如 AlexNet 和 ResNet101)的目的是将图像作为输入并预测其类别。

这里的 "预训练 "是指,深度学习架构 AlexNet 和 ResNet101 已经在某个(庞大的)数据集上进行过训练,因此带有由此产生的权重和偏差。架构与权重和偏置之间的区别应该非常明显,因为我们将在下一节看到,TorchVision 同时拥有架构和预训练模型。

1.1 Model Inference Process

由于我们将重点讨论如何使用预先训练好的模型来预测输入的类别(标签),因此我们也来讨论一下其中涉及的过程。这个过程被称为模型推理。整个过程包括以下主要步骤:

(1) 读取输入图像;
(2) 对图像进行转换;例如resize、center crop、normalization等;
(3) 前向传递:使用预训练的模型权重来获得输出向量,而输出向量中的每个元素都描述了模型对于输入图像属于特定类别的置信度预测结果;
(4) 预测结果:基于获得的置信度分数,显示预测结果。

1.2 Loading Pre-Trained Network using TorchVision

# [Optinal Step]
# %pip install torchvision
# Load necessary packages.
from PIL import Image
import torch
import torchvision
from torchvision import models
from torchvision import transforms

print(torch.__version__)
print(torchvision.__version__)
2.0.0
0.15.0
# Check the different models and architectures available to us.
dir(models)
['AlexNet',
 'AlexNet_Weights',
 'ConvNeXt',
 'ConvNeXt_Base_Weights',
 'ConvNeXt_Large_Weights',
 'ConvNeXt_Small_Weights',
 'ConvNeXt_Tiny_Weights',
 'DenseNet',
 'DenseNet121_Weights',
 'DenseNet161_Weights',
 'DenseNet169_Weights',
 'DenseNet201_Weights',
 'EfficientNet',
 'EfficientNet_B0_Weights',
 'EfficientNet_B1_Weights',
 'EfficientNet_B2_Weights',
 'EfficientNet_B3_Weights',
 'EfficientNet_B4_Weights',
 'EfficientNet_B5_Weights',
 'EfficientNet_B6_Weights',
 'EfficientNet_B7_Weights',
 'EfficientNet_V2_L_Weights',
 'EfficientNet_V2_M_Weights',
 'EfficientNet_V2_S_Weights',
 'GoogLeNet',
 'GoogLeNetOutputs',
 'GoogLeNet_Weights',
 'Inception3',
 'InceptionOutputs',
 'Inception_V3_Weights',
 'MNASNet',
 'MNASNet0_5_Weights',
 'MNASNet0_75_Weights',
 'MNASNet1_0_Weights',
 'MNASNet1_3_Weights',
 'MaxVit',
 'MaxVit_T_Weights',
 'MobileNetV2',
 'MobileNetV3',
 'MobileNet_V2_Weights',
 'MobileNet_V3_Large_Weights',
 'MobileNet_V3_Small_Weights',
 'RegNet',
 'RegNet_X_16GF_Weights',
 'RegNet_X_1_6GF_Weights',
 'RegNet_X_32GF_Weights',
 'RegNet_X_3_2GF_Weights',
 'RegNet_X_400MF_Weights',
 'RegNet_X_800MF_Weights',
 'RegNet_X_8GF_Weights',
 'RegNet_Y_128GF_Weights',
 'RegNet_Y_16GF_Weights',
 'RegNet_Y_1_6GF_Weights',
 'RegNet_Y_32GF_Weights',
 'RegNet_Y_3_2GF_Weights',
 'RegNet_Y_400MF_Weights',
 'RegNet_Y_800MF_Weights',
 'RegNet_Y_8GF_Weights',
 'ResNeXt101_32X8D_Weights',
 'ResNeXt101_64X4D_Weights',
 'ResNeXt50_32X4D_Weights',
 'ResNet',
 'ResNet101_Weights',
 'ResNet152_Weights',
 'ResNet18_Weights',
 'ResNet34_Weights',
 'ResNet50_Weights',
 'ShuffleNetV2',
 'ShuffleNet_V2_X0_5_Weights',
 'ShuffleNet_V2_X1_0_Weights',
 'ShuffleNet_V2_X1_5_Weights',
 'ShuffleNet_V2_X2_0_Weights',
 'SqueezeNet',
 'SqueezeNet1_0_Weights',
 'SqueezeNet1_1_Weights',
 'SwinTransformer',
 'Swin_B_Weights',
 'Swin_S_Weights',
 'Swin_T_Weights',
 'Swin_V2_B_Weights',
 'Swin_V2_S_Weights',
 'Swin_V2_T_Weights',
 'VGG',
 'VGG11_BN_Weights',
 'VGG11_Weights',
 'VGG13_BN_Weights',
 'VGG13_Weights',
 'VGG16_BN_Weights',
 'VGG16_Weights',
 'VGG19_BN_Weights',
 'VGG19_Weights',
 'ViT_B_16_Weights',
 'ViT_B_32_Weights',
 'ViT_H_14_Weights',
 'ViT_L_16_Weights',
 'ViT_L_32_Weights',
 'VisionTransformer',
 'Weights',
 'WeightsEnum',
 'Wide_ResNet101_2_Weights',
 'Wide_ResNet50_2_Weights',
 '_GoogLeNetOutputs',
 '_InceptionOutputs',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '_api',
 '_meta',
 '_utils',
 'alexnet',
 'convnext',
 'convnext_base',
 'convnext_large',
 'convnext_small',
 'convnext_tiny',
 'densenet',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'detection',
 'efficientnet',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b2',
 'efficientnet_b3',
 'efficientnet_b4',
 'efficientnet_b5',
 'efficientnet_b6',
 'efficientnet_b7',
 'efficientnet_v2_l',
 'efficientnet_v2_m',
 'efficientnet_v2_s',
 'get_model',
 'get_model_builder',
 'get_model_weights',
 'get_weight',
 'googlenet',
 'inception',
 'inception_v3',
 'list_models',
 'maxvit',
 'maxvit_t',
 'mnasnet',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet',
 'mobilenet_v2',
 'mobilenet_v3_large',
 'mobilenet_v3_small',
 'mobilenetv2',
 'mobilenetv3',
 'optical_flow',
 'quantization',
 'regnet',
 'regnet_x_16gf',
 'regnet_x_1_6gf',
 'regnet_x_32gf',
 'regnet_x_3_2gf',
 'regnet_x_400mf',
 'regnet_x_800mf',
 'regnet_x_8gf',
 'regnet_y_128gf',
 'regnet_y_16gf',
 'regnet_y_1_6gf',
 'regnet_y_32gf',
 'regnet_y_3_2gf',
 'regnet_y_400mf',
 'regnet_y_800mf',
 'regnet_y_8gf',
 'resnet',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext101_64x4d',
 'resnext50_32x4d',
 'segmentation',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'shufflenet_v2_x1_5',
 'shufflenet_v2_x2_0',
 'shufflenetv2',
 'squeezenet',
 'squeezenet1_0',
 'squeezenet1_1',
 'swin_b',
 'swin_s',
 'swin_t',
 'swin_transformer',
 'swin_v2_b',
 'swin_v2_s',
 'swin_v2_t',
 'vgg',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'video',
 'vision_transformer',
 'vit_b_16',
 'vit_b_32',
 'vit_h_14',
 'vit_l_16',
 'vit_l_32',
 'wide_resnet101_2',
 'wide_resnet50_2']

以AlexNet为例,我们可以看到还有一个名称为alexnet的条目。其中,大写的名称是Python类(AlexNet),而alexnet是一个便于操作的函数(convenience function),用于从AlexNet类返回实例化的模型。

这些方便函数也可以有不同的参数集,例如:densenet121、densenet161、densenet169以及densenet201,都是DenseNet的实例,但层数分别为121,161,169和201.

1.3. Using AlexNet for Image Classification

AlexnetNet是图像识别领域早期的一个突破性网络结构,相关文章可以参考Understanding Alexnet。该网络架构如下:

在这里插入图片描述

Step 1: Load the pre-trained model

# Create an instance of the network.
alexnet = models.alexnet(pretrained=True)
/home/wsl_ubuntu/anaconda3/envs/xy_trans/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/wsl_ubuntu/anaconda3/envs/xy_trans/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
# Note: Pytorch模型的扩展名通常为.pt或.pth
# Check the model details.
print(alexnet)
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Step 2: Specify image transformations

# Use transforms to compose all the data transformations.
transform = transforms.Compose([
    transforms.Resize(256), 
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])])     # Three numbers for RGB Channels.
# transforms.Resize: Resize the input images to 256x256 pixels.
# transforms.CenterCrop: Crop the image to 224×224 pixels about the center.
# transforms.Normalize: Normalize the image by setting its mean and standard deviation to the specified values.
# transforms.ToTensor: Convert the image to Pytorch tensor datatype.

Step 3: Load the input image and pre-process it.

# Download image
# !wget https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg -O dog.jpg
img = Image.open("./dog.jpg")
img

在这里插入图片描述

# Pre-process the image.
trans_img = transform(img)

img_batch = torch.unsqueeze(trans_img, 0)

Step 4: Model Inference

# Set the model to eval model.
alexnet.eval()

out = alexnet(img_batch)
print(out.shape)
torch.Size([1, 1000])
# Download classes text file
!wget https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt
--2023-12-14 21:30:09--  https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 0.0.0.0, ::
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|0.0.0.0|:443... failed: Connection refused.
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|::|:443... failed: Connection refused.
# Load labels.
with open('imagenet_classes.txt') as f:
    classes = [line.strip() for line in f.readlines()]
# Find out the maximum score.
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
print(classes[index[0]], percentage[index[0]].item())
Labrador retriever 41.58513259887695
# The model predicts the image to be of a Labrador Retriever with a 41.58% confidence.
_, indices = torch.sort(out, descending=True)
[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
[('Labrador retriever', 41.58513259887695),
 ('golden retriever', 16.59164810180664),
 ('Saluki, gazelle hound', 16.286897659301758),
 ('whippet', 2.8539111614227295),
 ('Ibizan hound, Ibizan Podenco', 2.39247727394104)]

1.4. Using ResNet for Image Classification

# Load the resnet101 model.
resnet = models.resnet101(pretrained=True)

# Set the model to eval mode.
resnet.eval()

# carry out model inference.
out = resnet(img_batch)

# Print the top 5 classes predicted by the model.
_, indices = torch.sort(out, descending=True)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /home/wsl_ubuntu/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:51<00:00, 3.47MB/s] 





[('Labrador retriever', 48.255577087402344),
 ('dingo, warrigal, warragal, Canis dingo', 7.900773048400879),
 ('golden retriever', 6.91691780090332),
 ('Eskimo dog, husky', 3.6434390544891357),
 ('bull mastiff', 3.046128273010254)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1310716.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SpringBoot】进阶之自定义starter(一起了解自定义starter的魅力)

&#x1f389;&#x1f389;欢迎来到我的CSDN主页&#xff01;&#x1f389;&#x1f389; &#x1f3c5;我是君易--鑨&#xff0c;一个在CSDN分享笔记的博主。&#x1f4da;&#x1f4da; &#x1f31f;推荐给大家我的博客专栏《SpringBoot开发》。&#x1f3af;&#x1f3af;…

解决设备能耗管理问题,易点易动来帮忙!

设备能耗管理是现代企业可持续发展的重要环节&#xff0c;然而&#xff0c;许多企业在设备能耗管理方面面临一系列问题&#xff1a; 能耗数据收集困难&#xff1a;企业需要监控和管理大量设备的能耗情况&#xff0c;但传统的手动方式收集能耗数据耗时耗力&#xff0c;无法实时获…

Python中的TesserOCR:文字识别的全方位指南

更多资料获取 &#x1f4da; 个人网站&#xff1a;ipengtao.com 文字识别在图像处理领域中起到了至关重要的作用&#xff0c;而TesserOCR&#xff08;Tesseract OCR的Python封装&#xff09;为开发者提供了一个强大的工具&#xff0c;使得文字识别变得更加便捷。本文将通过详细…

电声器件是什么

电声器件 电子元器件百科 文章目录 电声器件前言一、电声器件是什么二、电声器件的类别三、电声器件的应用实例四、电声器件的作用原理总结前言 电声器件在多种应用中起着重要作用,如家庭娱乐系统、音响设备、通信设备、汽车音响、舞台表演、声音检测和录音等领域。它们的设计…

新版Spring Security6.2案例 - Authentication用户名密码

前言&#xff1a; 前面有翻译了新版Spring Security6.2架构&#xff0c;包括总体架构&#xff0c;Authentication和Authorization&#xff0c;感兴趣可以直接点链接&#xff0c;这篇翻译官网给出的关于Authentication的Username/Password这页。 首先呢&#xff0c;官网就直接…

RHEL7.5编译openssl1.1.1w源码包到rpm包

openssl1.1.1w下载地址 https://www.openssl.org/source/ 安装依赖包 yum -y install curl which make gcc perl perl-WWW-Curl rpm-build wget http://mirrors.aliyun.com/centos-vault/7.5.1804/os/x86_64/Packages/perl-WWW-Curl-4.15-13.el7.x86_64.rpm rpm -ivh pe…

tuxera2023破解版免费下载 NTFS for Mac读写工具(附序列号)

Tuxera ntfs 2023 破解安装包是一个mac读写ntfs磁盘工具允许您访问&#xff0c;它允许您访问NFTS 驱动器上的文件。 该应用程序提供访问访问Mac 设备中NFTS 格式文件的驱动力&#xff0c;因此您有权基于格式文件进行无困难的访问Windows 数据。 在发生电力灾难或断电时使用防损…

重新认识Word——给图、表、公式等自动编号

重新认识Word——给图、表、公式等自动编号 给图增加题注题注失败的情况给图添加“如图xx-xx所示” 给公式插入题注第一步——先加题注第二步——设置两个制表符 解决题注“图一-1”的问题 前面我们已经学习了如何引用多级列表自动编号了&#xff0c;现在我们有第二个问题&…

汽车清除积碳和清洗节气门

汽车清除积碳和清洗节气门 汽车需要清除积碳的部位检查积碳方法&#xff1a; 清除积碳和清洗节气门风险&#xff1a;燃油宝 第一次清除积碳1万公里2万公里3万公里--5万公里6万公里以上 汽车需要清除积碳的部位 节气门喷油嘴进气道燃烧室 检查积碳方法&#xff1a; 建议每3到5…

基于javaweb实现的实践教学基地管理系统

一、系统架构 前端&#xff1a;html | js | css | bootstrap 后端&#xff1a;spring | springmvc | mybatis-plus 环境&#xff1a;jdk1.8 | mysql8 | tomcat | maven 二、代码及数据库 三、功能介绍 01. web-首页1 02. web-首页2 03. web-首页3 04. web-首页4 05. 管…

智能冶钢厂环境监控与设备控制系统(边缘物联网网关)

目录 1、项目背景 2、项目功能介绍 3、模块框架 3.1 架构框图 3.2 架构介绍 4、系统组成与工作原理 4.1 数据采集 4.2 指令控制 4.3 其他模块 4.3.1 网页、qt视频流 4.3.2 qt搜索进程 5、成果呈现 6、问题解决 7、项目总结 1、项目背景 这个项目的背景是钢铁行业的…

【算法Hot100系列】无重复字符的最长子串

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

学习MS Dynamics AX 2012编程开发 1. 了解Dynamics AX 2012

在本章中&#xff0c;您将了解开发环境的结构以及Microsoft Dynamics AX中的开发人员可以访问哪些工具。在本书的第一步演练之后&#xff0c;您将很容易理解著名的Hello World代码&#xff0c;您将知道应用程序对象树中的不同节点代表什么。 以下是您将在本章中学习的一些主题…

Python-docx 深入word源码 自定义字符间距

代码和实现效果 from docx import Document from docx.oxml import OxmlElement from docx.oxml.ns import qn from docx.shared import Pt# 调整pt设置字间距 def SetParagraphCharSpaceByPt(run, pt1):通过修改word源码方式, 添加w:spacing标签直接通过调整pt来设置字符间距…

软件设计中如何画各类图之七了解组件图:系统架构的关键视角

目录 1 前言2 组件图基本介绍3 画组件图的步骤4 组件图的用途5 场景及实际场景举例6 结语 1 前言 组件图是一种UML的图形化表示工具&#xff0c;为系统架构提供了重要视角。它描述了系统中各个组件以及它们之间的依赖关系和连接。用于展示系统中的组件、软件模块、以及它们之间…

成绩分级 C语言xdoj53

问题描述 给出一个百分制的成绩&#xff0c;要求输出成绩等级A,B,C,D,E。90分以上为A&#xff0c;80~89分为B,70~79分为C,60~69分为D&#xff0c;60分以下为E。 输入说明 输入一个正整数m&#xff08;0<m<100&#xff09; 输出说明 输出一个字符 输入样例 …

麦肯锡:2023年最被关注的科技趋势

1 近期&#xff0c;麦肯锡咨询公司公布了颇具影响力的《McKinsey Technology Trends Outlook 2023》报告&#xff0c;旨在通过其技术委员会的洞察力&#xff0c;揭示2023年可能改变商业舞台的15个技术趋势。报告的编撰不仅为企业和投资者提供了宝贵的方向指引&#xff0c;同时…

『 Linux 』重新理解挂起状态

文章目录 &#x1f984; 前言新建状态 &#x1f40b;挂起状态 &#x1f40b;唤入唤出 &#x1f40b;进程与操作系统间的联系 &#x1f40b; &#x1f984; 前言 『 Linux 』使用fork函数创建进程与进程状态的查看中提到了对挂起状态的一个理解&#xff1b; ​ 挂起状态相比于其…

【计算机组成体系结构】只读存储器ROM

一、ROM分类 二、计算机中重要的ROM 运行时操作系统在主存中&#xff0c;但是由于RAM断电后数据会丢失&#xff0c;所以操作系统都存储在辅存中&#xff0c;在开机时由CPU读入主存&#xff0c;而BIOS芯片就是用来存储自举装入程序的&#xff0c;它用于开机时引导把操作系统装入…

VS2022 将项目打包,导出为exe运行

我有一个在 VS2022 上开发的程序&#xff0c;基于.net 6框架, 想打包成 .exe程序&#xff0c;以在另一个没有安装VS的机器上运行&#xff0c;另一个机器是Win7系统&#xff0c;上面安装了.net 6框架。 虽然网上很多教程&#xff0c;需要安装Project Installer&#xff0c;配置A…