张量网络碎碎念:从 SO3 到 SO2

news2024/9/21 4:28:04

在上两篇文献中,我像大家介绍了 多通道 模型在 AI for Science 任务中的应用。核心思路类似 CV 中,将灰白单通道拓展到 RGB 多通道,能够提升图片表征能力。(见 图神经网络与分子表征:8. TFN)痛点在于张量积计算量太大。(见 张量网络碎碎念:CGC )

只要能解决这一性能瓶颈,我们就能更充分利用到 多通道 模型的强大表征能力。对此,2023-ICML-eSCN 首次振臂疾呼:使用 SO2 代替 SO3 能够大幅缓解这一问题!这篇文献我已经提了很多次了,也被学界广泛采纳,用于改进自己的模型。比如,2024_ICLR_EquiformerV2 以及 DeepH2。

在这一个月的时间里,我硬着头皮看了很多遍 2023-ICML-eSCN,目前自觉能消化 70% 内容,遂写文记之。(后面实在是消化不动了)

为什么张量积如此耗时

好比城市规划中无法避免的交通堵塞问题,多通道模型使用张量积来完成 不同通道内部的图卷积 。常见多通道模型的架构如下,详见图神经网络与分子表征:8. TFN:

请添加图片描述

在 TFN 以及 NequIP 等经典框架中,张量积求解借助 e3nn 库实现,详见 图神经网络与分子表征:8. TFN 和 张量网络碎碎念:CGC 。示意图如下:

请添加图片描述

这里我们简单复现下:

from e3nn import o3

irreps1 = o3.Irreps("1x1e")
irreps2 = o3.Irreps("1x1e")

tp = o3.FullTensorProduct(irreps1, irreps2)
input1 = irreps1.randn(-1)
input2 = irreps2.randn(-1)
results = tp(input1, input2)
print(f'Input 1: {input1}\nInput 2: {input2}\nResults: {results}')

可以看到,在 10 行代码内,我们就实现了一个张量积。显然,e3nn 帮我们隐去了大部分细节。事实上,e3nn 内部是使用 CGC 进行张量积计算的。在 张量网络碎碎念:CGC 中,我像大家展示了,如何借助 sympy 部分复现 e3nn 中张量积结果。这让我们能够感受到 CGC 方法的实现过程:

total_angular_momenta = [(0, 0), (1, 1), (1, 0), (1, -1), (2, 0), (2, 2), (2, 1), (2, -1), (2, -2)]
all_cgc_results = []
for idx, (l3, m3) in enumerate(total_angular_momenta):
    sub_results = []
    for idx1, m1 in enumerate([-1, 1, 0]):
        for idx2, m2 in enumerate([-1, 1, 0]):
            coefficient = clebsch_gordan_coefficients(l1=1, l2=1, m1=m1, m2=m2, l=l3, m=m3)
            a_sub_result = coefficient * input1[0][m1] * input2[0][m2]
            sub_results.append(a_sub_result)
    sum_sub_results = sum(sub_results)
    all_cgc_results.append(round(float(sum_sub_results), 4))

首先,我们对等式左边要求解的量进行了拆分,按照不同的 l 3 l_3 l3, m 3 m_3 m3 的组合拆分出 9 个待求解的量。

其次,在 l 1 l_1 l1, l 2 l_2 l2 固定的情况下,我们遍历 m 1 m_1 m1, m 2 m_2 m2,这样在最内层的循环中,我们可以确定 l 3 l_3 l3, m 3 m_3 m3, l 1 l_1 l1, l 2 l_2 l2, m 1 m_1 m1, m 2 m_2 m2 共 6 个值,在这些值全部确定后,我们代入公式求解 CGC:

            coefficient = clebsch_gordan_coefficients(l1=1, l2=1, m1=m1, m2=m2, l=l3, m=m3)

将求得的系数和输入值结合即可得到一个分量:

            a_sub_result = coefficient * input1[0][m1] * input2[0][m2]

我们在遍历 m 1 m_1 m1, m 2 m_2 m2 后,将所有分量相加,得到上述 9 个待求解量的其中一个值。

至此,我们能直观感受到,使用 CGC 方法求解,我们需要遍历 ( l 3 l_3 l3, m 3 m_3 m3), ( l 2 l_2 l2, m 2 m_2 m2), ( l 1 l_1 l1, m 1 m_1 m1)。我们知道磁动量的取值范围是由角动量值决定的:

在这里插入图片描述

上文所述,多通道模型,指每一个原子的表征,是由高阶张量组合表示的。

