使用PyTorch执行特征提取和微调的迁移学习来进行图像分类

news2025/1/11 9:03:32

使用PyTorch执行特征提取和微调的迁移学习来进行图像分类

    • 1. 效果图
    • 2 项目结构
    • 3 什么是迁移学习
    • 4 如何使用PyTorch进行迁移学习?
    • 5 花朵数据集
    • 源码
      • train_feature_extraction.py
      • fine_tune.py
      • inference.py
    • 参考

这篇博客将介绍如何使用PyTorch深度学习库执行图像分类的转移学习。
① 通过特征提取执行迁移学习
② 通过微调执行迁移学习

第①种方法通常更容易实现,在某些情况下效果很好。然而,它往往不如第二种方法准确。即模型的准确性和泛化能力都会受到影响。大多数形式的迁移学习都采用②微调。

通常建议使用特征提取方法来获得基线精度。如果准确度足以满足那就太棒了!然而,如果精度不够,那么应该进行微调,看看是否可以提高精度。
无论式通过特征提取还是微调的迁移学习,都会为你节省大量的时间和精力,而不是从头开始训练模型。

1. 效果图

5种花朵数据集,分别为雏菊、蒲公英、玫瑰、向日葵、郁金香,效果图如下:
在这里插入图片描述

特征提取效果图如下:
在这里插入图片描述
在这里插入图片描述

微调效果图如下:
在这里插入图片描述

在这里插入图片描述

2 项目结构

pip install opencv-contrib-python
pip install torch torchvision
pip install imutils matplotlib tqdm
  • 用于存储重要变量的配置脚本
  • 数据集加载器辅助函数
  • 在磁盘上构建和组织数据集的脚本,例如PyTorch的ImageFolder 和数据加载器类可以很容易地利用
  • 通过特征提取执行基本迁移学习的驱动程序脚本
  • 第二个驱动程序脚本,通过用全新的、新初始化的FC头替换预训练网络的完全连接(FC)层头来执行微调
  • 一个最终脚本,允许我们使用经过训练的模型进行推理

3 什么是迁移学习

从头开始训练卷积神经网络带来了许多挑战,最显著的是训练网络的数据量和进行训练所需的时间。

迁移学习是一种技术,允许使用为特定任务训练的模型作为不同任务的机器学习模型的起点。
例如,假设在ImageNet数据集上对模型进行图像分类训练。在这种情况下可以采用这个模型并“重新训练”它来识别它最初从未被训练来识别的类!
想象一下,你知道如何骑自行车,却想骑摩托车。你骑自行车的经验——保持平衡、保持方向、转弯和刹车——将帮助你更快地学会骑摩托车。这就是迁移学习在CNN的情况下所做的。使用迁移学习可以通过冻结参数、更改输出层和微调权重来直接使用训练有素的模型。
本质上可以缩短整个训练过程,并在很短的时间内获得高精度的模型。

4 如何使用PyTorch进行迁移学习?

迁移学习主要有两种类型:

  • 通过特征提取进行迁移学习(Transfer learning via feature extraction):从预先训练的网络中移除FC(Fully Connection)层头,并用softmax分类器替换它。这种方法非常简单,因为它允许将预先训练的CNN视为特征提取器,然后将这些特征通过Logistic回归分类器。

  • 通过微调进行迁移学习(Transfer learning via fine-tuning):在应用微调时,再次从预先训练的网络中移除FC(Fully Connection)层头,但这次构建了一个全新的、新初始化的FC层头并将其放置在网络的原始主体上。CNN主体中的权重被冻结,然后训练新的层头(通常具有非常小的学习率)。然后可以选择解冻网络的主体并训练整个网络。

第一种方法往往更容易使用,因为涉及的代码更少,需要调整的参数也更少。然而,第二种方法往往更准确,导致模型更好地推广。通过特征提取和微调的迁移学习都可以用PyTorch实现——我将在本教程的其余部分向您展示如何实现。

5 花朵数据集

