Overhaul Distillation(ICCV 2019)原理与代码解析

news2024/11/24 2:10:36

paper:A Comprehensive Overhaul of Feature Distillation

official implementation:GitHub - clovaai/overhaul-distillation: Official PyTorch implementation of "A Comprehensive Overhaul of Feature Distillation" (ICCV 2019)

本文的创新点

本文研究了知识蒸馏的各个方面,并提出了一种新的特征蒸馏方法,使蒸馏损失在教师特征变换、学生特征变换、特征蒸馏位置、距离函数各方面之间协同作用。具体来说,本文提出的蒸馏损失包括一个新设计的margin relu特征变换方法、一个新的蒸馏位置、以及一个partial L2距离函数。在ImageNet中,本文提出的方法使得ResNet-50取得了21.65%的top-1 error,优于教师网络ResNet-152的精度。

方法介绍 

蒸馏位置

激活函数是神经网络的重要组成部分,它使网络具有了非线性。但之前的大部分蒸馏方法都没有考虑到激活函数,蒸馏位置大都在于某各个layer或某个block的尾端,却没有考虑和激活函数如ReLU的关系。本文提出的方法中,蒸馏位置位于某个layer的尾端和第一个ReLU之间,如下图所示

 

pre-ReLU的位置可以使学生接触到教师模型通过ReLU之间的信息,避免了信息的分解和丢失。

损失函数

由于蒸馏位置是在ReLU之前,因此特征中的正值包含教师会利用的信息而负值没有,如果教师网络中的值是正值,学生网络应该生成和教师一样的值,如果教师网络中是负值,学生也应该生成负值从而使得激活状态和教师一致。因此作者提出的教师变换函数,保存正值同时有一个负值margin 

其中 \(m\) 是一个小于0的margin值,作者取名为margin ReLu。\(m\) 的具体值定义为每个通道的负响应值的期望,如下

 

\(m\) 一方面可以在训练过程中直接计算,如https://github.com/clovaai/overhaul-distillation/issues/7所示。也可以通过前一个BN层的参数来计算,作者在附录中给出了具体计算方法。

对于一个通道 \(\mathcal{C}\) 和教师特征 \(F^{i}_{t}\) 的第 \(i\) 个元素,该通道的margin值 \(m_{c}\) 为训练图片的期望值即式(3)。通常我们不知道 \(F^{i}_{t}\) 的分布,所以只能通过训练过程中的平均值来得到期望。但是ReLU前的BN层决定了一个batch中的特征 \(F^{i}_{t}\) 的分布,BN层将每个通道的特征归一化为均值 \(\mu\) 方差 \(\sigma\) 的高斯分布,即

每个通道的均值方差 \((\mu,\sigma)\) 对应BN层的参数 \((\beta,\gamma)\),因此利用 \(F^{i}_{t}\) 的分布可以直接计算边际值

 

利用高斯分布的概率密度函数pdf进行积分就可以得到期望,其中范围小于0。积分的结果可以通过正太分布的cdf累积分布函数 \(\Phi(\cdot)\) 进行简单的表示。

 

在官方实现中,也是通过这种方式即式(10)来计算margin值的。

因为蒸馏的位置是在ReLU函数前,negative response没有经过ReLU的过滤,因此蒸馏损失函数需要考虑到ReLU。在教师特征中,positive response实际上被网络使用,这意味着教师的正响应应该通过具体的值来传递,但负响应却不是。对于教师的负响应,如果学生的响应值高于目标值应该降低,而如果低于目标值不需要增加,因为不管具体值是多少都会被ReLU过滤掉。因此,本文提出了partial L2 distance函数,如下

完整的蒸馏损失函数如下

 

其中 \(\sigma_{m_{c}}\) 是教师转换函数margin ReLU,\(r\) 是是学生转换函数1x1 conv + BN,\(d_{p}\) 是距离函数patial L2 distance。

实验结果

在CIFAR-100数据集上,不同的教师网络和学生网络的结果如表2所示

不同的教师-学生网络组合,本文的方法和其它蒸馏方法的结果对比如下,可以看出,在所有组合下,本文提出的方法都得到了最低的error。

