PEFD-多投影蒸馏详细论文与代码解读(Improved Feature Distillation via Projector Ensemble)

news2025/2/25 17:40:26

论文链接:https://papers.nips.cc/paper_files/paper/2022/file/4ec0b6648bdf487a2f1c815924339022-Paper-Conference.pdf
源码链接:https://github.com/chenyd7/PEFD

文章目录

  • 前言
  • 一、论文核心
  • 二、论文摘要
  • 三、论文内容
  • 四、集成投影方法
  • 五、源码环境安装
  • 六、源码修改
    • 1、源码问题
    • 2、修改代码
    • 3、代码执行效果
  • 七、代码流程解读
  • 八.projector代码解读
    • 1、模型特征提取代码解读
    • 2、projector代码解读
      • 特征返回值
      • projector结构
      • 投影loss计算


前言

昨日看到蒸馏一篇蒸馏论文PEFD文章,论文提到特征蒸馏方法,本着好奇与疑问,于是我读了,有一些启示。为此,我将记录于此,改论文重点提出投影projector帮助学生模型特征空间转换,说是缓解overvit教师,我个人认为有点借助projector作为缓冲(像辅助教师)。既然读了,我将写下论文主要内容,并结合论文代码深入解读。


一、论文核心

论文背景:narrow the gap between the student and teacher’s feature spaces.various feature distillation methods have been developed by designing more powerful objective functions and determining more effective links between the layers of the
student and the teacher。缩小teacher与student特征空间gap,研究者更多聚焦目标函数(loss)或在teacher和student的layers中有效links。

解决问题:distillation model without a projector, the student network tends to overfit the teacher’s feature distributions despite having different architecture and weights initialization.缓解student模型过拟合teacher模型。

论文方法:通过特征投影projector解决,且在student模型上使用,从一个projector增加到三个projector,结构如下图。
注:projector可理解为投影projector-->代码使用nn.Linear方法。

在这里插入图片描述

二、论文摘要

先前特征蒸馏方法主要聚焦在loss函数设计和distilled layers的links,很少研究会使用projector。我们以往经验认为增加projector的特征蒸馏方法是有效的,然后我们提出投projector。我们发现即使学生和教师feature dimensions相同,基于学生with projector是有效的。我们也证明了without projector在不同学生网络架构和赋予不同初始化权重,学生网络tends to overfit教师网络,得到较差deep feature质量,影响分类结果。with projector能让学生网络更好聚焦特征extraction,能更好利用教师guidance。我们提出an ensenble of projectors进一步改善学生网络特征提取质量。实验表明,一系列teacher-student组合实验证明我们提出方法的有效性。
已有的知识蒸馏方法可以大致分为基于logit,基于特征和基于相似度的方法。根据之前研究,与其他两种算法相比,基于特征的方法通常可以提取出更好的学生网络。

三、论文内容

本文推测,模仿教师特征的过程为学生网络训练提供了更清晰的优化方向。尽管特征提取具有更好的性能,但缩小学生模型和教师模型特征空间之间差距仍然具有挑战性。为了提升学生模型特征学习能力,已经开发了各种通过设计更强大的目标函数并确定学生和教师模型层之间更有效的的连接的特征蒸馏方法。
本文发现,从学生模型到教师模型特征空间的特征投影过程在特征提取重起着关键作用,可以重新设计以提高性能。由于学生网络的特征维度并不总是有教师模型特征尺度相同,因此通常需要投影特征映射到公共空间重进行匹配。即使学生和教师网络特征维度相同,在学生网络上安装投影也能提高蒸馏性能。本文假设当最小化学生和教师模型特征差异是,添加投影进行蒸馏有助于缓解过拟合问题。此外受到添加投影进行特征提取有效性启发,提出了一个投影集合以进一步改进。直觉是具有不同初始化的投影会生成不同转换特征。因此根据集成学习理论,使用多个投影器有助于提高学生网络泛化能力。
为了匹配教师与学生模型维度,需要一个投影器projector转换学生或教师特征。本文实验中发现,将投影器强加于教师效果较差,因为来自教师原始且信息量更大的特征分布会被破坏。因此在提出蒸馏框架中,训练时投影器添加在学生模型,蒸馏训练后在被移除。

