超大规模分类(三):KNN softmax

news2025/1/15 3:52:41

传统的分类损失计算输入数据和每个类别中心的距离,来优化模型的训练。KNN softmax通过选择和输入数据最相关的top-K个类别,仅计算输入数据和top-K个类别中心的距离,以减小计算量。

![[Pasted image 20250109103750.png]]

KNN softmax首次诞生于达摩院机器智能技术实验室发表的SIGKDD 2020年《Large-Scale Training System for 100-Million Classification at Alibaba》

简单说下论文作者:

  • Pan Pan:潘攀,拍立淘创始人,著有《深度学习图像搜索与识别》
  • Liuyihan Song、Kang Zhao、Yiming Chen、Yingya Zhang均来自拍立淘团队
  • Yinghui Xu:徐盈辉,徐盈辉-复旦大学人工智能创新与产业(AI³)研究院 (fudan.edu.cn)
  • Rong Jin:金榕(阿里巴巴原副总裁、达摩副院长)_百度百科 (baidu.com)

问题建模

一个常见的图像分类任务整体流程如下:
![[Pasted image 20250109105251.png]]

输入图像 x i x_i xi送入Convolutional Feature Learning模块,提取图像表征 f x i ∈ R 1 × D f_{x_i}\in \mathbb{R}^{1\times D} fxiR1×D(其中 D D D表示维度),再通过Fully Connected Layer,将图像表征维度 f x i f_{x_i} fxi映射到类别数 C C C上,紧接着通过Softmax Function获取 [ 0 , 1 ] [0,1] [0,1]的概率值,计算分类损失。

我们来进行公式化定义,

(1)图像表征 f x i f_{x_i} fxi通过Fully Connected Layer将维度映射到类别数 C C C,可以建模成: f x i W ∈ R C f_{x_i}W \in \mathbb{R}^C fxiWRC,其中 W ∈ R D × C W \in \mathbb{R}^{D\times C} WRD×C。一般情况下,Fully Connected Layer会有偏置 b b b,将偏置 b b b设置为0。

(2)通过Softmax Function获取 [ 0 , 1 ] [0,1] [0,1]的概率值,得到 f x i W ∑ j e x p ( f x i W j ) \frac{f_{x_i}W}{\sum_j{exp(f_{x_i}W_j)}} jexp(fxiWj)fxiW,其中 W j ∈ R D × 1 W_j \in \mathbb{R}^{D\times 1} WjRD×1,表示第 j j j列数据,也指类别表征

(3)分类损失的定义为: L = − log ⁡ ( e x p ( f x i W y i ) ∑ j e x p ( f x i W j ) ) = − log ⁡ ( e x p ( ∥ f x i ∥ ⋅ ∥ W y i ∥ ⋅ c o s ( θ y j ) ) ∑ j e x p ( ∥ f x i ∥ ⋅ ∥ W j ∥ ⋅ c o s ( θ j ) ) ) (1) \begin{equation}\begin{aligned} L&=-\log\left(\frac{exp(f_{x_i}W_{y_i})}{\sum_j{exp(f_{x_i}W_j)}}\right)\\ &=-\log\left(\frac{exp(\|f_{x_i}\|\cdot \|W_{y_i}\|\cdot cos(\theta_{y_j}))}{\sum_j{exp(\|f_{x_i}\|\cdot\|W_j\|\cdot cos(\theta_{j}))}}\right)\\ \end{aligned} \end{equation}\tag{1} L=log(jexp(fxiWj)exp(fxiWyi))=log(jexp(fxiWjcos(θj))exp(fxiWyicos(θyj)))(1),其中 y i y_i yi指的是输入图像 x i x_i xi对应的类别下标,等式上下成立的原因是向量的内积公式 a ⋅ b = ∥ a ∥ ⋅ ∥ b ∥ ⋅ cos ⁡ θ \mathbf{a} \cdot \mathbf{b} = \|\mathbf{a}\| \cdot \|\mathbf{b}\| \cdot \cos\theta ab=abcosθ

