知识蒸馏代码实现(以MNIST手写数字体为例,自定义MLP网络做为教师和学生网络)

news2025/1/10 23:55:26

dataloader_tools.py

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

def load_data():
    # 载入MNIST训练集
    train_dataset = torchvision.datasets.MNIST(
        root = "../datasets/",
        train=True,
        transform=transforms.ToTensor(),
        download=True
    )

    # 载入MNIST测试集
    test_dataset = torchvision.datasets.MNIST(
        root = "../datasets/",
        train=False,
        transform=transforms.ToTensor(),
        download=True
    )

    # 生成训练集和测试集的dataloader
    train_dataloader = DataLoader(dataset=train_dataset,batch_size=12,shuffle=True)
    test_dataloader = DataLoader(dataset=test_dataset,batch_size=12,shuffle=False)
    return train_dataloader,test_dataloader

models.py

import torch
from torch import nn
# 教师模型
class TeacherModel(nn.Module):
    def __init__(self,in_channels=1,num_classes=10):
        super(TeacherModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,1200)
        self.fc2 = nn.Linear(1200,1200)
        self.fc3 = nn.Linear(1200,num_classes)
        self.dropout = nn.Dropout(p=0.5) #p=0.5是丢弃该层一半的神经元.
    def forward(self,x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)
        return x

class StudentModel(nn.Module):
    def __init__(self,in_channels=1,num_classes=10):
        super(StudentModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,20)
        self.fc2 = nn.Linear(20,20)
        self.fc3 = nn.Linear(20,num_classes)
    def forward(self,x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.relu(x)

        x = self.fc3(x)
        return x

train_tools.py

from torch import nn
import time
import torch
import tqdm
import torch.nn.functional as F

def train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device):
    # ----------------------开始计时-----------------------------------
    start_time = time.time()

    # 设置参数开始训练
    best_acc, best_epoch = 0, 0
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        model.train()
        # 训练集上训练模型权重
        for data, targets in tqdm.tqdm(train_dataloader):
            # 把数据加载到GPU上
            data = data.to(device)
            targets = targets.to(device)

            # 前向传播
            preds = model(data)
            loss = criterion(preds, targets)

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # 测试集上评估模型性能
        model.eval()
        num_correct = 0
        num_samples = 0
        with torch.no_grad():
            for x, y in test_dataloader:
                x = x.to(device)
                y = y.to(device)
                preds = model(x)
                predictions = preds.max(1).indices  # 返回每一行的最大值和该最大值在该行的列索引
                num_correct += (predictions == y).sum()
                num_samples += predictions.size(0)
            acc = (num_correct / num_samples).item()
            if acc > best_acc:
                best_acc = acc
                best_epoch = epoch
                # 保存模型最优准确率的参数
                torch.save(model.state_dict(), f"../weights/{model_name}_best_acc_params.pth")
        model.train()
        print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc),f'loss={loss}')
    print(f'最优准确率的epoch为{best_epoch},值为:{best_acc},最优参数已经保存到:weights/{model_name}_best_acc_params.pth')

    # -------------------------结束计时------------------------------------
    end_time = time.time()
    run_time = end_time - start_time
    # 将输出的秒数保留两位小数
    if int(run_time) < 60:
        print(f'训练用时为:{round(run_time, 2)}s')
    else:
        print(f'训练用时为:{round(run_time / 60, 2)}minutes')

def distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device):
    # -------------------------------------开始计时--------------------------------
    start_time = time.time()

    # 定以损失函数
    hard_loss = nn.CrossEntropyLoss()
    soft_loss = nn.KLDivLoss(reduction="batchmean")
    # 定义优化器
    optimizer = torch.optim.Adam(student_model.parameters(), lr=lr)

    best_acc,best_epoch = 0,0
    for epoch in range(epochs):
        student_model.train()
        # 训练集上训练模型权重
        for data,targets in tqdm.tqdm(train_dataloader):
            # 把数据加载到GPU上
            data = data.to(device)
            targets = targets.to(device)

            # 教师模型预测
            with torch.no_grad():
                teacher_preds = teacher_model(data)
            # 学生模型预测
            student_preds = student_model(data)
            # 计算hard_loss
            student_hard_loss = hard_loss(student_preds,targets)

            # 计算蒸馏后的预测结果及soft_loss
            ditillation_loss = soft_loss(
                F.softmax(student_preds/temp,dim=1),
                F.softmax(teacher_preds/temp,dim=1)
            )

            # 将hard_loss和soft_loss加权求和
            loss = temp * temp * alpha * student_hard_loss + (1-alpha)*ditillation_loss

            # 反向传播,优化权重
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        #测试集上评估模型性能
        student_model.eval()
        num_correct = 0
        num_samples = 0
        with torch.no_grad():
            for x,y in test_dataloader:
                x = x.to(device)
                y = y.to(device)
                preds = student_model(x)
                predictions = preds.max(1).indices #返回每一行的最大值和该最大值在该行的列索引
                num_correct += (predictions ==y).sum()
                num_samples += predictions.size(0)
            acc = (num_correct/num_samples).item()
            if acc>best_acc:
                best_acc = acc
                best_epoch = epoch
                # 保存模型最优准确率的参数
                torch.save(student_model.state_dict(),f"../weights/{model_name}_best_acc_params.pth")
        student_model.train()
        print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))
        print(f'student_hard_loss={student_hard_loss},ditillation_loss={ditillation_loss},loss={loss}')
    print(f'最优准确率的epoch为{best_epoch},值为:{best_acc},')

    # --------------------------------结束计时----------------------------------
    end_time = time.time()
    run_time = end_time - start_time
    # 将输出的秒数保留两位小数
    if int(run_time) < 60:
        print(f'训练用时为:{round(run_time, 2)}s')
    else:
        print(f'训练用时为:{round(run_time / 60, 2)}minutes')

训练教师网络

import torch
from torchinfo import summary #用来可视化的
import models
import dataloader_tools
import train_tools

# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True

# 载入MNIST训练集和测试集
train_dataloader,test_dataloader = dataloader_tools.load_data()

# 定义教师模型
model = models.TeacherModel()
model = model.to(device)
# 打印模型的参数
summary(model)

# 定义参数并开始训练
epochs = 10
lr = 1e-4
model_name = 'teacher'
train_tools.train(epochs,model,model_name,lr,train_dataloader,test_dataloader,device)
最优准确率的epoch为9,值为:0.9868999719619751

用非蒸馏的方法训练学生网络

import torch
from torchinfo import summary #用来可视化的
import dataloader_tools
import models
import train_tools

# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True

# 生成训练集和测试集的dataloader
train_dataloader,test_dataloader = dataloader_tools.load_data()

# 从头训练学生模型
model = models.StudentModel()
model = model.to(device)
# 查看模型参数
print(summary(model))

# 定义参数并开始训练
epochs = 10
lr = 1e-4
model_name = 'student'
train_tools.train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device)
最优准确率的epoch为9,准确率为:0.9382999539375305,最优参数已经保存到:weights/student_best_acc_params.pth
训练用时为:1.74minutes

用知识蒸馏的方法训练student model

import torch
import train_tools
import models
import dataloader_tools

# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True

# 加载数据
train_dataloader,test_dataloader = dataloader_tools.load_data()


# 加载训练好的teacher model
teacher_model = models.TeacherModel()
teacher_model = teacher_model.to(device)
teacher_model.load_state_dict(torch.load('../weights/teacher_best_acc_params.pth'))
teacher_model.eval()

# 准备新的学生模型
student_model = models.StudentModel()
student_model = student_model.to(device)
student_model.train()

# 开始训练
lr = 0.0001
epochs = 20
alpha = 0.3 # hard_loss权重
temp = 7 # 蒸馏温度
model_name = 'distill_student_loss'
# 调用train_tools中的
train_tools.distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device)

最优准确率的epoch为9,值为:0.9204999804496765,
训练用时为:2.14minutes

在这里插入图片描述

loss改为:

# temp的平方乘在student_hard_loss
loss = temp * temp * alpha * student_hard_loss + (1 - alpha) * ditillation_loss
最优准确率的epoch为9,值为:0.9336999654769897,
训练用时为:2.12minutes

loss改为:

# temp的平方乘ditillation_loss
loss = alpha * student_hard_loss + temp * temp * (1 - alpha) * ditillation_loss
最优准确率的epoch为9,值为:0.9176999926567078,
训练用时为:2.09minutes

上面的几种loss,蒸馏损失都出现了负数的情况。不太对劲。
在这里插入图片描述

其它开源的知识蒸馏算法如下:

open-mmlab开源的工具箱包含知识蒸馏算法