将用于微调实验的数据集是由TensorFlow开发团队策划的花朵图像数据集。该数据集的3670幅图像属于五种不同的花卉:

  • 雏菊:633张图片
  • 蒲公英:898张图片
  • 玫瑰:641张图片
  • 向日葵:699张图片
  • 郁金香:799张图片

目标是训练一个图像分类模型来识别这些花的每一种,将通过PyTorch应用迁移学习来实现这一目标。

源码

train_feature_extraction.py

# flower_photos: 5种花朵原始图片集
# config.py 配置文件将存储驱动程序脚本中使用的重要变量和参数。与其在每个脚本中重新定义它们只需在这里定义一次(从而使代码更干净、更容易阅读)
# create_dataloader.py help函数,Dataloader加载flower_photos
# output/ 存放训练损失图
# build_dataset.py 根据flower_photos目录构建数据集目录,将创建特殊的子目录来存储训练和验证拆分,允许PyTorch的ImageFolder脚本来解析目录并训练模型
# train_feature_extraction.py 执行特征提取的迁移学习,并把模型存储磁盘
# fine_tune.py 执行基于微调的迁移学习,并把模型存储磁盘
# inference.py 接受经过训练的PyTorch模型,并使用它对输入的花朵图像进行预测

# 要实现的第一种迁移学习方法是特征提取
# 通过特征提取进行迁移学习的工作原理如下:
# 采用预先训练的CNN(通常在ImageNet数据集上),从CNN上卸下FC(Fully Connection)层头,将网络主体的输出视为空间维度为M×N×C的任意特征提取器
# 分类器有俩个选择:
# 采用标准的逻辑回归分类器(如scikit学习库中的分类器),并根据每个图像中提取的特征对其进行训练。或者,更简单地说,将softmax分类器放在网络主体的顶部,
# 任何一种选择都是可行的,而且或多或少与另一种“相同”。
# 当提取的特征数据集适合机器的RAM时,第一个选项非常有效。这样可以加载整个数据集,实例化逻辑回归分类器模型的一个实例,然后对其进行训练。
# 当数据集太大而无法放入机器内存时,就会出现问题。当这种情况发生时,你可以使用类似在线学习的方法来训练你的逻辑回归分类器,但这只是引入了另一组库和依赖项。
# 相反,更容易的是利用PyTorch的强大功能,在提取的特征之上创建一个类似逻辑回归的分类器,然后使用PyTorch函数对其进行训练。

# 训练特征提取模型,执行该脚本后,将在输出目录中找到一个名为warmup_model.pth的文件——该文件是序列化PyTorch模型,然后可以用于在inference.py脚本中进行预测。
# 总的训练时间只有5分钟多一点,获得了84.26%的训练准确率和87.74%的验证准确率。
# USAGE
# python train_feature_extraction.py

# 导入必要的包
from pyimagesearch import config
from pyimagesearch import create_dataloaders # 从输入数据集目录创建PyTorch DataLoader的实例
from imutils import paths
from torchvision.models import resnet50 # 要使用的ImageNet的预训练模型
from torchvision import transforms # 允许定义一组预处理和/或数据增强,将依次应用于输入图像
from tqdm import tqdm # 用于创建格式良好的进度条的Python库
from torch import nn # 包含PyTorch的神经网络类和函数
import matplotlib.pyplot as plt
import numpy as np
import torch # 包含PyTorch的神经网络类和函数
import time

# 定义增强管道(使用Compose函数构建数据处理/扩充步骤,该函数位于PyTorch的transforms子模块中。
# 首先创建一个trainTransform,在给定输入图像的情况下,它将:
# 随机调整图像大小并将其裁剪为image_SIZE尺寸
# 随机执行水平翻转
# 在[-90,90]范围内随机执行旋转
# 将生成的图像转换为PyTorch张量
# 执行平均值减法和缩放,同样的用于验证数据集的 valTransform
# 请注意,我们不在验证转换器中执行数据扩充——没有必要对验证数据执行数据扩充。
# 创建了训练和验证Compose对象后,让我们应用get_datalader函数:)
trainTansform = transforms.Compose([
    transforms.RandomResizedCrop(config.IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(90),
    transforms.ToTensor(),
    transforms.Normalize(mean=config.MEAN, std=config.STD)
])
valTransform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=config.MEAN, std=config.STD)
])