在ImageNet数据集上和其它方法的对比如表4,可以看出本文的方法error也是最低的。 

代码解析

实现代码主要在distiller.py中,本文的第一个创新点在蒸馏的位置,即ReLU前,实现如下 

t_feats, t_out = self.t_net.extract_feature(x, preReLU=True)
s_feats, s_out = self.s_net.extract_feature(x, preReLU=True)

学生特征的转换为1x1卷积+BN,即实现中的self.Connectors,具体实现如下

def build_feature_connector(t_channel, s_channel):
    C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False),
         nn.BatchNorm2d(t_channel)]

    for m in C:
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()

    return nn.Sequential(*C)

教师特征的转换为本文提出的margin ReLU,边际margin值的计算如下,即上述的式(10)

def get_margin_from_BN(bn):
    margin = []
    std = bn.weight.data
    mean = bn.bias.data
    for (s, m) in zip(std, mean):
        s = abs(s.item())
        m = m.item()
        if norm.cdf(-m / s) > 0.001:
            margin.append(- s * math.exp(- (m / s) ** 2 / 2) / math.sqrt(2 * math.pi) / norm.cdf(-m / s) + m)
        else:
            margin.append(-3 * s)

    return torch.FloatTensor(margin).to(std.device)

蒸馏损失函数实现如下,其中第一行就是教师特征的转换函数,即式(2)

def distillation_loss(source, target, margin):
    target = torch.max(target, margin)
    loss = torch.nn.functional.mse_loss(source, target, reduction="none")
    loss = loss * ((source > target) | (target > 0)).float()
    return loss.sum()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/643034.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【状态估计】基于数据模型融合的电动车辆动力电池组状态估计研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

LVS负载均衡与DR模式

LVS负载均衡与DR模式 一、DR模式的特点二、LVS-DR中的ARP问题1.VIP地址相同导致响应冲突2.返回报文时源地址使用VIP,导致网关设备的ARP缓存表紊乱 三、DR模式 LVS负载均衡群集部署实验准备实验部署实验步骤1.配置负载调度器(192.168.30.10)2.…

荣登第一,亚马逊云科技帮助用户实现云上快速部署,轻松维护不同类型的数据库

近期,Gartner发布了2022年全球数据库管理系统(Database Management System,DBMS)市场份额报告,在这一排名中出现了微妙变化,那就是亚马逊云科技超过微软,登上了第一“宝座”,占据了市…

MySQL数据库常用命令

mysql是不见 分号 不执行,分号表示结束。\c可以终止命令的输入。 1.登录数据库 mysql -u root -p然后在输入密码 root 2.查看数据库(以分号结尾) show databases; 3.创建数据库 pk create database pk; 4.使用数据库pk use pk; 5.删除数据库pk drop database…

【2023电工杯】B题人工智能对大学生学习影响的评价26页论文及python代码

【2023电工杯】B题人工智能对大学生学习影响的评价26页论文及python代码 1 题目 B题 人工智能对大学生学习影响的评价 人工智能简称AI,最初由麦卡锡、明斯基等科学家于1956年在美国达特茅斯学院开会研讨时提出。 2016年,人工智能AlphaGo 4:1战胜韩国…

5分钟让你明白什么是面向对象编程

相信很多刚开始接触编程的小伙伴,对于什么是面向对象,什么是面向过程都是一脸懵逼的。 网上关于这两个的回答真的很多,但是都有一个共同特点:------------不容易懂。 让我们来看看某百科给出的定义: 能不能好好说话!…

浮点数在内存中的运算

他们力量的源泉,是值得信赖的搭档以及想要保护的对象还有强大的敌人 本文收录于青花雾气-计算机基础 往期回顾 从汇编代码探究函数栈帧的创建和销毁的底层原理 从0到1搞定在线OJ 数据在内存中的存储 计算机存储的大小端模式 目录 浮点数的二进制转化及存储规…

pySCENIC单细胞转录因子分析更新:数据库、软件更新

***pySCENIC全部往期精彩系列:1、PySCENIC(一):python版单细胞转录组转录因子分析2、PySCENIC(二):pyscenic单细胞转录组转录因子分析3、PySCENIC(三):pyscen…

