FSP:Flow of Solution Procedure (CVPR 2017) 原理与代码解析

news2024/11/15 8:54:07

paper:A Gift From Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning

code:https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/FSP.py

背景

深度神经网络DNN逐层生成特征。更高层的特征更接近于任务的有用特征。如果我们把DNN的输入看作问题,把输出看作答案,我们就可以把DNN中间生成的特征看作是求解过程中的中间结果。根据这一想法,FitNets可以让学生网络简单地模拟教师网络的中间结果。然而在DNN中,有许多方法或途径来解决从输入生成输出的问题。因此,模拟教师网络生成的特征对学生网络来说是一个硬约束hard constraint。就人而言,老师解释问题的解决过程,学生学习解决问题的流程。当输入特定的问题时,学生网络不一定需要学习中间输出,但当遇到特定类型的问题时,学生网络可以学习这一类问题的通用解决方法。因此作者认为,对于知识蒸馏中的教师网络,演示问题的解决过程比演示中间结果具有更好的泛化性

本文的创新点

本文将神经网络中层与层之间的信息流动定义为需要蒸馏的知识,并通过计算两个特征层之间的内积来得到这种知识。当将这种层之间的流动作为知识传递给学生网络时,作者通过实验得到了三个结论:

  1. 从教师网络学习这种蒸馏知识的学生网络比原始网络的优化(收敛)速度快得多。

  1. 学习这种蒸馏知识的学生网络比原始网络的性能更好。

  1. 即使教师网络是在一个不同的任务或数据集上训练得到的,学生网络也可以从教师网络中学习到这种知识,并且比从头训练的效果更好。

下图是本文提出的知识蒸馏方法的概念图

本文的贡献如下:

  1. 提出了一种知识蒸馏的新方法。

  1. 这种知识对于快速优化非常有用。

  1. 利用所提出的蒸馏知识定义网络的初始权重可以提高小模型的性能。

  1. 即使学生网络接受了与教师网络不同的训练任务,所提出的蒸馏知识也能提高学生网络的表现。

方法介绍

作者设计了网络中两个相邻层之间的FSP(flow of solution procedure)矩阵来表示问题的求解过程,对于挑选的层1输出的feature map表示为 \(F^{1}\in \mathbb{R}^{h\times w\times m}\),其中 \(h,w,m\) 分别表示特征图的高、宽、通道数。层2表示为 \(F^{2}\in \mathbb{R}^{h\times w\times n}\),则FSP矩阵 \(G\in \mathbb{R}^{m\times n}\) 可通过下式求得

其中 \(x\) 表示输入图片,\(W\) 表示网络权重参数。

对于残差网络,网络在一些位置的spatial size发生变化,我们选择教师网络和学生网络对应位置具有相同spatial size的特征图来生成FSP matrix,下图是一个示例

计算教师网络和学生网络对应FSP矩阵的L2损失,完整是损失函数如下

其中 \(\lambda_{i}\) 表示每一对FSP矩阵损失的权重,文中设定所有层计算的FSP之间的损失权重相等。\(N\) 表示所有的采样点。

代码解析

forward函数的输入g_sg_t分别表示学生网络和教师网络中所有用来计算FSP矩阵的层,在compute_fsp中每一层都与相邻层计算fsp矩阵,注意这里的相邻并不是说在原始网络中这两层的相邻的。这里相邻层之间计算fsp矩阵需要保证spatial size相等,如果不相等通过自适应平均池化使之相等。

from __future__ import print_function

import numpy as np
import torch.nn as nn
import torch.nn.functional as F


class FSP(nn.Module):
    """A Gift from Knowledge Distillation:
    Fast Optimization, Network Minimization and Transfer Learning"""
    def __init__(self, s_shapes, t_shapes):
        super(FSP, self).__init__()
        assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
        s_c = [s[1] for s in s_shapes]
        t_c = [t[1] for t in t_shapes]
        if np.any(np.asarray(s_c) != np.asarray(t_c)):
            raise ValueError('num of channels not equal (error in FSP)')

    def forward(self, g_s, g_t):
        # [(64,32,32,32),(64,64,32,32),(64,128,16,16),(64,256,8,8)]
        # [(64,32,32,32),(64,64,32,32),(64,128,16,16),(64,256,8,8)]
        s_fsp = self.compute_fsp(g_s)
        t_fsp = self.compute_fsp(g_t)
        loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)]
        return loss_group

    @staticmethod
    def compute_loss(s, t):
        return (s - t).pow(2).mean()

    @staticmethod
    def compute_fsp(g):
        fsp_list = []
        for i in range(len(g) - 1):
            bot, top = g[i], g[i + 1]  # (64,32,32,32),(64,64,32,32)
            b_H, t_H = bot.shape[2], top.shape[2]
            if b_H > t_H:
                bot = F.adaptive_avg_pool2d(bot, (t_H, t_H))
            elif b_H < t_H:
                top = F.adaptive_avg_pool2d(top, (b_H, b_H))
            else:
                pass
            bot = bot.unsqueeze(1)  # (64,1,32,32,32)
            top = top.unsqueeze(2)  # (64,64,1,32,32)
            bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1)  # (64,1,32,1024)
            top = top.view(top.shape[0], top.shape[1], top.shape[2], -1)  # (64,64,1,1024)
            fsp = (bot * top).mean(-1)  # (64,64,32,1024)->(64,64,32)
            fsp_list.append(fsp)
        return fsp_list

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/376594.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

