Colab/PyTorch - 003 Transfer Learning For Image Classification

news2024/11/17 16:00:46

Colab/PyTorch - 003 Transfer Learning For Image Classification

  • 1. 源由
  • 2. 迁移学习(ResNet50)
    • 2.1 数据集准备
    • 2.2 数据增强
    • 2.3 数据加载
    • 2.4 迁移学习
    • 2.5 数据集训练&验证
    • 2.6 模型推理
  • 3. 总结
  • 4. 参考资料

1. 源由

迁移学习已经彻底改变了 PyTorch 中处理图像分类的方式。

最近,PyTorch 因其易用性和学习性而受到了广泛关注。特斯拉人工智能高级总监安德烈·卡帕西在他的推特中说了以下的话。
关于 PyTorch 的安德烈·卡帕西 - 转移学习
PyTorch 非常透明,可以帮助研究人员和数据科学家实现高生产力和可靠的结果。

其实,这种方法,有点类似螺旋式上升的旋转阶梯。大量的学习让我们积累知识螺旋式上升到某个层面。当这个层面遇到新的类似问题时,只要经过适当的调整,就能更上一层楼。

而这个适当的调整并不需要大量的推倒重来,更多的是在原有的基础上进行适配调整。

在《Colab/PyTorch - 002 Pre Trained Models for Image Classification》中,放入一个没有训练过的图片,可以看到机器学习给出了一个相近的分类,俗语说“难字读半边”,大体就是这个意思。

在《Jammy@Jetson Orin - Tensorflow & Keras Get Started: Transfer Learning & Fine Tuning》,也讨论过关于Transfer Learning的概念。

2. 迁移学习(ResNet50)

基于ImageNet用数百万张图像进行了训练的ResNet50模型,对CalTech256数据集的子集来对10种动物的图像进行分类就是本次PyTorch的一个例子。

通过据集准备、数据增强、构建分类器、使用迁移学习来利用低级图像特征,比如边缘、纹理等模型参数特征(这些特征是预训练模型ResNet50继承得来的)。最后,通过训练集来进一步训练分类器,从而学习数据集图像中的更高级别的细节,比如眼睛、腿等。

2.1 数据集准备

CalTech256数据集包含30,607张图像,分为256个不同的标签类别,还有一个“杂乱”类别。对整个数据集进行训练需要数小时。

因此,我们将使用数据集的一个子集,其中包含10种动物:熊、黑猩猩、长颈鹿、大猩猩、美洲驼、鸵鸟、豪猪、臭鼬、三角龙和斑马。这些文件夹中的图像数量从81张(对于臭鼬)到212张(对于大猩猩)不等。

我们对这些图片进行如下分类:

  • 每个类别中的前60张图像进行训练
  • 接下来的10张图像用于验证
  • 其余的图像用于下面实验中的测试

因此,最终,我们有600张训练图像、100张验证图像、409张测试图像和10个动物类别。

操作步骤如下:

  1. 下载 CalTech256数据集
  2. 创建名为 train、valid 和 test 的三个目录。
  3. 在 train/valid/test 目录中分别创建 10 个子目录。这些子目录应该命名为 bear、chimp、giraffe、gorilla、llama、ostrich、porcupine、skunk、triceratops 和 zebra。
  4. 将 Caltech256 数据集中的前60张熊的图像移动到目录 train/bear。对每种动物重复此步骤。
  5. 将 Caltech256 数据集中的下一组10张熊的图像移动到目录 valid/bear。对每种动物重复此步骤。
  6. 将熊的剩余图像(即未包含在 train 或 valid 文件夹中的图像)复制到目录 test/bear。对每种动物重复此步骤。

2.2 数据增强

在每个 epoch 中,每个输入图像会经过多次变换,通过在变换中引入一些随机性来插入一些变化。

当训练多个 epoch 时,模型会看到输入图像都是新的随机变换。这导致了数据增强,然后模型试图更加泛化。

下面看到了一张三角龙图像的变换版本的示例:
在这里插入图片描述

