模型压缩与迁移:基于蒸馏技术的实战教程

news2025/3/31 19:40:52

1.前言

        模型蒸馏(Model Distillation),又称为知识蒸馏(Knowledge Distillation),是一种将大型、复杂的模型(通常称为教师模型,Teacher Model)的知识转移到小型、简单模型(通常称为学生模型,Student Model)上的技术。以下是模型蒸馏的介绍、出现原因及其作用:

(1)模型蒸馏的介绍

  1. 基本概念

    • 教师模型:一个已经训练好的、性能优异的大模型。
    • 学生模型:一个较小、较简单的模型,目标是学习教师模型的行为和知识。
    • 软标签(Soft Labels):教师模型输出的概率分布,而不是简单的类别标签,这些概率分布包含了教师模型关于输入数据的丰富信息。
  2. 训练过程

    • 训练教师模型直到它达到较高的准确率。
    • 使用教师模型的输出(软标签)来训练学生模型。
    • 学生模型同时学习硬标签(实际类别标签)和软标签,以此来模拟教师模型的行为。

(2)模型蒸馏为什么会出现

        模型蒸馏的出现主要是为了解决以下问题:

  1. 模型部署:大型模型在移动设备或嵌入式系统上部署时,由于计算资源有限,难以运行。
  2. 计算效率:大型模型在训练和推理过程中需要大量的计算资源,导致速度慢、成本高。
  3. 能源消耗:大型模型在数据中心运行时消耗大量电力,不符合节能减排的要求。

(3)模型蒸馏的作用

  1. 模型压缩:通过蒸馏,可以将大型模型压缩成小型模型,减少模型的参数数量,降低存储和计算需求。
  2. 性能保持:学生模型在保持较小规模的同时,能够尽可能地接近教师模型的性能。
  3. 加速推理:小型模型在推理时更快,适用于需要快速响应的应用场景。
  4. 降低能耗:小型模型在运行时消耗更少的计算资源,有助于降低能源消耗。
  5. 跨模型迁移:蒸馏技术可以用于将知识从一个领域的模型迁移到另一个领域,实现跨领域学习。

2.准备训练代码

(1) 定义模型结构

import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

(2)训练代码

temperature

  • 这个参数用于调节教师模型和学生模型输出logits的软化程度。在代码中,temperature 被设置为 5.0。
  • 在蒸馏过程中,教师和学生的logits通过除以温度值来软化,这有助于在训练学生模型时更好地捕捉教师模型的概率分布。
  • 温度值较高时,概率分布更加平滑,有助于学生模型学习;温度值较低时,概率分布更尖锐,更接近硬标签。

loss_function

  • 这是一个用于计算蒸馏损失的函数,代码中使用的是 nn.KLDivLoss,它是Kullback-Leibler散度损失,用于测量两个概率分布之间的差异。
  • reduction='batchmean' 表示损失是通过对批次中的所有样本求平均来减少的。

student_loss_function

  • 这是用于计算学生模型在真实标签上的分类损失的函数,代码中使用的是 nn.CrossEntropyLoss,这是多分类问题中常用的损失函数。

loss 和 student_loss

  • loss 是蒸馏损失,它是通过比较软化后的学生logits和教师logits来计算的。
  • student_loss 是学生模型在真实标签上的分类损失。
  • 这两个损失通过加权平均组合起来,形成最终的训练损失,其中蒸馏损失和分类损失的权重都是0.5。

