图片分类实战:食物分类问题(含半监督)

news2025/3/9 8:00:53

食物分类问题

simple_class


1. 导入必要的库和模块

import random
import torch
import torch.nn as nn
import numpy as np
import os
from PIL import Image #读取图片数据
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchvision import transforms
import time
import matplotlib.pyplot as plt
from model_utils.model import initialize_model
  • import random: 导入Python标准库中的随机数生成器。
  • import torch: 导入PyTorch库,用于深度学习模型的构建和训练。
  • import torch.nn as nn: 导入PyTorch的神经网络模块,包含各种层和损失函数。
  • import numpy as np: 导入NumPy库,用于数值计算和数组操作。
  • import os: 导入操作系统接口模块,用于文件路径处理。
  • from PIL import Image: 导入PIL(Python Imaging Library)库,用于图像处理。
  • from torch.utils.data import Dataset, DataLoader: 导入PyTorch的数据集和数据加载器类,用于管理数据集和批量加载数据。
  • from tqdm import tqdm: 导入tqdm库,用于显示进度条。
  • from torchvision import transforms: 导入PyTorch的图像变换模块,用于对图像进行预处理。
  • import time: 导入时间模块,用于记录训练时间。
  • import matplotlib.pyplot as plt: 导入matplotlib库,用于绘制图表。
  • from model_utils.model import initialize_model: 导入自定义模块中的初始化模型函数。

2. 设置随机种子以确保结果可重复

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

seed_everything(0)
  • def seed_everything(seed):: 定义一个函数 seed_everything,用于设置所有可能影响随机性的种子。
  • torch.manual_seed(seed): 设置PyTorch的CPU随机种子。
  • torch.cuda.manual_seed(seed): 设置PyTorch的GPU随机种子。
  • torch.cuda.manual_seed_all(seed): 如果有多个GPU,设置所有GPU的随机种子。
  • torch.backends.cudnn.benchmark = False: 关闭CuDNN自动优化功能,确保每次运行的结果一致。
  • torch.backends.cudnn.deterministic = True: 设置CuDNN为确定性模式,确保结果可重复。
  • random.seed(seed): 设置Python内置随机数生成器的种子。
  • np.random.seed(seed): 设置NumPy的随机数生成器的种子。
  • os.environ['PYTHONHASHSEED'] = str(seed): 设置环境变量 PYTHONHASHSEED,确保哈希值的一致性。
  • seed_everything(0): 调用 seed_everything 函数,设置全局随机种子为0。

3. 定义图像变换

HW = 224

train_transform = transforms.Compose(
    [
        transforms.ToPILImage(),   # 将numpy.ndarray转换为PIL.Image
        transforms.RandomResizedCrop(224),  # 随机裁剪并调整大小到224x224
        transforms.RandomRotation(50),  # 随机旋转角度在[-50, 50]之间
        transforms.ToTensor()  # 将PIL.Image转换为tensor
    ]
)

val_transform = transforms.Compose(
    [
        transforms.ToPILImage(),   # 将numpy.ndarray转换为PIL.Image
        transforms.ToTensor()  # 将PIL.Image转换为tensor
    ]
)
  • HW = 224: 定义图像的高度和宽度为224像素。
  • train_transform: 定义训练集的图像变换组合:
    • transforms.ToPILImage(): 将输入的numpy数组转换为PIL图像格式。
    • transforms.RandomResizedCrop(224): 随机裁剪并调整大小到224x224像素。
    • transforms.RandomRotation(50): 随机旋转图像,角度范围在[-50, 50]度之间。
    • transforms.ToTensor(): 将PIL图像转换为PyTorch张量(tensor),并将像素值归一化到[0, 1]区间。
  • val_transform: 定义验证集的图像变换组合:
    • transforms.ToPILImage(): 将输入的numpy数组转换为PIL图像格式。
    • transforms.ToTensor(): 将PIL图像转换为PyTorch张量(tensor),并将像素值归一化到[0, 1]区间。

4. 自定义数据集类

