【LM、LLM】浅尝二叉树在前馈神经网络上的应用

news2024/11/15 15:53:27

前言

随着大模型的发展,模型参数量暴涨,以Transformer的为组成成分的隐藏神经元数量增长的越来越多。因此,降低前馈层的推理成本逐渐进入视野。前段时间看到本文介绍的相关工作还是MNIST数据集上的实验,现在这个工作推进到BERT上面来了,再次引起兴趣记录一下。该工作将前馈神经基于二叉树结构进行改装,加速前向传播的速度,称为:快速前馈网络(FFF),然后应用FFF,取代BERT中的前馈网络(FF),实现12个神经元加速推理。

快速前馈网络算法概述

快速前馈网络(Fast Feedforward Network,FFF)是由两部分组成的:节点网络集合 N \mathcal{N} N 和叶子网络集合 L \mathcal{L} L

  • 节点网络集合 N \mathcal{N} N 包含了一组节点网络,每个节点网络都是一个 < dim ⁡ I , n , 1 > \left<\dim_I,n,1\right> dimI,n,1-前馈网络,并在输出上增加了一个 sigmoid 激活函数。这些节点网络按照平衡的可微分二叉树的形式排列,其中 N m + 1 , 2 n N_{m+1,2n} Nm+1,2n N m + 1 , 2 n + 1 N_{m+1,2n+1} Nm+1,2n+1 N m , n N_{m,n} Nm,n 的子节点。

  • 叶子网络集合 L \mathcal{L} L 包含了一组叶子网络,每个叶子网络都是一个 < dim ⁡ I , ℓ , dim ⁡ O > \left<\dim_I,\ell,\dim_O\right> dimI,,dimO-前馈网络。叶子网络没有子节点,它们的输出直接作为 FFF 的输出。

前向传播过程由下面算法控制。

算法的输入包括一个输入样本 ι \iota ι 和根节点 N 0 , 0 N_{0,0} N0,0,输出为该样本在 FFF 中的输出。

算法定义了两个函数: F o r w a r d T Forward_T ForwardT F o r w a r d I {Forward}_I ForwardI。其中, F o r w a r d T {Forward}_T ForwardT 函数用于计算节点的输出,而 F o r w a r d I {Forward}_I ForwardI 函数用于计算节点的指示值(indicator value)。

  • F o r w a r d T {Forward}_T ForwardT 函数中,如果当前节点是叶子节点,则直接调用该节点的前馈传播函数 N m , n ( ι ) N_{m,n}(\iota) Nm,n(ι) 来计算输出。否则,首先计算当前节点的输出 c m , n = N m , n ( ι ) c_{m,n}=N_{m,n}(\iota) cm,n=Nm,n(ι),然后递归地调用 F o r w a r d T {Forward}_T ForwardT 函数来计算当前节点的两个子节点的输出,并将它们加权相加作为当前节点的输出。
  • F o r w a r d I {Forward}_I ForwardI 函数中,如果当前节点是叶子节点,则直接调用该节点的前馈传播函数 N m , n ( ι ) N_{m,n}(\iota) Nm,n(ι) 来计算输出。否则,首先计算当前节点的输出 c m , n = N m , n ( ι ) c_{m,n}=N_{m,n}(\iota) cm,n=Nm,n(ι),然后根据输出值的大小决定选择哪个子节点进行递归计算。


传统前馈神经网络

快速前馈神经网络

与传统的前馈神经网络算法相比,该算法的主要区别在于引入了一个计算节点的指示值。指示值表示了当前节点的输出是否大于等于阈值(这里的阈值为0.5),根据指示值的大小来确定选择哪个子节点进行计算。这种方式可以大大减少计算量,提高前向传播的效率。同时,FFF 是一种具有平衡二叉树结构的前馈神经网络,其中节点网络和叶子网络分别用于处理中间层和输出层的计算。通过利用二叉树结构和递归计算,FFF 可以实现快速的前向传播。

UltraFastBERT

