Contents
- Introduction
- Speculative Decoding
- Standardized Sampling
- Speculative Sampling
- Analysis
- Number of Generated Tokens
- Calculating α \alpha α
- Walltime Improvement
- Number of Arithmetic Operations
- Choosing γ \gamma γ
- Experiments
- References
Introduction
- 为了提升自回归 LLM 的推理速度,作者提出 speculative decoding,用小语言模型去加速 LLM 推理,从而在一个 decoding step 解码出多个 tokens,并且不改变 LLM 的输出 (1. inference from large models is often not bottlenecked on arithmetic operations, but rather on memory bandwidth and communication; 2. computing the logits of a short continuation of K K K tokens in parallel has a very similar latency to that of sampling a single token)
- speculative decoding 能给 T5-XXL (11B) 的推理带来 2X-3X 的加速
Speculative Decoding
- Notation. M p M_p Mp 为要推理加速的 target model, p ( x ) p(x) p(x) 为 prefix 为 x < t x_{<t} x<t 时 M p M_p Mp 输出的概率分布 p ( x t ∣ x < t ) p(x_t|x_{<t}) p(xt∣x<t). M q M_q Mq 为 approximation model, q ( x ) q(x) q(x) 为 prefix 为 x < t x_{<t} x<t 时 M q M_q Mq 输出的概率分布 q ( x t ∣ x < t ) q(x_t|x_{<t}) q(xt∣x<t)
Standardized Sampling
- argmax, top-k, nucleus 等采样方法都可以被转化为从一个调整过的概率分布中依概率采样,例如 argmax sampling 可以看作是将模型输出的概率分布的非最大值全部设为 0,然后从归一化后的概率分布中采样,因此作者下面只考虑从概率分布中依概率采样的情况,但实际上各种采样方法都可以转化为该问题
Speculative Sampling
- Speculative Sampling. 先采样 x ∼ q ( x ) x\sim q(x) x∼q(x),如果 q ( x ) ≤ p ( x ) q(x)\leq p(x) q(x)≤p(x) 则接受该采样,反之以 1 − p ( x ) q ( x ) 1-\frac{p(x)}{q(x)} 1−q(x)p(x) 的概率拒绝该采样,然后重新从概率分布 p ′ ( x ) = n o r m ( m a x ( 0 , p ( x ) − q ( x ) ) ) p'(x)=norm(max(0,p(x)-q(x))) p′(x)=norm(max(0,p(x)−q(x))) 中采样。可以证明,上述采样方法采样得到的 x x x 满足 x ∼ p ( x ) x\sim p(x) x∼p(x) (见 “Correctness of Speculative Sampling”)
- Speculative Decoding Step. 首先
M
q
M_q
Mq 用自回归的方式采样出
γ
\gamma
γ 个 tokens,然后将其连同 prompt 一起送入
M
p
M_p
Mp 从而并行输出
γ
+
1
\gamma+1
γ+1 个 tokens 的
p
(
x
)
p(x)
p(x). 如果
γ
\gamma
γ 个 tokens 都被接受了,则再从
p
γ
+
1
(
x
)
p_{\gamma+1}(x)
pγ+1(x) 中采样出 token
t
t
t. 如果有 token 被拒绝,则对前
γ
\gamma
γ 个 tokens,找到其中最先被拒绝的 token (假设是第
n
+
1
n+1
n+1 个 token),将其重新从调整后的分布
p
′
(
x
)
p'(x)
p′(x) 中采样出 token
t
t
t,接受前
n
n
n 个 tokens 和 token
t
t
t,这样一个 step 能解码出
1
∼
γ
+
1
1 \sim \gamma+1
1∼γ+1 个 tokens
Correctness of Speculative Sampling
Analysis
Number of Generated Tokens
- 首先定义 acceptance rate
β
\beta
β
- 假设
β
\beta
β 独立同分布并记
α
=
E
(
β
)
\alpha=E(\beta)
α=E(β),则 # generated tokens 为 capped geometric variable,成功概率为
1
−
α
1-\alpha
1−α,cap 为
γ
+
1
\gamma+1
γ+1,Expected number of tokens produced by a single run of Algorithm 1 为
Calculating α \alpha α
Corollary 3.6 最后的式子期望里少了求和号?
Walltime Improvement
- 作者假设有足够的计算资源支持 increased concurrency,即 LLM 对 γ + 1 \gamma+1 γ+1 个 tokens 并行验证不会增加 walltime,speculative decoding 带来的额外开销仅为 approximation model M q M_q Mq
- cost coefficient
c
c
c. In our experiments where
M
q
M_q
Mq is typically a couple of orders of magnitude smaller than
M
p
M_p
Mp,
c
c
c was always less than 0.05 and often negligibly close to 0.
- expected improvement factor in total walltime. 假如
c
c
c 忽略不计,则 expected improvement factor 最大可以达到
1
1
−
α
\frac{1}{1-\alpha}
1−α1 (
γ
→
∞
\gamma\rightarrow\infty
γ→∞)
Number of Arithmetic Operations
Choosing γ \gamma γ
- 给定
c
c
c 和
α
\alpha
α 并假设有足够的计算资源,则最优的
γ
\gamma
γ 需要最大化 walltime Improvement factor (i.e., Theorem 3.8). 由于
γ
\gamma
γ 为整数,因此很容易找到数值解
- trade-off between inference speed and the total number of arithmetic operations (assuming
c
=
c
^
=
0
c = \hat c = 0
c=c^=0)
Experiments
- Empirical Walltime Improvement.
M
p
M_p
Mp: T5-XXL (11B).
M
q
M_q
Mq: T5-large (800M), T5-base (250M), and T5-small (77M)
- Theoretical Predictions vs. Empirical Runtimes
- Empirical
α
α
α Values. 可以发现对于所有模型而言,都有标准采样的
α
\alpha
α 低于
arg max
\argmax
argmax 的
α
\alpha
α (the sharper the adjusted distribution, the higher the
α
\alpha
α values.)
References
- Leviathan, Yaniv, Matan Kalman, and Yossi Matias. “Fast inference from transformers via speculative decoding.” International Conference on Machine Learning. PMLR, 2023.