例如, L m a x = 2 L_{max}=2 Lmax=2,则 L = 0 , 1 , 2 L=0,1,2 L=0,1,2,所有可能的 ( l l l, m m m) 组合有 ( 0 , 0 ) , ( 1 , 1 ) , ( 1 , 0 ) , ( 1 , − 1 ) , ( 2 , 0 ) , ( 2 , 2 ) , ( 2 , 1 ) , ( 2 , − 1 ) , ( 2 , − 2 ) (0, 0), (1, 1), (1, 0), (1, -1), (2, 0), (2, 2), (2, 1), (2, -1), (2, -2) (0,0),(1,1),(1,0),(1,1),(2,0),(2,2),(2,1),(2,1),(2,2) 共计 9 种。因此,我们说 L m a x = 2 L_{max}=2 Lmax=2 下单原子的表征是 9 维。

稍加推理可得:在最高价张量为 L m a x L_{max} Lmax 情况下,单原子表征为 ( L m a x + 1 ) 2 (L_{max}+1)^2 (Lmax+1)2 维度。

如果使用 CGC 方法求解,我们需要进入 3 层 for 循环,对每一个 ( l l l, m m m) 组合进行遍历,总共需要进行 ( ( L m a x + 1 ) 2 ) 3 = ( L m a x + 1 ) 6 ((L_{max}+1)^2)^3=(L_{max}+1)^6 ((Lmax+1)2)3=(Lmax+1)6 次遍历。在实际计算中,这些遍历全都是矩阵乘积,因此张量积异常耗时,成为交通堵塞中心。

这也是论文 2023-ICML-eSCN 指出传统张量积计算量为 O ( L ) 6 O(L)^6 O(L)6 的原因。

从 SO3 到 SO2 的优化逻辑链条

在开始之前,我想再回顾一下使用 CGC 方法计算张量积的过程:

import os

import matplotlib.pyplot as plt
from e3nn import o3
from sympy import S
from sympy.physics.quantum.cg import CG


def clebsch_gordan_coefficients(l1, m1, l2, m2, l, m):
    cg = CG(S(l1), S(m1), S(l2), S(m2), S(l), S(m)).doit()
    return float(cg)


irreps1 = o3.Irreps("1x1e")
irreps2 = o3.Irreps("1x1e")

tp = o3.FullTensorProduct(irreps1, irreps2)
print(tp)
tp.visualize()
plt.show()

input1 = irreps1.randn(-1)
input2 = irreps2.randn(-1)
results = tp(input1, input2)

# Possible values of total angular momentum J and its projection M
total_angular_momenta = [(0, 0), (1, 1), (1, 0), (1, -1), (2, 0), (2, 2), (2, 1), (2, -1), (2, -2)]
all_cgc_results = []
all_counter = 0
no_0_counter = 0
for idx, (l3, m3) in enumerate(total_angular_momenta):
    sub_results = []
    for idx1, m1 in enumerate([-1, 1, 0]):
        for idx2, m2 in enumerate([-1, 1, 0]):
            coefficient = clebsch_gordan_coefficients(l1=1, l2=1, m1=m1, m2=m2, l=l3, m=m3)
            all_counter = all_counter + 1
            if coefficient != 0:
                no_0_counter = no_0_counter + 1
                if m2 == 0:
                    print(f"Clebsch-Gordan Coefficient for (l1=1, m1={m1}, l2=1, m2={m2}, l={l3}, m={m3}): {coefficient}")
            a_sub_result = coefficient * input1[idx1] * input2[idx2]
            sub_results.append(a_sub_result)
    sum_sub_results = sum(sub_results)
    all_cgc_results.append(sum_sub_results)
print('\nThe tp results:\n')
print(results)
print('\nThe cgc results:\n')
print(all_cgc_results)

print(f'All counter: {all_counter}\nNot 0 counter: {no_0_counter }')

这里我对 CGC 非 0 个数进行了统计,并打印出了 m 2 = 0 , C G C ! = 0 m_2=0, CGC!=0 m2=0,CGC!=0 情况下的 CGC 值。

简单统计可以发现:

  1. 在 81 个 CGC 数值中,仅有 18 个值非零
  2. 这些非零数值中,存在大量相等或相反的数值

请添加图片描述

这两点观察构成了从 SO3 到 SO2 的底层逻辑:

  1. 如何利用 CGC 的稀疏性降低遍历次数
  2. 如何利用 CGC 的对称性合并同类项