# 创建DataLoader
(trainDS, trainLoader) = create_dataloaders.get_dataloader(config.TRAIN,
                                                           transforms=trainTansform,
                                                           batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE)
(valDS, valLoader) = create_dataloaders.get_dataloader(config.VAL,
                                                       transforms=valTransform,
                                                       batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)

# 通过特征提取为迁移学习准备ResNet50模型
# 加载预训练的ImageNet ResNet50 model
model = resnet50(pretrained=True)

# 由于使用ResNet50模型作为特征提取器,设置其参数为不可训练(默认情况下是可训练的)
for param in model.parameters():
    param.requires_grad = False

# 将一个新的分类顶部附加到我们的特征提取器并弹出它,连接到当前设备
# 创建一个由单个FC层组成的新FC层头。实际上当使用分类交叉熵损失进行训练时,这一层将作为代理softmax分类器。
# 然后,这个新层被附加到网络主体,模型本身被移动到设备(CPU或GPU)。
modelOutputFeats = model.fc.in_features
model.fc = nn.Linear(modelOutputFeats, len(trainDS.classes))
model = model.to(config.DEVICE)

# 接下来,初始化损失函数和优化方法(注意只是向优化器提供分类顶部的参数)
lossFunc = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.fc.parameters(), lr=config.LR)

# 计算训练/验证集的每一个纪元步数
trainSteps = len(trainDS) // config.FEATURE_EXTRACTION_BATCH_SIZE
valSteps = len(valDS) // config.FEATURE_EXTRACTION_BATCH_SIZE

# 初始化字典以存储训练历史
H = {"train_loss": [], "train_acc": [], "val_loss": [],
     "val_acc": []}

# 遍历纪元
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.EPOCHS)):
    # 设置模型训练模式
    model.train()

    # 初始化训练/验证损失
    totalTrainLoss = 0
    totalValLoss = 0

    # 初始化训练/验证集中的预测正确个数
    trainCorrect = 0
    valCorrect = 0

    # 遍历训练集
    # 对于trainLoader中的每一批数据,将图像和类标签移动到CPU/GPU、对数据进行预测、计算损失,计算梯度,更新模型权重,并将梯度归零
    # 累积在该时期的总训练损失、计算正确预测的总数
    for (i, (x, y)) in enumerate(trainLoader):
        # 传递输入到设备
        (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))

        # 向前传递并计算训练损失
        pred = model(x)
        loss = lossFunc(pred, y)

        # 计算损失梯度
        loss.backward()

        # 检查是否正在更新模型参数,如果是 更新它们,并将之前累积的梯度清零
        if (i + 2) % 2 == 0:
            opt.step()
            opt.zero_grad()

        # 将损失加上迄今为止的总训练损失,同样累加正确预测的数量
        totalTrainLoss += loss
        trainCorrect += (pred.argmax(1) == y).type(
            torch.float).sum().item()

        # 关闭autograd并将模型置于评估模式中——这是使用PyTorch进行评估时的要求
        # switch off autograd
        with torch.no_grad():
            # 设置模型为评估模式
            model.eval()

            # 在valLoader中循环所有数据点,对它们进行预测,并计算总损失和正确验证预测的数量。
            # 遍历验证集
            for (x, y) in valLoader:
                # 传递输入到设备
                (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))

                # 预测并计算验证损失
                pred = model(x)
                totalValLoss += lossFunc(pred, y)

                # 计算正确预测的数量
                valCorrect += (pred.argmax(1) == y).type(
                    torch.float).sum().item()

        # 以下代码块汇总训练/验证损失和准确性,更新训练历史记录,然后将损失/准确性信息打印到终端
        # 计算平均训练/验证损失
        avgTrainLoss = totalTrainLoss / trainSteps
        avgValLoss = totalValLoss / valSteps

        # 计算训练/验证准确性
        trainCorrect = trainCorrect / len(trainDS)
        valCorrect = valCorrect / len(valDS)

        # 更新训练历史
        H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
        H["train_acc"].append(trainCorrect)
        H["val_loss"].append(avgValLoss.cpu().detach().numpy())
        H["val_acc"].append(valCorrect)

        # 打印模型训练、验证信息
        print("[INFO] EPOCH: {}/{}".format(e + 1, config.EPOCHS))
        print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
            avgTrainLoss, trainCorrect))
        print("Val loss: {:.6f}, Val accuracy: {:.4f}".format(
            avgValLoss, valCorrect))