# Applying Transforms to the Data
image_transforms = { 
    'train': transforms.Compose([
        transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
}
  • transform.RandomResizedCrop 通过随机大小裁剪输入图像(在原始大小的0.8到1.0的范围内,并在默认范围内的0.75到1.33的随机宽高比)。然后将裁剪后的图像调整大小为256×256。
  • transform.RandomRotation 将图像随机旋转一个角度,角度范围为-15度到15度。
  • transform.RandomHorizontalFlip 以默认概率50%随机水平翻转图像。
  • transform.CenterCrop 从中心裁剪出一个224×224的图像。
  • transform.ToTensor 将 PIL 图像转换为浮点数张量,并将值归一化到0-1的范围,方法是将其除以255。
  • transform.Normalize 接受一个3通道张量,并按照该通道的输入均值和标准差对每个通道进行归一化。均值和标准差向量作为3元素向量输入。张量中的每个通道被规范化为 T = (T - 均值) / 标准差。

所有上述转换都使用train的transforms.Compose连接在一起。

对于验证和测试数据,我们不执行RandomResizedCrop、RandomRotation和RandomHorizontalFlip转换。相反,我们将验证图像调整大小为256×256,并裁剪出中心224×224的部分,以便能够将它们与预训练模型一起使用。最后,图像被转换为张量,并按照ImageNet中所有图像的均值和标准差进行归一化处理。

2.3 数据加载

Colab上运行,需要将制作好的数据集上传Google云存储。
在这里插入图片描述

数据存放目录如下所示:

├── caltech_10
│   ├── test
│   │   ├── bear
│   │   ├── chimp
│   │   ├── giraffe
│   │   ├── gorilla
│   │   ├── llama
│   │   ├── ostrich
│   │   ├── porcupine
│   │   ├── skunk
│   │   ├── triceratops
│   │   └── zebra
│   ├── train
│   │   ├── bear
│   │   ├── chimp
│   │   ├── giraffe
│   │   ├── gorilla
│   │   ├── llama
│   │   ├── ostrich
│   │   ├── porcupine
│   │   ├── skunk
│   │   ├── triceratops
│   │   └── zebra
│   └── valid
│       ├── bear
│       ├── chimp
│       ├── giraffe
│       ├── gorilla
│       ├── llama
│       ├── ostrich
│       ├── porcupine
│       ├── skunk
│       ├── triceratops
│       └── zebra
└── image_classification_using_transfer_learning_in_pytorch.ipynb

使用DataLoader加载用于训练的数据:

# Test on Google Drive

from google.colab import drive
drive.mount('/content/drive')

# Load the Data

# Set train and valid directory paths

dataset = '/content/drive/MyDrive/caltech_10'

train_directory = os.path.join(dataset, 'train')
valid_directory = os.path.join(dataset, 'valid')
test_directory = os.path.join(dataset, 'test')

# Batch size
bs = 32

# Number of classes
num_classes = len(os.listdir(valid_directory))  #10#2#257
print(num_classes)

# Load Data from folders
data = {
    'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),
    'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid']),
    'test': datasets.ImageFolder(root=test_directory, transform=image_transforms['test'])
}

# Get a mapping of the indices to the class names, in order to see the output classes of the test images.
idx_to_class = {v: k for k, v in data['train'].class_to_idx.items()}
print(idx_to_class)

# Size of Data, to be used for calculating Average Loss and Accuracy
train_data_size = len(data['train'])
valid_data_size = len(data['valid'])
test_data_size = len(data['test'])

# Create iterators for the Data loaded using DataLoader module
train_data_loader = DataLoader(data['train'], batch_size=bs, shuffle=True)
valid_data_loader = DataLoader(data['valid'], batch_size=bs, shuffle=True)
test_data_loader = DataLoader(data['test'], batch_size=bs, shuffle=True)

在这里插入图片描述

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(train_data_size, valid_data_size, test_data_size)

在这里插入图片描述

2.4 迁移学习

从头开始训练分类器是非常困难且耗时的。

因此,使用一个预训练的模型作为基础,并改变最后几层以根据我们想要的类别对图像进行分类。这有助于我们在一个小数据集上获得良好的结果,因为预训练模型已经从像ImageNet这样的更大数据集中学习了基本的图像特征。

