食物分类问题
simple_class
1. 导入必要的库和模块
import random
import torch
import torch.nn as nn
import numpy as np
import os
from PIL import Image #读取图片数据
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchvision import transforms
import time
import matplotlib.pyplot as plt
from model_utils.model import initialize_model
import random
: 导入Python标准库中的随机数生成器。import torch
: 导入PyTorch库,用于深度学习模型的构建和训练。import torch.nn as nn
: 导入PyTorch的神经网络模块,包含各种层和损失函数。import numpy as np
: 导入NumPy库,用于数值计算和数组操作。import os
: 导入操作系统接口模块,用于文件路径处理。from PIL import Image
: 导入PIL(Python Imaging Library)库,用于图像处理。from torch.utils.data import Dataset, DataLoader
: 导入PyTorch的数据集和数据加载器类,用于管理数据集和批量加载数据。from tqdm import tqdm
: 导入tqdm库,用于显示进度条。from torchvision import transforms
: 导入PyTorch的图像变换模块,用于对图像进行预处理。import time
: 导入时间模块,用于记录训练时间。import matplotlib.pyplot as plt
: 导入matplotlib库,用于绘制图表。from model_utils.model import initialize_model
: 导入自定义模块中的初始化模型函数。
2. 设置随机种子以确保结果可重复
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
seed_everything(0)
def seed_everything(seed):
: 定义一个函数seed_everything
,用于设置所有可能影响随机性的种子。torch.manual_seed(seed)
: 设置PyTorch的CPU随机种子。torch.cuda.manual_seed(seed)
: 设置PyTorch的GPU随机种子。torch.cuda.manual_seed_all(seed)
: 如果有多个GPU,设置所有GPU的随机种子。torch.backends.cudnn.benchmark = False
: 关闭CuDNN自动优化功能,确保每次运行的结果一致。torch.backends.cudnn.deterministic = True
: 设置CuDNN为确定性模式,确保结果可重复。random.seed(seed)
: 设置Python内置随机数生成器的种子。np.random.seed(seed)
: 设置NumPy的随机数生成器的种子。os.environ['PYTHONHASHSEED'] = str(seed)
: 设置环境变量PYTHONHASHSEED
,确保哈希值的一致性。seed_everything(0)
: 调用seed_everything
函数,设置全局随机种子为0。
3. 定义图像变换
HW = 224
train_transform = transforms.Compose(
[
transforms.ToPILImage(), # 将numpy.ndarray转换为PIL.Image
transforms.RandomResizedCrop(224), # 随机裁剪并调整大小到224x224
transforms.RandomRotation(50), # 随机旋转角度在[-50, 50]之间
transforms.ToTensor() # 将PIL.Image转换为tensor
]
)
val_transform = transforms.Compose(
[
transforms.ToPILImage(), # 将numpy.ndarray转换为PIL.Image
transforms.ToTensor() # 将PIL.Image转换为tensor
]
)
HW = 224
: 定义图像的高度和宽度为224像素。train_transform
: 定义训练集的图像变换组合:transforms.ToPILImage()
: 将输入的numpy数组转换为PIL图像格式。transforms.RandomResizedCrop(224)
: 随机裁剪并调整大小到224x224像素。transforms.RandomRotation(50)
: 随机旋转图像,角度范围在[-50, 50]度之间。transforms.ToTensor()
: 将PIL图像转换为PyTorch张量(tensor),并将像素值归一化到[0, 1]区间。
val_transform
: 定义验证集的图像变换组合:transforms.ToPILImage()
: 将输入的numpy数组转换为PIL图像格式。transforms.ToTensor()
: 将PIL图像转换为PyTorch张量(tensor),并将像素值归一化到[0, 1]区间。
4. 自定义数据集类
class food_Dataset(Dataset):
def __init__(self, path, mode="train"):
self.mode = mode
if mode == "semi":
self.X = self.read_file(path)
else:
self.X, self.Y = self.read_file(path)
self.Y = torch.LongTensor(self.Y) # 标签转为长整形
if mode == "train":
self.transform = train_transform
else:
self.transform = val_transform
def read_file(self, path):
if self.mode == "semi":
file_list = os.listdir(path)
xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
for j, img_name in enumerate(file_list):
img_path = os.path.join(path, img_name)
img = Image.open(img_path)
img = img.resize((HW, HW))
xi[j, ...] = np.array(img)
print("读到了%d个数据" % len(xi))
return xi
else:
for i in tqdm(range(11)):
file_dir = path + "/%02d" % i
file_list = os.listdir(file_dir)
xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
yi = np.zeros(len(file_list), dtype=np.uint8)
for j, img_name in enumerate(file_list):
img_path = os.path.join(file_dir, img_name)
img = Image.open(img_path)
img = img.resize((HW, HW))
xi[j, ...] = np.array(img)
yi[j] = i
if i == 0:
X = xi
Y = yi
else:
X = np.concatenate((X, xi), axis=0)
Y = np.concatenate((Y, yi), axis=0)
print("读到了%d个数据" % len(Y))
return X, Y
def __getitem__(self, item):
if self.mode == "semi":
return self.transform(self.X[item]), self.X[item]
else:
return self.transform(self.X[item]), self.Y[item]
def __len__(self):
return len(self.X)
class food_Dataset(Dataset):
: 定义一个继承自Dataset
的自定义数据集类food_Dataset
。def __init__(self, path, mode="train"):
: 初始化方法,接受数据集路径和模式(默认为“train”)作为参数。self.mode = mode
: 记录数据集的模式。if mode == "semi":
: 如果是半监督模式,则仅读取未标记的图像数据。else:
: 否则,读取带有标签的图像数据,并将标签转换为长整型。if mode == "train":
: 如果是训练模式,使用train_transform
进行图像变换。else:
: 否则,使用val_transform
进行图像变换。
def read_file(self, path):
: 定义一个读取文件的方法,根据不同的模式读取图像数据。if self.mode == "semi":
: 如果是半监督模式,读取未标记的图像数据:file_list = os.listdir(path)
: 获取目录下的所有文件名。xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
: 创建一个零数组用于存储图像数据。for j, img_name in enumerate(file_list):
: 遍历每个文件名,打开图像并调整大小,然后将其存储在xi
中。print("读到了%d个数据" % len(xi))
: 打印读取到的图像数量。return xi
: 返回图像数据。
else:
: 否则,读取带有标签的图像数据:for i in tqdm(range(11)):
: 使用tqdm
显示进度条,遍历每个类别(假设共有11个类别)。file_dir = path + "/%02d" % i
: 构建类别目录路径。file_list = os.listdir(file_dir)
: 获取该类别目录下的所有文件名。xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
: 创建一个零数组用于存储图像数据。yi = np.zeros(len(file_list), dtype=np.uint8)
: 创建一个零数组用于存储标签。for j, img_name in enumerate(file_list):
: 遍历每个文件名,打开图像并调整大小,然后将其存储在xi
中,并将对应的标签存储在yi
中。if i == 0:
: 如果是第一个类别,初始化X
和Y
。else:
: 否则,将当前类别的图像和标签连接到已有数据中。print("读到了%d个数据" % len(Y))
: 打印读取到的图像数量。return X, Y
: 返回图像数据和标签。
def __getitem__(self, item):
: 定义获取指定索引的数据项的方法。if self.mode == "semi":
: 如果是半监督模式,返回变换后的图像及其原始图像。else:
: 否则,返回变换后的图像及其标签。
def __len__(self):
: 定义返回数据集长度的方法,即图像的数量。
5. 半监督数据集类
class semiDataset(Dataset):
def __init__(self, no_label_loader, model, device, thres=0.99):
x, y = self.get_label(no_label_loader, model, device, thres)
if x == []:
self.flag = False
else:
self.flag = True
self.X = np.array(x)
self.Y = torch.LongTensor(y)
self.transform = train_transform
def get_label(self, no_label_loader, model, device, thres):
model = model.to(device)
pred_prob = []
labels = []
x = []
y = []
soft = nn.Softmax(dim=1)
with torch.no_grad():
for bat_x, _ in no_label_loader:
bat_x = bat_x.to(device)
pred = model(bat_x)
pred_soft = soft(pred)
pred_max, pred_value = pred_soft.max(1)
pred_prob.extend(pred_max.cpu().numpy().tolist())
labels.extend(pred_value.cpu().numpy().tolist())
for index, prob in enumerate(pred_prob):
if prob > thres:
x.append(no_label_loader.dataset[index][0])
y.append(labels[index])
return x, y
def __getitem__(self, item):
return self.transform(self.X[item]), self.Y[item]
def __len__(self):
return len(self.X)
class semiDataset(Dataset):
: 定义一个继承自Dataset
的半监督数据集类semiDataset
。def __init__(self, no_label_loader, model, device, thres=0.99):
: 初始化方法,接受未标记数据加载器、模型、设备和置信度阈值作为参数。x, y = self.get_label(no_label_loader, model, device, thres)
: 调用get_label
方法获取高置信度的伪标签样本。if x == []:
: 如果没有找到符合条件的样本,设置flag
为False
。else:
: 否则,设置flag
为True
,并将样本数据和标签存储在self.X
和self.Y
中,并使用train_transform
进行图像变换。
def get_label(self, no_label_loader, model, device, thres):
: 定义一个获取伪标签的方法。model = model.to(device)
: 将模型移动到指定设备(CPU或GPU)。pred_prob = []
: 初始化预测概率列表。labels = []
: 初始化标签列表。x = []
: 初始化图像数据列表。y = []
: 初始化标签列表。soft = nn.Softmax(dim=1)
: 初始化Softmax函数,用于将模型输出转换为概率分布。with torch.no_grad():
: 禁用梯度计算,减少内存占用和加速推理。for bat_x, _ in no_label_loader:
: 遍历未标记数据加载器中的每个批次。bat_x = bat_x.to(device)
: 将输入数据移动到指定设备。pred = model(bat_x)
: 使用模型进行前向传播,得到预测结果。pred_soft = soft(pred)
: 使用Softmax函数将预测结果转换为概率分布。pred_max, pred_value = pred_soft.max(1)
: 获取每个样本的最大概率及其对应的类别。pred_prob.extend(pred_max.cpu().numpy().tolist())
: 将最大概率值添加到pred_prob
列表中。labels.extend(pred_value.cpu().numpy().tolist())
: 将对应的类别标签添加到labels
列表中。
for index, prob in enumerate(pred_prob):
: 遍历每个样本的概率值。if prob > thres:
: 如果概率值大于设定的阈值,则认为该样本的预测结果是可靠的。x.append(no_label_loader.dataset[index][0])
: 将该样本的图像数据添加到x
列表中。y.append(labels[index])
: 将该样本的预测标签添加到y
列表中。
return x, y
: 返回筛选出的图像数据和标签。
def __getitem__(self, item):
: 定义获取指定索引的数据项的方法,返回变换后的图像及其标签。def __len__(self):
: 定义返回数据集长度的方法,即图像的数量。
明白了,我们将从“6. 获取半监督数据加载器”继续逐句分析代码,并保持详细的解释风格。
6. 获取半监督数据加载器
def get_semi_loader(no_label_loder, model, device, thres):
semiset = semiDataset(no_label_loder, model, device, thres)
if semiset.flag == False:
return None
else:
semi_loader = DataLoader(semiset, batch_size=16, shuffle=False)
return semi_loader
- get_semi_loader:定义了一个函数,用于创建包含伪标签样本的数据加载器。
- no_label_loder:未标记数据的加载器。
- model:当前训练的模型,用于对未标记数据进行预测。
- device:设备类型(CPU或GPU)。
- thres:置信度阈值,用于选择高置信度样本。
- semiset:使用
semiDataset
类创建一个包含伪标签样本的数据集对象。 - if semiset.flag == False:如果
semiDataset
对象中没有满足条件的样本,则返回None
。 - else:否则,使用
DataLoader
创建一个新的数据加载器semi_loader
,批次大小为16且不打乱数据。
明白了,让我们重新详细解析 myModel
类的定义部分,并继续进入训练和验证函数的解析。
7. 定义模型
class myModel(nn.Module):
def __init__(self, num_class):
super(myModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.layer1 = nn.Sequential(
nn.Conv2d(64, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(128, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.layer3 = nn.Sequential(
nn.Conv2d(256, 512, 3, 1, 1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(25088, 1000)
self.relu2 = nn.ReLU()
self.fc2 = nn.Linear(1000, num_class)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.pool1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.pool2(x)
x = x.view(x.size()[0], -1)
x = self.fc1(x)
x = self.relu2(x)
x = self.fc2(x)
return x
__init__
方法
- super(myModel, self).init():调用父类
nn.Module
的构造函数。 - self.conv1:定义第一个卷积层,输入通道数为3(RGB图像),输出通道数为64,卷积核大小为3x3,步长为1,填充为1。
- self.bn1:定义第一个批量归一化层,用于归一化卷积层的输出。
- self.relu:定义ReLU激活函数。
- self.pool1:定义第一个最大池化层,池化窗口大小为2x2。
- self.layer1:定义第一个卷积块,包含一个卷积层、批量归一化层、ReLU激活函数和最大池化层。卷积层将输入通道数从64变为128。
- self.layer2:定义第二个卷积块,与
layer1
类似,但将输入通道数从128变为256。 - self.layer3:定义第三个卷积块,与
layer2
类似,但将输入通道数从256变为512。 - self.pool2:定义第二个最大池化层,池化窗口大小为2x2。
- self.fc1:定义第一个全连接层,输入特征数为25088(经过前面的卷积和池化操作后的特征图大小),输出特征数为1000。
- self.relu2:定义第二个ReLU激活函数。
- self.fc2:定义第二个全连接层,输入特征数为1000,输出特征数为
num_class
(类别数量)。
forward
方法
- def forward(self, x):定义前向传播过程。
- x = self.conv1(x):对输入数据
x
进行第一次卷积操作。 - x = self.bn1(x):对卷积结果进行批量归一化。
- x = self.relu(x):应用ReLU激活函数。
- x = self.pool1(x):对激活结果进行最大池化操作。
- x = self.layer1(x):通过第一个卷积块。
- x = self.layer2(x):通过第二个卷积块。
- x = self.layer3(x):通过第三个卷积块。
- x = self.pool2(x):对第三个卷积块的结果进行最大池化操作。
- x = x.view(x.size()[0], -1):将多维张量展平成二维张量,以便输入到全连接层中。
- x = self.fc1(x):通过第一个全连接层。
- x = self.relu2(x):应用ReLU激活函数。
- x = self.fc2(x):通过第二个全连接层,输出最终预测结果。
- return x:返回模型的预测结果。
- x = self.conv1(x):对输入数据
8. 训练和验证函数
接下来我们继续解析 train_val
函数:
def train_val(model, train_loader, val_loader, no_label_loader, device, epochs, optimizer, loss, thres, save_path):
model = model.to(device)
semi_loader = None
plt_train_loss = []
plt_val_loss = []
plt_train_acc = []
plt_val_acc = []
max_acc = 0.0
for epoch in range(epochs):
train_loss = 0.0
val_loss = 0.0
train_acc = 0.0
val_acc = 0.0
start_time = time.time()
初始化部分
- model = model.to(device):将模型移动到指定设备(CPU或GPU)。
- semi_loader = None:初始化半监督数据加载器为
None
。 - plt_train_loss、plt_val_loss、plt_train_acc、plt_val_acc:分别存储训练和验证的损失及准确率。
- max_acc = 0.0:初始化最大验证准确率为0.0。
每个epoch的训练循环
for epoch in range(epochs):
train_loss = 0.0
val_loss = 0.0
train_acc = 0.0
val_acc = 0.0
start_time = time.time()
model.train()
for batch_x, batch_y in train_loader:
x, target = batch_x.to(device), batch_y.to(device)
pred = model(x)
train_bat_loss = loss(pred, target)
train_bat_loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loss += train_bat_loss.cpu().item()
train_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())
- for epoch in range(epochs):遍历每个epoch。
- train_loss = 0.0、val_loss = 0.0、train_acc = 0.0、val_acc = 0.0:初始化每个epoch的损失和准确率。
- start_time = time.time():记录当前epoch的开始时间。
- model.train():设置模型为训练模式。
- for batch_x, batch_y in train_loader:遍历训练数据加载器中的每个批次。
- x, target = batch_x.to(device), batch_y.to(device):将输入数据和标签移动到指定设备。
- pred = model(x):前向传播,计算模型输出。
- train_bat_loss = loss(pred, target):计算批次损失。
- train_bat_loss.backward():反向传播,计算梯度。
- optimizer.step():更新模型参数。
- optimizer.zero_grad():清空梯度,避免累积。
- train_loss += train_bat_loss.cpu().item():累加批次损失。
- train_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy()):计算并累加批次准确率。
处理半监督数据
if semi_loader is not None:
for batch_x, batch_y in semi_loader:
x, target = batch_x.to(device), batch_y.to(device)
pred = model(x)
semi_bat_loss = loss(pred, target)
semi_bat_loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loss += semi_bat_loss.cpu().item()
train_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())
print("半监督数据集的训练准确率为", train_acc / len(semi_loader.dataset))
- if semi_loader is not None:如果存在半监督数据加载器,则处理这些数据。
- for batch_x, batch_y in semi_loader:遍历半监督数据加载器中的每个批次。
- x, target = batch_x.to(device), batch_y.to(device):将输入数据和标签移动到指定设备。
- pred = model(x):前向传播,计算模型输出。
- semi_bat_loss = loss(pred, target):计算批次损失。
- semi_bat_loss.backward():反向传播,计算梯度。
- optimizer.step():更新模型参数。
- optimizer.zero_grad():清空梯度,避免累积。
- train_loss += semi_bat_loss.cpu().item():累加批次损失。
- train_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy()):计算并累加批次准确率。
- print(“半监督数据集的训练准确率为”, train_acc / len(semi_loader.dataset)):打印半监督数据集的训练准确率。
验证过程
model.eval()
with torch.no_grad():
for batch_x, batch_y in val_loader:
x, target = batch_x.to(device), batch_y.to(device)
pred = model(x)
val_bat_loss = loss(pred, target)
val_loss += val_bat_loss.cpu().item()
val_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())
- model.eval():设置模型为评估模式。
- with torch.no_grad():禁用梯度计算,节省内存和计算资源。
- for batch_x, batch_y in val_loader:遍历验证数据加载器中的每个批次。
- x, target = batch_x.to(device), batch_y.to(device):将输入数据和标签移动到指定设备。
- pred = model(x):前向传播,计算模型输出。
- val_bat_loss = loss(pred, target):计算批次损失。
- val_loss += val_bat_loss.cpu().item():累加批次损失。
- val_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy()):计算并累加批次准确率。
更新半监督数据加载器和保存最佳模型
if epoch % 3 == 0 and plt_val_acc[-1] > 0.6:
semi_loader = get_semi_loader(no_label_loader, model, device, thres)
if val_acc / len(val_loader.dataset) > max_acc:
torch.save(model, save_path)
max_acc = val_acc / len(val_loader.dataset)
- if epoch % 3 == 0 and plt_val_acc[-1] > 0.6:每3个epoch检查一次是否需要更新半监督数据加载器。
- semi_loader = get_semi_loader(no_label_loader, model, device, thres):调用
get_semi_loader
函数获取新的半监督数据加载器。
- semi_loader = get_semi_loader(no_label_loader, model, device, thres):调用
- if val_acc / len(val_loader.dataset) > max_acc:如果当前验证准确率高于历史最高,则保存当前模型。
- torch.save(model, save_path):保存模型到指定路径。
- max_acc = val_acc / len(val_loader.dataset):更新最大验证准确率。
打印训练结果
print('[%03d/%03d] %2.2f sec(s) TrainLoss : %.6f | valLoss: %.6f Trainacc : %.6f | valacc: %.6f' % \
(epoch, epochs, time.time() - start_time, plt_train_loss[-1], plt_val_loss[-1], plt_train_acc[-1], plt_val_acc[-1]))
- print:打印每个epoch的训练和验证结果,包括epoch编号、耗时、训练损失、验证损失、训练准确率和验证准确率。
9、绘制损失和准确率曲线
plt.plot(plt_train_loss)
plt.plot(plt_val_loss)
plt.title("loss")
plt.legend(["train", "val"])
plt.show()
plt.plot(plt_train_acc)
plt.plot(plt_val_acc)
plt.title("acc")
plt.legend(["train", "val"])
plt.show()
- plt.plot(plt_train_loss) 和 plt.plot(plt_val_loss):绘制训练和验证的损失变化曲线。
- plt.title(“loss”):设置图表标题为“loss”。
- plt.legend([“train”, “val”]):添加图例,区分训练和验证曲线。
- plt.show():显示图表。
- plt.plot(plt_train_acc) 和 plt.plot(plt_val_acc):绘制训练和验证的准确率变化曲线。
- plt.title(“acc”):设置图表标题为“acc”。
- plt.legend([“train”, “val”]):添加图例,区分训练和验证曲线。
- plt.show():显示图表。
好的,让我们详细解析你提供的代码段,并解释每个部分的功能和作用。
10、数据集路径设置与数据加载器初始化
train_path = r"F:\pycharm\beike\classification\food_classification\food-11_sample\training\labeled"
val_path = r"F:\pycharm\beike\classification\food_classification\food-11_sample\validation"
no_label_path = r"F:\pycharm\beike\classification\food_classification\food-11_sample\training\unlabeled\00"
train_set = food_Dataset(train_path, "train")
val_set = food_Dataset(val_path, "val")
no_label_set = food_Dataset(no_label_path, "semi")
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=16, shuffle=True)
no_label_loader = DataLoader(no_label_set, batch_size=16, shuffle=False)
- train_path、val_path 和 no_label_path:定义了训练集、验证集和未标记数据集的路径。
- food_Dataset:自定义的数据集类,用于加载和预处理图像数据。它接受路径和模式(“train”、“val” 或 “semi”)作为参数。
- train_set:创建一个训练数据集对象。
- val_set:创建一个验证数据集对象。
- no_label_set:创建一个未标记数据集对象。
- DataLoader:PyTorch中的数据加载器类,用于批量加载数据。
- train_loader:训练数据加载器,批次大小为16,且打乱数据。
- val_loader:验证数据加载器,批次大小为16,且打乱数据。
- no_label_loader:未标记数据加载器,批次大小为16,不打乱数据。
11、模型初始化
# model = myModel(11)
model, _ = initialize_model("vgg", 11, use_pretrained=True)
- myModel(11):注释掉的行表示使用自定义模型
myModel
,类别数为11。 - initialize_model(“vgg”, 11, use_pretrained=True):调用一个函数来初始化预训练的VGG模型,类别数为11,并使用预训练权重。
12、超参数设置
lr = 0.001
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
device = "cuda" if torch.cuda.is_available() else "cpu"
save_path = "model_save/best_model.pth"
epochs = 15
thres = 0.99
- lr:学习率设置为0.001。
- loss:损失函数使用交叉熵损失
nn.CrossEntropyLoss()
。 - optimizer:优化器使用AdamW优化器
torch.optim.AdamW
,并设置了学习率和权重衰减参数。 - device:检查是否有可用的GPU,如果没有则使用CPU。
- save_path:保存最佳模型的路径。
- epochs:训练轮数设置为15。
- thres:置信度阈值设置为0.99,用于半监督学习中选择高置信度样本。
13、训练和验证
train_val(model, train_loader, val_loader, no_label_loader, device, epochs, optimizer, loss, thres, save_path)
- train_val:调用训练和验证函数,传入模型、数据加载器、设备类型、训练轮数、优化器、损失函数、置信度阈值和保存路径。
数据集加载
train_set = food_Dataset(train_path, "train")
val_set = food_Dataset(val_path, "val")
no_label_set = food_Dataset(no_label_path, "semi")
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=16, shuffle=True)
no_label_loader = DataLoader(no_label_set, batch_size=16, shuffle=False)
- food_Dataset:假设这是一个自定义的数据集类,负责读取和预处理图像数据。
- DataLoader:用于高效加载数据,支持多线程和批处理。
模型初始化
model, _ = initialize_model("vgg", 11, use_pretrained=True)
- initialize_model:假设这是另一个自定义函数,用于初始化预训练的VGG模型,并根据需要调整输出层以适应11个分类任务。
超参数配置
lr = 0.001
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
device = "cuda" if torch.cuda.is_available() else "cpu"
save_path = "model_save/best_model.pth"
epochs = 15
thres = 0.99
- lr:学习率设置为0.001。
- loss:使用交叉熵损失函数,适用于多分类问题。
- optimizer:使用AdamW优化器,结合了Adam的优点并添加了权重衰减,有助于防止过拟合。
- device:自动检测并选择合适的计算设备(GPU或CPU)。
- save_path:指定保存最佳模型的文件路径。
- epochs:训练轮数设置为15。
- thres:置信度阈值设置为0.99,用于筛选高质量伪标签样本。
调用训练和验证函数
train_val(model, train_loader, val_loader, no_label_loader, device, epochs, optimizer, loss, thres, save_path)
- train_val:假设这是一个包含训练和验证逻辑的函数,负责在给定的训练轮数内迭代地训练模型,并在每个epoch结束后进行验证。
总结
这段代码实现了从数据集加载到模型训练和验证的完整流程,具体步骤包括:
- 数据集加载:通过自定义的
food_Dataset
类加载训练、验证和未标记数据集,并使用DataLoader
进行批处理和数据打乱。 - 模型初始化:使用预训练的VGG模型,并根据任务需求调整输出层。
- 超参数配置:设置学习率、损失函数、优化器等超参数,并确定训练轮数和设备类型。
- 训练和验证:调用
train_val
函数执行训练过程,并在每个epoch结束后进行验证,保存最佳模型。