mmrazor

github.com/open-mmlab/mmrazor

在这里插入图片描述

NAS:神经架构搜索
剪枝:Pruning
KD: 知识蒸馏
Quantization: 量化

自定义知识蒸馏算法:
在这里插入图片描述

mmdeploy

可以把算法部署到一些厂商支持的中间格式,如ONNX,tensorRT等。

在这里插入图片描述

HobbitLong的RepDistiller

github.com/HobbitLong/RepDistiller

在这里插入图片描述
在这里插入图片描述
里面有12种最新的知识蒸馏算法。

蒸馏网络可以应用于同一种模型,将大的学习的知识蒸馏到小的上面。
如下将resnet100做教师网络,resnet32做学生网络。

在这里插入图片描述

将一种模型迁移到另一种模型上。如vgg13做教师网络,mobilNetv2做学生网络:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1273161.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Unity 注释的方法

1、单行注释&#xff1a;使用双斜线&#xff08;//&#xff09;开始注释&#xff0c;后面跟注释内容。通常注释一个属性或者方法&#xff0c;如&#xff1a; //速度 public float Speed;//打印输出 private void DoSomething() {Debug.Log("运行了我"); } …

老师旁听公开课到底听什么

经常参加公开课是老师提升自己教学水平的一种方式。那么&#xff0c;在旁听公开课时&#xff0c;老师应该听什么呢&#xff1f; 听课堂氛围 一堂好的公开课&#xff0c;应该能够让学生积极参与&#xff0c;课堂气氛活跃&#xff0c;而不是老师一个人唱独角戏。如果老师能够引导…

第16关 革新云计算:如何利用弹性容器与托管K8S实现极速服务POD扩缩容

------> 课程视频同步分享在今日头条和B站 天下武功&#xff0c;唯快不破&#xff01; 大家好&#xff0c;我是博哥爱运维。这节课给大家讲下云平台的弹性容器实例怎么结合其托管K8S&#xff0c;使用混合服务架构&#xff0c;带来极致扩缩容快感。 下面是全球主流云平台弹…

Windows系列:windows2003-建立域

windows2003-建立域 Active Directory建立DNS建立域查看日志xp 加入域 Active Directory 活动目录是一个包括文件、打印机、应用程序、服务器、域、用户账户等对象的数据库。 常见概念&#xff1a;对象、属性、容器 域组件&#xff08;Domain Component&#xff0c;DC&#x…

java操作windows系统功能案例(二)

1、打印指定文件 可以使用Java提供的Runtime类和Process类来打印指定文件。以下是一个示例代码&#xff1a; import java.io.File; import java.io.IOException;public class PrintFile {public static void main(String[] args) {if (args.length ! 1) {System.out.println(…

C# Onnx 百度飞桨开源PP-YOLOE-Plus目标检测

目录 效果 模型信息 项目 代码 下载 C# Onnx 百度飞桨开源PP-YOLOE-Plus目标检测 效果 模型信息 Inputs ------------------------- name&#xff1a;image tensor&#xff1a;Float[1, 3, 640, 640] name&#xff1a;scale_factor tensor&#xff1a;Float[1, 2] ----…

HuggingFace学习笔记--Model的使用

1--Model介绍 Transformer的 model 一般可以分为&#xff1a;编码器类型&#xff08;自编码&#xff09;、解码器类型&#xff08;自回归&#xff09;和编码器解码器类型&#xff08;序列到序列&#xff09;&#xff1b; Model Head&#xff08;任务头&#xff09;是在base模型…

Windows11如何让桌面图标的箭头消失(去掉快捷键箭头)

在Windows 11中&#xff0c;桌面图标的箭头是快捷方式图标的一个标志&#xff0c;用来表示该图标是一个指向文件、文件夹或程序的快捷方式。如果要隐藏这些箭头&#xff0c;你需要修改Windows注册表或使用第三方软件。 在此之前&#xff0c;我需要提醒你&#xff0c;修改注册表…

【unity实战】如何更加规范的创建各种Rogue-Lite(肉鸽)风格的物品和BUFF效果(附项目源码)

文章目录 前言定义基类实现不同的BUFF效果一、回血BUFF1. 简单的回血效果实现2. BUFF层数控制回血量 二、攻击附带火焰伤害三、治疗领域1. 简单的治疗领域实现2. 添加技能冷却时间 通过拾取物品获取对应的BUFF参考源码完结 前言 当创建各种Rogue-Lite&#xff08;肉鸽&#xf…

