在上两篇文献中,我像大家介绍了 多通道 模型在 AI for Science 任务中的应用。核心思路类似 CV 中,将灰白单通道拓展到 RGB 多通道,能够提升图片表征能力。(见 图神经网络与分子表征:8. TFN)痛点在于张量积计算量太大。(见 张量网络碎碎念:CGC )
只要能解决这一性能瓶颈,我们就能更充分利用到 多通道 模型的强大表征能力。对此,2023-ICML-eSCN 首次振臂疾呼:使用 SO2 代替 SO3 能够大幅缓解这一问题!这篇文献我已经提了很多次了,也被学界广泛采纳,用于改进自己的模型。比如,2024_ICLR_EquiformerV2 以及 DeepH2。
在这一个月的时间里,我硬着头皮看了很多遍 2023-ICML-eSCN,目前自觉能消化 70% 内容,遂写文记之。(后面实在是消化不动了)
为什么张量积如此耗时
好比城市规划中无法避免的交通堵塞问题,多通道模型使用张量积来完成 不同通道内部的图卷积 。常见多通道模型的架构如下,详见图神经网络与分子表征:8. TFN:
在 TFN 以及 NequIP 等经典框架中,张量积求解借助 e3nn 库实现,详见 图神经网络与分子表征:8. TFN 和 张量网络碎碎念:CGC 。示意图如下:
这里我们简单复现下:
from e3nn import o3
irreps1 = o3.Irreps("1x1e")
irreps2 = o3.Irreps("1x1e")
tp = o3.FullTensorProduct(irreps1, irreps2)
input1 = irreps1.randn(-1)
input2 = irreps2.randn(-1)
results = tp(input1, input2)
print(f'Input 1: {input1}\nInput 2: {input2}\nResults: {results}')
可以看到,在 10 行代码内,我们就实现了一个张量积。显然,e3nn 帮我们隐去了大部分细节。事实上,e3nn 内部是使用 CGC 进行张量积计算的。在 张量网络碎碎念:CGC 中,我像大家展示了,如何借助 sympy 部分复现 e3nn 中张量积结果。这让我们能够感受到 CGC 方法的实现过程:
total_angular_momenta = [(0, 0), (1, 1), (1, 0), (1, -1), (2, 0), (2, 2), (2, 1), (2, -1), (2, -2)]
all_cgc_results = []
for idx, (l3, m3) in enumerate(total_angular_momenta):
sub_results = []
for idx1, m1 in enumerate([-1, 1, 0]):
for idx2, m2 in enumerate([-1, 1, 0]):
coefficient = clebsch_gordan_coefficients(l1=1, l2=1, m1=m1, m2=m2, l=l3, m=m3)
a_sub_result = coefficient * input1[0][m1] * input2[0][m2]
sub_results.append(a_sub_result)
sum_sub_results = sum(sub_results)
all_cgc_results.append(round(float(sum_sub_results), 4))
首先,我们对等式左边要求解的量进行了拆分,按照不同的 l 3 l_3 l3, m 3 m_3 m3 的组合拆分出 9 个待求解的量。
其次,在 l 1 l_1 l1, l 2 l_2 l2 固定的情况下,我们遍历 m 1 m_1 m1, m 2 m_2 m2,这样在最内层的循环中,我们可以确定 l 3 l_3 l3, m 3 m_3 m3, l 1 l_1 l1, l 2 l_2 l2, m 1 m_1 m1, m 2 m_2 m2 共 6 个值,在这些值全部确定后,我们代入公式求解 CGC:
coefficient = clebsch_gordan_coefficients(l1=1, l2=1, m1=m1, m2=m2, l=l3, m=m3)
将求得的系数和输入值结合即可得到一个分量:
a_sub_result = coefficient * input1[0][m1] * input2[0][m2]
我们在遍历 m 1 m_1 m1, m 2 m_2 m2 后,将所有分量相加,得到上述 9 个待求解量的其中一个值。
至此,我们能直观感受到,使用 CGC 方法求解,我们需要遍历 ( l 3 l_3 l3, m 3 m_3 m3), ( l 2 l_2 l2, m 2 m_2 m2), ( l 1 l_1 l1, m 1 m_1 m1)。我们知道磁动量的取值范围是由角动量值决定的:
上文所述,多通道模型,指每一个原子的表征,是由高阶张量组合表示的。
例如, L m a x = 2 L_{max}=2 Lmax=2,则 L = 0 , 1 , 2 L=0,1,2 L=0,1,2,所有可能的 ( l l l, m m m) 组合有 ( 0 , 0 ) , ( 1 , 1 ) , ( 1 , 0 ) , ( 1 , − 1 ) , ( 2 , 0 ) , ( 2 , 2 ) , ( 2 , 1 ) , ( 2 , − 1 ) , ( 2 , − 2 ) (0, 0), (1, 1), (1, 0), (1, -1), (2, 0), (2, 2), (2, 1), (2, -1), (2, -2) (0,0),(1,1),(1,0),(1,−1),(2,0),(2,2),(2,1),(2,−1),(2,−2) 共计 9 种。因此,我们说 L m a x = 2 L_{max}=2 Lmax=2 下单原子的表征是 9 维。
稍加推理可得:在最高价张量为 L m a x L_{max} Lmax 情况下,单原子表征为 ( L m a x + 1 ) 2 (L_{max}+1)^2 (Lmax+1)2 维度。
如果使用 CGC 方法求解,我们需要进入 3 层 for 循环,对每一个 ( l l l, m m m) 组合进行遍历,总共需要进行 ( ( L m a x + 1 ) 2 ) 3 = ( L m a x + 1 ) 6 ((L_{max}+1)^2)^3=(L_{max}+1)^6 ((Lmax+1)2)3=(Lmax+1)6 次遍历。在实际计算中,这些遍历全都是矩阵乘积,因此张量积异常耗时,成为交通堵塞中心。
这也是论文 2023-ICML-eSCN 指出传统张量积计算量为 O ( L ) 6 O(L)^6 O(L)6 的原因。
从 SO3 到 SO2 的优化逻辑链条
在开始之前,我想再回顾一下使用 CGC 方法计算张量积的过程:
import os
import matplotlib.pyplot as plt
from e3nn import o3
from sympy import S
from sympy.physics.quantum.cg import CG
def clebsch_gordan_coefficients(l1, m1, l2, m2, l, m):
cg = CG(S(l1), S(m1), S(l2), S(m2), S(l), S(m)).doit()
return float(cg)
irreps1 = o3.Irreps("1x1e")
irreps2 = o3.Irreps("1x1e")
tp = o3.FullTensorProduct(irreps1, irreps2)
print(tp)
tp.visualize()
plt.show()
input1 = irreps1.randn(-1)
input2 = irreps2.randn(-1)
results = tp(input1, input2)
# Possible values of total angular momentum J and its projection M
total_angular_momenta = [(0, 0), (1, 1), (1, 0), (1, -1), (2, 0), (2, 2), (2, 1), (2, -1), (2, -2)]
all_cgc_results = []
all_counter = 0
no_0_counter = 0
for idx, (l3, m3) in enumerate(total_angular_momenta):
sub_results = []
for idx1, m1 in enumerate([-1, 1, 0]):
for idx2, m2 in enumerate([-1, 1, 0]):
coefficient = clebsch_gordan_coefficients(l1=1, l2=1, m1=m1, m2=m2, l=l3, m=m3)
all_counter = all_counter + 1
if coefficient != 0:
no_0_counter = no_0_counter + 1
if m2 == 0:
print(f"Clebsch-Gordan Coefficient for (l1=1, m1={m1}, l2=1, m2={m2}, l={l3}, m={m3}): {coefficient}")
a_sub_result = coefficient * input1[idx1] * input2[idx2]
sub_results.append(a_sub_result)
sum_sub_results = sum(sub_results)
all_cgc_results.append(sum_sub_results)
print('\nThe tp results:\n')
print(results)
print('\nThe cgc results:\n')
print(all_cgc_results)
print(f'All counter: {all_counter}\nNot 0 counter: {no_0_counter }')
这里我对 CGC 非 0 个数进行了统计,并打印出了 m 2 = 0 , C G C ! = 0 m_2=0, CGC!=0 m2=0,CGC!=0 情况下的 CGC 值。
简单统计可以发现:
- 在 81 个 CGC 数值中,仅有 18 个值非零
- 这些非零数值中,存在大量相等或相反的数值
这两点观察构成了从 SO3 到 SO2 的底层逻辑:
- 如何利用 CGC 的稀疏性降低遍历次数
- 如何利用 CGC 的对称性合并同类项
在上例中,我们很容易就观察到了这两点,但实际计算过程中的 CGC 矩阵并没有如此好的性质。对此,作者思路是:结合实际问题,没有稀疏性可以创造稀疏性。我们再回顾下多通道网络与张量积的历史渊源:
一切问题的源头是遍历 3 个 ( l l l, m m m) 组合,作者选择逐个击破之。其中最容易控制的是张量积符号右侧,对应于上式中的 ( l 2 l_2 l2, m 2 m_2 m2) 。
在张量积符号右侧, h h h 是神经网络的位置,是可学习参数,跟遍 ( l l l, m m m) 组合没有关系。 Y Y Y 是我们引入的一个函数,用于确保等变性。具体来说,他是一个球谐函数。我们使用 e3nn 初始化一个表征:
from e3nn import o3
irreps_sh = o3.Irreps.spherical_harmonics(lmax=2)
print(f'Representation: {irreps_sh}')
输出如下:
表征在张量世界里就相当于是一个基组,因此:
我们随机初始化两个原子的位置,得到原子间向量,再嵌入到球谐基组中,得到在球谐基组表示下的一组值:
from e3nn import o3
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = "True"
import matplotlib.pyplot as plt
import numpy as np
import torch
irreps_sh = o3.Irreps.spherical_harmonics(lmax=2)
print(f'Representation: {irreps_sh}')
atom_i_pos = np.random.randn(1, 3)
atom_j_pos = np.random.randn(1, 3)
vec_i_j = atom_i_pos - atom_j_pos
normed_vec_i_j = vec_i_j / np.linalg.norm(vec_i_j)
normed_vec_i_j_tensor = torch.tensor(normed_vec_i_j)
# normed_vec_i_j_tensor = torch.tensor([0.0, 1.0, 0.0])
sh = o3.spherical_harmonics(irreps_sh, normed_vec_i_j_tensor, normalize=True, normalization='component')
print(sh)
可以看到,这组值中共用 9 个分量,分别对应 9 个 ( l l l, m m m) 组合,即:
( 0 , 0 ) , ( 1 , 1 ) , ( 1 , 0 ) , ( 1 , − 1 ) , ( 2 , 0 ) , ( 2 , 2 ) , ( 2 , 1 ) , ( 2 , − 1 ) , ( 2 , − 2 ) (0, 0), (1, 1), (1, 0), (1, -1), (2, 0), (2, 2), (2, 1), (2, -1), (2, -2) (0,0),(1,1),(1,0),(1,−1),(2,0),(2,2),(2,1),(2,−1),(2,−2)
在随机初始化的坐标向量下,映射结果全是实值,但如果我们喂给球谐基组一个特殊的向量,比如: ( 0 , 1 , 0 ) (0, 1, 0) (0,1,0)
映射结果就会出现大量的 0 值:
normed_vec_i_j_tensor = torch.tensor([0.0, 1.0, 0.0])
sh = o3.spherical_harmonics(irreps_sh, normed_vec_i_j_tensor, normalize=True, normalization='component')
print(sh)
非零位置对应于 m = 0 m=0 m=0 的位置:
至此,我们发现了,简化 3 重遍历 ( l l l, m m m) 组合的第一丝曙光。那就是,如果我们喂给球谐基组一个特殊的向量 ( 0 , 1 , 0 ) (0, 1, 0) (0,1,0) ,张量积右侧将不再遍历 m 2 m_2 m2,因为 m 2 = 0 m_2=0 m2=0 才有非零值。
但是,分子、材料中的原子位置是随机的,我们无法控制啊。对此,2023-ICML-eSCN 作者指出:
整个张量网络具有旋转等变性,我们可以先将随机原子间向量旋转到 ( 0 , 1 , 0 ) (0, 1, 0) (0,1,0) ,代入等式计算张量积以后,再对计算结果进行逆旋转。
就好比桌子上有 10 颗随机摆放的鸡蛋,我们先将其摆放成一排,这样 1 发子弹就能打碎 10 颗蛋,再将碎片复原回原来的位置。这在原文里叫 point and shoot ,也是传统计算物理中的一个方法。
原文中公式如下:
注意到,此时,我们并没有展开张量积。如果按照 CGC 的方法,我们仍需要 3 重遍历,只是其中一层无需遍历 m m m:
令人惊喜的是,CGC 的稀疏性和对称性此时发挥了作用。2023-ICML-eSCN 作者指出:
此时,我们仅需要遍历 ( l o , m o ) (l_o,m_o) (lo,mo) 以及 l f l_f lf, l i l_i li,外加对称性,计算量大幅降低。作者推导后指出,使用这种方法能将计算量从 O ( L ) 6 O(L)^6 O(L)6 降低至 O ( L ) 3 O(L)^3 O(L)3。由于稀疏性、对称性过于高,作者完全抛弃了 CGC 的模式,仅对非零值编程处理,这些都在原文里,就不再展示了。
为了方便大家理解行文逻辑,作者展示了几个插图:
总体优化思路:
缩减后的 CGC 矩阵具有对称性:
至此,我已经完全梳理了从 SO3 到 SO2 的优化逻辑链条,但为什么叫 “从 SO3 到 SO2” 呢?
为什么叫 “从 SO3 到 SO2”
在 图神经网络与分子表征:番外——等变术语 中,我向大家介绍了为什么要等变。这里等变通常指 3 维空间内旋转等变。
白话就是:3 维空间中的分子,旋转后,再输入到网络里,结果也会旋转。
如何定义旋转呢?
在原文里,作者指出,我们可以用两个角度定义 3 维空间的旋转,一个角度定义 2 维空间的旋转:
当 3 维旋转其中一个角度固定后,就变成了 2 维空间中的旋转。这跟上述 point-and-shoot 策略有一定的交集。在 point-and-shoot 中,我们先将原子间距离向量进行了转换,固定成了 ( 0 , 1 , 0 ) (0,1,0) (0,1,0)(原文说是 y 轴方向)。这使得原有对称性要求变得宽松了起来,从 SO3 要求变成了 SO2 要求,这也是全文最后升华的地方。