class food_Dataset(Dataset):
    def __init__(self, path, mode="train"):
        self.mode = mode
        if mode == "semi":
            self.X = self.read_file(path)
        else:
            self.X, self.Y = self.read_file(path)
            self.Y = torch.LongTensor(self.Y)  # 标签转为长整形

        if mode == "train":
            self.transform = train_transform
        else:
            self.transform = val_transform

    def read_file(self, path):
        if self.mode == "semi":
            file_list = os.listdir(path)
            xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
            for j, img_name in enumerate(file_list):
                img_path = os.path.join(path, img_name)
                img = Image.open(img_path)
                img = img.resize((HW, HW))
                xi[j, ...] = np.array(img)
            print("读到了%d个数据" % len(xi))
            return xi
        else:
            for i in tqdm(range(11)):
                file_dir = path + "/%02d" % i
                file_list = os.listdir(file_dir)
                xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
                yi = np.zeros(len(file_list), dtype=np.uint8)
                for j, img_name in enumerate(file_list):
                    img_path = os.path.join(file_dir, img_name)
                    img = Image.open(img_path)
                    img = img.resize((HW, HW))
                    xi[j, ...] = np.array(img)
                    yi[j] = i
                if i == 0:
                    X = xi
                    Y = yi
                else:
                    X = np.concatenate((X, xi), axis=0)
                    Y = np.concatenate((Y, yi), axis=0)
            print("读到了%d个数据" % len(Y))
            return X, Y

    def __getitem__(self, item):
        if self.mode == "semi":
            return self.transform(self.X[item]), self.X[item]
        else:
            return self.transform(self.X[item]), self.Y[item]

    def __len__(self):
        return len(self.X)
  • class food_Dataset(Dataset):: 定义一个继承自 Dataset 的自定义数据集类 food_Dataset
  • def __init__(self, path, mode="train"):: 初始化方法,接受数据集路径和模式(默认为“train”)作为参数。
    • self.mode = mode: 记录数据集的模式。
    • if mode == "semi":: 如果是半监督模式,则仅读取未标记的图像数据。
    • else:: 否则,读取带有标签的图像数据,并将标签转换为长整型。
    • if mode == "train":: 如果是训练模式,使用 train_transform 进行图像变换。
    • else:: 否则,使用 val_transform 进行图像变换。
  • def read_file(self, path):: 定义一个读取文件的方法,根据不同的模式读取图像数据。
    • if self.mode == "semi":: 如果是半监督模式,读取未标记的图像数据:
      • file_list = os.listdir(path): 获取目录下的所有文件名。
      • xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8): 创建一个零数组用于存储图像数据。
      • for j, img_name in enumerate(file_list):: 遍历每个文件名,打开图像并调整大小,然后将其存储在 xi 中。
      • print("读到了%d个数据" % len(xi)): 打印读取到的图像数量。
      • return xi: 返回图像数据。
    • else:: 否则,读取带有标签的图像数据:
      • for i in tqdm(range(11)):: 使用 tqdm 显示进度条,遍历每个类别(假设共有11个类别)。
      • file_dir = path + "/%02d" % i: 构建类别目录路径。
      • file_list = os.listdir(file_dir): 获取该类别目录下的所有文件名。
      • xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8): 创建一个零数组用于存储图像数据。
      • yi = np.zeros(len(file_list), dtype=np.uint8): 创建一个零数组用于存储标签。
      • for j, img_name in enumerate(file_list):: 遍历每个文件名,打开图像并调整大小,然后将其存储在 xi 中,并将对应的标签存储在 yi 中。
      • if i == 0:: 如果是第一个类别,初始化 XY
      • else:: 否则,将当前类别的图像和标签连接到已有数据中。
      • print("读到了%d个数据" % len(Y)): 打印读取到的图像数量。
      • return X, Y: 返回图像数据和标签。
  • def __getitem__(self, item):: 定义获取指定索引的数据项的方法。
    • if self.mode == "semi":: 如果是半监督模式,返回变换后的图像及其原始图像。
    • else:: 否则,返回变换后的图像及其标签。
  • def __len__(self):: 定义返回数据集长度的方法,即图像的数量。

5. 半监督数据集类