# 绘制训练历史,序列化模型到磁盘
# 展示训练模型的总耗时
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
    endTime - startTime))

# 绘制训练/验证损失和准确性图
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(config.WARMUP_PLOT)

# 序列化模型到磁盘
torch.save(model, config.WARMUP_MODEL)

fine_tune.py

# ① 通过特征提取执行迁移学习
# ② 通过微调执行迁移学习
#
# ①在某些情况下效果很好,但其简单性也有缺点,即模型的准确性和泛化能力都会受到影响。大多数形式的迁移学习都采用②微调。

# 与特征提取类似,首先从网络中移除FC层头,但这次创建了一个全新的层头,其中包含一组线性、ReLU和丢弃层,类似于您在现代最先进的CNN上看到的内容。
# 然后执行以下组合:
# 冻结网络主体中的所有层并训练层头
# 冻结所有层,训练层头,然后解冻身体并训练
# 只需将所有图层解冻并一起训练即可
# 确切地说,你使用哪种方法是你自己进行的实验——一定要测量哪种方法的损失最小,准确度最高!

# 通过PyTorch的迁移学习应用微调
# 由于模型更为复杂(由于在网络主体中添加了新的FC层头),现在训练需要大约6.5分钟。然而在图4中获得了比简单特征提取方法更高的精度(分别为90.83%/90.19%和84.26%/87.74%)
# 虽然执行微调确实需要更多的工作,但通常会发现精度更高,模型会更好地推广。
# USAGE
# python fine_tune.py

# import the necessary packages
from pyimagesearch import config
from pyimagesearch import create_dataloaders
from imutils import paths
from torchvision.models import resnet50
from torchvision import transforms
from tqdm import tqdm
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import shutil
import torch
import time
import os

# 定义训练和验证转换,和对特征提取所做的相同
# 定义增强管道(使用Compose函数构建数据处理/扩充步骤,该函数位于PyTorch的transforms子模块中。
# 首先创建一个trainTransform,在给定输入图像的情况下,它将:
# 随机调整图像大小并将其裁剪为image_SIZE尺寸
# 随机执行水平翻转
# 在[-90,90]范围内随机执行旋转
# 将生成的图像转换为PyTorch张量
# 执行平均值减法和缩放,同样的用于验证数据集的 valTransform
# 请注意,我们不在验证转换器中执行数据扩充——没有必要对验证数据执行数据扩充。
# 创建了训练和验证Compose对象后,让我们应用get_datalader函数:)
trainTansform = transforms.Compose([
    transforms.RandomResizedCrop(config.IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(90),
    transforms.ToTensor(),
    transforms.Normalize(mean=config.MEAN, std=config.STD)
])
valTransform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=config.MEAN, std=config.STD)
])

# 创建DataLoader
(trainDS, trainLoader) = create_dataloaders.get_dataloader(config.TRAIN,
                                                           transforms=trainTansform,
                                                           batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE)
(valDS, valLoader) = create_dataloaders.get_dataloader(config.VAL,
                                                       transforms=valTransform,
                                                       batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)