(4)在常规实践中,图像表征 f x i f_{x_i} fxi和类别表征 W j W_j Wj一般都事先归一化好,仅需要考虑两个表征间的余弦距离。同时,需要乘上一个缩放因子,用于控制训练的激进程度,例如
L = − log ⁡ ( e x p ( α ⋅ c o s ( θ y j ) ) ∑ j e x p ( α ⋅ c o s ( θ j ) ) ) = − log ⁡ ( e x p ( α ⋅ f x i ∥ f x i ∥ ⋅ W y i ∥ W y i ∥ ) ∑ j e x p ( α ⋅ f x i ∥ f x i ∥ ⋅ W j ∥ W j ∥ ) = − log ⁡ ( e x p ( α ⋅ f x i n o r m ⋅ W i n o r m ) ∑ j e x p ( α ⋅ f x i n o r m ⋅ W j n o r m ) (2) \begin{equation}\begin{aligned} L&=-\log\left(\frac{exp(\alpha \cdot cos(\theta_{y_j}))}{\sum_j{exp(\alpha \cdot cos(\theta_{j}))}}\right)\\ &=-\log\left(\frac{exp(\alpha \cdot \frac{f_{x_i}}{\|f_{x_i}\|}\cdot \frac{W_{y_i}}{\|W_{y_i}\|})}{\sum_j{exp(\alpha \cdot \frac{f_{x_i}}{\|f_{x_i}\|}\cdot \frac{W_{j}}{\|W_{j}\|}}}\right)\\ &=-\log\left(\frac{exp(\alpha \cdot f_{x_i}^{norm} \cdot W_{_i}^{norm})}{\sum_j{exp(\alpha \cdot f_{x_i}^{norm}\cdot W_{j}^{norm}}}\right)\\ \end{aligned} \end{equation}\tag{2} L=log(jexp(αcos(θj))exp(αcos(θyj)))=log jexp(αfxifxiWjWjexp(αfxifxiWyiWyi) =log(jexp(αfxinormWjnormexp(αfxinormWinorm))(2)
,这个就是CLIP用的损失函数的形式了。

KNN softmax

全连接层的模型并行

如果特征维度是512维,分类1个亿的全连接层参数有 512 × 100000000 = 5.12 ∗ 1 0 10 512\times 100000000=5.12*10^{10} 512×100000000=5.121010。若参数存储形式为fp32,即1个参数需要4个字节,那么占用的显存为 5.12 × 1 0 10 ∗ 4 1024 × 1024 × 1024 = 191.1 G B \frac{5.12\times 10^{10}*4}{1024\times 1024\times 1024}=191.1GB 1024×1024×10245.12×10104=191.1GB

很显然,单块显卡装不下。于是,本文将全连接层参数均分到每一块显卡上。假设我们有256块V100显卡,每块显卡只需要装 191.1 G B 256 = 0.74 G B \frac{191.1 GB}{256}=0.74GB 256191.1GB=0.74GB,很显然,每块显卡的负担小得多了。

![[Pasted image 20250109210809.png]]
做法如上图所示,包括数据并行和模型并行。

  • 数据并行指的是Convolutional Feature Learning模块参数复制到每块GPU上,只有数据均分成 N N N份,送入不同GPU中。
  • 模型并行特指全连接层参数均分成 N N N份,存储到不同GPU中。
    具体流程如下:
    (1)数据均分成 N N N份,送到不同GPU中。
    (2)每块GPU上,通过Convolutional Feature Learning模块提取图像表征,再执行all-gather操作,将不同GPU的表征汇聚到每一块GPU上。(假设有3块GPU,每块GPU提取了 R 2 × 512 \mathbb{R}^{2\times 512} R2×512表征,执行all-gather操作后,将3块GPU的表征汇聚起来,分发到所有GPU上,每块GPU提取的表征变为 R 6 × 512 \mathbb{R}^{6\times 512} R6×512
    (3)第 i i i块GPU将图像表征送到第 i i i份全连接层参数上
    (4)执行分布式softmax计算,以及损失的计算
    (5)每块GPU参数反向传播,在反向传播至Convolutional Feature Learning模块前,汇聚梯度,再进一步向前传播。
    (6)参数更新时,第 i i i份全连接层参数仅通过第 i i i块GPU的梯度进行更新;Convolutional Feature Learning模块则通过全GPU的梯度进行更新。

尽管做了全连接层的模型并行,但是全连接层的计算量级实在太大,越80%的训练时间消耗在全连接层的操作上(全连接层前向传播,softmax前向传播,softmax反向传播,全连接层反向传播)

top-K类别选择

在公式(2)中,有 L = − log ⁡ ( e x p ( α ⋅ f x i n o r m ⋅ W y i n o r m ) ∑ j e x p ( α ⋅ f x i n o r m ⋅ W j n o r m ) L=-\log\left(\frac{exp(\alpha \cdot f_{x_i}^{norm} \cdot W_{y_i}^{norm})}{\sum_j{exp(\alpha \cdot f_{x_i}^{norm}\cdot W_{j}^{norm}}}\right) L=log(jexp(αfxinormWjnormexp(αfxinormWyinorm)),分类损失需要计算输入表征 f x i n o r m f_{x_i}^{norm} fxinorm和所有类别表征的余弦距离。由于类别数特别大,计算难度特别高,所以选择从中挑选 K K K个类别,进行分母的计算。

这是一个典型的检索场景,文中利用输入数据类别 y i y_i yi的类别表征 W y i W_{y_i} Wyi去检索所有类别中心表征,得到top-K个相似度最高的类别,用于分类损失的分母计算。

分布式KNN图构建

KNN图的建立可以理解为:给定query集合,以及doc集合,建立每个query到doc内最相近top-k个样本的关系。

在1亿类别分类场景,query和doc集合都等于1亿类别,建KNN图流程就特指:将每1个类别中心作为query,检索1亿个类别中心内,最相似的top-k个类别中心,构成 1 亿 × k 1亿\times k 1亿×k的相似度矩阵。

大规模检索场景常用的策略为ANN检索(Approximate Nearest Neighbor,近似最近邻检索)。但作者发现ANN对召回影响较大,导致损失偏差较大,效果不好,推荐采用暴力检索(brute-force)。

暴力检索不影响召回率,但很耗时,所以无法每个iteration更新一次,本文是每隔一个epoch更新一次KNN图。

因为模型并行,已经将全连接层均分到每块GPU上,建立KNN图是需要考虑该因素。传统的建图策略是:将所有GPU上的类别表征聚合到每块GPU上,得到完整的doc集合。计算每块GPU上的类别表征与完整doc集合的相似度矩阵,很显然,对显存消耗很高。

采用分布式建图,策略为:假设将GPU(id=0)作为query,计算KNN图,流程有:

  • 在GPU(id=0)上,计算query到GPU(id=0)上类别表征的top-k,结果传播到GPU(id=1)上
  • 在GPU(id=1)上,计算query到GPU(id=1)上类别表征的top-k,结果传播到GPU(id=2)上
  • 最后,将最终结果返回到GPU(id=0)上
    这样的处理方式对显存消耗非常小,并且GPU间的通信量也少。在这里插入图片描述

具体实现时,类别中心的存储由fp32改为fp16,并且采用TensorCore进行相似度计算加速(较原方法能加速3倍)。fp16的精度低于fp32,为平衡速度和效果,首先用fp16精度从全类别中心里搜top- k ′ k^{'} k,再利用fp32精度从top- k ′ k^{'} k中搜出top-k。

经过上述一通操作,1亿类别中心的KNN建图时间仅需0.75h。

采用和全连接层模型并行类似的策略,将KNN图按照query维度均分到每块GPU上,平均每块GPU仅需承担 372 G B / 256 = 1.45 G B 372GB/256=1.45GB 372GB/256=1.45GB,在可承受范围内。

效果比较

分别用了1百万类、1千万类、1亿类的数据进行训练,统计分类准确率和吞吐量,结果如下:
分类准确率:![[Pasted image 20250110095143.png]]

  • selective softmax:分母中通过Hashing Forest来选择k个类别,未采用KNN方式选择
  • MACH:一种加速策略,速度快,但效果不好
  • Full Softmax指的是分类损失中,分布用全类别表征计算得到
    吞吐量:![[Pasted image 20250110095201.png]]

,表明KNN Softmax能够有效提升吞吐量,类别越多,提升幅度越大。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2276785.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ubuntu官方软件包网站 字体设置

在https://ubuntu.pkgs.org/22.04/ubuntu-universe-amd64/xl2tpd_1.3.16-1_amd64.deb.html搜索找到需要的软件后,点击,下滑, 即可在Links和Download找到相关链接,下载即可, 但是找不到ros的安装包, 字体设…

项目实战——使用python脚本完成指定OTA或者其他功能的自动化断电上电测试

前言 在嵌入式设备的OTA场景测试和其他断电上电测试过程中,有的场景发生在夜晚或者随时可能发生,这个时候不可能24h人工盯着,需要自动化抓取串口日志处罚断电上电操作。 下面的python脚本可以实现自动抓取串口指定关键词,然后触发…

电脑分辨率调到为多少最佳?电脑分辨率最佳设置

电脑分辨率是指电脑屏幕上显示的像素点的数量,通常用水平和垂直方向的像素点数来表示,例如19201080。像素点越多,显示的内容就越清晰,但也会占用更多的系统资源和电力。那么多电脑分辨率多少最佳?以及电脑分辨率如何调…

代码随想录算法【Day20】

Day20 二叉搜索树 235. 二叉搜索树的最近公共祖先 理解只要当前节点的值在p和q节点的值的中间,那这个值就是最近的公共祖先,绝对不是次近的,这个题就好做了。 递归法 二叉搜索树本身是有序的,所以不涉及到前中后序的遍历 cl…

【SpringBoot】@Value 没有注入预期的值

问题复现 在装配对象成员属性时,我们常常会使用 Autowired 来装配。但是,有时候我们也使用 Value 进行装配。不过这两种注解使用风格不同,使用 Autowired 一般都不会设置属性值,而 Value 必须指定一个字符串值,因为其…

车联网安全 -- 数字证书到底证明了什么?

在车联网安全--TLS握手过程详解里面,我们了解到握手时,Server会向Client发送Server Certificate,用于证明自己的身份合法,为什么会有这一步呢? 我们回顾一下数字签名的过程: Bob将使用自己的公钥对“Hello…

Elasticsarch:使用全文搜索在 ES|QL 中进行过滤 - 8.17

8.17 在 ES|QL 中引入了 match 和 qstr 函数,可用于执行全文过滤。本文介绍了它们的作用、使用方法、与现有文本过滤方法的区别、当前的限制以及未来的改进。 ES|QL 现在包含全文函数,可用于使用文本查询过滤数据。我们将回顾可用的文本过滤方法&#xf…

【HTML+CSS+JS+VUE】web前端教程-31-css3新特性

圆角 div{width: 100px;height: 100px;background-color: saddlebrown;border-radius: 5px;}阴影 div{width: 200px;height: 100px;background-color: saddlebrown;margin: 0 auto;box-shadow: 10px 10px 20px rgba(0, 0, 0, 0.5);}

Spring Boot 项目自定义加解密实现配置文件的加密

在Spring Boot项目中, 可以结合Jasypt 快速实现对配置文件中的部分属性进行加密。 完整的介绍参照: Spring Boot Jasypt 实现application.yml 属性加密的快速示例 但是作为一个技术强迫症,总是想着从底层开始实现属性的加解密,…

若依前后端分离项目部署(使用docker)

文章目录 一、搭建后端1.1 搭建流程:1.2 后端零件:1.2.1 mysql容器创建:1.2.2 redis容器创建:1.2.3 Dockerfile内容:1.2.4 构建项目镜像:1.2.5 创建后端容器: 二、前端搭建:2.1 搭建流程&#x…

Vue2+OpenLayers使用Overlay实现点击获取当前经纬度信息(提供Gitee源码)

目录 一、案例截图 二、安装OpenLayers库 三、代码实现 关键参数: 实现思路: 核心代码: 完整代码: 四、Gitee源码 一、案例截图 二、安装OpenLayers库 npm install ol 三、代码实现 覆盖物(Overlay&#xf…

Oracle 终止正在执行的SQL

目录 一. 背景二. 操作简介三. 投入数据四. 效果展示 一. 背景 项目中要求进行性能测试,需要向指定的表中投入几百万条数据。 在数据投入的过程中发现投入的数据不对,需要紧急停止SQL的执行。 二. 操作简介 👉需要DBA权限👈 ⏹…

Oopsie【hack the box】

Oopsie 解题流程 文件上传 首先开启机器后,我们先使用 nmap -sC -SV来扫描一下IP地址: -sC:使用 Nmap 的默认脚本扫描(通常是 NSE 脚本,Nmap Scripting Engine)。这个选项会自动执行一系列常见的脚本&am…

V少JS基础班之第四弹

一、 前言 第四弹内容是操作符。 本章结束。第一个月的内容就完成了, 是一个节点。 下个月我们就要开始函数的学习了。 我们学习完函数之后。很多概念就可以跟大家补充说明了。 OK,那我们就开始本周的操作符学习 本系列为一周一更,计划历时6…

【STM32-学习笔记-7-】USART串口通信

文章目录 USART串口通信Ⅰ、硬件电路Ⅱ、常见的电平标准Ⅲ、串口参数及时序Ⅳ、STM32的USART简介数据帧起始位侦测数据采样波特率发生器 Ⅴ、USART函数介绍Ⅵ、USART_InitTypeDef结构体参数1、USART_BaudRate2、USART_WordLength3、USART_StopBits4、USART_Parity5、USART_Mode…

Docker 安装开源的IT资产管理系统Snipe-IT

一、安装 1、创建docker-compose.yaml version: 3services:snipeit:container_name: snipeitimage: snipe/snipe-it:v6.1.2restart: alwaysports:- "8000:80"volumes:- ./logs:/var/www/html/storage/logsdepends_on:- mysqlenv_file:- .env.dockernetworks:- snip…

达梦8-DMSQL程序设计学习笔记1-DMSQL程序简介

1、DMSQL程序简介 DMSQL程序是达梦数据库对标准SQL语言的扩展,是一种过程化SQL语言。在DMSQL程序中,包括一整套数据类型、条件结构、循环结构和异常处理结构等,DMSQL程序中可以执行SQL语句,SQL语句中也可以使用DMSQL函数。 DMSQ…

NLP中常见的分词算法(BPE、WordPiece、Unigram、SentencePiece)

文章目录 一、基本概念二、传统分词方法2.1 古典分词方法2.2 拆分为单个字符 三、基于子词的分词方法(Subword Tokenization)3.1 主要思想3.2 主流的 Subword 算法3.3 Subword 与 传统分词方法的比较 四、Byte Pair Encoding (BPE)4.1 主要思想4.2 算法过…

第三十六章 Spring之假如让你来写MVC——拦截器篇

Spring源码阅读目录 第一部分——IOC篇 第一章 Spring之最熟悉的陌生人——IOC 第二章 Spring之假如让你来写IOC容器——加载资源篇 第三章 Spring之假如让你来写IOC容器——解析配置文件篇 第四章 Spring之假如让你来写IOC容器——XML配置文件篇 第五章 Spring之假如让你来写…

PyTorch 深度学习框架快速入门 (小土堆)

PyTorch 深度学习框架快速入门 深度学习框架常用模块数据集存取图片数据处理库 —— PILOS 模块实例 Tensorboard 记录机器学习的过程Transform 进行图像变换数据集的下载DataLoaderModule 自定义网络前向传播卷积层卷积简单应用 最大池化非线性层线性层 简单的整合基于现有网络…