class semiDataset(Dataset):
    def __init__(self, no_label_loader, model, device, thres=0.99):
        x, y = self.get_label(no_label_loader, model, device, thres)
        if x == []:
            self.flag = False
        else:
            self.flag = True
            self.X = np.array(x)
            self.Y = torch.LongTensor(y)
            self.transform = train_transform

    def get_label(self, no_label_loader, model, device, thres):
        model = model.to(device)
        pred_prob = []
        labels = []
        x = []
        y = []
        soft = nn.Softmax(dim=1)
        with torch.no_grad():
            for bat_x, _ in no_label_loader:
                bat_x = bat_x.to(device)
                pred = model(bat_x)
                pred_soft = soft(pred)
                pred_max, pred_value = pred_soft.max(1)
                pred_prob.extend(pred_max.cpu().numpy().tolist())
                labels.extend(pred_value.cpu().numpy().tolist())

        for index, prob in enumerate(pred_prob):
            if prob > thres:
                x.append(no_label_loader.dataset[index][0])
                y.append(labels[index])
        return x, y

    def __getitem__(self, item):
        return self.transform(self.X[item]), self.Y[item]

    def __len__(self):
        return len(self.X)
  • class semiDataset(Dataset):: 定义一个继承自 Dataset 的半监督数据集类 semiDataset
  • def __init__(self, no_label_loader, model, device, thres=0.99):: 初始化方法,接受未标记数据加载器、模型、设备和置信度阈值作为参数。
    • x, y = self.get_label(no_label_loader, model, device, thres): 调用 get_label 方法获取高置信度的伪标签样本。
    • if x == []:: 如果没有找到符合条件的样本,设置 flagFalse
    • else:: 否则,设置 flagTrue,并将样本数据和标签存储在 self.Xself.Y 中,并使用 train_transform 进行图像变换。
  • def get_label(self, no_label_loader, model, device, thres):: 定义一个获取伪标签的方法。
    • model = model.to(device): 将模型移动到指定设备(CPU或GPU)。
    • pred_prob = []: 初始化预测概率列表。
    • labels = []: 初始化标签列表。
    • x = []: 初始化图像数据列表。
    • y = []: 初始化标签列表。
    • soft = nn.Softmax(dim=1): 初始化Softmax函数,用于将模型输出转换为概率分布。
    • with torch.no_grad():: 禁用梯度计算,减少内存占用和加速推理。
      • for bat_x, _ in no_label_loader:: 遍历未标记数据加载器中的每个批次。
        • bat_x = bat_x.to(device): 将输入数据移动到指定设备。
        • pred = model(bat_x): 使用模型进行前向传播,得到预测结果。
        • pred_soft = soft(pred): 使用Softmax函数将预测结果转换为概率分布。
        • pred_max, pred_value = pred_soft.max(1): 获取每个样本的最大概率及其对应的类别。
        • pred_prob.extend(pred_max.cpu().numpy().tolist()): 将最大概率值添加到 pred_prob 列表中。
        • labels.extend(pred_value.cpu().numpy().tolist()): 将对应的类别标签添加到 labels 列表中。
    • for index, prob in enumerate(pred_prob):: 遍历每个样本的概率值。
      • if prob > thres:: 如果概率值大于设定的阈值,则认为该样本的预测结果是可靠的。
        • x.append(no_label_loader.dataset[index][0]): 将该样本的图像数据添加到 x 列表中。
        • y.append(labels[index]): 将该样本的预测标签添加到 y 列表中。
    • return x, y: 返回筛选出的图像数据和标签。
  • def __getitem__(self, item):: 定义获取指定索引的数据项的方法,返回变换后的图像及其标签。
  • def __len__(self):: 定义返回数据集长度的方法,即图像的数量。

明白了,我们将从“6. 获取半监督数据加载器”继续逐句分析代码,并保持详细的解释风格。

6. 获取半监督数据加载器

def get_semi_loader(no_label_loder, model, device, thres):
    semiset = semiDataset(no_label_loder, model, device, thres)
    if semiset.flag == False:
        return None
    else:
        semi_loader = DataLoader(semiset, batch_size=16, shuffle=False)
        return semi_loader
  • get_semi_loader:定义了一个函数,用于创建包含伪标签样本的数据加载器。
    • no_label_loder:未标记数据的加载器。
    • model:当前训练的模型,用于对未标记数据进行预测。
    • device:设备类型(CPU或GPU)。
    • thres:置信度阈值,用于选择高置信度样本。
  • semiset:使用 semiDataset 类创建一个包含伪标签样本的数据集对象。
  • if semiset.flag == False:如果 semiDataset 对象中没有满足条件的样本,则返回 None
  • else:否则,使用 DataLoader 创建一个新的数据加载器 semi_loader,批次大小为16且不打乱数据。

明白了,让我们重新详细解析 myModel 类的定义部分,并继续进入训练和验证函数的解析。

7. 定义模型