optimizer

  • 这是用于优化学生模型参数的优化器,代码中使用的是 optim.Adam,它是一种自适应学习率的优化算法。
  • params 是学生模型中需要优化的参数列表。
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from torchvision import models
from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
    # image_path = os.path.join(data_root, "data_set", "flower_data")
    image_path = "/home/trq/data/Test5_resnet/flower_data"
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # Load teacher model
    teacher_net = resnet34(num_classes=5).to(device)
    tearcher_model_weight_path = "resNet34.pth"
    assert os.path.exists(tearcher_model_weight_path), f"File '{tearcher_model_weight_path}' does not exist."
    teacher_net.load_state_dict(torch.load(tearcher_model_weight_path, map_location="cpu"),strict=False)
    teacher_net.to(device)


    # Load student model
    student_net = models.resnet18(pretrained=False)
    student_model_weight_path = "resnet18-f37072fd.pth"
    assert os.path.exists(student_model_weight_path), "file {} does not exist.".format(student_model_weight_path)
    student_net.load_state_dict(torch.load(student_model_weight_path, map_location="cpu"))
    student_net.fc = nn.Linear(student_net.fc.in_features, 5)
    student_net.to(device)
    

    # Distillation loss function
    loss_function = nn.KLDivLoss(reduction='batchmean')
    student_loss_function = nn.CrossEntropyLoss()

    # Optimizer for the student model
    params = [p for p in student_net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

    epochs = 30
    best_acc = 0.0
    save_path = ('./distilled_ConvNet.pth')
    train_steps = len(train_loader)
    temperature = 5.0  # Temperature for distillation

    for epoch in range(epochs):
        student_net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            teacher_logits = teacher_net(images.to(device))
            student_logits = student_net(images.to(device))

            # Soften the logits
            teacher_logits = teacher_logits / temperature
            student_logits = student_logits / temperature

            # Compute the distillation loss
            loss = loss_function(torch.nn.functional.log_softmax(student_logits, dim=1),
                                 torch.nn.functional.softmax(teacher_logits, dim=1)) * (temperature ** 2)

            # Compute the classification loss
            student_loss = student_loss_function(student_logits, labels.to(device))

            # Combine losses
            loss = 0.5 * loss + 0.5 * student_loss
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        student_net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = student_net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(student_net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

(3)模型和数据集的下载链接

        包含resnet18模型和resnet34模型,class_indices.json,图像等相关数据

https://pan.baidu.com/s/1ZDCbichDcdaiAH6kxYNsIA

提取码: svv5 

3.自建模型训练使用蒸馏技术训练自建模型

(1)模型结构-model_10.py

import torch
from torch import nn


class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        
        # 定义10层卷积
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 输入通道数为3,输出通道数为32
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))  # 添加自适应平均池化层
        # 全连接层
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * 1 * 1, 1024),  # 根据MaxPool的使用次数和输入图像大小计算得来的维度
            nn.ReLU(),
            nn.Linear(1024, 5)  # 输出层,5分类
        )
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = self.adaptive_pool(x)  # 应用自适应池化
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

(2)自建模型训练-train-10.py

import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model_10 import ConvNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    
    
    image_path = "/home/trq/data/Test5_resnet/flower_data"
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)
    
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    
    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)
    
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    
    net = ConvNet()
    weights_path = "ConvNet.pth"
    assert os.path.exists(weights_path), f"File '{weights_path}' does not exist."
    # model.load_state_dict(torch.load(weights_path, map_location="cpu"))
    state_dict = torch.load(weights_path, map_location="cpu")
    net.load_state_dict(state_dict,strict=False)
    net.to(device)
    
    # define loss function
    loss_function = nn.CrossEntropyLoss()
    
    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)
    
    epochs = 30
    best_acc = 0.0
    save_path = './ConvNet.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()
            
            # print statistics
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)
        
        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)

        
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))
        
        
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
    
    print('Finished Training')


if __name__ == '__main__':
    main()

  (3)训练结果

        训练60epoch后的结果,模型val_accuracy: 0.780已经是最高了