在这里插入图片描述
正如上面图像中所看到的,内部层保持与预训练ResNet50模型相同,只有最后几层被更改以适应我们的类别数量。

# Load pretrained ResNet50 Model
resnet50 = models.resnet50(pretrained=True)
resnet50 = resnet50.to(device)

Canziani等人列出了许多预训练模型,用于各种实际应用,分析了每个模型的准确性和推断所需的时间。

ResNet50是一个在准确性和推断时间之间有良好折衷的模型之一。当在PyTorch中加载模型时,默认情况下,所有参数的’requires_grad’字段都设置为true。这意味着对参数值的每一次更改都将被存储以用于训练时使用的反向传播图中。这增加了内存需求。由于我们预训练模型中的大多数参数已经训练过,我们将’requires_grad’字段重置为false。

# Freeze model parameters
for param in resnet50.parameters():
    param.requires_grad = False

接下来,我们将ResNet50模型的最后一层替换为一小组Sequential层。ResNet50的最后一个全连接层的输入被馈送到一个线性层。它有256个输出,然后被馈送到ReLU和Dropout层。然后是一个256×10的线性层,它有10个输出,对应于我们CalTech子集中的10个类别。

# Change the final layer of ResNet50 Model for Transfer Learning
fc_inputs = resnet50.fc.in_features

resnet50.fc = nn.Sequential(
    nn.Linear(fc_inputs, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, num_classes), # Since 10 possible outputs
    nn.LogSoftmax(dim=1) # For using NLLLoss()
)

由于我们将在GPU上进行训练,我们准备好将模型移到GPU上。

# Convert model to be used on GPU
resnet50 = resnet50.to(device)

接下来,我们定义用于训练的损失函数和优化器。PyTorch提供了多种损失函数。我们使用负对数似然损失函数,因为它对于多类别分类很有用。PyTorch还支持多种优化器。我们使用Adam优化器。Adam是最流行的优化器之一,因为它可以针对每个参数单独调整学习率。

# Define Optimizer and Loss Function
loss_func = nn.NLLLoss()
optimizer = optim.Adam(resnet50.parameters())

2.5 数据集训练&验证

训练是在固定的epoch数量内进行的,每个图像在单个epoch中处理一次。训练数据加载器以批量方式加载数据。我们给定了一个批量大小为32。这意味着每个批次最多可以有32张图像。

对于每个批次,输入图像被传递到模型中,即前向传播,以获取输出。然后使用提供的损失函数或成本函数使用真实值和计算得到的输出来计算损失。使用反向函数计算损失相对于可训练参数的梯度。

注:使用迁移学习时,只需要计算属于模型末尾的少数新添加层的一小部分参数的梯度。对模型的摘要函数调用可以显示实际参数数量和可训练参数的数量。这种方法中的优势是只需要训练总模型参数的大约十分之一。

梯度计算是使用自动求导和反向传播完成的,在图中使用链式法则进行微分。PyTorch在反向传播过程中累积所有梯度。因此,在训练循环的开始时将它们清零是至关重要的。这可以通过优化器的zero_grad函数实现。最后,在反向传播过程中计算了梯度后,使用优化器的step函数更新参数。

对整个批次计算总损失和准确度,然后对所有批次进行平均,以获得整个epoch的损失和准确度值。

随着训练进行的增加,模型倾向于过度拟合数据,导致其在新的测试数据上表现不佳。保持一个单独的验证集很重要,这样我们就可以在合适的时机停止训练,并防止过拟合。在每个epoch的训练循环之后立即进行验证。由于在验证过程中我们不需要进行任何梯度计算,所以它是在一个torch.no_grad()块内完成的。

对于每个验证批次,输入和标签被传输到GPU(如果cuda可用,否则它们被传输到CPU)。输入经过前向传播,然后对批次和循环结束时的整个epoch进行损失和准确度计算。