作为多任务学习的特征蒸馏,近期方法,SRRL和CID组合基于特征和基于logit损失提升性能。由于蒸馏方法对超参数和教师-学生组合敏感,额外的目标将增加系数调整的训练成本。为了缓解这个问题,本文特征蒸馏简单使用方向对齐(Direction Alignment, DA)损失:
在这里插入图片描述
本文假设没有投影器的学生网络训练过程可以被视作为在相同特征空间的多任务学习(蒸馏和分类任务)。此时学生特征倾向于过拟合教师特征,从而降低分类判别力。这里用两种测量方法验证这一假设。一个是测量学生和教师特征的差异:
在这里插入图片描述
显然,由于学生特征会直接与教师特征交互,因此在不同种子中,没有投影器的学生MDA性能显著差于有投影器的学生模型。然而通过研究学生特征空间中类别间余弦相似度,发现在没有投影器的情况下提取学生特征判别力较小。类间余弦相似度:
在这里插入图片描述
下图中可知,与没有投影器的学生网络相比,有投影器的学生网络产生了更多的判别特征。图中所示,在没有投影器情况下,学生模型往往会过度拟合在教师的特征空间。由于分类和蒸馏在同一个特征空间中执行。由于分类和蒸馏任务是在同一个特征空间中执行,因此生成的特征对于分类来说不太可区分。
在这里插入图片描述

四、集成投影方法

上述分析表明,投影器可以提高学生模型蒸馏性能。受此启发,提出了集成投影器进行进一步改进。使用多个投影器有两个动机。首先,具有不同初始化的投影器提供不同转换特征,这有利于学生的可推广性。其次,由于使用ReLU函数使投影器能够执行非线性特征提取时投影的学生特征可能包含0,而教师模型由于CNN中常用的平均池化层操作不太可能为0。也就是说,在单个投影层情况下,教师和学生模型之间特征分布差距很大,因此使用集成学习是训练误差和泛化能力之间实现良好平衡的自然方式。
在这里插入图片描述

五、源码环境安装

github上下载源码,直接安装环境对应的torch版本。
我缺少tensorflow,直接使用:

pip install tensorflow -i   https://pypi.tuna.tsinghua.edu.cn/simple some-package

我是windows10使用安装,环境即可完成。

六、源码修改

1、源码问题

我遇到问题是源码缺少self.train_data与self.train_labels,如下:

img, target = self.train_data[index], self.train_labels[index]

代码在cifar100.py第34行左右,其部分代码如下:

class CIFAR100Instance(datasets.CIFAR100):
    """CIFAR100Instance Dataset.
    """
    def __getitem__(self, index):
        if self.train:

            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image

实际是cifar图片加载问题,或许是我缺少环境,若你们能跑通,请直接忽略,否则我们进行下一步修改。

2、修改代码

尽然是图片加载除了问题,我们将修改图片加载代码即可,我采用torchvision方法加载cifar数据,修改train_student.py第170含左右,
源代码如下:

    if opt.dataset == 'cifar100':
        
        train_loader, val_loader, n_data = get_cifar100_dataloaders(batch_size=opt.batch_size,
                                                                    num_workers=opt.num_workers,
                                                                    is_instance=True)



        n_cls = 100

注释或删除上面数据加载代码,修改后代码如下:

    if opt.dataset == 'cifar100':
        import torchvision.datasets
        from torch.utils.data import DataLoader
        train_data = torchvision.datasets.CIFAR100(root="./data", train=True,
                                                   transform=torchvision.transforms.ToTensor(),
                                                   download=True)

        train_loader = DataLoader(train_data, batch_size=64)
        val_loader=train_loader

        # train_loader, val_loader, n_data = get_cifar100_dataloaders(batch_size=opt.batch_size,
        #                                                             num_workers=opt.num_workers,
        #                                                             is_instance=True)



        n_cls = 100

我们修改了数据加载部分,自然也得调整一下模型加载数据格式,位置在loops.py第90行:
源代码如下:

    end = time.time()
    for idx, data in enumerate(train_loader):
        
        input, target, index = data
        data_time.update(time.time() - end)

注释或删除上面数据加载代码,修改后代码如下:

    end = time.time()
    for idx, data in enumerate(train_loader):
        input, target = data
        # input, target, index = data
        data_time.update(time.time() - end)