UltraFastBERT,一种BERT变体,在推理过程中使用0.3%的神经元,同时表现 与类似的BERT模型相当。UltraFastBERT选择性地使用4095个神经元中的12个(有选择的执行矩阵乘法(CMM))进行每层推理。这是通过用快速前馈网络(FFFs)取代前馈网络来实现的。

FFF_BMM代码

import torch
from torch import nn
import math

class FFF(nn.Module):
	def __init__(self, input_width: int, depth: int, output_width: int, *args, **kwargs):
		super().__init__(*args, **kwargs)

		self.input_width = input_width
		self.depth = depth
		self.output_width = output_width

		self.n_nodes = 2 ** (depth + 1) - 1
		self.initialise_weights()

	def initialise_weights(self):
		init_factor_l1 = 1.0 / math.sqrt(self.input_width)
		init_factor_l2 = 1.0 / math.sqrt(self.depth + 1)
		self.w1s = nn.Parameter(torch.empty(self.n_nodes, self.input_width).uniform_(-init_factor_l1, +init_factor_l1), requires_grad=True)
		self.w2s = nn.Parameter(torch.empty(self.n_nodes, self.output_width).uniform_(-init_factor_l2, +init_factor_l2), requires_grad=True)

	def forward(self, x):
		# the shape of x is (batch_size, input_width)
		# retrieve the indices of the relevant elements
		batch_size = x.shape[0]
		current_nodes = torch.zeros((batch_size,), dtype=torch.long, device=x.device)
		all_nodes = torch.zeros(batch_size, self.depth+1, dtype=torch.long, device=x.device)
		all_logits = torch.empty((batch_size, self.depth+1), dtype=torch.float, device=x.device)

		for i in range(self.depth+1):
			all_nodes[:, i] = current_nodes
			plane_coeffs = self.w1s.index_select(dim=0, index=current_nodes)			# (batch_size, input_width)
			plane_coeff_score = torch.bmm(x.unsqueeze(1), plane_coeffs.unsqueeze(-1))	# (batch_size, 1, 1)
			plane_score = plane_coeff_score.squeeze(-1).squeeze(-1) 					# (batch_size,)
			all_logits[:, i] = plane_score
			plane_choices = (plane_score >= 0).long()									# (batch_size,)

			current_nodes = current_nodes * 2 + plane_choices + 1						# (batch_size,)

		# get the weights
		selected_w2s = self.w2s.index_select(0, index=all_nodes.flatten()) \
			.view(batch_size, self.depth+1, self.output_width)	# (batch_size, depth+1, output_width)

		# forward pass
		mlp1 = torch.nn.functional.gelu(all_logits)				# (batch_size, depth+1)
		mlp2 = torch.bmm(mlp1.unsqueeze(1), selected_w2s) 		# (batch_size, output_width)
		
		# done
		return mlp2
	

从代码可以看出,与传统的批矩阵乘法(BMM)不同的是,在forward中,基于二叉树的结构,通过迭代计算节点的索引和权重,使用激活函数(GeLU)对结果进行处理,并最终得到输出。

结果

在推理过程中仅使用0.3%的神经元,同时表现与类似的BERT模型相当(下游任务没有降很多点);实现78倍CPU加速,实现40倍PyTorch加速。

总结

该工作很有趣,将传统前馈神经网络定义成一棵二叉树,其叶子是小型神经网络,在每个非叶子节点处都有一个微小的神经网络(单个神经元也可以工作)来决定走哪条路径取决于在输入上。在训练期间,它们对所选路径进行加权平均值,从而得出树的所有叶子(在输入上评估为神经网络)的总加权平均值,但在推理过程中,它们可以只遵循投票最高的分支,从而得出建议的结果指数加速。并且,基于FFF的思想,将工作推到BERT这种语言模型上,初步证明了大模型的前馈层的神经元并不是都需要参与推理。

文章及公开的代码还介绍了条件矩阵乘法的详细细节,因此感兴趣可以深入研究一下。

参考文献

【1】paper:Exponentially Faster Language Modelling,https://arxiv.org/abs/2311.10770
【2】code:https://github.com/pbelcak/fastfeedforward
【3】paper:Fast Feedforward Networks,https://arxiv.org/abs/2308.14711