def train_and_validate(model, loss_criterion, optimizer, epochs=25):
    '''
    Function to train and validate
    Parameters
        :param model: Model to train and validate
        :param loss_criterion: Loss Criterion to minimize
        :param optimizer: Optimizer for computing gradients
        :param epochs: Number of epochs (default=25)
  
    Returns
        model: Trained Model with best validation accuracy
        history: (dict object): Having training loss, accuracy and validation loss, accuracy
    '''
    
    start = time.time()
    history = []
    best_loss = 100000.0
    best_epoch = None

    for epoch in range(epochs):
        epoch_start = time.time()
        print("Epoch: {}/{}".format(epoch+1, epochs))
        
        # Set to training mode
        model.train()
        
        # Loss and Accuracy within the epoch
        train_loss = 0.0
        train_acc = 0.0
        
        valid_loss = 0.0
        valid_acc = 0.0
        
        for i, (inputs, labels) in enumerate(train_data_loader):

            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Clean existing gradients
            optimizer.zero_grad()
            
            # Forward pass - compute outputs on input data using the model
            outputs = model(inputs)
            
            # Compute loss
            loss = loss_criterion(outputs, labels)
            
            # Backpropagate the gradients
            loss.backward()
            
            # Update the parameters
            optimizer.step()
            
            # Compute the total loss for the batch and add it to train_loss
            train_loss += loss.item() * inputs.size(0)
            
            # Compute the accuracy
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))
            
            # Convert correct_counts to float and then compute the mean
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            
            # Compute total accuracy in the whole batch and add to train_acc
            train_acc += acc.item() * inputs.size(0)
            
            #print("Batch number: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}".format(i, loss.item(), acc.item()))

        
        # Validation - No gradient tracking needed
        with torch.no_grad():

            # Set to evaluation mode
            model.eval()

            # Validation loop
            for j, (inputs, labels) in enumerate(valid_data_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass - compute outputs on input data using the model
                outputs = model(inputs)

                # Compute loss
                loss = loss_criterion(outputs, labels)

                # Compute the total loss for the batch and add it to valid_loss
                valid_loss += loss.item() * inputs.size(0)

                # Calculate validation accuracy
                ret, predictions = torch.max(outputs.data, 1)
                correct_counts = predictions.eq(labels.data.view_as(predictions))

                # Convert correct_counts to float and then compute the mean
                acc = torch.mean(correct_counts.type(torch.FloatTensor))

                # Compute total accuracy in the whole batch and add to valid_acc
                valid_acc += acc.item() * inputs.size(0)

                #print("Validation Batch number: {:03d}, Validation: Loss: {:.4f}, Accuracy: {:.4f}".format(j, loss.item(), acc.item()))
        if valid_loss < best_loss:
            best_loss = valid_loss
            best_epoch = epoch

        # Find average training loss and training accuracy
        avg_train_loss = train_loss/train_data_size 
        avg_train_acc = train_acc/train_data_size

        # Find average training loss and training accuracy
        avg_valid_loss = valid_loss/valid_data_size 
        avg_valid_acc = valid_acc/valid_data_size

        history.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])
                
        epoch_end = time.time()
    
        print("Epoch : {:03d}, Training: Loss - {:.4f}, Accuracy - {:.4f}%, \n\t\tValidation : Loss - {:.4f}, Accuracy - {:.4f}%, Time: {:.4f}s".format(epoch, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))
        
        # Save if the model has best accuracy till now
        torch.save(model, dataset+'_model_'+str(epoch)+'.pt')
            
    return model, history, best_epoch
Epoch: 1/30
Epoch : 000, Training: Loss - 1.6488, Accuracy - 46.3333%, 
		Validation : Loss - 0.6532, Accuracy - 94.0000%, Time: 69.4381s
Epoch: 2/30
Epoch : 001, Training: Loss - 0.5886, Accuracy - 87.1667%, 
		Validation : Loss - 0.3524, Accuracy - 89.0000%, Time: 8.7085s
Epoch: 3/30
Epoch : 002, Training: Loss - 0.3438, Accuracy - 92.1667%, 
		Validation : Loss - 0.2165, Accuracy - 97.0000%, Time: 9.7899s
Epoch: 4/30
Epoch : 003, Training: Loss - 0.2306, Accuracy - 94.8333%, 
		Validation : Loss - 0.2157, Accuracy - 93.0000%, Time: 8.2054s