3、代码执行效果

记得修改如下参数:

 parser.add_argument('--path_t', type=str, default='./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth', help='teacher model snapshot')

    # distillation
    parser.add_argument('--distill', type=str, default='ours', choices=['kd', 'ours'])
    parser.add_argument('--trial', type=str, default='1', help='trial id')

    parser.add_argument('-r', '--gamma', type=float, default=1, help='weight for classification')
    parser.add_argument('-a', '--alpha', type=float, default=0, help='weight balance for KD')
    parser.add_argument('-b', '--beta', type=float, default=25, help='weight balance for other losses')

按照以上方法,执行代码效果如下:
在这里插入图片描述

七、代码流程解读

代码流程解读,直接告知数据加载一块格式,若出现问题,只要将数据格式改成我给的格式,也可以是模型运行。
在这里插入图片描述
按照这样数据输入模型,即可运行。

八.projector代码解读

1、模型特征提取代码解读

我以源码resnet的backbone为列解读特征提取。

    def forward(self, x, is_feat=False, preact=False):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)  # 32x32
        f0 = x

        x, f1_pre = self.layer1(x)  # 32x32 x最后输出进行了relu而f1_pre没有进行relu
        f1 = x
        x, f2_pre = self.layer2(x)  # 16x16
        f2 = x
        x, f3_pre = self.layer3(x)  # 8x8
        f3 = x

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        f4 = x
        x = self.fc(x)

        if is_feat:
            if preact:
                return [f0, f1_pre, f2_pre, f3_pre, f4], x
            else:
                return [f0, f1, f2, f3, f4], x
        else:
            return x

以上代码可知,特征提取一种是有激活函数后每层特征,返回值为return [f0, f1, f2, f3, f4], x,另一种为五激活函数前特征提取,返回值为return [f0, f1_pre, f2_pre, f3_pre, f4], x。

2、projector代码解读

我以源码resnet的backbone为列解读特征提取。
首先教师模型与学生模型特征返回值代码解读。

特征返回值

preact = False
        feat_s, logit_s = model_s(input, is_feat=True, preact=preact)
        with torch.no_grad():
            feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
            feat_t = [f.detach() for f in feat_t]     

以上教师返回feat_s=[f0, f1_pre, f2_pre, f3_pre, f4], logit_s=x,学生网络与教师类似。

projector结构

projector实际是nn.linear结构,其代码如下:

class Reg(nn.Module):
    """Linear regressor"""
    def __init__(self, dim_in=1024, dim_out=1024):
        super(Reg, self).__init__()
        self.linear = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        x = self.linear(x)
        return x

投影loss计算

我们解释一下,以下loss计算公式,若为原来KD蒸馏方式为只需将opt.alpha赋权重值,opt.beta赋值为0可实现原有蒸馏方式;若使用本论文蒸馏方式,需将opt.alpha赋值为0,opt.beta赋值;

loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd

论文多个投影方法代码如下:

        # cls + kl div
        loss_cls = criterion_cls(logit_s, target)
        loss_div = criterion_div(logit_s, logit_t)

        # other kd beyond KL divergence
        if opt.distill == 'kd':
            loss_kd = 0       
        elif opt.distill == 'ours':  # 1 - cos(theta_i): average different projections 
            f_t = feat_t[-1]
            
            relu = torch.nn.ReLU() 
            # linear Regress
            f_s1 = feat_s[-1]     # 64 512
            f_s1 = module_list[1](f_s1)  # 64 256
            f_s1 = relu(f_s1)  # 64 256
            f_s2 = feat_s[-1]  # 64 512
            f_s2 = module_list[2](f_s2)            
            f_s2 = relu(f_s2)     # 64 256
            f_s3 = feat_s[-1]
            f_s3 = module_list[3](f_s3)            
            f_s3 = relu(f_s3)
            f_s = (f_s1 + f_s2 + f_s3) / 3  # 64 256
            bsz = f_s.shape[0]
            bdm = f_s.shape[1]
                        
            # inner product (normalize first and inner product)
            normft = f_t.pow(2).sum(1, keepdim=True).pow(1. / 2)
            outft = f_t.div(normft)            
            normfs = f_s.pow(2).sum(1, keepdim=True).pow(1. / 2)
            outfs = f_s.div(normfs)
            
            cos_theta = (outft * outfs).sum(1, keepdim=True)
            G_diff = 1 - cos_theta
            loss_kd = (G_diff).sum() / bsz      
        else:
            raise NotImplementedError(opt.distill)

        loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd

以上代码可知,投影实际将学生模型最后输出[batch,classs_n]通过projector结构转换,总共执行了三次,将其平均,将得到本论文提的集成投影的特征空间蒸馏。

值得注意是:论文方法没有使用loss_div = criterion_div(logit_s, logit_t)此loss。


三、四内容参考链接:https://blog.csdn.net/qgh1223/article/details/130724222

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/863070.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大厂急了,30+名企“报复性招人”

📢紧急通知!紧急通知! 2024届秋招已全面开启! 没错!你没听错!!!2024届秋招真的开始了,你还沉浸在暑假温柔乡,有些人已经拿到了offer!惊不惊喜意…

nodejs+vue+elementui小区物业管理系统_78ahx

课题主要分为四大模块:即管理员模块,物业管理模块、业主模块和维修员模块,主要功能包括:个人中心、物业管理、业主管理、维修员管理、小区公告管理、小区信息管理、房产信息管理、车位信息管理、停车位管理、停车信息管理、缴费信…

性能测试工具RunnerGo中如何管理接口

RunnerGo是一款基于go语言开发的开源测试平台,支持接口管理、自动化测试、性能测试等3大测试模块,今天给大家带来如何使用RunnerGo进行接口管理。 RunnerGo的接口管理类似于 Apipost,满足绝大多数接口管理需求。以下是使用 RunnerGo 进行接口…

储能pcb的布局注意事项与制造难点

随着新能源需求的不断增长和能源结构的转型,储能技术的市场规模不断扩大。储能PCB作为储能系统中电池模块的重要组成部分,对整个系统的安全性和性能起到关键作用。今天我们就来聊聊,储能pcb有什么特征。 什么是储能:储能是指能量…

封装Ellipsis组件,亲测使用各种场景

自己封装了Ellipsis组件 基于reacttaro,以下是实现代码,分为JSX和CSS文件 JSX代码如下: import { FC, Fragment, JSX, useState } from react; import { Image, StandardProps, Text, View } from tarojs/components;import iconDropDown fr…

2023年初中信息技术学科暑假备课

目录 2023年初中信息技术学科暑假备课1. 创意空间1.1 教师的空间1.2 学生的空间1.3 关于FTP服务器设置 2. 什么是编程2.1 编程语言2.2 人人都应学好编程2.3. 编程难吗?2.4 python用途 3. 开发环境3.1 打开IDLE3.2 IDLE窗口3.2.1 shell窗口和编辑窗口 4. 项目式教学4…

SysML V1.2 P1.概述

Scope SysML重用UML 2的子集,并提供额外的扩展来满足语言的需求。该规范根据UML 2中被重用的部分和UML 2的扩展,记录了语言体系结构。该规范包括完整语言的具体语法(符号),并指定UML 2的扩展。UML 2规范的可重用部分没有直接包含在规范中&…

解析Python面向对象:从小白到面向对象大师的蜕变之路

文章目录 一 类和对象的概念二 类的认识2.1 类的定义和使用语法2.2 成员变量和成员方法 三 类和对象3.1 类和对象的关系3.2 构造方法3.3 魔术方法概述(内置类方法)3.4 内置方法详解 四 面向对象三大特性4.1 封装4.1.1 封装的理解4.1.2 私有成员变量和方法…

致谢丨感谢有你,JumpServer开源项目九周年致谢名单

2014年到2023年,JumpServer开源项目已经走过了九年的时间。感谢以下社区贡献者对JumpServer项目的帮助和支持。 因为有你,一切才能成真。 JumpServer开源项目贡献者奖杯将于近日邮寄到以上贡献者手中,同时JumpServer开源项目组还准备了一份小…

2023年计算机科学与信息技术国际会议(ECCSIT 2023)