在上例中,我们很容易就观察到了这两点,但实际计算过程中的 CGC 矩阵并没有如此好的性质。对此,作者思路是:结合实际问题,没有稀疏性可以创造稀疏性。我们再回顾下多通道网络与张量积的历史渊源:

请添加图片描述

一切问题的源头是遍历 3 个 ( l l l, m m m) 组合,作者选择逐个击破之。其中最容易控制的是张量积符号右侧,对应于上式中的 ( l 2 l_2 l2, m 2 m_2 m2) 。

在张量积符号右侧, h h h 是神经网络的位置,是可学习参数,跟遍 ( l l l, m m m) 组合没有关系。 Y Y Y 是我们引入的一个函数,用于确保等变性。具体来说,他是一个球谐函数。我们使用 e3nn 初始化一个表征:

from e3nn import o3


irreps_sh = o3.Irreps.spherical_harmonics(lmax=2)
print(f'Representation: {irreps_sh}')

输出如下:

请添加图片描述

表征在张量世界里就相当于是一个基组,因此:

请添加图片描述

我们随机初始化两个原子的位置,得到原子间向量,再嵌入到球谐基组中,得到在球谐基组表示下的一组值:

from e3nn import o3
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = "True"

import matplotlib.pyplot as plt
import numpy as np
import torch

irreps_sh = o3.Irreps.spherical_harmonics(lmax=2)
print(f'Representation: {irreps_sh}')

atom_i_pos = np.random.randn(1, 3)
atom_j_pos = np.random.randn(1, 3)

vec_i_j = atom_i_pos - atom_j_pos
normed_vec_i_j = vec_i_j / np.linalg.norm(vec_i_j)
normed_vec_i_j_tensor = torch.tensor(normed_vec_i_j)

# normed_vec_i_j_tensor = torch.tensor([0.0, 1.0, 0.0])

sh = o3.spherical_harmonics(irreps_sh, normed_vec_i_j_tensor, normalize=True, normalization='component')
print(sh)

请添加图片描述

可以看到,这组值中共用 9 个分量,分别对应 9 个 ( l l l, m m m) 组合,即:

( 0 , 0 ) , ( 1 , 1 ) , ( 1 , 0 ) , ( 1 , − 1 ) , ( 2 , 0 ) , ( 2 , 2 ) , ( 2 , 1 ) , ( 2 , − 1 ) , ( 2 , − 2 ) (0, 0), (1, 1), (1, 0), (1, -1), (2, 0), (2, 2), (2, 1), (2, -1), (2, -2) (0,0),(1,1),(1,0),(1,1),(2,0),(2,2),(2,1),(2,1),(2,2)

在随机初始化的坐标向量下,映射结果全是实值,但如果我们喂给球谐基组一个特殊的向量,比如: ( 0 , 1 , 0 ) (0, 1, 0) (0,1,0)

映射结果就会出现大量的 0 值:

normed_vec_i_j_tensor = torch.tensor([0.0, 1.0, 0.0])

sh = o3.spherical_harmonics(irreps_sh, normed_vec_i_j_tensor, normalize=True, normalization='component')
print(sh)

非零位置对应于 m = 0 m=0 m=0 的位置:

请添加图片描述

至此,我们发现了,简化 3 重遍历 ( l l l, m m m) 组合的第一丝曙光。那就是,如果我们喂给球谐基组一个特殊的向量 ( 0 , 1 , 0 ) (0, 1, 0) (0,1,0) ,张量积右侧将不再遍历 m 2 m_2 m2,因为 m 2 = 0 m_2=0 m2=0 才有非零值。

但是,分子、材料中的原子位置是随机的,我们无法控制啊。对此,2023-ICML-eSCN 作者指出:

整个张量网络具有旋转等变性,我们可以先将随机原子间向量旋转到 ( 0 , 1 , 0 ) (0, 1, 0) (0,1,0) ,代入等式计算张量积以后,再对计算结果进行逆旋转。

就好比桌子上有 10 颗随机摆放的鸡蛋,我们先将其摆放成一排,这样 1 发子弹就能打碎 10 颗蛋,再将碎片复原回原来的位置。这在原文里叫 point and shoot ,也是传统计算物理中的一个方法。

原文中公式如下:

请添加图片描述

注意到,此时,我们并没有展开张量积。如果按照 CGC 的方法,我们仍需要 3 重遍历,只是其中一层无需遍历 m m m

请添加图片描述

令人惊喜的是,CGC 的稀疏性和对称性此时发挥了作用。2023-ICML-eSCN 作者指出:

请添加图片描述