class myModel(nn.Module):
    def __init__(self, num_class):
        super(myModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)

        self.layer1 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(25088, 1000)
        self.relu2 = nn.ReLU()
        self.fc2 = nn.Linear(1000, num_class)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool2(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.relu2(x)
        x = self.fc2(x)
        return x
__init__ 方法
  • super(myModel, self).init():调用父类 nn.Module 的构造函数。
  • self.conv1:定义第一个卷积层,输入通道数为3(RGB图像),输出通道数为64,卷积核大小为3x3,步长为1,填充为1。
  • self.bn1:定义第一个批量归一化层,用于归一化卷积层的输出。
  • self.relu:定义ReLU激活函数。
  • self.pool1:定义第一个最大池化层,池化窗口大小为2x2。
  • self.layer1:定义第一个卷积块,包含一个卷积层、批量归一化层、ReLU激活函数和最大池化层。卷积层将输入通道数从64变为128。
  • self.layer2:定义第二个卷积块,与 layer1 类似,但将输入通道数从128变为256。
  • self.layer3:定义第三个卷积块,与 layer2 类似,但将输入通道数从256变为512。
  • self.pool2:定义第二个最大池化层,池化窗口大小为2x2。
  • self.fc1:定义第一个全连接层,输入特征数为25088(经过前面的卷积和池化操作后的特征图大小),输出特征数为1000。
  • self.relu2:定义第二个ReLU激活函数。
  • self.fc2:定义第二个全连接层,输入特征数为1000,输出特征数为 num_class(类别数量)。
forward 方法
  • def forward(self, x):定义前向传播过程。
    • x = self.conv1(x):对输入数据 x 进行第一次卷积操作。
    • x = self.bn1(x):对卷积结果进行批量归一化。
    • x = self.relu(x):应用ReLU激活函数。
    • x = self.pool1(x):对激活结果进行最大池化操作。
    • x = self.layer1(x):通过第一个卷积块。
    • x = self.layer2(x):通过第二个卷积块。
    • x = self.layer3(x):通过第三个卷积块。
    • x = self.pool2(x):对第三个卷积块的结果进行最大池化操作。
    • x = x.view(x.size()[0], -1):将多维张量展平成二维张量,以便输入到全连接层中。
    • x = self.fc1(x):通过第一个全连接层。
    • x = self.relu2(x):应用ReLU激活函数。
    • x = self.fc2(x):通过第二个全连接层,输出最终预测结果。
    • return x:返回模型的预测结果。

8. 训练和验证函数

接下来我们继续解析 train_val 函数:

def train_val(model, train_loader, val_loader, no_label_loader, device, epochs, optimizer, loss, thres, save_path):
    model = model.to(device)
    semi_loader = None
    plt_train_loss = []
    plt_val_loss = []
    plt_train_acc = []
    plt_val_acc = []
    max_acc = 0.0

    for epoch in range(epochs):
        train_loss = 0.0
        val_loss = 0.0
        train_acc = 0.0
        val_acc = 0.0
        start_time = time.time()
初始化部分
  • model = model.to(device):将模型移动到指定设备(CPU或GPU)。
  • semi_loader = None:初始化半监督数据加载器为 None
  • plt_train_lossplt_val_lossplt_train_accplt_val_acc:分别存储训练和验证的损失及准确率。
  • max_acc = 0.0:初始化最大验证准确率为0.0。
每个epoch的训练循环
for epoch in range(epochs):
    train_loss = 0.0
    val_loss = 0.0
    train_acc = 0.0
    val_acc = 0.0
    start_time = time.time()

    model.train()
    for batch_x, batch_y in train_loader:
        x, target = batch_x.to(device), batch_y.to(device)
        pred = model(x)
        train_bat_loss = loss(pred, target)
        train_bat_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += train_bat_loss.cpu().item()
        train_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())
  • for epoch in range(epochs):遍历每个epoch。
  • train_loss = 0.0val_loss = 0.0train_acc = 0.0val_acc = 0.0:初始化每个epoch的损失和准确率。
  • start_time = time.time():记录当前epoch的开始时间。
  • model.train():设置模型为训练模式。
  • for batch_x, batch_y in train_loader:遍历训练数据加载器中的每个批次。
    • x, target = batch_x.to(device), batch_y.to(device):将输入数据和标签移动到指定设备。
    • pred = model(x):前向传播,计算模型输出。
    • train_bat_loss = loss(pred, target):计算批次损失。
    • train_bat_loss.backward():反向传播,计算梯度。
    • optimizer.step():更新模型参数。
    • optimizer.zero_grad():清空梯度,避免累积。
    • train_loss += train_bat_loss.cpu().item():累加批次损失。
    • train_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy()):计算并累加批次准确率。