内存数据库的设计与实现(已在大型项目中应用)

一、概况 1、设计总图 组成,由Redis集群缓存,普通缓存,传统数据库,各类数据驱动 2、内存数据库的增删改查,分页查询 组成,由数据查询,分页查询,数据存储,数据修改,数据删除 3、内存数据库的驱动 组成,由驱动适配器,普通缓存驱动,Redis缓存驱动 4、内存数据库与…

C++常见类型及占用内存表

GPS生产厂家在定义数据的时候都会有一定的数据类型&#xff0c;例如double、int、float等&#xff0c;我们知道它们在内存中都对应了一定的字节大小&#xff0c;而我在实际使用时涉及到了端序的问题&#xff08;大端序高字节在前&#xff0c;小端序低字节在前&#xff09;&…

redis主从同步:如何实现数据一致

Redis 提供了主从库模式&#xff0c;以保证数据副本的一致&#xff0c;主从库之间采用的是读写分离的方式。读操作&#xff1a;主库、从库都可以接收&#xff1b;写操作&#xff1a;首先到主库执行&#xff0c;然后&#xff0c;主库将写操作同步给从库。和mysql差不多。但是同步…

自动驾驶专题介绍 ———— 毫米波雷达

文章目录介绍工作原理特点性能参数应用厂家介绍 毫米波雷达是工作在毫米波波段探测的雷达&#xff0c;与普通雷达相似&#xff0c;是通过发射无线电信号并接收反射信号来测量物体间的距离。毫米波雷达工作频率为30~300GHz(波长为1 - 10mm)&#xff0c;波长介于厘米波和光波之间…

【数据挖掘实战】——家用电器用户行为分析及事件识别(BP神经网络)

项目地址&#xff1a;Datamining_project: 数据挖掘实战项目代码 目录 一、背景和挖掘目标 1、问题背景 2、原始数据 3、挖掘目标 二、分析方法与过程 1、初步分析 2、总体流程 第一步&#xff1a;数据抽取 第二步&#xff1a;探索分析 第三步&#xff1a;数据的预处…

为什么负责任的技术始于数据治理

每个组织都处理数据&#xff0c;但并非每个组织都将其数据用作业务资产。但是&#xff0c;随着数据继续呈指数级增长&#xff0c;将数据视为业务资产正在成为竞争优势。 埃森哲的一项研究发现&#xff0c;只有 33% 的公司“足够信任他们的数据&#xff0c;能够有效地使用它并从…

色环电阻的阻值如何识别

这种是色环电阻&#xff0c;其外表有一圈圈不同颜色的色环&#xff0c;现在在一些电器和电源电路中还有使用。下面的两种色环电阻它颜色还不一样&#xff0c;一个蓝色&#xff0c;一个土黄色&#xff0c;其实这个蓝色的属于金属膜色环电阻&#xff0c;外表涂的是一层金属膜&…

Qt新手入门指南 - 如何创建模型/视图(四)

每个UI开发人员都应该了解ModelView编程&#xff0c;本教程的目标是为大家提供一个简单易懂的介绍。Qt 是目前最先进、最完整的跨平台C开发工具。它不仅完全实现了一次编写&#xff0c;所有平台无差别运行&#xff0c;更提供了几乎所有开发过程中需要用到的工具。如今&#xff…

AJAX介绍及其应用

1.1 AJAX 简介 AJAX全称为 Asynchronous JavaScript and XML &#xff0c;就是异步的js和xml。通过AJAX可以在浏览器中向服务器发送异步请求&#xff0c;最大的优势&#xff0c;无刷新获取数据。AJAX不是新的编程语言&#xff0c;而是一种现有的标准组合再一起使用的新方式 应…