# 加载ResNet模型,其中包含在ImageNet数据集上预先训练的权重。在这个微调中将构建一个新的FC层头,然后同时训练FC层头和网络主体。
# 然而首先需要密切关注网络架构中的批处理规范化层。这些层具有特定的平均值和标准偏差值,这些值是在最初在ImageNet数据集上训练网络时获得的。
# 不想在训练期间更新这些统计数据,冻结了BatchNorm2d的任何实例。构建新的headModel,它由一系列FC=>RELU=>DROPOUT层组成,最后一个线性层的输出是数据集中的类的数量,最后将新的headModel添加到网络中,从而替换旧的FC层头。
# 加载预训练的ImageNet ResNet50 model
# 真正的变化来自于从磁盘加载ResNet并修改体系结构本身
model = resnet50(pretrained=True)
numFeatures = model.fc.in_features

# 遍历模型的模块,设置批量归一化为非训练状态
for module, param in zip(model.modules(), model.parameters()):
    if isinstance(module, nn.BatchNorm2d):
        param.requires_grad = False

# 定义网络头,添加到模型
headModel = nn.Sequential(
    nn.Linear(numFeatures, 512),
    nn.ReLU(),
    nn.Dropout(0.25),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, len(trainDS.classes))
)
model.fc = headModel

# 将一个新的分类顶部附加到微调模型并连接到当前设备
model = model.to(config.DEVICE)

# 初始化损失函数和优化方法(注意只是向优化器提供分类顶部的参数)
lossFunc = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=config.LR)

# 计算训练/验证集的每一个纪元步数
trainSteps = len(trainDS) // config.FINETUNE_BATCH_SIZE
valSteps = len(valDS) // config.FINETUNE_BATCH_SIZE

# 初始化字典以存储训练历史
H = {"train_loss": [], "train_acc": [], "val_loss": [],
     "val_acc": []}

# 遍历纪元
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.EPOCHS)):
    # 设置模型为训练模式
    model.train()

    # 初始化总训练和验证损失
    totalTrainLoss = 0
    totalValLoss = 0

    # 初始化训练/验证的正确预测数
    trainCorrect = 0
    valCorrect = 0

    # 遍历训练集
    for (i, (x, y)) in enumerate(trainLoader):
        # 将输入图像及标签传递给设备
        (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))

        # 向前传递并计算训练损失
        pred = model(x)
        loss = lossFunc(pred, y)

        # 计算梯度
        loss.backward()

        # 检查是否正在更新模型参数,如果是 更新它们,并将之前累积的梯度清零
        if (i + 2) % 2 == 0:
            opt.step()
            opt.zero_grad()

        # 将损失加上迄今为止的总训练损失,同样累加正确预测的数量
        totalTrainLoss += loss
        trainCorrect += (pred.argmax(1) == y).type(
            torch.float).sum().item()

        # 关闭autograd并将模型置于评估模式中——这是使用PyTorch进行评估时的要求
        # switch off autograd
        with torch.no_grad():
            # 设置模型为评估模式
            model.eval()

            # 在valLoader中循环所有数据点,对它们进行预测,并计算总损失和正确验证预测的数量。
            # 遍历验证集
            for (x, y) in valLoader:
                # 把输入传递到模型
                (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))

                # 预测及计算验证损失
                pred = model(x)
                totalValLoss += lossFunc(pred, y)

                # 计算正确预测数
                valCorrect += (pred.argmax(1) == y).type(
                    torch.float).sum().item()

        # 计算训练和验证平均损失
        avgTrainLoss = totalTrainLoss / trainSteps
        avgValLoss = totalValLoss / valSteps

        # 计算训练和验证精度
        trainCorrect = trainCorrect / len(trainDS)
        valCorrect = valCorrect / len(valDS)

        # 更新训练历史值
        H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
        H["train_acc"].append(trainCorrect)
        H["val_loss"].append(avgValLoss.cpu().detach().numpy())
        H["val_acc"].append(valCorrect)

        # 打印训练和验证信息
        print("[INFO] EPOCH: {}/{}".format(e + 1, config.EPOCHS))
        print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
            avgTrainLoss, trainCorrect))
        print("Val loss: {:.6f}, Val accuracy: {:.4f}".format(
            avgValLoss, valCorrect))