train epoch[1/30] loss:0.971: 100%|██████████| 207/207 [00:08<00:00, 24.01it/s]
valid epoch[1/30]: 100%|██████████| 23/23 [00:00<00:00, 31.44it/s]
[epoch 1] train_loss: 0.623  val_accuracy: 0.742
train epoch[2/30] loss:0.368: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[2/30]: 100%|██████████| 23/23 [00:00<00:00, 33.18it/s]
[epoch 2] train_loss: 0.604  val_accuracy: 0.736
train epoch[3/30] loss:0.661: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[3/30]: 100%|██████████| 23/23 [00:00<00:00, 32.38it/s]
[epoch 3] train_loss: 0.614  val_accuracy: 0.723
train epoch[4/30] loss:0.797: 100%|██████████| 207/207 [00:07<00:00, 26.66it/s]
valid epoch[4/30]: 100%|██████████| 23/23 [00:00<00:00, 31.70it/s]
[epoch 4] train_loss: 0.619  val_accuracy: 0.725
train epoch[5/30] loss:0.809: 100%|██████████| 207/207 [00:07<00:00, 26.87it/s]
valid epoch[5/30]: 100%|██████████| 23/23 [00:00<00:00, 32.26it/s]
[epoch 5] train_loss: 0.594  val_accuracy: 0.698
train epoch[6/30] loss:0.302: 100%|██████████| 207/207 [00:07<00:00, 26.81it/s]
valid epoch[6/30]: 100%|██████████| 23/23 [00:00<00:00, 32.49it/s]
[epoch 6] train_loss: 0.591  val_accuracy: 0.728
train epoch[7/30] loss:0.708: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[7/30]: 100%|██████████| 23/23 [00:00<00:00, 33.09it/s]
[epoch 7] train_loss: 0.589  val_accuracy: 0.720
train epoch[8/30] loss:0.709: 100%|██████████| 207/207 [00:07<00:00, 26.73it/s]
valid epoch[8/30]: 100%|██████████| 23/23 [00:00<00:00, 32.55it/s]
[epoch 8] train_loss: 0.575  val_accuracy: 0.734
train epoch[9/30] loss:0.691: 100%|██████████| 207/207 [00:07<00:00, 26.61it/s]
valid epoch[9/30]: 100%|██████████| 23/23 [00:00<00:00, 34.43it/s]
[epoch 9] train_loss: 0.555  val_accuracy: 0.734
train epoch[10/30] loss:0.442: 100%|██████████| 207/207 [00:07<00:00, 26.81it/s]
valid epoch[10/30]: 100%|██████████| 23/23 [00:00<00:00, 32.91it/s]
[epoch 10] train_loss: 0.548  val_accuracy: 0.703
train epoch[11/30] loss:0.363: 100%|██████████| 207/207 [00:07<00:00, 26.46it/s]
valid epoch[11/30]: 100%|██████████| 23/23 [00:00<00:00, 30.53it/s]
[epoch 11] train_loss: 0.550  val_accuracy: 0.728
train epoch[12/30] loss:0.519: 100%|██████████| 207/207 [00:07<00:00, 26.19it/s]
valid epoch[12/30]: 100%|██████████| 23/23 [00:00<00:00, 33.14it/s]
[epoch 12] train_loss: 0.545  val_accuracy: 0.734
train epoch[13/30] loss:0.478: 100%|██████████| 207/207 [00:07<00:00, 26.48it/s]
valid epoch[13/30]: 100%|██████████| 23/23 [00:00<00:00, 32.75it/s]
[epoch 13] train_loss: 0.532  val_accuracy: 0.755
train epoch[14/30] loss:0.573: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[14/30]: 100%|██████████| 23/23 [00:00<00:00, 33.40it/s]
[epoch 14] train_loss: 0.542  val_accuracy: 0.747
train epoch[15/30] loss:0.595: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[15/30]: 100%|██████████| 23/23 [00:00<00:00, 34.54it/s]
[epoch 15] train_loss: 0.542  val_accuracy: 0.758
train epoch[16/30] loss:0.191: 100%|██████████| 207/207 [00:07<00:00, 26.83it/s]
valid epoch[16/30]: 100%|██████████| 23/23 [00:00<00:00, 32.04it/s]
[epoch 16] train_loss: 0.532  val_accuracy: 0.761
train epoch[17/30] loss:0.566: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[17/30]: 100%|██████████| 23/23 [00:00<00:00, 33.56it/s]
[epoch 17] train_loss: 0.523  val_accuracy: 0.739
train epoch[18/30] loss:0.509: 100%|██████████| 207/207 [00:07<00:00, 26.79it/s]
valid epoch[18/30]: 100%|██████████| 23/23 [00:00<00:00, 30.35it/s]
[epoch 18] train_loss: 0.526  val_accuracy: 0.742
train epoch[19/30] loss:0.781: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[19/30]: 100%|██████████| 23/23 [00:00<00:00, 31.58it/s]
[epoch 19] train_loss: 0.506  val_accuracy: 0.764
train epoch[20/30] loss:0.336: 100%|██████████| 207/207 [00:07<00:00, 26.64it/s]
valid epoch[20/30]: 100%|██████████| 23/23 [00:00<00:00, 33.95it/s]
[epoch 20] train_loss: 0.537  val_accuracy: 0.764
train epoch[21/30] loss:0.475: 100%|██████████| 207/207 [00:07<00:00, 26.65it/s]
valid epoch[21/30]: 100%|██████████| 23/23 [00:00<00:00, 33.27it/s]
[epoch 21] train_loss: 0.511  val_accuracy: 0.764
train epoch[22/30] loss:0.513: 100%|██████████| 207/207 [00:07<00:00, 26.53it/s]
valid epoch[22/30]: 100%|██████████| 23/23 [00:00<00:00, 32.16it/s]
[epoch 22] train_loss: 0.482  val_accuracy: 0.761
train epoch[23/30] loss:0.172: 100%|██████████| 207/207 [00:07<00:00, 26.62it/s]
valid epoch[23/30]: 100%|██████████| 23/23 [00:00<00:00, 33.02it/s]
[epoch 23] train_loss: 0.501  val_accuracy: 0.761
train epoch[24/30] loss:1.127: 100%|██████████| 207/207 [00:07<00:00, 26.54it/s]
valid epoch[24/30]: 100%|██████████| 23/23 [00:00<00:00, 34.24it/s]
[epoch 24] train_loss: 0.492  val_accuracy: 0.755
train epoch[25/30] loss:0.905: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[25/30]: 100%|██████████| 23/23 [00:00<00:00, 30.22it/s]
[epoch 25] train_loss: 0.492  val_accuracy: 0.758
train epoch[26/30] loss:1.044: 100%|██████████| 207/207 [00:07<00:00, 26.75it/s]
valid epoch[26/30]: 100%|██████████| 23/23 [00:00<00:00, 33.86it/s]
[epoch 26] train_loss: 0.476  val_accuracy: 0.777
train epoch[27/30] loss:0.552: 100%|██████████| 207/207 [00:07<00:00, 26.73it/s]
valid epoch[27/30]: 100%|██████████| 23/23 [00:00<00:00, 31.55it/s]
[epoch 27] train_loss: 0.465  val_accuracy: 0.745
train epoch[28/30] loss:0.387: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[28/30]: 100%|██████████| 23/23 [00:00<00:00, 32.30it/s]
[epoch 28] train_loss: 0.482  val_accuracy: 0.769
train epoch[29/30] loss:0.251: 100%|██████████| 207/207 [00:07<00:00, 26.69it/s]
valid epoch[29/30]: 100%|██████████| 23/23 [00:00<00:00, 32.98it/s]
[epoch 29] train_loss: 0.466  val_accuracy: 0.777
train epoch[30/30] loss:0.368: 100%|██████████| 207/207 [00:07<00:00, 26.57it/s]
valid epoch[30/30]: 100%|██████████| 23/23 [00:00<00:00, 31.95it/s]
[epoch 30] train_loss: 0.467  val_accuracy: 0.780
Finished Training

