知识蒸馏实战代码教学二(代码实战部分)

news2024/11/22 23:38:00

一、上章原理回顾

具体过程:

        (1)首先我们要先训练出较大模型既teacher模型。(在图中没有出现)

        (2)再对teacher模型进行蒸馏,此时我们已经有一个训练好的teacher模型,所以我们能很容易知道teacher模型输入特征x之后,预测出来的结果teacher_preds标签。

        (3)此时,求到老师预测结果之后,我们需要求解学生在训练过程中的每一次结果student_preds标签。

        (4)先求hard_loss,也就是学生模型的预测student_preds与真实标签targets之间的损失。

        (5)再求soft_loss,也就是学生模型的预测student_preds与教师模型teacher_preds的预测之间的损失。

        (6)求出hard_loss与soft_loss之后,求和总loss=a*hard_loss + (1-a)soft_loss,a是一个自己设置的权重参数,我在代码中设置为a=0.3。

        (7)最后反向传播继续迭代。

二、代码实现

1、数据集

        数据集采用的是手写数字的数据集mnist数据集,如果没有下载,代码部分中会进行下载,只需要把download改成True,然后就会保存在当前目录中。该数据集将其分成80%的训练集,20%的测试集,最后返回train_dataset和test_datatset。

class MyDataset(Dataset):
    def __init__(self,opt):
        self.opt = opt

    def MyData(self):
        ## mnist数据集下载0
        mnist = datasets.MNIST(
            root='../datasets/', train=True, download=False, transform=transforms.Compose(
                [transforms.Resize(self.opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        )

        dataset_size = len(mnist)
        train_size = int(0.8 * dataset_size)
        test_size = dataset_size - train_size

        train_dataset, test_dataset = random_split(mnist, [train_size, test_size])

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.opt.batch_size,
            shuffle=True,
        )

        test_dataloader = DataLoader(
            test_dataset,
            batch_size=self.opt.batch_size,
            shuffle=False,  # 在测试集上不需要打乱顺序
        )
        return train_dataloader,test_dataloader

2、teacher模型和训练实现

       (1) 首先是teacher模型构造,经过三次线性层。

import torch.nn as nn
import torch

img_area = 784

class TeacherModel(nn.Module):
    def __init__(self,in_channel=1,num_classes=10):
        super(TeacherModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(img_area,1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1, img_area)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

        (2)训练teacher模型

        老师模型训练完成后其权重参数会保存在teacher.pth当中,为以后调用。

import torch.nn as nn
import torch


## 创建文件夹
from tqdm import tqdm

from dist.TeacherModel import TeacherModel

weight_path = 'C:/Users/26394/PycharmProjects/untitled1/dist/params/teacher.pth'
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速


class TeacherTrainer():
    def __init__(self,opt,train_dataloader,test_dataloader):
        self.opt = opt
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader

    def trainer(self):
        # 老师模型
        opt = self.opt
        train_dataloader = self.train_dataloader
        test_dataloader = self.test_dataloader

        teacher_model = TeacherModel()
        teacher_model = teacher_model.to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

        for epoch in range(opt.n_epochs):  ## epoch:50
            teacher_model.train()

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                preds = teacher_model(data)
                loss = criterion(preds, targets)

                optimizer_teacher.zero_grad()
                loss = criterion(preds, targets)
                loss.backward()
                optimizer_teacher.step()

            teacher_model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = teacher_model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

            torch.save(teacher_model.state_dict(), weight_path)

        teacher_model.train()
        print('teacher: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

        (3)训练teacher模型

        模型参数都在paras()当中设置好了,直接调用teacher_model就行,然后将其权重参数会保存在teacher.pth当中。

import argparse

import torch

from dist.DistillationTrainer import DistillationTrainer
from dist.MyDateLoader import MyDataset
from dist.TeacherTrainer import TeacherTrainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def paras():
    ## 超参数配置
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")
    opt = parser.parse_args()
    ## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行
    print(opt)
    return opt


if __name__ == '__main__':
    opt = paras()
    data = MyDataset(opt)
    train_dataloader, test_dataloader = data.MyData()

    # 训练Teacher模型
    teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)
    teacher_trainer.trainer()







 3、学生模型的构建

        学生模型也是经过了三次线性层,但是神经元没有teacher当中多。所以student模型会比teacher模型小很多。

import torch.nn as nn
import torch

img_area = 784

class StudentModel(nn.Module):
    def __init__(self,in_channel=1,num_classes=10):
        super(StudentModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(img_area,20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_classes)

    def forward(self, x):
        x = x.view(-1, img_area)
        x = self.fc1(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

4、知识蒸馏训练

(1)首先读取teacher模型。

        将teacher模型中的权重参数teacher.pth放入模型当中。

 #拿取训练好的模型
        teacher_model = TeacherModel()
        if os.path.exists(weights):
            teacher_model.load_state_dict(torch.load(weights))
            print('successfully')
        else:
            print('not loading')
        teacher_model = teacher_model.to(device)

(2)设置损失求解的函数

        hard_loss用的就是普通的交叉熵损失函数,而soft_loss就是用的KL散度。

        # hard_loss
        hard_loss = nn.CrossEntropyLoss()
        # hard_loss权重
        alpha = 0.3

        # soft_loss
        soft_loss = nn.KLDivLoss(reduction="batchmean")

(3)之后再进行蒸馏训练,温度为7

  •         先求得hard_loss就是用学生模型预测的标签和真实标签进行求得损失。
  •         再求soft_loss就是用学生模型预测的标签和老师模型预测的标签进行求得损失。使用softmax时候还需要进行除以温度temp。
  •         最后反向传播,求解模型
       for epoch in range(opt.n_epochs):  ## epoch:5

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                # 老师模型预测
                with torch.no_grad():
                    teacher_preds = teacher_model(data)

                # 学生模型预测
                student_preds = model(data)
                # 计算hard_loss
                student_loss = hard_loss(student_preds, targets)

                # 计算蒸馏后的预测损失
                ditillation_loss = soft_loss(
                    F.softmax(student_preds / temp, dim=1),
                    F.softmax(teacher_preds / temp, dim=1)
                )

                loss = alpha * student_loss + (1 - alpha) * ditillation_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

        model.train()
        print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(4)整个蒸馏训练代码

import torch.nn as nn
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm

from dist.StudentModel import StudentModel
from dist.TeacherModel import TeacherModel

weights = 'C:/Users/26394/PycharmProjects/untitled1//dist/params/teacher.pth'

# D_weight_path = 'C:/Users/26394/PycharmProjects/untitled1/dist/params/distillation.pth'
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速


class DistillationTrainer():
    def __init__(self,opt,train_dataloader,test_dataloader):
        self.opt = opt
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader


    def trainer(self):
        opt = self.opt
        train_dataloader = self.train_dataloader
        test_dataloader = self.test_dataloader

        #拿取训练好的模型
        teacher_model = TeacherModel()
        if os.path.exists(weights):
            teacher_model.load_state_dict(torch.load(weights))
            print('successfully')
        else:
            print('not loading')
        teacher_model = teacher_model.to(device)
        teacher_model.eval()

        model = StudentModel()
        model = model.to(device)

        temp = 7

        # hard_loss
        hard_loss = nn.CrossEntropyLoss()
        # hard_loss权重
        alpha = 0.3

        # soft_loss
        soft_loss = nn.KLDivLoss(reduction="batchmean")

        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

        for epoch in range(opt.n_epochs):  ## epoch:5

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                # 老师模型预测
                with torch.no_grad():
                    teacher_preds = teacher_model(data)

                # 学生模型预测
                student_preds = model(data)
                # 计算hard_loss
                student_loss = hard_loss(student_preds, targets)

                # 计算蒸馏后的预测损失
                ditillation_loss = soft_loss(
                    F.softmax(student_preds / temp, dim=1),
                    F.softmax(teacher_preds / temp, dim=1)
                )

                loss = alpha * student_loss + (1 - alpha) * ditillation_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

        model.train()
        print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(5)蒸馏训练的主函数

        该部分大致与teacher模型训练类似,只是调用不同。

import argparse

import torch

from dist.DistillationTrainer import DistillationTrainer
from dist.MyDateLoader import MyDataset
from dist.TeacherTrainer import TeacherTrainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def paras():
    ## 超参数配置
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")
    opt = parser.parse_args()
    ## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行
    print(opt)
    return opt


if __name__ == '__main__':
    opt = paras()
    data = MyDataset(opt)
    train_dataloader, test_dataloader = data.MyData()

    # 训练Teacher模型
    # teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)
    # teacher_trainer.trainer()

    distillation_trainer = DistillationTrainer(opt,train_dataloader,test_dataloader)
    distillation_trainer.trainer()





三、总结

        总的来说,知识蒸馏是一种有效的模型压缩技术,可以通过在模型训练过程中引入额外的监督信号来训练简化的模型,从而获得与大型复杂模型相近的性能,但具有更小的模型尺寸和计算开销。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1457910.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++ 二分模版 数的范围

给定一个按照升序排列的长度为 n 的整数数组,以及 q 个查询。 对于每个查询,返回一个元素 k 的起始位置和终止位置(位置从 0 开始计数)。 如果数组中不存在该元素,则返回 -1 -1。 输入格式 第一行包含整数 n 和 q &…

QT串口通讯上位机_数据超时接收功能及定时发送功能设计

目录 1.概述2.本次内容最终实现3.代码部分4.完整工程文件下载 1.概述 基于《串口开发基础》 在该基础上增加超时时间接收功能,加入定时器循环; 例如,接收数据开始后,在100ms内未接收到任何数据,视作本次数据接收结束&…

数据结构第3章 串

名人说:莫道桑榆晚,为霞尚满天。——刘禹锡(刘梦得,诗豪) 本篇笔记整理:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 目录 0、思维导图1、基本概念1)主…

Java+SpringBoot:农业疾病防治新选择

✍✍计算机编程指导师 ⭐⭐个人介绍:自己非常喜欢研究技术问题!专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目:有源码或者技术上的问题欢迎在评论区一起讨论交流! ⚡⚡ Java实战 |…

NBlog个人博客部署维护过程记录 -- 后端springboot + 前端vue

项目是fork的Naccl大佬NBlog项目,页面做的相当漂亮,所以选择了这个。可以参考2.3的效果图 惭愧,工作两年了也没个自己的博客系统,趁着过年时间,开始搭建一下. NBlog原项目的github链接:Naccl/NBlog: &#…

leetcode(动态规划)53.最大子数组和(C++详细解释)DAY12

文章目录 1.题目示例提示 2.解答思路3.实现代码结果 4.总结 1.题目 给你一个整数数组 nums ,请你找出一个具有最大和的连续子数组(子数组最少包含一个元素),返回其最大和。 子数组 是数组中的一个连续部分。 示例 提示 2.解答思…

Sora一出 哪里又要裁员了?

上班前夕迎来大新闻,那就是Sora了,Sora是什么,有什么牛逼之处,怎么实现的,我们跟着官方文档透露出来的一点点信息,简单的捋一捋。 一、Sora是什么 官方给出的定义是:世界模拟器。这很明显有夸大…

数据结构:动态内存分配+内存分区+宏+结构体

一、作业 1.定义一个学生结构体&#xff0c;包含结构体成员&#xff1a;身高&#xff0c;姓名&#xff0c;成绩&#xff1b;定义一个结构体数组有7个成员&#xff0c;要求终端输入结构体成员的值&#xff0c;根据学生成绩&#xff0c;进行冒泡排序。 #include <stdio.h>…

Qt C++春晚刘谦魔术约瑟夫环问题的模拟程序

什么是约瑟夫环问题&#xff1f; 约瑟夫问题是个有名的问题&#xff1a;N个人围成一圈&#xff0c;从第一个开始报数&#xff0c;第M个将被杀掉&#xff0c;最后剩下一个&#xff0c;其余人都将被杀掉。例如N6&#xff0c;M5&#xff0c;被杀掉的顺序是&#xff1a;5&#xff…

如何利用Idea创建一个Servlet项目(新手向)

&#x1f495;"Echo"&#x1f495; 作者&#xff1a;Mylvzi 文章主要内容&#xff1a;如何利用Idea创建一个Servlet项目(新手向) Servlet是tomcat的api,利用Servlet进行webapp开发很方便,本文将介绍如何通过Idea创建一个Servlet项目(一共分为七步,这可能是我们写过的…

备战蓝桥杯---动态规划(应用1)

话不多说&#xff0c;直接看题&#xff1a; 首先我们考虑暴力&#xff0c;用二维前缀和即可&#xff0c;复杂度为o(n^4). 其实&#xff0c;我们不妨枚举任意2行&#xff0c;枚举以这个为边界的最大矩阵。 我们把其中的每一列前缀和维护出来&#xff0c;相当于把一个矩阵压缩成…

观察者模式和发布订阅模式的区别

从下图中可以看出&#xff0c;观察者模式中观察者和目标直接进行交互&#xff0c;而发布订阅模式中统一由调度中心进行处理&#xff0c;订阅者和发布者互不干扰。这样一方面实现了解耦&#xff0c;还有就是可以实现更细粒度的一些控制。比如发布者发布了很多消息&#xff0c;但…

【Vue3】搭建Pinia环境及其基本使用

下载 npm i pinia引入并注册 App.vue import { createApp } from vue import { createPinia } from pinia import App from ./App.vue // 1. 引入 import { createPinia } from piniaconst app createApp(App) // 2. 创建 const pinia createPinia() // 3. 注册 app.use(p…

python----面向对象

这里写目录标题 面向对象思想类类的定义类名的定义类的构造函数的定义类的属性类的方法定义 继承语法关于构造函数问题 文件操作绝对路径相对路径pycharm获取绝对路径和相对路径文件读写读文件open&#xff08;&#xff09;read&#xff08;&#xff09;readline&#xff08;&a…

2021年CSP-J认证 CCF信息学奥赛中小学初级组 第一轮真题-单项选择题解析

2021年 中小学信息学奥赛CSP-J真题解析 1、以下不属于面向对象程序设计语言的是 A、c B、python C、java D、c 答案&#xff1a;D 考点分析&#xff1a;主要考查编程语言&#xff0c;ABC都是面向对象语言&#xff0c;D选项c语言是面向过程语言&#xff0c;答案D 2、以下奖…

202427读书笔记|《猫的自信:治愈系生活哲学绘本》——吸猫指南书,感受猫咪的柔软慵懒与治愈

202427读书笔记|《猫的自信&#xff1a;治愈系生活哲学绘本》——吸猫指南书&#xff0c;感受猫咪的柔软慵懒与治愈 《猫的自信&#xff1a;治愈系生活哲学绘本》作者林行瑞&#xff0c;治愈系小漫画绘本&#xff0c;10分钟可以读完的一本书&#xff0c;线条明媚&#xff0c;自…

SQL注入工具之SQLmap入门操作

了解SQLmap 基础操作 SQLmap是一款自动化的SQL注入工具&#xff0c;可以用于检测和利用SQL注入漏洞。 以下是SQLmap的入门操作步骤&#xff1a; 1.下载SQLmap&#xff1a;可以从官方网站&#xff08;https://sqlmap.org/&#xff09;下载最新版本的SQLmap。 2.打开终端&#…

CDP和Chrome

CDP和Chrome CDP和WebDriver Protocol WebDriver和 Chrome DevTools Protocol&#xff08;CDP&#xff09; 是用于自动化浏览器的两个主要协议&#xff0c;大多数的浏览器自动化工具都是基于上述其中之一来实现的。可以通过这两种形式来和浏览器交互&#xff0c;通过代码来控…

使用maven集成spring在测试的时候报出了如下的异常:version 60

使用maven集成spring在测试的时候报出了如下的异常&#xff1a; Caused by: java.lang.IllegalArgumentException: Unsupported class file major version 60 解决&#xff1a;

MAC M1安装vmware和centos7虚拟机并配置静态ip

一、下载vmware和centos7镜像 1、VMWare Fusion 官网的下载地址是&#xff1a;下载地址 下载好之后注册需要秘钥&#xff0c;在官网注册后使用免费的个人秘钥 2、centos7 下载地址&#xff1a; https://biosyxh.cn:5001/sharing/pAlcCGNJf 二、虚拟机安装 直接将下…