此时,我们仅需要遍历 ( l o , m o ) (l_o,m_o) (lo,mo) 以及 l f l_f lf, l i l_i li,外加对称性,计算量大幅降低。作者推导后指出,使用这种方法能将计算量从 O ( L ) 6 O(L)^6 O(L)6 降低至 O ( L ) 3 O(L)^3 O(L)3。由于稀疏性、对称性过于高,作者完全抛弃了 CGC 的模式,仅对非零值编程处理,这些都在原文里,就不再展示了。

为了方便大家理解行文逻辑,作者展示了几个插图:

总体优化思路:

请添加图片描述

缩减后的 CGC 矩阵具有对称性:

请添加图片描述

至此,我已经完全梳理了从 SO3 到 SO2 的优化逻辑链条,但为什么叫 “从 SO3 到 SO2” 呢?

为什么叫 “从 SO3 到 SO2”

在 图神经网络与分子表征:番外——等变术语 中,我向大家介绍了为什么要等变。这里等变通常指 3 维空间内旋转等变。

白话就是:3 维空间中的分子,旋转后,再输入到网络里,结果也会旋转。

如何定义旋转呢?

在原文里,作者指出,我们可以用两个角度定义 3 维空间的旋转,一个角度定义 2 维空间的旋转:

请添加图片描述

当 3 维旋转其中一个角度固定后,就变成了 2 维空间中的旋转。这跟上述 point-and-shoot 策略有一定的交集。在 point-and-shoot 中,我们先将原子间距离向量进行了转换,固定成了 ( 0 , 1 , 0 ) (0,1,0) (0,1,0)(原文说是 y 轴方向)。这使得原有对称性要求变得宽松了起来,从 SO3 要求变成了 SO2 要求,这也是全文最后升华的地方。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1957122.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

昇思25天学习打卡营第4天|基础知识-数据变换 Transforms

目录 数据变换 Transforms 环境 Common Transforms Compose Vision Transforms Rescale Normalize HWC2CHW Text Transforms PythonTokenizer Lookup Lambda Transforms 补充知识 map方法: lambda函数 filter() 函数 数据变换 Transforms 通常情况下&…

新手小白也能做!十大跨境电商平台推荐

梦想着将你的商品推向世界吗?跨境电商不仅是一门生意,更是一场文化和创新的交流。但作为新手,面对众多平台,你可能会有些迷茫。别担心,我们为你精选了几个全球知名的跨境电商平台,从入驻条件到开店成本&…

【windows11禁止自动更新(重启自动更新)】