# 打印训练的最终耗时
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
    endTime - startTime))

# 绘制训练和损失精确度图
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(config.FINETUNE_PLOT)

# 序列化模型到磁盘
torch.save(model, config.FINETUNE_MODEL)

inference.py

# 使用PyTorch应用迁移学习的两种方法:特征提取、微调,这两种方法都使模型获得了80-90%的准确率

# 微调迁移学习能获得更好的结果

# USAGE
# python inference.py --model output/warmup_model.pth
# python inference.py --model output/finetune_model.pth

import argparse  # 解析命令行参数

import matplotlib.pyplot as plt  # 绘制输出图像及预测结果
import torch  # PyTorch绑定函数及方法
from torchvision import transforms  # 通过顺序的方式执行一系列数据预处理

# 导入必要的包
from pyimagesearch import config  # 全局配置文件
from pyimagesearch import create_dataloaders  # 帮助函数以根据图像目录创建DataLoader对象得到dataset/val文件夹

# 构建命令行参数及解析
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=False, default="output/warmup_model.pth",
                help="path to trained model model")
args = vars(ap.parse_args())

# 创建 数据预处理管道
# 调整图像大小并将其裁剪为IMAGE_SIZE尺寸
# 将生成的图像转换为PyTorch张量
# 执行平均值缩放
testTransform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=config.MEAN, std=config.STD)
])

# 计算反均值和标准差 calculate the inverse mean and standard deviation
invMean = [-m / s for (m, s) in zip(config.MEAN, config.STD)]
invStd = [1 / s for s in config.STD]

# 定义去归一化变换 define our de-normalization transform 以展示图片到屏幕
deNormalize = transforms.Normalize(mean=invMean, std=invStd)

# 初始化测试集和DataLoader
print("[INFO] loading the dataset...")
(testDS, testLoader) = create_dataloaders.get_dataloader(config.VAL,
                                                         transforms=testTransform, batchSize=config.PRED_BATCH_SIZE,
                                                         shuffle=True)

# 检查是否有可用GPU,如果是,定义相应地图位置
if torch.cuda.is_available():
    map_location = lambda storage, loc: storage.cuda()

# 否则,使用cpu训练模型
else:
    map_location = "cpu"

# 加载模型
print("[INFO] loading the model...")
model = torch.load(args["model"], map_location=map_location)

# 设置模型为cpu/gpu训练,设置为评估模式
model.to(config.DEVICE)
model.eval()

# 获取一批测试数据集
batch = next(iter(testLoader))
(images, labels) = (batch[0], batch[1])

# 初始化图像
fig = plt.figure("Results", figsize=(10, 10))

# switch off autograd
with torch.no_grad():
    # 把图像传递到设备
    images = images.to(config.DEVICE)

    # 执行预测
    print("[INFO] performing inference...")
    preds = model(images)

    # 遍历所有批次
    for i in range(0, config.PRED_BATCH_SIZE):
        # 初始化一个子图以绘制图像和预测结果
        # ax = plt.subplot(config.PRED_BATCH_SIZE, 1, i + 1) # 4行1列
        ax = plt.subplot(config.PRED_BATCH_SIZE / 2, 2, i + 1)  # 2行2列

        # 获取图像,反归一化,缩放原始像素为[0,255] 并更改通道,从第一个通道到最后一个通道排序
        # 通过“撤销”平均缩放和交换颜色通道顺序来取消图像的标准化
        image = images[i]
        image = deNormalize(image).cpu().numpy()
        image = (image * 255).astype("uint8")
        image = image.transpose((1, 2, 0))

        # 获取正确的标签
        idx = labels[i].cpu().numpy()
        gtLabel = testDS.classes[idx]

        # 获取预测的标签
        pred = preds[i].argmax().cpu().numpy()
        predLabel = testDS.classes[pred]

        # 添加真实标签及预测标签到图像上
        info = "Ground Truth: {}, Predicted: {}".format(gtLabel,
                                                        predLabel)
        plt.imshow(image)
        plt.title(info)
        plt.axis("off")

    # 展示
    plt.tight_layout()
    plt.show()