【4】code:https://github.com/pbelcak/UltraFastBERT
【5】model:https://huggingface.co/pbelcak/UltraFastBERT-1x11-long

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1249880.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java程序员面试专栏 专业技能篇】Java SE核心面试指引(三):核心机制策略

关于Java SE部分的核心知识进行一网打尽,包括四部分:基础知识考察、面向对象思想、核心机制策略、Java新特性,通过一篇文章串联面试重点,并且帮助加强日常基础知识的理解,全局思维导图如下所示 本篇Blog为第三部分:核心机制策略,子节点表示追问或同级提问 异常处理 …

机器学习之自监督学习(四)MoCo系列翻译与总结(一)

Momentum Contrast for Unsupervised Visual Representation Learning Abstract 我们提出了“动量对比”&#xff08;Momentum Contrast&#xff0c;MoCo&#xff09;来进行无监督的视觉表示学习。从对比学习的角度来看&#xff0c;我们将其视为字典查找&#xff0c;通过构建…

Spring - Mybatis-设计模式总结

Mybatis-设计模式总结 1、Builder模式 2、工厂模式 3、单例模式 4、代理模式 5、组合模式 6、模板方法模式 7、适配器模式 8、装饰者模式 9、迭代器模式 虽然我们都知道有26个设计模式&#xff0c;但是大多停留在概念层面&#xff0c;真实开发中很少遇到&#xff0c;…

Day31| Leetcode 455. 分发饼干 Leetcode 376. 摆动序列 Leetcode 53. 最大子数组和

进入贪心了&#xff0c;我觉得本专题是最烧脑的专题 Leetcode 455. 分发饼干 题目链接 455 分发饼干 让大的饼干去满足需求量大的孩子即是本题的思路&#xff1a; class Solution { public:int findContentChildren(vector<int>& g, vector<int>& s) {…

【差分放大电路分析】2021-12-31

缘由有哪位愿意帮助一下的-嵌入式-CSDN问答 截图&#xff0c;数值自己去计算。上2图是接电阻&#xff0c;下2图是接三极管。

JVM内存模型及调优

本文将为大家详细介绍JVM内存模型及如何对JVM内存进行调优。我们将分为以下几个部分进行讲解&#xff1a; JVM内存模型概述JVM内存区域及作用JVM内存调优方法实战案例与优化技巧 一、JVM内存模型概述 在深入了解JVM内存模型之前&#xff0c;我们需要先了解一下Java内存模型&am…

01、Tensorflow实现二元手写数字识别

01、Tensorflow实现二元手写数字识别&#xff08;二分类问题&#xff09; 开始学习机器学习啦&#xff0c;已经把吴恩达的课全部刷完了&#xff0c;现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣&#xff0c;作为入门的素材非常合适。 基于Tensorflow 2.10.0 1、…

NeurIPS 2023|AI Agents先行者CAMEL:第一个基于大模型的多智能体框架

AI Agents是当下大模型领域备受关注的话题&#xff0c;用户可以引入多个扮演不同角色的LLM Agents参与到实际的任务中&#xff0c;Agents之间会进行竞争和协作等多种形式的动态交互&#xff0c;进而产生惊人的群体智能效果。本文介绍了来自KAUST研究团队的大模型心智交互CAMEL框…

浅谈安科瑞无线测温设备在挪威某项目的应用

摘要&#xff1a;安科瑞无线温度设备装置通过无线温度收发器和各无线温度传感器直接进行温度值的传输&#xff0c;并采用液晶显示各无线温度传感器所测温度。 Absrtact:Acre wireless temperature device directly transmits the temperature value through the wireless temp…

Nginx安装与配置、使用Nginx负载均衡及动静分离、后台服务部署、环境准备、系统拓扑图

目录 1. 系统拓扑图 2. 环境准备 3. 服务器安装 3.1 mysql&#xff0c;tomcat 3.2 Nginx的安装 4. 部署 4.1 后台服务部署 4.2 Nginx配置负载均衡及静态资源部署 1. 系统拓扑图 说明&#xff1a; 用户请求达到Nginx若请求资源为静态资源&#xff0c;则将请求转发至静态…

