深度学习每周学习总结P4(猴痘识别)

news2024/11/30 0:49:39
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

–来自百度网盘超级会员V5的分享

目录

    • 0. 总结
    • 1. 数据导入部分
    • 2. 划分数据集
    • 3. 模型构建部分
      • 3.1 模型构建
      • 3.2 公式推导
    • 4. 设置超参数
    • 5. 定义训练函数
    • 6. 定义测试函数
    • 7. 训练过程
    • 8. 结果可视化
    • 9.模型的保存及预测
      • 9.1 代码部分
      • 9.2 疑问解析

0. 总结

数据导入部分:本次数据导入没有使用torchvision自带的数据集,需要将原始数据进行处理包括数据导入,数据类型转换。

划分数据集:划定训练集测试集后,再使用torch.utils.data中的DataLoader()加载数据

模型构建部分:有两个部分一个初始化部分(init())列出了网络结构的所有层,比如卷积层池化层等。第二个部分是前向传播部分,定义了数据在各层的处理过程。

设置超参数:在这之前需要定义损失函数,学习率,以及根据学习率定义优化器(例如SGD随机梯度下降),用来在训练中更新参数,最小化损失函数。

定义训练函数:函数的传入的参数有四个,分别是设置好的DataLoader(),定义好的模型,损失函数,优化器。函数内部初始化损失准确率为0,接着开始循环,使用DataLoader()获取一个批次的数据,对这个批次的数据带入模型得到预测值,然后使用损失函数计算得到损失值。接下来就是进行反向传播以及使用优化器优化参数,梯度清零放在反向传播之前或者是使用优化器优化之后都是可以的,一般是默认放在反向传播之前。

定义测试函数:函数传入的参数相比训练函数少了优化器,只需传入设置好的DataLoader(),定义好的模型,损失函数。此外除了处理批次数据时无需再设置梯度清零、返向传播以及优化器优化参数,其余部分均和训练函数保持一致。

训练过程:定义训练次数,有几次就使用整个数据集进行几次训练,初始化四个空list分别存储每次训练及测试的准确率及损失。使用model.train()开启训练模式,调用训练函数得到准确率及损失。使用model.eval()将模型设置为评估模式,调用测试函数得到准确率及损失。接着就是将得到的训练及测试的准确率及损失存储到相应list中并合并打印出来,得到每一次整体训练后的准确率及损失。

结果可视化

模型的保存,调取及使用。在PyTorch中,通常使用 torch.save(model.state_dict(), ‘model.pth’) 保存模型的参数,使用 model.load_state_dict(torch.load(‘model.pth’)) 加载参数。

需要改进优化的地方:再保证整体流程没有问题的情况下,继续细化细节研究,比如一些函数的原理及作用,如何提升训练集准确率等问题。

1. 数据导入部分

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets

import os,PIL,pathlib

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device
device(type='cuda')
# 查看数据分类
import os,PIL,random,pathlib

data_dir = './data/mpox_recognize/'
data_dir = pathlib.Path(data_dir) # 使用pathlib.Path()函数将字符串类型的文件夹路径转换为pathlib.Path对象。

data_paths = list(data_dir.glob('*')) # 使用glob()方法获取data_dir路径下的所有文件路径,并以列表形式存储在data_paths中。
# classNames = [str(path).split('\\')[-1] for path in data_paths] # 通过split()函数对data_paths中的每个文件路径执行分割操作,获得各个文件所属的类别名称
# classNames = [str(path).split('\\')[2] for path in data_paths]
classNames = [path.parts[-1] for path in data_paths] # pathlib的.parts方法会返回路径各部分的一个元组
# classNames = [path.parts[2] for path in data_paths]
classNames
['Monkeypox', 'Others']
total_datadir = './data/mpox_recognize/'

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)
total_data
Dataset ImageFolder
    Number of datapoints: 2142
    Root location: ./data/mpox_recognize/
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

2. 划分数据集

# 划分数据集
train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset,test_dataset = torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset,test_dataset
(<torch.utils.data.dataset.Subset at 0x24b70342d00>,
 <torch.utils.data.dataset.Subset at 0x24b70342a90>)
train_size,test_size
(1713, 429)
# 使用dataloader加载数据,并设置好基本的batch_size
batch_size = 32