Epoch: 5/30
Epoch : 004, Training: Loss - 0.1882, Accuracy - 95.5000%, 
		Validation : Loss - 0.1668, Accuracy - 96.0000%, Time: 9.5612s
Epoch: 6/30
Epoch : 005, Training: Loss - 0.1682, Accuracy - 95.3333%, 
		Validation : Loss - 0.1474, Accuracy - 97.0000%, Time: 8.0876s
Epoch: 7/30
Epoch : 006, Training: Loss - 0.1850, Accuracy - 94.8333%, 
		Validation : Loss - 0.1665, Accuracy - 96.0000%, Time: 8.8183s
Epoch: 8/30
Epoch : 007, Training: Loss - 0.1645, Accuracy - 96.5000%, 
		Validation : Loss - 0.1560, Accuracy - 94.0000%, Time: 8.5604s
Epoch: 9/30
Epoch : 008, Training: Loss - 0.1490, Accuracy - 94.3333%, 
		Validation : Loss - 0.1529, Accuracy - 95.0000%, Time: 8.7029s
Epoch: 10/30
Epoch : 009, Training: Loss - 0.1236, Accuracy - 95.8333%, 
		Validation : Loss - 0.1578, Accuracy - 96.0000%, Time: 8.4867s
Epoch: 11/30
Epoch : 010, Training: Loss - 0.0958, Accuracy - 97.8333%, 
		Validation : Loss - 0.2119, Accuracy - 92.0000%, Time: 9.2300s
Epoch: 12/30
Epoch : 011, Training: Loss - 0.1213, Accuracy - 96.5000%, 
		Validation : Loss - 0.1623, Accuracy - 95.0000%, Time: 8.8091s
Epoch: 13/30
Epoch : 012, Training: Loss - 0.1710, Accuracy - 94.0000%, 
		Validation : Loss - 0.2524, Accuracy - 89.0000%, Time: 9.2230s
Epoch: 14/30
Epoch : 013, Training: Loss - 0.1252, Accuracy - 95.6667%, 
		Validation : Loss - 0.1514, Accuracy - 96.0000%, Time: 8.6822s
Epoch: 15/30
Epoch : 014, Training: Loss - 0.0881, Accuracy - 97.6667%, 
		Validation : Loss - 0.1434, Accuracy - 95.0000%, Time: 8.8090s
Epoch: 16/30
Epoch : 015, Training: Loss - 0.0828, Accuracy - 96.8333%, 
		Validation : Loss - 0.1271, Accuracy - 96.0000%, Time: 8.4651s
Epoch: 17/30
Epoch : 016, Training: Loss - 0.0705, Accuracy - 97.8333%, 
		Validation : Loss - 0.1536, Accuracy - 96.0000%, Time: 9.4177s
Epoch: 18/30
Epoch : 017, Training: Loss - 0.0834, Accuracy - 97.6667%, 
		Validation : Loss - 0.1726, Accuracy - 94.0000%, Time: 8.0583s
Epoch: 19/30
Epoch : 018, Training: Loss - 0.0740, Accuracy - 98.1667%, 
		Validation : Loss - 0.1585, Accuracy - 95.0000%, Time: 9.0694s
Epoch: 20/30
Epoch : 019, Training: Loss - 0.0659, Accuracy - 98.3333%, 
		Validation : Loss - 0.2178, Accuracy - 93.0000%, Time: 8.6098s
Epoch: 21/30
Epoch : 020, Training: Loss - 0.0819, Accuracy - 97.8333%, 
		Validation : Loss - 0.1866, Accuracy - 95.0000%, Time: 8.9363s
Epoch: 22/30
Epoch : 021, Training: Loss - 0.0921, Accuracy - 96.5000%, 
		Validation : Loss - 0.1907, Accuracy - 95.0000%, Time: 8.2608s
Epoch: 23/30
Epoch : 022, Training: Loss - 0.0662, Accuracy - 98.0000%, 
		Validation : Loss - 0.1494, Accuracy - 95.0000%, Time: 11.4841s