参考

  • https://pyimagesearch.com/2021/10/11/pytorch-transfer-learning-and-image-classification/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/657189.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

U盘重装系统Win10详细步骤和方法

当前超多的用户都在使用Win10系统,有些用户想使用U盘来重装一下Win10系统,但不知道具体怎么操作,其实操作起来难度不会很大,可以按照以下小编给大家分享的U盘重装系统Win10详细步骤和方法,就能轻松顺利完成U盘重装系统…

Jetson TX2 NX的GPIO引脚使用方式

Jetson TX2 NX是一款高性能的嵌入式AI计算平台,其中引脚的设计和使用对于开发人员来说非常重要。在本文中,我们将会介绍Jetson TX2 NX的引脚并说明其功能和使用方式。 官方文档官方文档 引脚概述 Jetson TX2 NX具有许多不同类型的引脚,包…

C++ 类的构造函数和析构函数

目录 类的构造函数和析构函数构造函数声明构造函数定义构造函数使用构造函数默认构造函数 析构函数析构函数的声明析构函数的定义 改进Stock类(加入构造函数和析构函数) 类的构造函数和析构函数 构造函数 常规的初始化语法不适用类的初始化 例如: int a 10;//整…

Deepin20.9 安装Mysql

文章目录 mysql下载查看 mysql 状态卸载卸载mysql:清理残留数据检查是否删除完毕 mysql Deepin 安装 下载 从网上下载 https://dev.mysql.com/get/mysql-apt-config_0.8.23-1_all.deb 安装 mysql-apt-config 下载文件名: mysql-apt-config_0.8.23-1_all.deb …

PoseiSwap IDO 即将开启,一览 $POSE 经济模型

以太坊创始人 Vitalik Buterin 曾在今年以太坊黑山大会上,进行了以“以太坊的三个技术挑战:扩容、隐私和用户安全”为主题的演讲,阐明了具有隐私性、可扩展性和安全性的且易访问的区块链生态将是行业发展趋势,或许重复造轮子正在变…

【探索 Kubernetes|作业管理篇 系列 10】Pod 健康检查和恢复机制

前言 大家好,我是秋意零。 上一篇中介绍了,Pod 的服务对象,从而对 Pod 有了更深的理解; 今天的主题是 Pod 健康检查和恢复机制,我们将结束 Pod 的内容。 最近搞了一个扣扣群,旨在技术交流、博客互助&am…

图像中提取文本

将从此图像中提取文本。我使用得是 PyCharm,您随意编辑器或IDE 1、下载所需得库和exe文件: tesseract-ocr 可执行exe文件下载后,安装时无需指定安装目录。 http://jaist.dl.sourceforge.net/project/tesseract-ocr-alt/tesseract-ocr-setup-3…

代码随想录二刷day25 | 回溯 之 216.组合总和III 17.电话号码的字母组合

216.组合总和III 题目链接 解题思路: 选取过程如图: 图中,可以看出,只有最后取到集合(1,3)和为4 符合条件。 递归三部曲 确定递归函数参数 和77. 组合 一样,依然需要一维数组path…

走进人工智能|深度学习 算法的创世纪

前言: 深度学习通过训练深层神经网络模型,可以自动学习和提取数据的特征,包括更准确的图像识别、自然语言处理、医学诊断等方面的应用。 文章目录 序言背景算法的创世纪技术支持应用领域程序员如何学总结 序言 深度学习是一种机器学习方法&a…

easyui05(datagrid数据新增)

一.对话框&#xff1a;Dialog 加载页面 <div id"myDialog" style"display:none"></div> 二.editGoods.jsp 表单 myForm <head> <meta http-equiv"Content-Type" content"text/html; charsetUTF-8"> <tit…

2023年互联网Java面试复习大纲:ZK+Redis+MySQL+Java基础+架构