train_dl = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 1 # 用于数据加载的子进程数量。通常,将其设置为大于 0 的值可以加快数据加载速度,特别是当数据集很大时。默认值为 0,表示在主进程中加载数据。
)
test_dl = torch.utils.data.DataLoader(
    test_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 1
)
for X,y in test_dl:
    print("Shape of X [N,C,H,W]: ",X.shape)
    print("Shape of y: ",y.shape,y.dtype)
    break
Shape of X [N,C,H,W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

3. 模型构建部分

3.1 模型构建

import torch.nn.functional as F

class Network_bn(nn.Module):
    def __init__(self):
        super(Network_bn,self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 3,out_channels = 12,kernel_size = 5,stride = 1,padding = 0)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(in_channels = 12,out_channels = 12,kernel_size = 5,stride = 1,padding = 0)
        self.bn2 = nn.BatchNorm2d(12)
        self.pool = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(in_channels = 12,out_channels = 24,kernel_size = 5,stride = 1,padding = 0)
        self.bn3 = nn.BatchNorm2d(24)
        self.conv4 = nn.Conv2d(in_channels = 24,out_channels = 24,kernel_size = 5,stride = 1,padding = 0)
        self.bn4 = nn.BatchNorm2d(24)
        self.dropout = nn.Dropout(p = 0.5)
        self.fc1 = nn.Linear(24*50*50,len(classNames))
    
    def forward(self,x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool(x)
        x = x.view(-1,24*50*50)
        x = F.relu(self.fc1(x))
        return x

print("Using {} device".format(device))

model = Network_bn().to(device)
model
Using cuda device





Network_bn(
  (conv1): Conv2d(3, 12, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(12, 12, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(12, 24, kernel_size=(5, 5), stride=(1, 1))
  (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(24, 24, kernel_size=(5, 5), stride=(1, 1))
  (bn4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=60000, out_features=2, bias=True)
)

3.2 公式推导

3, 224, 224(输入数据)

-> 12, 220, 220(经过卷积层1)

-> 12, 216, 216(经过卷积层2)-> 12, 108, 108(经过池化层1)

-> 24, 104, 104(经过卷积层3)

-> 24, 100, 100(经过卷积层4)-> 24, 50, 50(经过池化层2)

-> 60000 -> num_classes(2)

计算公式:

卷积维度计算公式:

  • 高度方向: H o u t = ( H i n − K e r n e l _ s i z e + 2 × p a d d i n g ) s t r i d e + 1 H_{out} = \frac{\left(H_{in} - Kernel\_size + 2\times padding\right)}{stride} + 1 Hout=stride(HinKernel_size+2×padding)+1

  • 宽度方向: W o u t = ( W i n − K e r n e l _ s i z e + 2 × p a d d i n g ) s t r i d e + 1 W_{out} = \frac{\left(W_{in} - Kernel\_size + 2\times padding\right)}{stride} + 1 Wout=stride(WinKernel_size+2×padding)+1

  • 卷积层通道数变化:数据通道数为卷积层该卷积层定义的输出通道数,例如:self.conv1 = nn.Conv2d(3,64,kernel_size = 3)。在这个例子中,输出的通道数为64,这意味着卷积层使用了64个不同的卷积核,每个核都在输入数据上独立进行卷积运算,产生一个新的通道。需要注意,卷积操作不是在单独的通道上进行的,而是跨所有输入通道(本例中为3个通道)进行的,每个卷积核提供一个新的输出通道。

池化层计算公式:

  • 高度方向: H o u t = ( H i n + 2 × p a d d i n g H − d i l a t i o n H × ( k e r n e l _ s i z e H − 1 ) − 1 s t r i d e H + 1 ) H_{out} = \left(\frac{H_{in} + 2 \times padding_H - dilation_H \times (kernel\_size_H - 1) - 1}{stride_H} + 1 \right) Hout=(strideHHin+2×paddingHdilationH×(kernel_sizeH1)1+1)

  • 宽度方向: W o u t = ( W i n + 2 × p a d d i n g W − d i l a t i o n W × ( k e r n e l _ s i z e W − 1 ) − 1 s t r i d e W + 1 ) W_{out} = \left( \frac{W_{in} + 2 \times padding_W - dilation_W \times (kernel\_size_W - 1) - 1}{stride_W} + 1 \right) Wout=(strideWWin+2×paddingWdilationW×(kernel_sizeW1)1+1)

其中:

  • H i n H_{in} Hin W i n W_{in} Win 是输入的高度和宽度。
  • p a d d i n g H padding_H paddingH p a d d i n g W padding_W paddingW 是在高度和宽度方向上的填充量。
  • k e r n e l _ s i z e H kernel\_size_H kernel_sizeH k e r n e l _ s i z e W kernel\_size_W kernel_sizeW 是卷积核或池化核在高度和宽度方向上的大小。
  • s t r i d e H stride_H strideH s t r i d e W stride_W strideW 是在高度和宽度方向上的步长。
  • d i l a t i o n H dilation_H dilationH d i l a t i o n W dilation_W dilationW 是在高度和宽度方向上的膨胀系数。

请注意,这里的膨胀系数 $dilation \times (kernel_size - 1) $实际上表示核在膨胀后覆盖的区域大小。例如,一个 $3 \times 3 $ 的核,如果膨胀系数为2,则实际上它覆盖的区域大小为$ 5 \times 5 $(原始核大小加上膨胀引入的间隔)。

计算流程:

输入数据:( 3 ∗ 224 ∗ 224 3*224*224 3224224)

conv1计算:卷积核数12,输出的通道也为12。-> ( 12 ∗ 220 ∗ 220 ) (12*220*220) (12220220)
输出维度 = ( 224 − 5 + 2 × 0 ) 1 + 1 = 220 \text{输出维度} = \frac{\left(224 - 5 + 2 \times 0\right)}{1} + 1 = 220 输出维度=1(2245+2×0)+1=220

conv2计算:-> ( 12 ∗ 216 ∗ 216 ) (12*216*216) (12216216)
输出维度 = ( 220 − 5 + 2 × 0 ) 1 + 1 = 216 \text{输出维度} = \frac{\left(220 - 5 + 2 \times 0\right)}{1} + 1 = 216 输出维度=1(2205+2×0)+1=216

pool1计算:通道数不变,步长为2-> ( 12 ∗ 108 ∗ 108 ) (12*108*108) (12108108)
输出维度 = ( 216 + 2 × 0 − 1 × ( 2 − 1 ) − 1 2 + 1 ) = 107 + 1 = 108 \text{输出维度} = \left(\frac{216 + 2 \times 0 - 1 \times \left(2 - 1\right) - 1}{2} + 1 \right) = 107 +1 = 108 输出维度=(2216+2×01×(21)1+1)=107+1=108

conv3计算:-> ( 24 ∗ 104 ∗ 104 ) (24*104*104) (24104104)
输出维度 = ( 108 − 5 + 2 × 0 ) 1 + 1 = 104 \text{输出维度} = \frac{\left(108 - 5 + 2 \times 0\right)}{1} + 1 = 104 输出维度=1(1085+2×0)+1=104

conv4计算:-> ( 24 ∗ 100 ∗ 100 ) (24*100*100) (24100100)
输出维度 = ( 104 − 5 + 2 × 0 ) 1 + 1 = 100 \text{输出维度} = \frac{\left(104 - 5 + 2 \times 0\right)}{1} + 1 = 100 输出维度=1(1045+2×0)+1=100

pool2计算:-> ( 24 ∗ 50 ∗ 50 ) (24*50*50) (245050)
输出维度 = ( 100 + 2 × 0 − 1 × ( 2 − 1 ) − 1 2 + 1 ) = 49 + 1 = 50 \text{输出维度} = \left(\frac{100 + 2 \times 0 - 1 \times \left(2 - 1\right) - 1}{2} + 1 \right) = 49 +1 = 50 输出维度=(2100+2×01×(21)1+1)=49+1=50

flatten层:-> 60000 60000 60000

n u m _ c l a s s e s ( 2 ) num\_classes(2) num_classes(2)

4. 设置超参数

loss_fn = torch.nn.CrossEntropyLoss()
learn_rate = 1e-4
opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

5. 定义训练函数

len(train_dl.dataset),len(train_dl),len(test_dl.dataset),len(test_dl)
(1713, 54, 429, 14)
# 训练循环
def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset) # 训练集的大小,一共1713张图片
    num_batches = len(dataloader) # 批次数目,54(1713/32)
    
    train_loss,train_acc = 0,0 # 初始化训练损失和正确率
    
    for X,y in dataloader: # 获取图片及其标签
        X,y = X.to(device),y.to(device)
        
        # 计算预测误差
        pred = model(X) # 网络输出
        loss = loss_fn(pred,y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 返乡传播
        optimizer.zero_grad() # grad属性归零
        loss.backward() # 反向传播
        optimizer.step() # 每一步自动更新
        
        # 记录acc与loss
        train_acc += (pred.argmax(1)==y).type(torch.float).sum().item()
        train_loss += loss.item()
        
    train_acc /= size
    train_loss /= num_batches
    
    return train_acc,train_loss

6. 定义测试函数

def test(dataloader,model,loss_fn):
    size = len(dataloader.dataset) # 测试集的大小,一共429张图片
    num_batches = len(dataloader) # 批次数目,14(429/32=13.4,向上取整)
    
    test_acc,test_loss = 0,0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for X,y in dataloader:
            X,y = X.to(device),y.to(device)

            # 计算预测误差
            pred = model(X)
            loss = loss_fn(pred,y)
            
            # 记录acc与loss
            test_acc += (pred.argmax(1)==y).type(torch.float).sum().item()
            test_loss += loss.item()
    
    test_acc /= size
    test_loss /= num_batches
    
    return test_acc,test_loss

7. 训练过程

epochs = 20
train_acc = []
train_loss = []
test_acc = []
test_loss = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)
    
    model.eval()
    epoch_test_acc,epoch_test_loss = test(test_dl,model,loss_fn)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss))
    
print('Done!')
Epoch: 1,Train_acc:57.0%,Train_loss:0.669,Test_acc:60.1%,Test_loss:0.661
Epoch: 2,Train_acc:65.8%,Train_loss:0.615,Test_acc:64.8%,Test_loss:0.618
Epoch: 3,Train_acc:70.1%,Train_loss:0.568,Test_acc:67.6%,Test_loss:0.594
Epoch: 4,Train_acc:74.1%,Train_loss:0.530,Test_acc:70.6%,Test_loss:0.589
Epoch: 5,Train_acc:76.3%,Train_loss:0.500,Test_acc:70.9%,Test_loss:0.561
Epoch: 6,Train_acc:78.6%,Train_loss:0.472,Test_acc:73.2%,Test_loss:0.532
Epoch: 7,Train_acc:81.1%,Train_loss:0.442,Test_acc:74.1%,Test_loss:0.518
Epoch: 8,Train_acc:82.2%,Train_loss:0.421,Test_acc:75.3%,Test_loss:0.516
Epoch: 9,Train_acc:83.1%,Train_loss:0.401,Test_acc:76.0%,Test_loss:0.497
Epoch:10,Train_acc:84.2%,Train_loss:0.388,Test_acc:77.6%,Test_loss:0.495
Epoch:11,Train_acc:85.6%,Train_loss:0.371,Test_acc:77.6%,Test_loss:0.478
Epoch:12,Train_acc:86.4%,Train_loss:0.363,Test_acc:78.6%,Test_loss:0.487
Epoch:13,Train_acc:87.9%,Train_loss:0.342,Test_acc:78.3%,Test_loss:0.476
Epoch:14,Train_acc:89.0%,Train_loss:0.331,Test_acc:80.2%,Test_loss:0.457
Epoch:15,Train_acc:89.8%,Train_loss:0.317,Test_acc:79.5%,Test_loss:0.480
Epoch:16,Train_acc:90.3%,Train_loss:0.313,Test_acc:80.9%,Test_loss:0.472
Epoch:17,Train_acc:90.7%,Train_loss:0.298,Test_acc:80.9%,Test_loss:0.442
Epoch:18,Train_acc:91.4%,Train_loss:0.289,Test_acc:82.1%,Test_loss:0.440
Epoch:19,Train_acc:91.1%,Train_loss:0.286,Test_acc:82.5%,Test_loss:0.429
Epoch:20,Train_acc:91.2%,Train_loss:0.277,Test_acc:81.6%,Test_loss:0.427
Done!

8. 结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

9.模型的保存及预测

9.1 代码部分

# 模型保存
# torch.save(model.'mpox_rec_model.pth') # 保存整个模型

# 模型加载
# model2 = torch.load('mpox_rec_model.pth') 
# model2 = model2.to(device) # 理论上在哪里保存模型,加载模型也会优先在哪里,指定一下确保不会出错
# 状态字典保存
torch.save(model.state_dict(),'mpox_rec_model_state_dict.pth') # 仅保存状态字典

# 加载状态字典到模型
model2 = Network_bn().to(device) # 重新定义模型
model2.load_state_dict(torch.load('mpox_rec_model_state_dict.pth')) # 加载状态字典到模型
<All keys matched successfully>
# 指定路径图片预测
from PIL import Image
import torchvision.transforms as transforms

classes = list(total_data.class_to_idx) # classes = list(total_data.class_to_idx)

def predict_one_image(image_path,model,transform,classes):
    
    test_img = Image.open(image_path).convert('RGB')
    # plt.imshow(test_img) # 展示待预测的图片
    
    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)
    
    model.eval()
    output = model(img)
    print(output) # 观察模型预测结果的输出数据
    
    _,pred = torch.max(output,1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')
# 预测训练集中的某张照片
predict_one_image(image_path='./data/mpox_recognize/Monkeypox/M01_01_00.jpg', 
                  model=model, 
                  transform=train_transforms, 
                  classes=classes)
tensor([[1.7220, 1.2390]], device='cuda:0', grad_fn=<ReluBackward0>)
预测结果是:Monkeypox

9.2 疑问解析

问题一:您好,请为我分析一下这段代码的含义,“classes = list (total_data.class_to_idx)”,其中total_data是使用torchvision.datasets.ImageFolder()处理得到的。请问具体的处理过程是怎样的?不同文件夹的名称处理后会自动作为分类处理吗?

问题解答:

您提供的代码行’ classes = list(total_data.class_to_idx) '通常用于PyTorch中加载和处理图像数据集的上下文中,特别是在使用"torchvision.datasets.ImageFolder"类时。让我们分解这个过程和具体的代码行来理解它的含义,以及它如何适应PyTorch图像数据处理的大背景。

理解“torchvision.datasets.ImageFolder”

"torchvision.datasets.ImageFolder"是PyTorch的torchvision包提供的一个实用程序类,旨在简化存储在磁盘目录结构中的图像数据的加载、转换和批处理。对于图像分类任务,它特别有用,因为您将数据集组织到子目录中,每个子目录以您希望模型学习的一个类命名。

目录结构

在深入研究代码之前,有必要了解预期的目录结构。假设您有一个包含三个类别的数据集:“猫”、“狗”和“鸟”。目录结构应该是这样的:

/path/to/data/
├── cats
│   ├── cat001.jpg
│   ├── cat002.jpg
│   └── ...
├── dogs
│   ├── dog001.jpg
│   ├── dog002.jpg
│   └── ...
└── birds
    ├── bird001.jpg
    ├── bird002.jpg
    └── ...

使用ImageFolder

当你实例化一个带有数据路径的’ ImageFolder '对象时,PyTorch会做以下操作:

  1. 扫描目录:递归扫描镜像文件的目录结构。
  2. 为类分配索引:每个子目录被视为一个单独的类,PyTorch自动为每个类分配索引。类名(子目录名)到它们分配的索引的映射存储在名为’class_to_idx '的字典属性中。
  3. 加载和转换图像:它提供了动态加载图像和应用任何指定转换(调整大小,规范化等)的机制。

分析特定的代码行

classes = list(total_data.class_to_idx)

鉴于上述背景:

  • ‘total_data’:这是’ ImageFolder '类的一个实例,它已经按照描述的方式处理了目录结构。

  • ‘total_data.class_to_idx’:这是一个字典,其中键是类名(派生自子目录的名称),值是它们对应的索引(整数)。索引是根据类名的排序顺序分配的。

  • ‘list(total_data.class_to_idx)’:此代码将’ class_to_idx '字典的键转换为列表。生成的“classes”列表包含按顺序排序的类名(子目录名)。此列表通常用于根据代码中其他地方的索引引用类名,特别是在解释模型的输出时。

总结:

因此,总而言之,您提供的代码提取并列出了“ImageFolder”自动识别和索引的类名(来自数据集中的子目录名)。“ImageFolder”的这个功能简化了加载图像数据进行分类的任务,确保每个类别,由一个文件夹表示,被自动视为一个单独的类。




本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1586445.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

关于AI Agent、RAG技术揭秘:如何让人工智能更懂你?

人工智能技术正以前所未有的速度改变着我们的世界。从深度学习算法的突破到自动化和机器学习技术的进步。在这个变革的时代&#xff0c;几种前沿技术尤其引人注目&#xff0c;其中包括RAG&#xff08;Retrieval-Augmented Generation&#xff09;、AI Agent以及多模态技术。 近…

python之文件操作与管理

1、文件操作 通过open&#xff08;&#xff09;操作&#xff0c;来创建文件对象&#xff0c;下面是open&#xff08;&#xff09;函数语法如下&#xff1a; open&#xff08;file,mode r,buffering -1 , encoding None ,errors None , newline None,closefd True,opener …

2024年Flink CDC 实时同步数据(MySQL到MySQL)

#准备工作# 看到一下图片说明执行成功&#xff01;&#xff01;&#xff01; 异常处理及分析&#xff1a; Could not execute SQL statement. Reason: org.apache.flink.sql.parser.impl.ParseException: Encountered "\connector\" at line 21, column 3. Was expec…

记录一次浅拷贝导致的缺陷

背景&#xff1a;假期表查询&#xff0c;表中存放的工作日信息是按照月份维度的&#xff0c;例如1月的假期表信息是NNNYYYYYNN.....一共31天&#xff0c;如果是工作日那么就是Y&#xff0c;如果非工作日那就是N。获取指定日期的下一个工作日&#xff0c;就会先查出这个月份的这…

Qt学习记录(C++)——Day 2

目录 一、作业 要求&#xff1a; 实现&#xff1a; 1.创建新的窗口类 2. 主窗口中实现 二、 窗口菜单设计 效果展示图 三、图片资源的导入 步骤&#xff1a; 举例&#xff1a; 四、 对话框 1.模拟对话框 2. 非模态对话框 3.错误对话框 4.信息对话框 5.提问对话…

边缘计算采集网关如何助力制造企业解决数采问题-天拓四方

一、企业背景 某大型制造企业&#xff0c;位于国内某经济发达的工业园区内&#xff0c;拥有多个生产线和智能化设备&#xff0c;致力于提高生产效率、降低运营成本。随着企业规模的扩大和生产自动化的推进&#xff0c;该企业面临着海量数据处理、实时响应和网络安全等多重挑战…

Linux下网络编程基础知识--协议

网络基础 这一个课程的笔记 相关文章 协议 Socket编程 高并发服务器实现 线程池 协议 一组规则, 数据传输和数据的解释的规则。 比如说依次发送文件的文件名, 文件的大小, 以及实际的文件, 这样规定发送一个文件的顺序以及发送的每一个部分的格式等可以算是一种协议 型协议 …

(学习日记)2024.04.12:UCOSIII第四十节:软件定时器函数接口讲解

写在前面&#xff1a; 由于时间的不足与学习的碎片化&#xff0c;写博客变得有些奢侈。 但是对于记录学习&#xff08;忘了以后能快速复习&#xff09;的渴望一天天变得强烈。 既然如此 不如以天为单位&#xff0c;以时间为顺序&#xff0c;仅仅将博客当做一个知识学习的目录&a…

CUDA 12.4文档1:使用GPU好处CUDA可扩展编程模型

本博客参考官方文档进行介绍&#xff0c;全网仅此一家进行中文翻译&#xff0c;走过路过不要错过。 官方网址&#xff1a;https://docs.nvidia.com/cuda/cuda-c-programming-guide/ 本文档分成多个博客进行介绍&#xff0c;在本人专栏中含有所有内容&#xff1a; https://bl…

Linux网卡:连接虚拟与现实的桥梁

在介绍Linux网卡之前&#xff0c;让我们先迈入时光机&#x1f570;️&#xff0c;回到1980年代末期&#xff0c;互联网正在逐步从一个科研网络向公众网络转变&#xff0c;Linux——一个自由和开源的操作系统诞生了&#x1f427;。Linux的出现&#xff0c;对于计算机科学领域来说…

【微命令】git 如何修改某个分支的名字(git branch -m newbranch)

简要信息&#xff0c;快速记录 命令 # 切换到某个需要修改的分支 git checkout oldbranch# 修改分支名字 git branch -m newbranch假设作为git设计者&#xff0c;要用来修改branch的命令&#xff0c;那么就是 git branch作为前缀&#xff0c;然后进一步修改的命令是branch相关…

CorelDRAW21.2.4中文最新官方和谐版下载

CorelDRAW是一款由加拿大Corel公司出品的平面设计软件&#xff0c;也被称为CDR。它是一款功能强大的矢量图形制作和排版软件&#xff0c;主要面向绘图设计师和印刷输出人员。该软件提供了矢量插图、页面布局、图片编辑和设计工具&#xff0c;广泛应用于排版印刷、矢量图形编辑及…

24/04/11总结

IO流(First edition): IO流&#xff1a;用于读入写出文件中的数据 流的方向&#xff08;输入指拿出来,输出指写进去) 输入流:读取 输出流:写出 操作文件类型 字节流:所有类型文件 字符流:纯文本 字节流: InputStream的子类:FileInputStream:操作本地文件的字节输入流 OutputSt…

【HTML】制作一个简单的实时字体时钟

目录 前言 HTML部分 CSS部分 JS部分 效果图 总结 前言 无需多言&#xff0c;本文将详细介绍一段HTML代码&#xff0c;具体内容如下&#xff1a; 开始 首先新建文件夹&#xff0c;创建一个文本文档&#xff0c;两个文件夹&#xff0c;其中HTML的文件名改为[index.html]&am…

二分查找-图文详解,看不懂你来打我。。。

一、查找算法 在计算机科学和算法领域&#xff0c;搜索是一项基本的任务。在海量数据中寻找特定的元素是一项常见的任务&#xff0c;而二分查找&#xff08;Binary Search&#xff09;是一种非常高效的搜索算法&#xff0c;特别适用于有序数组。 二、二分查找 二分查找是一种…

【阿里淘天笔试题汇总】2024-04-10-阿里淘天春招笔试题-三语言题解(CPP/Python/Java)

&#x1f36d; 大家好这里是KK爱Coding &#xff0c;一枚热爱算法的程序员 ✨ 本系列打算持续跟新淘天近期的春秋招笔试题汇总&#xff5e; &#x1f4bb; ACM银牌&#x1f948;| 多次AK大厂笔试 &#xff5c; 编程一对一辅导 &#x1f44f; 感谢大家的订阅➕ 和 喜欢&#x1f…

【前端Vue】Vue3+Pinia小兔鲜电商项目第5篇:整体认识和路由配置,本资源由 收集整理【附代码文档】

Vue3ElementPlusPinia开发小兔鲜电商项目完整教程&#xff08;附代码资料&#xff09;主要内容讲述&#xff1a;认识Vue3&#xff0c;使用create-vue搭建Vue3项目1. Vue3组合式API体验,2. Vue3更多的优势,1. 认识create-vue,2. 使用create-vue创建项目,1. setup选项的写法和执行…

基于Springboot的箱包存储系统(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的箱包存储系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&…

linux的io的知识大全

C语言的io操作 写文件 #include<stdio.h> #include<string.h>#define FILE_NAME "log.txt" int main() {FILE * fp fopen(FILE_NAME, "w");if(fpNULL){printf("fopen error!\n");}const char* msg "hello zk\n";int c…

【复现】浙大恩特客户资源管理系统 SQL注入漏洞_71

目录 一.概述 二 .漏洞影响 三.漏洞复现 1. 漏洞一&#xff1a; 四.修复建议&#xff1a; 五. 搜索语法&#xff1a; 六.免责声明 一.概述 浙大恩特客户资源管理系统是一款针对企业客户资源管理的软件产品。该系统旨在帮助企业高效地管理和利用客户资源&#xff0c;提升…