VS2022使用Vim按键

VS2022使用Vim按键 在插件管理里面搜索VsVim 点击安装&#xff0c;重启VS 工具->选项->VsVim 配置按键由谁处理&#xff0c;建议Ctrl C之类常用的使用VS处理&#xff0c;其它使用Vim处理

shell编程系列(7)-使用wc进行文本统计

文章目录 前言wc命令的使用wc命令的参数说明&#xff1a;统计字数统计行数打印文本行号 结语 前言 统计功能也是我们在shell编程中经常碰到的一个需求&#xff0c;wc命令可以适用于任何需要统计的数据&#xff0c;不只是统计文本&#xff0c;配合ls命令我们可以统计文件的个数…

electron调用dll问题总汇

通过一天的调试安装&#xff0c;electron调用dll成功&#xff0c;先列出当前的环境&#xff1a;node版本: 18.12.0&#xff0c;32位的&#xff08;因为dll为32位的&#xff09; VS2019 python node-gyp 1、首先要查看报错原因&#xff0c;通常在某一行会有提示&#xff0c;常…

在Linux上安装KVM虚拟机

一、搭建KVM环境 KVM&#xff08;Kernel-based Virtual Machine&#xff09;是一个基于内核的系统虚拟化模块&#xff0c;从Linux内核版本2.6.20开始&#xff0c;各大Linux发行版就已经将其集成于发行版中。KVM与Xen等虚拟化相比&#xff0c;需要硬件支持的完全虚拟化。KVM由内…

vue3 router-view 使用keep-alive报错parentcomponent.ctx.deactivate is not a function

问题 如下图&#xff0c;在component组件上添加v-if判断&#xff0c;会报错: parentcomponent.ctx.deactivate is not a function 解决方法 去除v-if&#xff0c;将key直接添加上。由于有的公用页面&#xff0c;需要刷新&#xff0c;不希望缓存&#xff0c;所以需要添加key…

2023/11/30JAVAweb学习

数组json形式 想切换实现类,只需要只在你需要的类上添加 Component 如果在同一层,可以更改扫描范围,但是不推荐这种方法 注入时存在多个同类型bean解决方式

C 中的结构 - 存储、指针、函数和自引用结构

0. 结构体的内存分配 当声明某种类型的结构变量时&#xff0c;结构成员被分配连续&#xff08;相邻&#xff09;的内存位置。 struct student{char name[20];int roll;char gender;int marks[5];} stu1; 此处&#xff0c;内存将分配给name[20]、roll、gender和marks[5]。st1这…

Redis学习文档

目录 一、概念1、特征2、关系型数据库和非关系型数据库的区别3、键的结构4、Redis的Java客户端5、缓存更新策略5.1、概念5.2、代码 6、缓存穿透6.1、含义6.2、解决办法6.3、缓存空值代码举例6.4、布隆过滤器代码举例 7、缓存击穿7.1、概念7.2、解决办法7.3、互斥锁代码举例7.4、…

卡码网语言基础课 | 17. 判断集合成员

目录 一、 set 集合 二、 创建集合 2.1 引入头文件 2.2 创建 2.3 插入元素 2.4 删除元素 三、 find的用法 四、 实现基本解题 五、 延伸拓展 题目&#xff1a;编写一个程序&#xff0c;判断给定的整数 n 是否存在于给定的集合中。 输入描述&#xff1a; 有多组测试…

Pycharm中使用matplotlib绘制动态图形

Pycharm中使用matplotlib绘制动态图形 最终效果 最近用pycharm学习D2L时发现官方在jupyter notebook交互式环境中能动态绘制图形&#xff0c;但是在pycharm脚本环境中只会在最终 plt.show() 后输出一张静态图像。于是有了下面这段自己折腾了一下午的代码&#xff0c;用来在pych…

jetson nano SSH远程连接(使用MobaXterm)

文章目录 SSH远程连接1.SSH介绍2.准备工作3.连接步骤3.1 IP查询3.2 新建会话和连接 SSH远程连接 本节课的实现&#xff0c;需要将Jetson Nano和电脑保持在同一个局域网内&#xff0c;也就是连接同一个路 由器&#xff0c;通过SSH的方式来实现远程登陆。 1.SSH介绍 SSH是一种网…