第一种: 鼠标点击我的电脑点管理,然后就打开了计算机管理,将Windows Update改成禁用,然后确定,重启。 第二种方法: 下载地址:下载地址 下载好之后点击下载的那个应用程序**(点了之后就可以删掉了&#…

Mysql 输出本月初至当前的全部天数

代码&#xff1a; SELECT DATE_FORMAT(DATE_ADD(CONCAT(DATE_FORMAT(CURDATE(), "%Y-%m-01")),INTERVAL (CAST(help_topic_id AS SIGNED)) DAY),%Y-%m-%d) as DATE FROMmysql.help_topic WHERE help_topic_id < TIMESTAMPDIFF(DAY, CONCAT(DATE_FORMAT(CURDATE…

【软件测试】--接口测试

1. 接口用例设计 接口测试的测试点 功能测试 单接口功能&#xff1a; 手工测试中的单个业务模块&#xff0c;一般对应一个接口 登陆业务 --> 登陆接口加入购物车业务 --> 加入购物车接口订单业务 --> 订单接口支付业务 --> 支付接口 借助工具、代码。绕开前端界面…

Mybatis框架基础知识

Mybatis 1.1什么是Mybatis 1.MyBatis 是一款优秀的持久层框架&#xff0c;它支持自定义 SQL、存储过程以及高级映射。 2.MyBatis 免除了几乎所有的 JDBC 代码以及设置参数和获取结果集的工作。 3.MyBatis 可以通过简单的 XML 或注解来配置和映射原始类型、接口和 Java POJO…

Celeron® J1900/N2807/2930 +FPGA PCI104模块,支持 WinXP(无硬件图形加速)操作系统、宽温

Celeron J1900/N2807/2930 PCI104模块 规格产品类型PCI 104 嵌入式主板芯片组SOCCPUIntel Celeron Processor J1900/N2807/2930内存板载2GB DDR3L双通道内存BIOSAMI 显示 L VDS 18/24-bit&#xff0c;VGA L VDS 支持最大分辨率为 1366768&#xff0c;VGA 支持最大分辨率为2048…

【分享】三种有效的PDF转Word技巧

在日常工作和学习中&#xff0c;有时我们需要将PDF文件转换为Word文档&#xff0c;或者转为Word文档后&#xff0c;可以更方便编辑和修改其内容。今天分享3个可以快速转换格式的工具&#xff0c;一起来看看吧&#xff01; 方法1&#xff1a;使用PDF编辑器 PDF编辑器一般具备格…

SpringBoot整合jasypt加密和解密yml配置文件

使用场景 在微服务架构中&#xff0c;配置管理是一个重要的问题。通常&#xff0c;我们会在配置文件中存放一些敏感信息&#xff0c;如数据库连接字符串、API 密钥等。这些敏感信息如果明文存储在配置文件中&#xff0c;存在较大的安全隐患。为了提高安全性&#xff0c;我们需…

虚拟串口下载破解

文章目录 文章介绍下载链接安装教程 文章介绍 下载虚拟串口并破解 下载链接 下载链接 安装教程 安装完成后不要运行&#xff0c;将Cracked中的文件复制 替换文件安装目录中的这两个文件

历史性突破:200层以上存储芯片率先量产,领先国外芯片巨头

中国芯片技术取得了历史性进展&#xff0c;我们已经可以率先实现200层以上存储芯片的量产&#xff0c;走在了国外存储芯片行业前列。 这一突破性的进展&#xff0c;有望使我们在全球芯片行业中占据领先地位。 后起之秀&#xff0c;鱼跃龙门 存储芯片的生产&#xff0c;不仅是…

三菱GX WORKS3中如何使用恒定周期程序实现定时中断?

三菱GX WORKS3中如何使用恒定周期程序实现定时中断? 如下图所示,打开GX WOKRS3新建一个项目,恒定周期项目树下,添加一个程序本体,在程序中编写简单的自加1指令INC, 如下图所示,在参数—CPU参数中找到 中断设置—恒定周期间隔设置,我这里使用的是I31定时中断,所以这里我…

浅入浅出MyBatis(二)简单实现MyBatis底层机制

MyBatis底层机制示意图 mybatis-config.xml mybatis-config.xml 是MyBatis全局配置文件&#xff0c;在项目中只能有一份。通过该配置文件可以得到SqlSessionFactory对象 SqlSessionFactory 通过SqlSessionFactory可以得到SqlSession&#xff0c;拿到SqlSession就可以操作数据…

【C语言】简易版扫雷游戏(数组、函数的练习)

目录 一、分析和设计 1.1、扫雷游戏的功能分析 1.2、文件结构设计&#xff08;多文件的练习&#xff09; 1.3、数据结构的设计 二、代码 三、效果展示 三、优化 一、分析和设计 1.1、扫雷游戏的功能分析 以在线版的扫雷游戏为参考&#xff0c;分析它的功能&#xff1a;扫…

JAVA中的多线程详解

1.概念 进程(Process)&#xff1a; 进程是一个包含自身执行地址的程序&#xff0c;多线程使程序可以同时存在多个执行片段&#xff0c;这些执行片段根据不同的条件和环境同步或者异步工作&#xff0c;由于转换的数独很快&#xff0c;使人感觉上进程像是在同时运行。 现在的计…

Kafka知识总结(事务+数据存储+请求模型+常见场景)

文章收录在网站&#xff1a;http://hardyfish.top/ 文章收录在网站&#xff1a;http://hardyfish.top/ 文章收录在网站&#xff1a;http://hardyfish.top/ 文章收录在网站&#xff1a;http://hardyfish.top/ 事务 事务Producer保证消息写入分区的原子性&#xff0c;即这批消…

nodejs与npm版本对应表

Node.js — Node.js 版本 (nodejs.org)

GB28181国标视频汇聚平台EasyCVR视频管理系统如何更改GIS地图的默认位置?

GB28181国标视频汇聚平台EasyCVR视频管理系统以其强大的拓展性、灵活的部署方式、高性能的视频能力和智能化的分析能力&#xff0c;为各行各业的视频监控需求提供了优秀的解决方案。通过简单的配置和操作流程&#xff0c;用户可以轻松地进行远程视频监控、存储和查看&#xff0…

【Qt开发】No matching signal for on_toolButton_clicked() 解决方案

【Qt开发】No matching signal for on_toolButton_clicked() 解决方案 文章目录 No matching signal for xxx 解决方案附录&#xff1a;C语言到C的入门知识点&#xff08;主要适用于C语言精通到Qt的C开发入门&#xff09;C语言与C的不同C中写C语言代码C语言到C的知识点Qt开发中…