1.前言
模型蒸馏(Model Distillation),又称为知识蒸馏(Knowledge Distillation),是一种将大型、复杂的模型(通常称为教师模型,Teacher Model)的知识转移到小型、简单模型(通常称为学生模型,Student Model)上的技术。以下是模型蒸馏的介绍、出现原因及其作用:
(1)模型蒸馏的介绍
-
基本概念
- 教师模型:一个已经训练好的、性能优异的大模型。
- 学生模型:一个较小、较简单的模型,目标是学习教师模型的行为和知识。
- 软标签(Soft Labels):教师模型输出的概率分布,而不是简单的类别标签,这些概率分布包含了教师模型关于输入数据的丰富信息。
-
训练过程
- 训练教师模型直到它达到较高的准确率。
- 使用教师模型的输出(软标签)来训练学生模型。
- 学生模型同时学习硬标签(实际类别标签)和软标签,以此来模拟教师模型的行为。
(2)模型蒸馏为什么会出现
模型蒸馏的出现主要是为了解决以下问题:
- 模型部署:大型模型在移动设备或嵌入式系统上部署时,由于计算资源有限,难以运行。
- 计算效率:大型模型在训练和推理过程中需要大量的计算资源,导致速度慢、成本高。
- 能源消耗:大型模型在数据中心运行时消耗大量电力,不符合节能减排的要求。
(3)模型蒸馏的作用
- 模型压缩:通过蒸馏,可以将大型模型压缩成小型模型,减少模型的参数数量,降低存储和计算需求。
- 性能保持:学生模型在保持较小规模的同时,能够尽可能地接近教师模型的性能。
- 加速推理:小型模型在推理时更快,适用于需要快速响应的应用场景。
- 降低能耗:小型模型在运行时消耗更少的计算资源,有助于降低能源消耗。
- 跨模型迁移:蒸馏技术可以用于将知识从一个领域的模型迁移到另一个领域,实现跨领域学习。
2.准备训练代码
(1) 定义模型结构
import torch.nn as nn
import torch
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channel, out_channel, stride=1, downsample=None,
groups=1, width_per_group=64):
super(Bottleneck, self).__init__()
width = int(out_channel * (width_per_group / 64.)) * groups
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
kernel_size=1, stride=1, bias=False) # squeeze channels
self.bn1 = nn.BatchNorm2d(width)
# -----------------------------------------
self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
kernel_size=3, stride=stride, bias=False, padding=1)
self.bn2 = nn.BatchNorm2d(width)
# -----------------------------------------
self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
kernel_size=1, stride=1, bias=False) # unsqueeze channels
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self,
block,
blocks_num,
num_classes=1000,
include_top=True,
groups=1,
width_per_group=64):
super(ResNet, self).__init__()
self.include_top = include_top
self.in_channel = 64
self.groups = groups
self.width_per_group = width_per_group
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, blocks_num[0])
self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
if self.include_top:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def _make_layer(self, block, channel, block_num, stride=1):
downsample = None
if stride != 1 or self.in_channel != channel * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(channel * block.expansion))
layers = []
layers.append(block(self.in_channel,
channel,
downsample=downsample,
stride=stride,
groups=self.groups,
width_per_group=self.width_per_group))
self.in_channel = channel * block.expansion
for _ in range(1, block_num):
layers.append(block(self.in_channel,
channel,
groups=self.groups,
width_per_group=self.width_per_group))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.include_top:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resnet34(num_classes=1000, include_top=True):
# https://download.pytorch.org/models/resnet34-333f7ec4.pth
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
(2)训练代码
temperature
- 这个参数用于调节教师模型和学生模型输出logits的软化程度。在代码中,
temperature
被设置为 5.0。 - 在蒸馏过程中,教师和学生的logits通过除以温度值来软化,这有助于在训练学生模型时更好地捕捉教师模型的概率分布。
- 温度值较高时,概率分布更加平滑,有助于学生模型学习;温度值较低时,概率分布更尖锐,更接近硬标签。
loss_function
- 这是一个用于计算蒸馏损失的函数,代码中使用的是
nn.KLDivLoss
,它是Kullback-Leibler散度损失,用于测量两个概率分布之间的差异。 reduction='batchmean'
表示损失是通过对批次中的所有样本求平均来减少的。
student_loss_function
- 这是用于计算学生模型在真实标签上的分类损失的函数,代码中使用的是
nn.CrossEntropyLoss
,这是多分类问题中常用的损失函数。
loss
和 student_loss
loss
是蒸馏损失,它是通过比较软化后的学生logits和教师logits来计算的。student_loss
是学生模型在真实标签上的分类损失。- 这两个损失通过加权平均组合起来,形成最终的训练损失,其中蒸馏损失和分类损失的权重都是0.5。
optimizer
- 这是用于优化学生模型参数的优化器,代码中使用的是
optim.Adam
,它是一种自适应学习率的优化算法。 params
是学生模型中需要优化的参数列表。
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from torchvision import models
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
# data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
# image_path = os.path.join(data_root, "data_set", "flower_data")
image_path = "/home/trq/data/Test5_resnet/flower_data"
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print("using {} images for training, {} images for validation.".format(train_num,
val_num))
# Load teacher model
teacher_net = resnet34(num_classes=5).to(device)
tearcher_model_weight_path = "resNet34.pth"
assert os.path.exists(tearcher_model_weight_path), f"File '{tearcher_model_weight_path}' does not exist."
teacher_net.load_state_dict(torch.load(tearcher_model_weight_path, map_location="cpu"),strict=False)
teacher_net.to(device)
# Load student model
student_net = models.resnet18(pretrained=False)
student_model_weight_path = "resnet18-f37072fd.pth"
assert os.path.exists(student_model_weight_path), "file {} does not exist.".format(student_model_weight_path)
student_net.load_state_dict(torch.load(student_model_weight_path, map_location="cpu"))
student_net.fc = nn.Linear(student_net.fc.in_features, 5)
student_net.to(device)
# Distillation loss function
loss_function = nn.KLDivLoss(reduction='batchmean')
student_loss_function = nn.CrossEntropyLoss()
# Optimizer for the student model
params = [p for p in student_net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)
epochs = 30
best_acc = 0.0
save_path = ('./distilled_ConvNet.pth')
train_steps = len(train_loader)
temperature = 5.0 # Temperature for distillation
for epoch in range(epochs):
student_net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
teacher_logits = teacher_net(images.to(device))
student_logits = student_net(images.to(device))
# Soften the logits
teacher_logits = teacher_logits / temperature
student_logits = student_logits / temperature
# Compute the distillation loss
loss = loss_function(torch.nn.functional.log_softmax(student_logits, dim=1),
torch.nn.functional.softmax(teacher_logits, dim=1)) * (temperature ** 2)
# Compute the classification loss
student_loss = student_loss_function(student_logits, labels.to(device))
# Combine losses
loss = 0.5 * loss + 0.5 * student_loss
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
student_net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = student_net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(student_net.state_dict(), save_path)
print('Finished Training')
if __name__ == '__main__':
main()
(3)模型和数据集的下载链接
包含resnet18模型和resnet34模型,class_indices.json,图像等相关数据
https://pan.baidu.com/s/1ZDCbichDcdaiAH6kxYNsIA
提取码: svv5
3.自建模型训练和使用蒸馏技术训练自建模型
(1)模型结构-model_10.py
import torch
from torch import nn
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
# 定义10层卷积
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1), # 输入通道数为3,输出通道数为32
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1)) # 添加自适应平均池化层
# 全连接层
self.fc_layers = nn.Sequential(
nn.Linear(512 * 1 * 1, 1024), # 根据MaxPool的使用次数和输入图像大小计算得来的维度
nn.ReLU(),
nn.Linear(1024, 5) # 输出层,5分类
)
def forward(self, x):
x = self.conv_layers(x)
x = self.adaptive_pool(x) # 应用自适应池化
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
return x
(2)自建模型训练-train-10.py
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model_10 import ConvNet
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
image_path = "/home/trq/data/Test5_resnet/flower_data"
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print("using {} images for training, {} images for validation.".format(train_num,
val_num))
net = ConvNet()
weights_path = "ConvNet.pth"
assert os.path.exists(weights_path), f"File '{weights_path}' does not exist."
# model.load_state_dict(torch.load(weights_path, map_location="cpu"))
state_dict = torch.load(weights_path, map_location="cpu")
net.load_state_dict(state_dict,strict=False)
net.to(device)
# define loss function
loss_function = nn.CrossEntropyLoss()
# construct an optimizer
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)
epochs = 30
best_acc = 0.0
save_path = './ConvNet.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
logits = net(images.to(device))
loss = loss_function(logits, labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('Finished Training')
if __name__ == '__main__':
main()
(3)训练结果
训练60epoch后的结果,模型val_accuracy: 0.780已经是最高了
train epoch[1/30] loss:0.971: 100%|██████████| 207/207 [00:08<00:00, 24.01it/s]
valid epoch[1/30]: 100%|██████████| 23/23 [00:00<00:00, 31.44it/s]
[epoch 1] train_loss: 0.623 val_accuracy: 0.742
train epoch[2/30] loss:0.368: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[2/30]: 100%|██████████| 23/23 [00:00<00:00, 33.18it/s]
[epoch 2] train_loss: 0.604 val_accuracy: 0.736
train epoch[3/30] loss:0.661: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[3/30]: 100%|██████████| 23/23 [00:00<00:00, 32.38it/s]
[epoch 3] train_loss: 0.614 val_accuracy: 0.723
train epoch[4/30] loss:0.797: 100%|██████████| 207/207 [00:07<00:00, 26.66it/s]
valid epoch[4/30]: 100%|██████████| 23/23 [00:00<00:00, 31.70it/s]
[epoch 4] train_loss: 0.619 val_accuracy: 0.725
train epoch[5/30] loss:0.809: 100%|██████████| 207/207 [00:07<00:00, 26.87it/s]
valid epoch[5/30]: 100%|██████████| 23/23 [00:00<00:00, 32.26it/s]
[epoch 5] train_loss: 0.594 val_accuracy: 0.698
train epoch[6/30] loss:0.302: 100%|██████████| 207/207 [00:07<00:00, 26.81it/s]
valid epoch[6/30]: 100%|██████████| 23/23 [00:00<00:00, 32.49it/s]
[epoch 6] train_loss: 0.591 val_accuracy: 0.728
train epoch[7/30] loss:0.708: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[7/30]: 100%|██████████| 23/23 [00:00<00:00, 33.09it/s]
[epoch 7] train_loss: 0.589 val_accuracy: 0.720
train epoch[8/30] loss:0.709: 100%|██████████| 207/207 [00:07<00:00, 26.73it/s]
valid epoch[8/30]: 100%|██████████| 23/23 [00:00<00:00, 32.55it/s]
[epoch 8] train_loss: 0.575 val_accuracy: 0.734
train epoch[9/30] loss:0.691: 100%|██████████| 207/207 [00:07<00:00, 26.61it/s]
valid epoch[9/30]: 100%|██████████| 23/23 [00:00<00:00, 34.43it/s]
[epoch 9] train_loss: 0.555 val_accuracy: 0.734
train epoch[10/30] loss:0.442: 100%|██████████| 207/207 [00:07<00:00, 26.81it/s]
valid epoch[10/30]: 100%|██████████| 23/23 [00:00<00:00, 32.91it/s]
[epoch 10] train_loss: 0.548 val_accuracy: 0.703
train epoch[11/30] loss:0.363: 100%|██████████| 207/207 [00:07<00:00, 26.46it/s]
valid epoch[11/30]: 100%|██████████| 23/23 [00:00<00:00, 30.53it/s]
[epoch 11] train_loss: 0.550 val_accuracy: 0.728
train epoch[12/30] loss:0.519: 100%|██████████| 207/207 [00:07<00:00, 26.19it/s]
valid epoch[12/30]: 100%|██████████| 23/23 [00:00<00:00, 33.14it/s]
[epoch 12] train_loss: 0.545 val_accuracy: 0.734
train epoch[13/30] loss:0.478: 100%|██████████| 207/207 [00:07<00:00, 26.48it/s]
valid epoch[13/30]: 100%|██████████| 23/23 [00:00<00:00, 32.75it/s]
[epoch 13] train_loss: 0.532 val_accuracy: 0.755
train epoch[14/30] loss:0.573: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[14/30]: 100%|██████████| 23/23 [00:00<00:00, 33.40it/s]
[epoch 14] train_loss: 0.542 val_accuracy: 0.747
train epoch[15/30] loss:0.595: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[15/30]: 100%|██████████| 23/23 [00:00<00:00, 34.54it/s]
[epoch 15] train_loss: 0.542 val_accuracy: 0.758
train epoch[16/30] loss:0.191: 100%|██████████| 207/207 [00:07<00:00, 26.83it/s]
valid epoch[16/30]: 100%|██████████| 23/23 [00:00<00:00, 32.04it/s]
[epoch 16] train_loss: 0.532 val_accuracy: 0.761
train epoch[17/30] loss:0.566: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[17/30]: 100%|██████████| 23/23 [00:00<00:00, 33.56it/s]
[epoch 17] train_loss: 0.523 val_accuracy: 0.739
train epoch[18/30] loss:0.509: 100%|██████████| 207/207 [00:07<00:00, 26.79it/s]
valid epoch[18/30]: 100%|██████████| 23/23 [00:00<00:00, 30.35it/s]
[epoch 18] train_loss: 0.526 val_accuracy: 0.742
train epoch[19/30] loss:0.781: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[19/30]: 100%|██████████| 23/23 [00:00<00:00, 31.58it/s]
[epoch 19] train_loss: 0.506 val_accuracy: 0.764
train epoch[20/30] loss:0.336: 100%|██████████| 207/207 [00:07<00:00, 26.64it/s]
valid epoch[20/30]: 100%|██████████| 23/23 [00:00<00:00, 33.95it/s]
[epoch 20] train_loss: 0.537 val_accuracy: 0.764
train epoch[21/30] loss:0.475: 100%|██████████| 207/207 [00:07<00:00, 26.65it/s]
valid epoch[21/30]: 100%|██████████| 23/23 [00:00<00:00, 33.27it/s]
[epoch 21] train_loss: 0.511 val_accuracy: 0.764
train epoch[22/30] loss:0.513: 100%|██████████| 207/207 [00:07<00:00, 26.53it/s]
valid epoch[22/30]: 100%|██████████| 23/23 [00:00<00:00, 32.16it/s]
[epoch 22] train_loss: 0.482 val_accuracy: 0.761
train epoch[23/30] loss:0.172: 100%|██████████| 207/207 [00:07<00:00, 26.62it/s]
valid epoch[23/30]: 100%|██████████| 23/23 [00:00<00:00, 33.02it/s]
[epoch 23] train_loss: 0.501 val_accuracy: 0.761
train epoch[24/30] loss:1.127: 100%|██████████| 207/207 [00:07<00:00, 26.54it/s]
valid epoch[24/30]: 100%|██████████| 23/23 [00:00<00:00, 34.24it/s]
[epoch 24] train_loss: 0.492 val_accuracy: 0.755
train epoch[25/30] loss:0.905: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[25/30]: 100%|██████████| 23/23 [00:00<00:00, 30.22it/s]
[epoch 25] train_loss: 0.492 val_accuracy: 0.758
train epoch[26/30] loss:1.044: 100%|██████████| 207/207 [00:07<00:00, 26.75it/s]
valid epoch[26/30]: 100%|██████████| 23/23 [00:00<00:00, 33.86it/s]
[epoch 26] train_loss: 0.476 val_accuracy: 0.777
train epoch[27/30] loss:0.552: 100%|██████████| 207/207 [00:07<00:00, 26.73it/s]
valid epoch[27/30]: 100%|██████████| 23/23 [00:00<00:00, 31.55it/s]
[epoch 27] train_loss: 0.465 val_accuracy: 0.745
train epoch[28/30] loss:0.387: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[28/30]: 100%|██████████| 23/23 [00:00<00:00, 32.30it/s]
[epoch 28] train_loss: 0.482 val_accuracy: 0.769
train epoch[29/30] loss:0.251: 100%|██████████| 207/207 [00:07<00:00, 26.69it/s]
valid epoch[29/30]: 100%|██████████| 23/23 [00:00<00:00, 32.98it/s]
[epoch 29] train_loss: 0.466 val_accuracy: 0.777
train epoch[30/30] loss:0.368: 100%|██████████| 207/207 [00:07<00:00, 26.57it/s]
valid epoch[30/30]: 100%|██████████| 23/23 [00:00<00:00, 31.95it/s]
[epoch 30] train_loss: 0.467 val_accuracy: 0.780
Finished Training
(4)蒸馏训练
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34
from model_10 import ConvNet
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
# data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
# image_path = os.path.join(data_root, "data_set", "flower_data")
image_path = "/home/trq/data/Test5_resnet/flower_data"
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print("using {} images for training, {} images for validation.".format(train_num,
val_num))
teacher_net = resnet34(num_classes=5).to(device)
tearcher_model_weight_path = "resNet34.pth"
assert os.path.exists(tearcher_model_weight_path), f"File '{tearcher_model_weight_path}' does not exist."
teacher_net.load_state_dict(torch.load(tearcher_model_weight_path, map_location="cpu"),strict=False)
teacher_net.to(device)
# Load student model
student_net = ConvNet()
student_model_weight_path = "ConvNet.pth"
assert os.path.exists(student_model_weight_path), "file {} does not exist.".format(student_model_weight_path)
student_net.load_state_dict(torch.load(student_model_weight_path, map_location="cpu"))
student_net.to(device)
# Distillation loss function
loss_function = nn.KLDivLoss(reduction='batchmean')
student_loss_function = nn.CrossEntropyLoss()
# Optimizer for the student model
params = [p for p in student_net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)
epochs = 30
best_acc = 0.0
save_path = ('./distilled_ConvNet.pth')
train_steps = len(train_loader)
temperature = 5.0 # Temperature for distillation
for epoch in range(epochs):
student_net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
teacher_logits = teacher_net(images.to(device))
student_logits = student_net(images.to(device))
# Soften the logits
teacher_logits = teacher_logits / temperature
student_logits = student_logits / temperature
# Compute the distillation loss
loss = loss_function(torch.nn.functional.log_softmax(student_logits, dim=1),
torch.nn.functional.softmax(teacher_logits, dim=1)) * (temperature ** 2)
# Compute the classification loss
student_loss = student_loss_function(student_logits, labels.to(device))
# Combine losses
loss = 0.5 * loss + 0.5 * student_loss
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
student_net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = student_net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(student_net.state_dict(), save_path)
print('Finished Training')
if __name__ == '__main__':
main()
没有截屏,可以自己试试,测试了自建模型训练30epoch后接着蒸馏训练30epoch,val_accuracy可以到达0.81.
(5)模型文件
https://pan.baidu.com/s/1gVTJPvAQ3oDEZcGYoJvuLw
提取码: ddk5
4.总结
如果模型结果简单,可以使用蒸馏训练提升模型的准确性,当然要先训练一个教师模型.