scanpy 单细胞分析API接口使用案例

参考&#xff1a;https://zhuanlan.zhihu.com/p/537206999 https://scanpy.readthedocs.io/en/stable/api.html scanpy python包主要分四个模块&#xff1a; 1&#xff09;read 读写模块、 https://scanpy.readthedocs.io/en/stable/api.html#reading 2&#xff09;pp Prepr…

springBoot自动装配原理探究springBoot配置类Thymeleaf模板引擎

微服务 微服务是一种架构风格&#xff0c;由于单体架构不利于团队协作完成并且代码量较大&#xff0c;后期维护成本较高&#xff0c;逐渐有了微服务架构。微服务是将一个项目拆分成不同的服务&#xff0c;各个服务之间相互独立互不影响&#xff0c;互相通过轻量级机制通信比如…

(转载)STM32与LAN9252构建EtherCAT从站

目录 &#xff08;一&#xff09;&#xff1a;项目简介 EtherCAT及项目简述 LAN9252工作模式 整体开发流程 移植要处理的问题 代码层面的工作 开发中使用的工具 &#xff08;二&#xff09;&#xff1a;SSC的使用 SSC简介和下载 SSC构建协议栈文件和XML &#xff08…

爬虫数据解析-正则表达式

数据解析-正则表达式 正则表达式 正则编写规则简介 字符含义.匹配除换行符以外的任意字符|A|B表示&#xff1a;匹配正则表达式条件A或B^匹配字符串的开始(在集合[]里表示"非"&#xff09;的意思$匹配字符串的结束{n}重复n次{,n}重复小于n次{n,}重复n次或更多次{n,…

2023软件测试金三银四常见的软件测试面试题-【抓包和网络协议篇】

八、抓包与网络协议 8.1 抓包工具怎么用 我原来的公司对于抓包这块&#xff0c;在App的测试用得比较多。我们会使用fiddler抓取数据检查结果&#xff0c;定位问题&#xff0c;测试安全&#xff0c;制造弱网环境; 如&#xff1a;抓取数据通过查看请求数据&#xff0c;请求行&…

经验 // 指标异常了怎么办?

本文参考了数据万花筒的文章&#xff0c;结合我自己工作经验。希望给大家一些帮助。 指标异常排查&#xff0c;是数据分析师的工作重点之一&#xff0c;是各行各业数据分析师都绕不开的话题。 本文试图回答&#xff1a; 1、指标波动的影响因素有哪些&#xff1f; 2、如何快速…

Web3中文|泰勒·斯威夫特演唱会票务闹乌龙,NFT票务急需普及

2022年底&#xff0c;美国艺人Taylor Swift&#xff08;泰勒斯威夫特&#xff09;的2023年巡回演唱会Eras Tour门票开始出票。作为当今世界最受欢迎的流行歌手之一&#xff0c;四年多没举办大型巡演无疑积攒了大量的粉丝需求。但是在2022年11月15日开放预售的当天&#xff0c;售…

数据驱动下的物种保护,拯救生命的“特效药”

如果给出这样      一张猎豹的图片      我们能否通过图中有限的信息      判断它的年龄、健康状况      以及所属族群?      如果你是一名研究动物的专家,你可能会从其花纹和斑点中获取一定量的信息,但对于大多数人以及一线的动物保护者来说,它可能只是一…

imx6ull——I2C驱动

I2C基本介绍 SCL 为高电平&#xff0c;SDA 出现下降沿:起始位 SCL 位高电平&#xff0c;SDA出现上升沿:停止位 主机——从机地址&#xff08;ack&#xff09;——寄存器地址&#xff08;ack&#xff09;——数据&#xff08;ack&#xff09; 重点&#xff1a;先是写&#xff0c…

context.Context

context.Context前言一、为什么要context二、context有什么用三、基本数据结构3.1、context包的整体工作机制3.2 基本接口和结构体3.3 API函数3.4 辅助函数3.5 context用法3.6 使用 context 传递数据的争议总结参考资料前言 context是go语言的一个并发包&#xff0c;一个标准库…

平台总线开发(id和设备树匹配)

目录 一、ID匹配之框架代码 二、ID匹配之led驱动​​​​​​​ 三、设备树匹配 四、设备树匹配之led驱动 五、一个编写驱动用的宏 一、ID匹配之框架代码 id匹配&#xff08;可想象成八字匹配&#xff09;&#xff1a;一个驱动可以对应多个设备 ------优先级次低 注意事项…