处理半监督数据
if semi_loader is not None:
    for batch_x, batch_y in semi_loader:
        x, target = batch_x.to(device), batch_y.to(device)
        pred = model(x)
        semi_bat_loss = loss(pred, target)
        semi_bat_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += semi_bat_loss.cpu().item()
        train_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())
    print("半监督数据集的训练准确率为", train_acc / len(semi_loader.dataset))
  • if semi_loader is not None:如果存在半监督数据加载器,则处理这些数据。
    • for batch_x, batch_y in semi_loader:遍历半监督数据加载器中的每个批次。
    • x, target = batch_x.to(device), batch_y.to(device):将输入数据和标签移动到指定设备。
    • pred = model(x):前向传播,计算模型输出。
    • semi_bat_loss = loss(pred, target):计算批次损失。
    • semi_bat_loss.backward():反向传播,计算梯度。
    • optimizer.step():更新模型参数。
    • optimizer.zero_grad():清空梯度,避免累积。
    • train_loss += semi_bat_loss.cpu().item():累加批次损失。
    • train_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy()):计算并累加批次准确率。
    • print(“半监督数据集的训练准确率为”, train_acc / len(semi_loader.dataset)):打印半监督数据集的训练准确率。
验证过程
model.eval()
with torch.no_grad():
    for batch_x, batch_y in val_loader:
        x, target = batch_x.to(device), batch_y.to(device)
        pred = model(x)
        val_bat_loss = loss(pred, target)
        val_loss += val_bat_loss.cpu().item()
        val_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())
  • model.eval():设置模型为评估模式。
  • with torch.no_grad():禁用梯度计算,节省内存和计算资源。
  • for batch_x, batch_y in val_loader:遍历验证数据加载器中的每个批次。
    • x, target = batch_x.to(device), batch_y.to(device):将输入数据和标签移动到指定设备。
    • pred = model(x):前向传播,计算模型输出。
    • val_bat_loss = loss(pred, target):计算批次损失。
    • val_loss += val_bat_loss.cpu().item():累加批次损失。
    • val_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy()):计算并累加批次准确率。
更新半监督数据加载器和保存最佳模型
if epoch % 3 == 0 and plt_val_acc[-1] > 0.6:
    semi_loader = get_semi_loader(no_label_loader, model, device, thres)

if val_acc / len(val_loader.dataset) > max_acc:
    torch.save(model, save_path)
    max_acc = val_acc / len(val_loader.dataset)
  • if epoch % 3 == 0 and plt_val_acc[-1] > 0.6:每3个epoch检查一次是否需要更新半监督数据加载器。
    • semi_loader = get_semi_loader(no_label_loader, model, device, thres):调用 get_semi_loader 函数获取新的半监督数据加载器。
  • if val_acc / len(val_loader.dataset) > max_acc:如果当前验证准确率高于历史最高,则保存当前模型。
    • torch.save(model, save_path):保存模型到指定路径。
    • max_acc = val_acc / len(val_loader.dataset):更新最大验证准确率。
打印训练结果
print('[%03d/%03d] %2.2f sec(s) TrainLoss : %.6f | valLoss: %.6f Trainacc : %.6f | valacc: %.6f' % \
      (epoch, epochs, time.time() - start_time, plt_train_loss[-1], plt_val_loss[-1], plt_train_acc[-1], plt_val_acc[-1]))
  • print:打印每个epoch的训练和验证结果,包括epoch编号、耗时、训练损失、验证损失、训练准确率和验证准确率。
9、绘制损失和准确率曲线
plt.plot(plt_train_loss)
plt.plot(plt_val_loss)
plt.title("loss")
plt.legend(["train", "val"])
plt.show()

plt.plot(plt_train_acc)
plt.plot(plt_val_acc)
plt.title("acc")
plt.legend(["train", "val"])
plt.show()
  • plt.plot(plt_train_loss)plt.plot(plt_val_loss):绘制训练和验证的损失变化曲线。
  • plt.title(“loss”):设置图表标题为“loss”。
  • plt.legend([“train”, “val”]):添加图例,区分训练和验证曲线。
  • plt.show():显示图表。
  • plt.plot(plt_train_acc)plt.plot(plt_val_acc):绘制训练和验证的准确率变化曲线。
  • plt.title(“acc”):设置图表标题为“acc”。
  • plt.legend([“train”, “val”]):添加图例,区分训练和验证曲线。
  • plt.show():显示图表。
    好的,让我们详细解析你提供的代码段,并解释每个部分的功能和作用。