【蓝桥杯省赛真题47】Scratch小猫踩球 蓝桥杯scratch图形化编程 中小学生蓝桥杯省赛真题讲解

目录 scratch小猫踩球 一、题目要求 编程实现 二、案例分析 1、角色分析

vue3.0使用leaflet

1、获取天地图密钥&#xff1b; 访问:https://www.tianditu.gov.cn/ 注册并登录&#xff0c;访问开发资源 》地图API 》 地图服务》申请key 应用管理》创建新应用》获取到对应天地图key 2、引入leaflet组件 参考资料&#xff1a;https://leafletjs.com/reference.html#pa…

一盏茶的时间,入门 Node.js

一、.什么是 Node.js&#xff1f; Node.js 是一个基于 Chrome V8 引擎的 JavaScript 运行时&#xff0c;用于构建高性能、可伸缩的网络应用。 它采用事件驱动、非阻塞 I/O 模型&#xff0c;使其在处理并发请求时表现出色。 二、安装 Node.js 首先&#xff0c;让我们从 Node.…

CSS3新特性(2-1)

CSS3新特性 前言border&#xff1a;radius标签属性选择器box-sizing透明度 前言 本文主要讲解CSS3有哪些新的特性和内容&#xff0c;那么好&#xff0c;本文正式开始. border&#xff1a;radius 新增了圆角边框概念&#xff0c;可以通过具体数值或者百分比&#xff0c;来让边…

互联网上门洗鞋店小程序

上门洗鞋店小程序门店版是基于原平台版进行增强的&#xff0c;结合洗鞋行业的线下实际运营经验和需求&#xff0c;专为洗鞋人和洗鞋店打造的高效、实用、有价值的管理软件系统。 它能够帮助洗鞋人建立自己的私域流量&#xff0c;实现会员用户管理&#xff0c;实现用户与商家的点…

电源控制系统架构(PCSA)之电源控制框架概览

目录 6 电源控制框架 6.1 电源控制框架概述 6.1.1 电源控制框架低功耗接口 6.1.2 电源控制框架基础设施组件 6 电源控制框架 电源控制框架是标准基础设施组件、接口和相关方法的集合&#xff0c;可用于构建SoC电源管理所需的基础设施。 本章介绍框架的主要组件和低功耗接…

FFmpeg零基础学习(一)——初步介绍与环境搭建

目录 前言正文一、开发环境二、搭建环境二、测试代码 参考 前言 FFmpeg是一个开源的跨平台多媒体处理框架&#xff0c;它包含了一组用于处理音频、视频、字幕等多媒体数据的库和工具。FFmpeg提供了强大的功能和灵活性&#xff0c;被广泛用于多媒体应用开发、视频编辑、流媒体传…

每日OJ题_算法_双指针_力扣11. 盛最多水的容器

力扣11. 盛最多水的容器 11. 盛最多水的容器 - 力扣&#xff08;LeetCode&#xff09; 难度 中等 给定一个长度为 n 的整数数组 height 。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线&#xff0c;使得它们与 x 轴共同构成…

Windows核心编程 跨进程操作

目录 进程A拿到进程B句柄是否能用 句柄的权限 关于句柄表 跨进程使用句柄-继承 CreateProcess&#xff1a;bInheritHandles OpenProcess FindWinodw GetCurrentProcess 跨进程使用句柄-拷贝 跨进程操作内存 WriteProcessMemory VirtualProtectEx ReadProcessMemo…

<蓝桥杯软件赛>零基础备赛20周--第7周--栈和二叉树

报名明年4月蓝桥杯软件赛的同学们&#xff0c;如果你是大一零基础&#xff0c;目前懵懂中&#xff0c;不知该怎么办&#xff0c;可以看看本博客系列&#xff1a;备赛20周合集 20周的完整安排请点击&#xff1a;20周计划 每周发1个博客&#xff0c;共20周&#xff08;读者可以按…