会议简介 Brief Introduction 2023年计算机科学与信息技术国际会议(ECCSIT 2023) 会议时间:2023年12月15日-17日 召开地点:中国北海 大会官网:www.eccsit.org 2023年计算机科学与信息技术国际会议(ECCSIT 2023)由国际电气、电子与能源工程协会…

Redis_哨兵模式

9. 哨兵模式 9.1 简介 当主库宕机,在从库中选择一个,切换为主库。 问题: 主库是否真正宕机?哪一个从库可以作为主库使用?如何实现将新的主库的信息通过给从库和客户端? 9.2 基本流程 哨兵主要任务: 监控选择主库通知 会有…

100个Java工具类之42:序列化工具类Apache之SerializationUtils

序列化工具类Apache之 org.apache.commons.lang3.SerializationUtils 众所周知:Java中序列化是指,将Java对象转换为可存储传输的字节序列的过程。 序列化作用: 1、网络传输:网络可以传输字节化的java对象 2、数据安全&#xf…

jQuery EasyUI datagrid 无记录时,增加“暂无数据“提示

我们只需要在onLoadSuccess中添加如下代码&#xff1a; if (data.total 0) {var body $(this).data().datagrid.dc.body2;body.find(table tbody).append(<tr><td width" body.width() " style"height: 35px; text-align: center;"><h…

20230811导出Redmi Note12Pro 5G手机的录音机APP的录音

20230811导出Redmi Note12Pro 5G手机的录音机APP的录音 2023/8/11 10:54 redmi note12 pro 录音文件 位置 貌似必须导出录音&#xff0c;录音的源文件不知道存储到哪里了&#xff01; 参考资料&#xff1a; https://jingyan.baidu.com/article/b87fe19e9aa79b1319356842.html 红…

以数据为中心的标记语言--yaml

&#x1f600;前言 本篇博文是关于以数据为中心的配置文件yaml的说明和应用&#xff0c;希望能够帮助到您&#x1f60a; &#x1f3e0;个人主页&#xff1a;晨犀主页 &#x1f9d1;个人简介&#xff1a;大家好&#xff0c;我是晨犀&#xff0c;希望我的文章可以帮助到大家&…

2023年电赛---运动目标控制与自动追踪系统(E题)—— 视觉部分

文章目录 一、前言二、视觉部分2.1&#xff1a;k210识别激光点2.2&#xff1a;k210识别方框和4个角点 三、总结 一、前言 &#x1f337;此次电赛我负责的部分主要是视觉&#xff0c;所以我着重和详细讲解一下视觉部分&#xff0c;不止限于此次电赛&#xff0c;而是从这次电赛视…

canfestival_主站发送同步对象触发主站PDO发送

1.入口处 2.开启定时器 3.调用定时器函数 4.切换到初始化状态&#xff0c;自动切换到预操作状态&#xff0c;最后进入操作状态 看到在预操作状态下&#xff0c;进行了通信状态的切换&#xff0c;调用相应的函数&#xff0c;如下&#xff1a; 5.调用开启SYNC的函数 查找对象字典…

让机器人懂得人类“常识”,3D语义地图能做到吗?

机器人需要一张保姆级地图。 随着机器人的智能化技术不断迭代&#xff0c;对于复杂的行为决策、人机交互等任务仅感知环境的空间几何信息已无法满足要求&#xff0c;它需要让机器人能够像人一样&#xff0c;懂得环境中的物体类别及其位置&#xff0c;即环境的语义信息。以扫地机…

字节跳动基于火山引擎DataLeap的一站式数据治理架构实践

更多技术交流、求职机会&#xff0c;欢迎关注字节跳动数据平台微信公众号&#xff0c;回复【1】进入官方交流群 在7月22日举行的 ArchSummit 全球架构师峰会&#xff08;深圳站&#xff09;上&#xff0c;来自火山引擎DataLeap的技术专家为大家带来了字节跳动基于火山引擎DataL…

day16:static、final、常量(static final)、

一、static 特点&#xff1a;属于类 、存储在方法去、只有一份或者只执行一次、随类运行执行 可以修饰静态变量 、静态方法 、静态代码块 静态变量能否继承&#xff1f; 静态变量属于类&#xff0c;是共享的资源&#xff0c;不认为是被继承的 静态变量不可以定义在静态方法中…