我的创作纪念日之这四年的收获与体会

第一次来写自己的创作纪念哈,不知不觉都已经过去整整四年了,好与不好还请大家担待: 1、机缘 1. 记得是大一、大二的时候就听学校的大牛说,可以通过写 CSDN 博客,来提升自己的代码和逻辑能力,以及后面工作…

图解LeetCode——994. 腐烂的橘子

一、题目 在给定的 m x n 网格 grid 中,每个单元格可以有以下三个值之一: 值 0 代表空单元格;值 1 代表新鲜橘子;值 2 代表腐烂的橘子。 每分钟,腐烂的橘子 周围 4 个方向上相邻 的新鲜橘子都会腐烂。 返回直到单元格…

醒醒吧,连新来的实习生都在进阶自动化,你还在点点点吗,聪明人都在提升自己!

5年测试老兵了,真的很迷茫,觉得自己不再提升自己,真的会被实习生替代。 很多朋友跟我吐槽,说自己虽然已经工作3-4年,可工作依旧是点点点,新来的实习生用一周的时间就把工作内容学会了,他的压力…

让博客支持使用 ChatGPT 生成文章摘要是一种什么样的体验?

让博客支持使用 ChatGPT 生成文章摘要是一种什么样的体验? 起因 Sakurairo 主题支持了基于 ChatGPT 的 AI 摘要功能,我有点眼红,但是因为那是个主题限定功能,而我用的又是 Argon,遂想着让 Argon 也支持 AI 摘要功能。…

【spring】spring是什么?详解它的特点与模块

作者:Insist-- 个人主页:insist--个人主页 作者会持续更新网络知识和python基础知识,期待你的关注 目录 一、spring介绍 二、spring的特点(七点) 1、简化开发 2、AOP的支持 3、声明式事务的支持 4、方便测试 5、…

springcloud 父项目建立(一)

我们开发项目,现在基本都用到maven,以及用父子项目,以及公共模块依赖,来构建方便扩展的项目体系; 首先我们建立父项目 microservice ,主要是一个pom,管理module,以及管理依赖&#x…

shell实现多并发控制

背景: 遇到一个业务需求,一个上位机需要向多个下位机传送文件,当前的实现是for循环遍历所有下位机,传送文件,但是此种方法耗时太久,需要优化。因此可以通过并发的方式向下位机传送文件。 这边写一段测试代…

【Vue3 第二十七章】路由和状态管理

一、路由 1.1 服务端路由 与 客户端路由 服务端路由 服务端路由指的是服务器根据用户访问的 URL 路径返回不同的响应结果。当我们在一个传统的服务端渲染的 web 应用中点击一个链接时,浏览器会从服务端获得全新的 HTML,然后重新加载整个页面。客户端路…

人机交互学习-2 人机交互基础知识

人机交互基础知识 交互框架作用执行/评估活动周期 EEC四个组成部分七个阶段和两个步骤 执行隔阂&评估隔阂扩展EEC模型四个部分两个阶段 交互形式命令行交互菜单驱动界面基于表格的界面直接操纵问答界面隐喻界面自然语言交互交互形式小结 理解用户信息处理模型信号处理机人类…

“秩序与自由”——超详细的低代码开发B端产品前端页面设计规范

Hi,我们是钟茂林和李星潮,来自万应低代码 UI 设计团队。 编辑搜图 编辑搜图 左:钟茂林 右:李星潮 在过去,B 端应用通常只在企业内部员工中使用,与 C 端产品数以千万计的用户相比显得少之…

Pycharm 配置Django 框架(详解篇)

首先你必须具备pycharm 专业版 / 社区版也可以 打开pycharm专业版 找到在最下方菜单栏找到 Terminal 第二步:检查自己的python版本 python --version 第三步: 寻找和自己python版本匹配的django版本 (图片来源: 化雨随风 …

【NLP模型】文本建模(2)TF-IDF关键词提取原理

一、说明 tf-idf是个可以提取文章关键词的模型;他是基于词频,以及词的权重综合因素考虑的词价值刻度模型。一般地开发NLP将包含三个层次单元:最大数据单元是语料库、语料库中有若干文章、文章中有若干词语。这样从词频上说,就有词…