多数的公司总体上面试都是以自我介绍项目介绍项目细节/难点提问基础知识点考核算法题这个流程下来的。有些公司可能还会问几个实际的场景类的问题&#xff0c;这个环节阿里是必问的&#xff0c;这种问题通常是没有正确答案的&#xff0c;就看个人的理解&#xff0c;个人的积累了…

Vue练手项目之仿京东到家主页

目录 概述1.效果展示2.使用原始HtmlCSS实现3.使用Vue.js进行组件化3.1 Header部分组件实现3.1.1图标的展示3.1.2 定义Vue调试的名称3.1.3 使用scoped隔离组件间的css影响 3.2 附近店铺部分实现3.3 底部导航栏组件的实现3.4 将组件组成一个整体页面 4.代码地址 概述 本人是一个…

【微信小程序开发】第 9 课 - 小程序的协同工作和发布

欢迎来到博主 Apeiron 的博客&#xff0c;祝您旅程愉快 &#xff01; 时止则止&#xff0c;时行则行。动静不失其时&#xff0c;其道光明。 目录 1、协同工作 1.1、了解权限管理需求 1.2、了解项目成员的组织结构 1.3、小程序的开发流程 2、小程序成员管理 2.1、成员管…

【Unity Shader】Special Effects(八)Wireframe 线框化(UI)

更新日期:2023年6月17日。 Github源码:[点我获取源码] 索引 Wireframe 线框化思路分析Sobel算子片元输入数据结构-定义片元输入数据结构-填充片元输入数据结构-传入属性定义求梯度值方法求边缘方法范围控制线框化渐变动画Wireframe 线框化 线框化效果可以将一张图像根据纹理…

从618「技术暗战」,看乡村振兴的未来「赛点」

作者 | 曾响铃 文 | 响铃说 作为消费复苏后的首个消费节点&#xff0c;从“史上消费者福利最大的618”“史上投入最大的一届618”等口号&#xff0c;都能感觉到这届618的火药味比以往要浓得多。 有业内人士透露&#xff0c;这次的年中大促无论从商品种类、数量还是提供的服务…

【自动化测试】是否有必要做自动化测试?

‍目录 一、前言 二、自动化目的 三、自动化分类 四、自动化实现 一、前言 在一些测试交流群经常会看到有小伙伴在问&#xff0c;"怎么做自动化测试&#xff1f;学习自动化测试有什么资料吗&#xff1f;自动化测试是不是很牛逼&#xff1f;" &#xff0c;甚至有…

Python之面向对象和继承

一、关于None和判断的总结 1.1、None是什么&#xff1f; 与C和JAVA不同&#xff0c;python中是没有NULL的&#xff0c;取而代之的是None。None是一个特殊的常量&#xff0c;表示变量没有指向任何对象。在Python中&#xff0c;None本身实际上也是对象&#xff0c;有自己的类型N…

浅谈自动化测试框架开发

在自动化测试项目中&#xff0c;为了实现更多功能&#xff0c;我们需要引入不同的库、框架。 首先&#xff0c;你需要将常用的这些库、框架都装上。 pip install requests pip install selenium pip install appium pip install pytest pip install pytest-rerunfailures pip …

【深度学习】基于pytorch的FER2013人脸表情图像识别(ResNet/VGG/DenseNet)

题目要求 1.1. 任务要求 数据集&#xff1a;Facial Expression Recognition Challenge&#xff0c;共有7类&#xff1a;生气、恶心、害怕、快乐、悲伤、惊讶、中性。 基本要求&#xff08;50%&#xff09;&#xff1a;构建ResNet分类模型18层。 改进&#xff08;30%&#x…

Disruptor(1):Disruptor简介

1 什么是Disruptor Martin Fowler在自己网站上写了一篇LMAX架构的文章&#xff0c;在文章中他介绍了LMAX是一种新型零售金融交易平台&#xff0c;它能够以很低的延迟产生大量交易。这个系统是建立在JVM平台上&#xff0c;其核心是一个业务逻辑处理器&#xff0c;它能够在一个线…