Epoch: 24/30
Epoch : 023, Training: Loss - 0.0618, Accuracy - 98.0000%, 
		Validation : Loss - 0.1616, Accuracy - 94.0000%, Time: 9.5806s
Epoch: 25/30
Epoch : 024, Training: Loss - 0.0767, Accuracy - 97.0000%, 
		Validation : Loss - 0.1904, Accuracy - 94.0000%, Time: 8.1954s
Epoch: 26/30
Epoch : 025, Training: Loss - 0.0476, Accuracy - 98.8333%, 
		Validation : Loss - 0.1830, Accuracy - 94.0000%, Time: 9.1184s
Epoch: 27/30
Epoch : 026, Training: Loss - 0.0575, Accuracy - 98.3333%, 
		Validation : Loss - 0.2433, Accuracy - 94.0000%, Time: 8.7180s
Epoch: 28/30
Epoch : 027, Training: Loss - 0.0831, Accuracy - 97.1667%, 
		Validation : Loss - 0.2413, Accuracy - 93.0000%, Time: 8.5477s
Epoch: 29/30
Epoch : 028, Training: Loss - 0.0649, Accuracy - 98.0000%, 
		Validation : Loss - 0.2413, Accuracy - 93.0000%, Time: 9.0394s
Epoch: 30/30
Epoch : 029, Training: Loss - 0.0647, Accuracy - 97.5000%, 
		Validation : Loss - 0.1586, Accuracy - 94.0000%, Time: 9.5677s

在这里插入图片描述在这里插入图片描述
如上图所示,对于这个数据集,验证和训练损失都相当快地稳定下来。准确度也迅速提高到了0.9左右的范围。随着epoch数量的增加,训练损失进一步减小,导致过拟合,但验证结果并没有显著改善。因此,我们选择了具有更高准确度和较低损失的epoch的模型。最好是在早期停止以防止过拟合训练数据。在我们的情况下,我们选择了第8个epoch,其验证准确度为96%。

2.6 模型推理

一旦我们有了模型,我们就可以对单个测试图像或整个测试数据集进行推断,以获取测试准确度。测试集准确度的计算与验证代码类似,只是它是在测试数据集上进行的。我们已经在Python笔记本中包含了computeTestSetAccuracy函数来完成相同的工作。让我们在下面讨论如何为给定的测试图像找到输出类别。

首先,输入图像经历用于验证/测试数据的所有转换。然后将结果张量转换为四维张量,并通过模型传递,该模型输出不同类别的对数概率。对模型输出的指数给出了类别概率。然后我们选择具有最高概率的类作为我们的输出类别。选择具有最高概率的类作为我们的输出类别。

def predict(model, test_image_name):
    '''
    Function to predict the class of a single test image
    Parameters
        :param model: Model to test
        :param test_image_name: Test image

    '''
    
    transform = image_transforms['test']


    test_image = Image.open(test_image_name)
    plt.imshow(test_image)
    
    test_image_tensor = transform(test_image)
    if torch.cuda.is_available():
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224).cuda()
    else:
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224)
    
    with torch.no_grad():
        model.eval()
        # Model outputs log probabilities
        out = model(test_image_tensor)
        ps = torch.exp(out)

        topk, topclass = ps.topk(3, dim=1)
        cls = idx_to_class[topclass.cpu().numpy()[0][0]]
        score = topk.cpu().numpy()[0][0]

        for i in range(3):
            print("Predcition", i+1, ":", idx_to_class[topclass.cpu().numpy()[0][i]], ", Score: ", topk.cpu().numpy()[0][i])
# Test a particular model on a test image
! wget https://cdn.pixabay.com/photo/2018/10/01/12/28/skunk-3716043_1280.jpg -O skunk.jpg
dataset = '/content/drive/MyDrive/caltech_10'
model = torch.load("{}_model_{}.pt".format(dataset, best_epoch))
predict(model, 'skunk.jpg')

# Load Data from folders
#computeTestSetAccuracy(model, loss_func)

在这里插入图片描述判断的准确率99.97%。

3. 总结

使用一个在ImageNet的1000个类别上预训练的模型,非常有效地应用了ResNet50模型,通过小数据集的训练,对感兴趣的10个不同类别的图像进行了分类。