10、数据集路径设置与数据加载器初始化

train_path = r"F:\pycharm\beike\classification\food_classification\food-11_sample\training\labeled"
val_path = r"F:\pycharm\beike\classification\food_classification\food-11_sample\validation"
no_label_path = r"F:\pycharm\beike\classification\food_classification\food-11_sample\training\unlabeled\00"

train_set = food_Dataset(train_path, "train")
val_set = food_Dataset(val_path, "val")
no_label_set = food_Dataset(no_label_path, "semi")

train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=16, shuffle=True)
no_label_loader = DataLoader(no_label_set, batch_size=16, shuffle=False)
  • train_pathval_pathno_label_path:定义了训练集、验证集和未标记数据集的路径。
  • food_Dataset:自定义的数据集类,用于加载和预处理图像数据。它接受路径和模式(“train”、“val” 或 “semi”)作为参数。
    • train_set:创建一个训练数据集对象。
    • val_set:创建一个验证数据集对象。
    • no_label_set:创建一个未标记数据集对象。
  • DataLoader:PyTorch中的数据加载器类,用于批量加载数据。
    • train_loader:训练数据加载器,批次大小为16,且打乱数据。
    • val_loader:验证数据加载器,批次大小为16,且打乱数据。
    • no_label_loader:未标记数据加载器,批次大小为16,不打乱数据。

11、模型初始化

# model = myModel(11)
model, _ = initialize_model("vgg", 11, use_pretrained=True)
  • myModel(11):注释掉的行表示使用自定义模型 myModel,类别数为11。
  • initialize_model(“vgg”, 11, use_pretrained=True):调用一个函数来初始化预训练的VGG模型,类别数为11,并使用预训练权重。

12、超参数设置

lr = 0.001
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
device = "cuda" if torch.cuda.is_available() else "cpu"
save_path = "model_save/best_model.pth"
epochs = 15
thres = 0.99
  • lr:学习率设置为0.001。
  • loss:损失函数使用交叉熵损失 nn.CrossEntropyLoss()
  • optimizer:优化器使用AdamW优化器 torch.optim.AdamW,并设置了学习率和权重衰减参数。
  • device:检查是否有可用的GPU,如果没有则使用CPU。
  • save_path:保存最佳模型的路径。
  • epochs:训练轮数设置为15。
  • thres:置信度阈值设置为0.99,用于半监督学习中选择高置信度样本。

13、训练和验证

train_val(model, train_loader, val_loader, no_label_loader, device, epochs, optimizer, loss, thres, save_path)
  • train_val:调用训练和验证函数,传入模型、数据加载器、设备类型、训练轮数、优化器、损失函数、置信度阈值和保存路径。
数据集加载
train_set = food_Dataset(train_path, "train")
val_set = food_Dataset(val_path, "val")
no_label_set = food_Dataset(no_label_path, "semi")

train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=16, shuffle=True)
no_label_loader = DataLoader(no_label_set, batch_size=16, shuffle=False)
  • food_Dataset:假设这是一个自定义的数据集类,负责读取和预处理图像数据。
  • DataLoader:用于高效加载数据,支持多线程和批处理。
模型初始化
model, _ = initialize_model("vgg", 11, use_pretrained=True)
  • initialize_model:假设这是另一个自定义函数,用于初始化预训练的VGG模型,并根据需要调整输出层以适应11个分类任务。
超参数配置
lr = 0.001
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
device = "cuda" if torch.cuda.is_available() else "cpu"
save_path = "model_save/best_model.pth"
epochs = 15
thres = 0.99
  • lr:学习率设置为0.001。
  • loss:使用交叉熵损失函数,适用于多分类问题。
  • optimizer:使用AdamW优化器,结合了Adam的优点并添加了权重衰减,有助于防止过拟合。
  • device:自动检测并选择合适的计算设备(GPU或CPU)。
  • save_path:指定保存最佳模型的文件路径。
  • epochs:训练轮数设置为15。
  • thres:置信度阈值设置为0.99,用于筛选高质量伪标签样本。
调用训练和验证函数
train_val(model, train_loader, val_loader, no_label_loader, device, epochs, optimizer, loss, thres, save_path)
  • train_val:假设这是一个包含训练和验证逻辑的函数,负责在给定的训练轮数内迭代地训练模型,并在每个epoch结束后进行验证。