(4)蒸馏训练

import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34
from model_10 import ConvNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
    # image_path = os.path.join(data_root, "data_set", "flower_data")
    image_path = "/home/trq/data/Test5_resnet/flower_data"
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))


    teacher_net = resnet34(num_classes=5).to(device)
    tearcher_model_weight_path = "resNet34.pth"
    assert os.path.exists(tearcher_model_weight_path), f"File '{tearcher_model_weight_path}' does not exist."
    teacher_net.load_state_dict(torch.load(tearcher_model_weight_path, map_location="cpu"),strict=False)
    teacher_net.to(device)

    
    # Load student model
    student_net = ConvNet()
    student_model_weight_path = "ConvNet.pth"
    assert os.path.exists(student_model_weight_path), "file {} does not exist.".format(student_model_weight_path)
    student_net.load_state_dict(torch.load(student_model_weight_path, map_location="cpu"))
    student_net.to(device)

    # Distillation loss function
    loss_function = nn.KLDivLoss(reduction='batchmean')
    student_loss_function = nn.CrossEntropyLoss()

    # Optimizer for the student model
    params = [p for p in student_net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

    epochs = 30
    best_acc = 0.0
    save_path = ('./distilled_ConvNet.pth')
    train_steps = len(train_loader)
    temperature = 5.0  # Temperature for distillation

    for epoch in range(epochs):
        student_net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            teacher_logits = teacher_net(images.to(device))
            student_logits = student_net(images.to(device))

            # Soften the logits
            teacher_logits = teacher_logits / temperature
            student_logits = student_logits / temperature

            # Compute the distillation loss
            loss = loss_function(torch.nn.functional.log_softmax(student_logits, dim=1),
                                 torch.nn.functional.softmax(teacher_logits, dim=1)) * (temperature ** 2)

            # Compute the classification loss
            student_loss = student_loss_function(student_logits, labels.to(device))

            # Combine losses
            loss = 0.5 * loss + 0.5 * student_loss
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        student_net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = student_net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(student_net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

       没有截屏,可以自己试试,测试了自建模型训练30epoch后接着蒸馏训练30epoch,val_accuracy可以到达0.81.

(5)模型文件

https://pan.baidu.com/s/1gVTJPvAQ3oDEZcGYoJvuLw

提取码: ddk5 

4.总结

        如果模型结果简单,可以使用蒸馏训练提升模型的准确性,当然要先训练一个教师模型.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2323627.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

XSS通关技巧

目录 第一关&#xff1a; 第二关&#xff1a; 第三关&#xff1a; 第四关&#xff1a; 第五关&#xff1a; 第六关&#xff1a; 第七关&#xff1a; 第八关&#xff1a; 第九关&#xff1a; 第十关&#xff1a; 第十一关&#xff1a; 第十二关&#xff1a; 第十三关&#xff1a…

el-tree树多选,将选中的树对象中某个字段值改为true,并过滤出所有为true的对象,组成新的数组

功能实现&#xff1a; el-tree树多选&#xff0c;将选中的树对象中某个字段值改为true,并过滤出所有为true的对象&#xff0c;组成新的数组提交给后端 <template><div><!-- 树形菜单 --><el-tree:data"stageList"show-checkboxdefault-expand-…

大文件版本管理git-lfs

1. 安装 Git Large File Storage (LFS) 是一个 开源的 Git 扩展&#xff0c;用于替换 Git 仓库中的大文件&#xff0c;用指针文件替代实际的大文件&#xff0c;可以在保持仓库轻量级的同时&#xff0c;有效地管理大型文件。 如果install提示失败&#xff0c;多试几次&#xf…

融合YOLO11与行为树的人机协作智能框架:动态工效学优化与自适应安全决策

人工智能技术要真正发挥其价值&#xff0c;必须与生产生活深度融合&#xff0c;为产业发展和人类生活带来实际效益。近年来&#xff0c;基于深度学习的机器视觉技术在工业自动化领域取得了显著进展&#xff0c;其中YOLO&#xff08;You Only Look Once&#xff09;算法作为一种…

Postgresql源码(142)子查询提升pull_up_sublinks

1 案例 drop table t_fun01; create table t_fun01 (image_id numeric primary key, content_id varchar(50), file_code varchar(20)); create index idx3 on t_fun01(content_id); create index idx4 on t_fun01(file_code); insert into t_fun01 select t.i, t.i%10, t.i%1…

23种设计模式-桥接(Bridge)设计模式

桥接设计模式 &#x1f6a9;什么是桥接设计模式&#xff1f;&#x1f6a9;桥接设计模式的特点&#x1f6a9;桥接设计模式的结构&#x1f6a9;桥接设计模式的优缺点&#x1f6a9;桥接设计模式的Java实现&#x1f6a9;代码总结&#x1f6a9;总结 &#x1f6a9;什么是桥接设计模式…

【黑皮书】 AVL树

目录 前言 一 AVL树的介绍 二 单旋转 二 双旋转 总结 前言 AVL树的学习 一 AVL树的介绍 AVL树是带有平衡条件的二叉查找树&#xff0c;这个平衡条件要持续保持&#xff0c;而且必须保证树的深度为O(log(N))最简单的想法就是要求左右子树具有相同的高度 一棵AVL树是…

【机器学习】什么是决策树?

什么是决策树&#xff1f; 决策树是一种用于分类和回归问题的模型。它通过一系列的“决策”将数据逐步分裂&#xff0c;最终得出预测结果。可以把它看作是一个“树”&#xff0c;每个节点表示一个特征的判断&#xff0c;而每个分支代表了可能的判断结果&#xff0c;最终的叶子…

使用独立服务器的最佳方式指南

在寻找合适的主机服务方案时&#xff0c;可以考虑独立服务器&#xff0c;因为它拥有管理员权限以及更高的性能配置。在本指南中&#xff0c;我们将介绍独立服务器的多种用途&#xff0c;并分析为什么选择独立服务器可能是处理高性能、资源密集型应用和大流量网站的最佳方案。 搭…

【HTML 基础教程】HTML 属性

HTML 属性 属性是 HTML 元素提供的附加信息。 属性通常出现在 HTML 标签的开始标签中&#xff0c;用于定义元素的行为、样式、内容或其他特性。 属性总是以 name"value" 的形式写在标签内&#xff0c;name 是属性的名称&#xff0c;value 是属性的值。 HTML 属性 …

爬虫问题整理(2025.3.27)

此时此刻&#xff0c;困扰我一天的两个问题终于得到了解决&#xff0c;在此分享给大家。 问题1&#xff1a;使用anaconda prompt无法进行pip安装&#xff0c;这里只是一个示例&#xff0c;实际安装任何模块都会出现类似报错。 解决办法&#xff1a;关掉梯子......没错&#xf…

短信验证码安全需求设计

背景&#xff1a; 近期发现部分系统再短信充值频繁&#xff0c;发现存在恶意消耗短信额度现象&#xff0c;数据库表排查&#xff0c;发现大量非合法用户非法调用短信接口API导致额度耗尽。由于系统当初设计存在安全缺陷&#xff0c;故被不法分子进行利用&#xff0c;造成损失。…

若依专题——基础应用篇

若依搭建 搭建后端项目 ① Git 克隆并初始化项目 ② MySQL 导入与配置 ③ 启动 Redis 搭建后端项目注意事项&#xff1f; ① 项目初始化慢&#xff0c;执行clean、package ② MySQL导入后&#xff0c;修改application-druid.yml ③ Redis有密码&#xff0c;修改ap…

AI for CFD入门指南(传承版)

AI for CFD入门指南 前言适用对象核心目标基础准备传承机制 AI for CFDLibtorch的介绍与使用方法PytorchAutogluon MakefileVscodeOpenFOAMParaviewGambit 前言 适用对象 新加入课题组的硕士/博士研究生对AICFD交叉领域感兴趣的本科生实习生需要快速上手组内研究工具的合作研…

DeepSeek+RAG局域网部署

已经有很多平台集成RAG模式&#xff0c;dify&#xff0c;cherrystudio等&#xff0c;这里通过AI辅助&#xff0c;用DS的API实现一个简单的RAG部署。框架主要技术栈是Chroma,langchain,streamlit&#xff0c;答案流式输出&#xff0c;并且对答案加上索引。支持doc,docx,pdf,txt。…

个人学习编程(3-24) 数据结构

括号的匹配&#xff1a; if((s[i]) && now() || (s[i]] && now[)){ #include <bits/stdc.h>using namespace std;int main() {char s[300];scanf("%s",&s);int i;int len strlen(s);stack <char> st;for (i 0; i < len; i){if(…

面试八股文--框架篇(SSM)

一、Spring框架 1、什么是spring Spring框架是一个开源的Java平台应用程序框架&#xff0c;由Rod Johnson于2003年首次发布。它提供了一种全面的编程和配置模型&#xff0c;用于构建现代化的基于Java的企业应用程序。Spring框架的核心特性包括依赖注入&#xff08;DI&#xf…

跨语言语言模型预训练

摘要 最近的研究表明&#xff0c;生成式预训练在英语自然语言理解任务中表现出较高的效率。在本研究中&#xff0c;我们将这一方法扩展到多种语言&#xff0c;并展示跨语言预训练的有效性。我们提出了两种学习跨语言语言模型&#xff08;XLM&#xff09;的方法&#xff1a;一种…

Nodejs上传文件的问题

操作系统&#xff1a;window和linux都会遇到 软件环境&#xff1a;v20.10.0的Nodejs 1、前端代码如下&#xff1a; 2、后端Nodejs 2.1、注册接口 2.2、上传接口 其中memoryUpload方法代码如下&#xff1a; 3、用页面上传文件 查看具体报错原因&#xff1a; TypeError: sourc…

无人机螺旋桨平衡标准

螺旋桨平衡是确保无人机(UAV)平稳运行、可靠性和使用寿命的关键过程。螺旋桨的不平衡会导致振动、噪音&#xff0c;并加速关键部件的磨损&#xff0c;从而对飞行性能产生负面影响。 ISO 21940-11:2016标准为旋翼平衡提供了一个广泛引用的框架&#xff0c;定义了可接受的不平衡…