测试代码:003 Image Classification using Transfer Learning in Pytorch

4. 参考资料

【1】Colab/PyTorch - Getting Started with PyTorch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1664530.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

shared_ptr 引用计数相关问题

前言 智能指针是 C11 增加的非常重要的特性&#xff0c;并且也是面试的高频考点&#xff0c;本文主要解释以下几个问题&#xff1a; 引用计数是怎么共享的、怎么解决并发问题的资源释放时&#xff0c;控制块的内存释放吗weak_ptr 怎么判断对象是否已经释放 文中源码用的是 L…

从零自制docker-12-【overlayfs】

文章目录 overlayfsexec.Command("tar", "-xvf", busyboxTarURL, "-C", busyboxURL).CombinedOutput()exec.Command格式差异 挂载mount卸载unmount代码地址结果演示 overlayfs 就是联合文件系统&#xff0c;将多个文件联合在一起成为一个统一的…

【VTKExamples::Rendering】第五期 环形阵列Rotations

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK样例环形阵列Rotations,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~YO 1. Rotations

程序环境和预处理、编译链接过程、编译的几个阶段、运行环境、预定义符号等的介绍

文章目录 前言一、程序的翻译环境和执行环境二、编译链接过程三、编译的几个阶段四、运行环境五、预定义符号总结 前言 程序环境和预处理、编译链接过程、编译的几个阶段、运行环境、预定义符号的介绍。 一、程序的翻译环境和执行环境 在 ANSI C 的任何一种实现中&#xff0c…

DDR5和LPDDR4/5 命令解析

关键名称介绍 DDR5 SDRAM和LPDDR4/5都采用了高级的命令集来支持更高效的内存管理和操作,其中“Multi-purpose command (MPC)”、“Mode Register Read (MRR)”、“Mode Register Write (MRW)”,以及“Write Pattern Command”是几种关键的命令类型,它们在内存初始化、配置和…

力扣 5-11

704. 二分查找 给定一个 n 个元素有序的&#xff08;升序&#xff09;整型数组 nums 和一个目标值 target &#xff0c;写一个函数搜索 nums 中的 target&#xff0c;如果目标值存在返回下标&#xff0c;否则返回 -1。 这道题目的前提是数组为有序数组&#xff0c;同时题目还强…

028.实现 strStr()

题意 给你两个字符串 haystack 和 needle &#xff0c;请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标&#xff08;下标从 0 开始&#xff09;。如果 needle 不是 haystack 的一部分&#xff0c;则返回 -1 。 难度 简单 示例 例 1 输入&#xff1a;hays…

Java----数组的定义和使用

1.数组的定义 在Java中&#xff0c;数组是一种相同数据类型的集合。数组在内存中是一段连续的空间。 2.数组的创建和初始化 2.1数组的创建 在Java中&#xff0c;数组创建的形式与C语言又所不同。 Java中数组创建的形式 T[] 数组名 new T[N]; 1.T表示数组存放的数据类型…

1290.二进制链表转整数

给你一个单链表的引用结点 head。链表中每个结点的值不是 0 就是 1。已知此链表是一个整数数字的二进制表示形式。 请你返回该链表所表示数字的 十进制值 。 示例 1&#xff1a; 输入&#xff1a;head [1,0,1] 输出&#xff1a;5 解释&#xff1a;二进制数 (101) 转化为十进制…

静态住宅代理 IP 的影响

在不断发展的在线业务和数字营销领域&#xff0c;保持领先地位势在必行。在业界掀起波澜的最新创新之一是静态住宅代理 IP 的利用。这些知识产权曾经是为精通技术的个人保留的利基工具&#xff0c;现在正在成为各行业企业的游戏规则改变者。 一、静态住宅代理IP到底是什么&…

LeetCode/NowCoder-链表经典算法OJ练习1

目录 说在前面 题目一&#xff1a;移除链表元素 题目二&#xff1a;反转链表 题目三&#xff1a;合并两个有序链表 题目四&#xff1a;链表的中间节点 SUMUP结尾 说在前面 dear朋友们大家好&#xff01;&#x1f496;&#x1f496;&#x1f496;数据结构的学习离不开刷题…