总结

这段代码实现了从数据集加载到模型训练和验证的完整流程,具体步骤包括:

  1. 数据集加载:通过自定义的 food_Dataset 类加载训练、验证和未标记数据集,并使用 DataLoader 进行批处理和数据打乱。
  2. 模型初始化:使用预训练的VGG模型,并根据任务需求调整输出层。
  3. 超参数配置:设置学习率、损失函数、优化器等超参数,并确定训练轮数和设备类型。
  4. 训练和验证:调用 train_val 函数执行训练过程,并在每个epoch结束后进行验证,保存最佳模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2312033.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

每日一练之移除链表元素

题目: 画图解析: 方法:双指针 解答代码(注:解答代码带解析): //题目给的结构体 /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* }…

力大砖飞,纯暴力搜索——蓝桥p2110(写着玩的)

#include<bits/stdc.h>const int N1000000;using namespace std;bool mp[2][N];int cnt0; int n;void dfs(int row,int col){cntcnt%1000000007;if(coln && row2){cnt;return ;}if(row>2){ //下一列 dfs(0,col1);return;}if(mp[row][col]1){ //下一行 dfs(row…

如何计算两个向量的余弦相似度

参考笔记&#xff1a; https://zhuanlan.zhihu.com/p/677639498 日常学习之&#xff1a;如何计算两个向量或者矩阵的余弦相似度-CSDN博客 1.余弦相似度定理 百度的解释&#xff1a;余弦相似度&#xff0c;又称为余弦相似性&#xff0c;是通过计算两个向量的夹角余弦值来评估…

OkHttp:工作原理 拦截器链深度解析

目录 一、OKHttp 的基本使用 1. 添加依赖 2. 发起 HTTP 请求 3. 拦截器&#xff08;Interceptor&#xff09; 4. 高级配置 二、OKHttp 核心原理 1. 责任链模式&#xff08;Interceptor Chain&#xff09; 2. 连接池&#xff08;ConnectionPool&#xff09; 3. 请求调度…

python: DDD+ORM using oracle 21c