实现树莓派DS18B20读取温度(OneWire)

简介 使用的是树莓派3B, Go编程实现OneWire方式读取DS18B20温度。 接线 DS18B20 包含经典三线&#xff0c; VCC和GND自不必说&#xff0c; 主要的是DQ线&#xff0c; 需要接4.7K的上拉电阻&#xff0c; 即4.7K欧姆的电阻接到DQ和VCC&#xff0c; 否则树莓派识别不到DS18B20&am…

2024kali linux上安装java8

1 kali下载Java 8安装包 访问Oracle官网或其他可信的Java下载站点&#xff0c;如华为云的开源镜像站&#xff08;例如&#xff1a;https://repo.huaweicloud.com/java/jdk/8u202-b08/jdk-8u202-linux-x64.tar.gz&#xff09;。 确保下载的是与你的Kali Linux系统架构&#xf…

Covalent Network(CQT)通过 “新曙光” 计划实现重要里程碑,增强以太坊时光机,提供 30% 的年化质押收益率

Covalent Network&#xff08;CQT&#xff09;作为集成超过 280 条区块链&#xff0c;并服务于超过 2.8 亿个钱包的领先结构化数据基础设施层&#xff0c;宣布了其战略计划 “新曙光” 中的一个重要进展。随着网络升级并完成了准备工作的 75%&#xff0c;这将为即将部署的以太坊…

2024数维杯数学建模B题完整论文讲解(含每一问python代码+结果+可视化图)

大家好呀&#xff0c;从发布赛题一直到现在&#xff0c;总算完成了2024数维杯数学建模挑战赛生物质和煤共热解问题的研究完整的成品论文。 本论文可以保证原创&#xff0c;保证高质量。绝不是随便引用一大堆模型和代码复制粘贴进来完全没有应用糊弄人的垃圾半成品论文。 B题论…

基于单片机的直流电机检测与控制系统

摘要&#xff1a; 文章设计一款流电机控制系统&#xff0c;以 STC89C51 作为直流电机控制系统的主控制器&#xff0c;采用 LM293 做为驱动器实现 对直流电机的驱动&#xff0c;采用霍尔实现对直流电机速度的检测&#xff1b;本文对直流电机控制系统功能分析&#xff0c;选择确…

探索Linux:深入理解各种指令与用法

文章目录 cp指令mv指令cat指令more指令less指令head指令tail指令与时间相关的指令date指令 cal指令find指令grep指令zip/unzip指令总结 上一个Linux文章我们介绍了大部分指令&#xff0c;这节我们将继续介绍Linux的指令和用法。 cp指令 功能&#xff1a;复制文件或者目录 语法…

TMS320F280049 CLB模块--FSM(3)

功能框图 FSM有效状态机内部框图如下图所示&#xff0c;可以看到内部有S0 / S1两个状态和下一状态的跳转查找表。还有个输出查找表。 下图是FSM LUT的示意框图。FSM还可以工作在3输入或4输入的查找表模式下。对于输入&#xff0c;EXTRA_EXT_IN1/0可以替换S0/1。 寄存器 参考文…

词令蚂蚁庄园今日答案如何在微信小程序查看蚂蚁庄园今天问题的正确答案?

词令蚂蚁庄园今日答案如何在微信小程序查看蚂蚁庄园今天问题的正确答案&#xff1f; 1、打开微信&#xff0c;点击搜索框&#xff1b; 2、打开搜索页面&#xff0c;选择小程序搜索&#xff1b; 3、在搜索框&#xff0c;输入词令搜索点击进入词令微信小程序&#xff1b; 4、打开…

vivado Kintex-7 配置存储器器件

Kintex-7 配置存储器器件 下表所示闪存器件支持通过 Vivado 软件对 Kintex -7 器件执行擦除、空白检查、编程和验证等配置操作。 本附录中的表格所列赛灵思系列非易失性存储器将不断保持更新 &#xff0c; 并支持通过 Vivado 软件对其中所列非易失性存储器 进行擦除、…