sql script: create table GEOVINDU.School --創建表 ( SchoolId char(5) NOT NULL, -- SchoolName nvarchar2(500) NOT NULL, SchoolTelNo varchar(8) NULL, PRIMARY KEY (SchoolId) --#主鍵 );create table GEOVINDU.Teacher ( TeacherId char(5) NOT NULL , TeacherFirstNa…

基于 LeNet 网络的 MNIST 数据集图像分类

1.LeNet的原始实验数据集MNIST 名称&#xff1a;MNIST手写数字数据集 数据类型&#xff1a;灰度图 &#xff08;一通道&#xff09; 图像大小&#xff1a;28*28 类别数&#xff1a;10类&#xff08;数字0-9&#xff09; 1.通过torchvision.datasets.MNIST下载并保存到本地…

Day4 C语言与画面显示练习

文章目录 1. harib01a例程2. harib01b例程3. harib01e例程4. harib01f例程5. harib01h例程 1. harib01a例程 上一章主要是将画面搞成黑屏&#xff0c;如果期望做点什么图案&#xff0c;只需要再VRAM里写点什么就好了&#xff0c;使用nask汇编语言实现一个函数write_mem8&#…

一周热点-OpenAI 推出了 GPT-4.5,这可能是其最后一个非推理模型

在人工智能领域,大型语言模型一直是研究的热点。OpenAI 的 GPT 系列模型在自然语言处理方面取得了显著成就。GPT-4.5 是 OpenAI 在这一领域的又一力作,它在多个方面进行了升级和优化。 1 新模型的出现 GPT-4.5 目前作为研究预览版发布。与 OpenAI 最近的 o1 和 o3 模型不同,…

《UE5_C++多人TPS完整教程》学习笔记34 ——《P35 网络角色(Network Role)》

本文为B站系列教学视频 《UE5_C多人TPS完整教程》 —— 《P35 网络角色&#xff08;Network Role&#xff09;》 的学习笔记&#xff0c;该系列教学视频为计算机工程师、程序员、游戏开发者、作家&#xff08;Engineer, Programmer, Game Developer, Author&#xff09; Stephe…

手写简易Tomcat核心实现:深入理解Servlet容器原理

目录 一、Tomcat概况 1. tomcat全局图 2.项目结构概览 二、实现步骤详解 2.1 基础工具包&#xff08;com.qcby.util&#xff09; 2.1.1 ResponseUtil&#xff1a;HTTP响应生成工具 2.1.2 SearchClassUtil&#xff1a;类扫描工具 2.1.3 WebServlet&#xff1a;自定义注解…

mac本地安装运行Redis-单机

记录一下我以前用的连接服务器的跨平台SSH客户端。 因为还要准备毕设...... 服务器又过期了&#xff0c;只能把redis安装下载到本地了。 目录 1.github下载Redis 2.安装homebrew 3.更新GCC 4.自行安装Redis 5.通过 Homebrew 安装 Redis 安装地址&#xff1a;https://git…

【ThreeJS Basics 09】Debug

文章目录 简介从 dat.GUI 到 lil-gui例子安装 lil-gui 并实例化不同类型的调整改变位置针对非属性的调整复选框颜色 功能/按钮调整几何形状文件夹调整 GUI宽度标题关闭文件夹隐藏按键切换 结论 简介 每一个创意项目的一个基本方面是能够轻松调整。开发人员和参与项目的其他参与…

【笔记】STM32L4系列使用RT-Thread Studio电源管理组件(PM框架)实现低功耗

硬件平台&#xff1a;STM32L431RCT6 RT-Thread版本&#xff1a;4.1.0 目录 一.新建工程 二.配置工程 ​编辑 三.移植pm驱动 四.配置cubeMX 五.修改驱动文件&#xff0c;干掉报错 六.增加用户低功耗逻辑 1.设置唤醒方式 2.设置睡眠时以及唤醒后动作 ​编辑 3.增加测试命…

类和对象:

1. 类的定义&#xff1a; 1. 类定义格式&#xff1a; 对于我们的类的话&#xff0c;我们是把类看成一个整体&#xff0c;我们的函数里面没有找到我们的成员变量&#xff0c;我们就在我们的类里面找。 我们看我们的第二点&#xff1a; 我们的类里面&#xff0c;我们通常会对…

【十三】Golang 通道

&#x1f4a2;欢迎来到张胤尘的开源技术站 &#x1f4a5;开源如江河&#xff0c;汇聚众志成。代码似星辰&#xff0c;照亮行征程。开源精神长&#xff0c;传承永不忘。携手共前行&#xff0c;未来更辉煌&#x1f4a5; 文章目录 通道通道声明初始化缓冲机制无缓冲通道代码示例 带…

软考中级_【软件设计师】知识点之【面向对象】

简介&#xff1a; 软件设计师考试中&#xff0c;面向对象模块为核心考点&#xff0c;涵盖类与对象、继承、封装、多态等基础概念&#xff0c;重点考查UML建模&#xff08;类图/时序图/用例图&#xff09;、设计模式&#xff08;如工厂、单例模式&#xff09;及SOLID设计原则。要…

分布式锁—7.Curator的分布式锁一

大纲 1.Curator的可重入锁的源码 2.Curator的非可重入锁的源码 3.Curator的可重入读写锁的源码 4.Curator的MultiLock源码 5.Curator的Semaphore源码 1.Curator的可重入锁的源码 (1)InterProcessMutex获取分布式锁 (2)InterProcessMutex的初始化 (3)InterProcessMutex.…

《UE5_C++多人TPS完整教程》学习笔记35 ——《P36 武器类(Weapon Class)》

本文为B站系列教学视频 《UE5_C多人TPS完整教程》 —— 《P36 武器类&#xff08;Weapon Class&#xff09;》 的学习笔记&#xff0c;该系列教学视频为计算机工程师、程序员、游戏开发者、作家&#xff08;Engineer, Programmer, Game Developer, Author&#xff09; Stephen …

[密码学实战]Java实现国密TLSv1.3单向认证

一、代码运行结果 1.1 运行环境 1.2 运行结果 1.3 项目架构 二、TLS 协议基础与国密背景 2.1 TLS 协议的核心作用 TLS(Transport Layer Security) 是保障网络通信安全的加密协议,位于 TCP/IP 协议栈的应用层和传输层之间,提供: • 数据机密性:通过对称加密算法(如 AE…

最小栈 _ _

一&#xff1a;题目 二&#xff1a;思路 解释&#xff1a;一个栈名为st&#xff0c;其用来正常的出入栈&#xff0c;一个栈名为minst&#xff0c;其的栈顶元素一定是最小的元素 入栈&#xff1a;第一个元素&#xff0c;两个栈一起入&#xff0c;后面再入栈&#xff0c;只有入栈…