欢迎关注『youcans论文精读』系列
本专栏内容和资源同步到 GitHub/youcans
【youcans论文精读】KAN 2.0:面向科学的KAN网络
- 1. KAN2.0 简介
- 1.1 KAN 2.0 论文发布
- 1.2 KAN2.0 的新特点:
- 1.3 KAN 回顾
- 2. MultiKAN:用乘法增强 KAN 网络的表达能力和可解释性
- 3. Science to KANs
- 3.1 在KANs中添加重要特征
- 3.2 为KAN构建模块化结构
- 3.3 KAN编译器:将符号公式编译成 KAN
- 4. KANs to Science
- 4.1 识别重要特征
- 5. KAN2.0 的应用
- 6. KAN2.0 总结
- 7. KAN2.0 的安装与使用
- 7.1 下载 KAN2.0
- 7.2 安装 KAN2.0
- 7.3 新增的 MultKAN 模块
- 7.4 KAN2.0 官方例程
KAN 2.0 引入了新的功能,如带有乘法节点的MultKAN、将符号公式编译成KAN的编译器kanpiler以及将KAN架构转换为树状图的树转化器,进一步提高了 KAN 架构的可解释性和通用性。
1. KAN2.0 简介
1.1 KAN 2.0 论文发布
2024年 4月,Ziming Liu 等提出了 KAN(Kolmogorov-Arnold Networks) 网络架构。2024年8月,KAN的作者Ziming Liu 等又推出了 KAN2.0:面向科学的KAN网络。
Ziming Liu, Pingchuan Ma, Yixuan Wang, Wojciech Matusik, Max Tegmark, KAN 2.0:Kolmogorov-Arnold Networks Meet Science, Aug 2024
【论文下载地址】:KAN 2.0:Kolmogorov-Arnold Networks Meet Science
【GitHub地址】:Github-pyKAN
AI+Science 的一个主要挑战在于它们固有的不相容性:AI 主要基于联接主义(connectionism),而科学则依赖于符号主义(symbolism)。
为了融合 AI 与科学这两个世界,我们提出了一个框架,使Kolmogorov-Arnold网络(KANs)和科学无缝协同。该框架强调了KAN在科学发现的三个方面的使用:识别相关特征、揭示模块结构和发现符号公式。这种协同作用是双向的:科学对KAN(将科学知识融入KAN),KAN对科学(通过KAN发掘科学规律)。
KAN 2.0版本重点介绍pykan中的主要新功能:
- MultKAN:具有乘法节点的KAN。
- kanpiler:一个将符号公式编译为KAN的KAN编译器。
- 树转换器:将KAN(或任何神经网络)转换为树图。
基于这些工具,我们展示了KAN发现各种物理定律的能力,包括守恒量、拉格朗日量、隐藏对称性和本构方程等重要概念,为经典物理学的研究提供了新的视角。
1.2 KAN2.0 的新特点:
- MultKAN:在KAN中引入乘法节点,增强模型的表达能力。
- kanpiler:将符号公式翻译成KAN的编译器,提高模型的实用性。
- Tree Transformer:将KAN2.0架构转换为树形图,增强模型的可解释性。
KAN2.0在科学发现中的作用主要体现在三个方面:识别重要特征、揭示模块结构、发现符号公式,这些功能在原有KAN的基础上都得到了增强。
KAN2.0的可解释性更加普适,适用于化学、生物等难以用符号方程式表示的领域。用户可以将模块化结构构建到KAN2.0中,并通过与MLP神经元的交换直观地观察模块化结构。
研究团队计划将KAN2.0应用于更大规模的问题,并将其扩展到物理学以外的其他科学学科。
1.3 KAN 回顾
KAN网络(Kolmogorov-Arnold Networks)是一种创新的神经网络架构,它是受Kolmogorov-Arnold表示定理的启发而设计的。
Kolmogorov-Arnold表示定理
- Kolmogorov-Arnold表示定理指出,任何连续的多变量函数都可以表示为有限个单变量连续函数的组合。
- KAN网络通过将激活函数置于网络的边(连接)上,而不是传统的节点上,来实现这一理论。
- KAN中没有任何线性权重,网络中的每个权重都变成了B-spline型单变量函数的可学习参数。
KAN网络的基本原理
- Kolmogorov-Arnold表示定理:任何连续的多变量函数都可以表示为有限个单变量连续函数的组合。
- 网络结构:KAN网络将激活函数置于网络的边(连接)上,而不是传统的节点上。
- 参数化激活函数:KAN网络中的激活函数被参数化为B-spline型单变量函数,这使得每个权重都变成了可学习的参数。
KAN网络的特点
- 高精度与少参数:KAN网络能够以较少的参数量实现比MLP更高的预测准确性,展现出更优的参数效率。
- 可解释性与可视化:KAN网络提供了MLP难以实现的模型可解释性和交互性,有助于科学规律的发现和理解。
- 应用场景多样:从拟合符号公式、特殊函数到偏微分方程求解,乃至避免灾难性遗忘,展示了KAN在多个领域的潜力。
KAN 将激活函数置于网络的边(连接)上,而不是传统的节点上。这种结构变化使得 KAN 在模型精度和可解释性方面显著优于传统的多层感知器(MLP),在数学和物理定律方面具有很大的潜力。
2. MultiKAN:用乘法增强 KAN 网络的表达能力和可解释性
KAN2.0 在最初的 KAN 网络架构的基础上,引入了如图2所示的 MultKAN 乘法节点,其核心改进是引入额外的乘法层进行增强。
Kolmogorov-Arnold表示定理提出,任何连续高维函数都可以分解为单变量连续函数和加法的有限组合。因此,最初的 KAN 网络仅包含加法运算:
然而,考虑到乘法在科学和日常生活中的普遍存在,MultKAN 网络加入了乘法模块,可以更清楚地揭示数据中的乘法结构,以期增强可解释性和表达能力。
如图2所示,MultKAN 和 KAN相似,都包含标准KAN层,但区别在于 MultKAN 插入了乘法节点,对输入的子节点进行乘法运算后再进行恒等变换。MultKAN 由标准 KANLayer
Φ
l
\Phi_l
Φl 和乘法层
M
l
M_l
Ml 组成。
用Python代码可表示为(
⨀
\bigodot
⨀表示逐元素乘法):
根据图2,MultKAN网络进行的运算就可以写作:
如下图所示,对于当乘法任务
f
(
x
,
y
)
=
x
y
f(x,y)=xy
f(x,y)=xy ,MultKAN 确实学会了使用一个乘法节点执行简单的乘法。
我们将 KAN 扩展到 MultiKAN,后续将 KAN 和MultKAN视为同义词,即默认 KAN 都将允许乘法层的存在。
GitHub 仓库中的 KAN 代码(Github-pyKAN)已经更新,可以通过 pip 快捷命令直接安装使用。
3. Science to KANs
在科学领域,领域知识至关重要,让我们可以在数据稀少或不存在的情况下,也能有效工作。因此,对KAN采用基于物理的方法会很有帮助:将可用的领域知识整合到KAN中,同时保持其从数据中发现新物理规律的灵活性。
文中作者探讨了三种可以整合到KAN中的领域知识,从最粗略(最简单/相关性)到最精细(最困难/因果关系):重要特征、模块化结构和符号公式。
3.1 在KANs中添加重要特征
在回归问题中,目标是找到一个函数f,使得y=f(x1, x2, ···, xn)。假设我们希望引入一个辅助输入变量a=a(x1, x2, …, xn),将函数转化为y=f(x1, ···, xn, xa)。
尽管辅助变量a不增加新的信息,但它可以提高神经网络的表达能力。这是因为网络无需消耗资源来计算辅助变量。此外,计算可能变得更简单,从而提升可解释性。
这里,用户可以使用augment_input方法向输入添加辅助特征:
图3显示了包含辅助变量和不包含这些辅助变量的KAN:(a)由符号公式编译而成的KAN,需要5条连接边;(b)(c)包含辅助变量的KAN,仅需2或3条连接边,损失分别为10⁻⁶和10⁻⁴。
3.2 为KAN构建模块化结构
模块化在自然界中非常普遍:比如,人类大脑皮层被划分为几个功能不同的模块,每个模块负责特定任务,如感知或决策。模块化简化了对神经网络的理解,因为它允许我们整体解释神经元群集,而不是单独分析每个神经元。
结构模块化的特点是连接群集,其中特征是群集内的连接远强于群集间的连接。为此,作者引入了module方法:保留群集内的连接,同时去除群集间的连接。
模块由用户来指定,语法是:
model.module(start_layer_id, ‘[nodes_id]->[subnodes_id]->[nodes_id]...’)
具体而言,模块化有两种类型:可分性和对称性。
- 可分性:如果说一个函数是可分的,那么它就可以表示为非重叠变量组的函数的和或积。
- 对称性:如果f(x1, x2, x3, ···)=g(h(x1, x2), x3, ···),则这个函数在变量(x1, x2)上是对称的。因为只要h(x1, x2)保持不变,即使x1和x2发生变化,f的值仍然保持不变。
3.3 KAN编译器:将符号公式编译成 KAN
为了结合「符号方程」和「神经网络」的优势,提出了一个两步程序:(1)将符号方程编译成KAN,(2)使用数据微调这些KAN。第一步可以将已知的领域知识嵌入到KAN中,而第二步则专注于从数据中学习新的「物理」知识。
具体来说,首先提出了用于将符号公式编译成KAN的kanpiler(KAN编译器)。如图5(a)所示,该过程包括三个主要步骤:
(1) 符号公式被解析为树结构,其中节点表示表达式,边表示操作/函数。
(2) 然后修改此树以与KAN图的结构对齐。修改包括通过虚拟边将所有叶子节点移动到输入层,并添加虚拟子节点/节点以匹配KAN架构。这些虚拟边/节点/子节点仅执行身份转换。
(3) 变量在第一层中组合在一起,有效地将树转换为图。
为了便于表达,图中将 1D 曲线放置在边缘上以表示函数。我们在费曼数据集上对 kanpiler 进行了基准测试,它成功地处理了所有120个方程。
示例如图5(b)所示。kanpiler 接受输入变量(作为sympy符号)和输出表达式(作为sympty表达式),并返回一个如下的 KAN模型:
model = kanpiler(input_variables, output_expression)
返回的KAN模型处于符号模式,即符号函数被精确编码。如果我们使用三次样条来近似这些符号函数,我们得到MSE损失 ℓ ∝ N − 8 ℓ \propto N^{−8} ℓ∝N−8,其中N是网格间隔的数量(与模型参数的数量成正比)。
宽度/深度扩展以提高表达能力:由kanpiler生成的KAN网络结构紧凑,没有多余的边缘,这可能会限制其表达能力并阻碍进一步的微调。为了解决这个问题,我们提出了expand_width和expand_depth方法来扩展网络,使其变得更宽、更深。扩展方法最初添加了零激活函数,这些函数在训练过程中会出现零梯度。因此,应使用扰动方法将这些零函数扰动为非零值,使其可以用非零梯度进行训练。
4. KANs to Science
今天的黑箱的深度神经网络功能强大,但解释这些模型仍然具有挑战性。科学家们不仅寻求高性能的模型,还寻求从模型中提取有意义知识的能力。在本节中,我们将重点介绍如何提高KAN科学目的的可解释性。
我们将探讨从KAN中提取知识的三个层次,从最基本到最复杂:重要特征(第4.1节)、模块结构(第4.2节)和符号公式(第4.3节)。
4.1 识别重要特征
给定一个回归模型f,有 y ≈ f ( x 1 , x 2 , … , x n ) y≈f(x1,x2,…,xn) y≈f(x1,x2,…,xn),我们的目标是为输入变量分配重要性分数。
之前所使用的L1范数只考虑到了局部信息。基于 KAN网络,提出了一种更有效的归因分数,能更好反映变量的重要性,还可以根据这种归因分数对网络进行剪枝。
在真实数据集中,输入维度可能很大,但只有少数变量可能是相关的。为了解决这个问题,我们建议根据归因得分修剪掉不相关的特征,这样我们就可以专注于最相关的特征。用户可以应用prune_input来仅保留最相关的变量。
### 4.2 识别模块化结构
归因分数可以告诉我们哪些边或节点更有价值,但它没有揭示模块化结构,即重要的边和节点如何连接。
模块化结构可以分为两种:解剖模块化(anatomical modularity)和功能模块化(functional modularity)。
解剖模块化是指,空间上彼此靠近的神经元相比距离较远的神经元具有更强的连接趋势。论文采用了「神经元交换」方法 auto_swap,可以在保留网络功能的同时缩短连接,有助于识别模块。图7展示了两个成功识别模块的auto_swap任务:多任务匹配和分层多数投票。其中,KAN的模块结构相比MLP更加简单且富有组织性。
解剖模块化不能分析网络全局的模块化结构和整体功能,功能模块化分析通过输入和输出的前向和后向传递来收集有关信息。
图8定义了三种类型的功能模块化:可分性、一般可分性和一般对称性。
图8:检测KANs中的功能模块化。
(a) 我们研究了三种类型的函数模块化:可分离性(加性或乘性)、一般可分离力和对称性。
(b) 递归应用这些测试可以将函数转换为树。这里的函数可以是符号函数(顶部)、KAN(中间)或MLP(底部)。KAN和MLP在训练结束时都会生成正确的树图,但显示出不同的训练动态。
### 4.3 识别符号公式
符号公式是最有信息量的,因为它们可以直接、清楚地揭示重要特征和模块结构。在Liu等人[57]中,作者展示了一系列示例,他们可以从中提取符号公式,并在需要时提供一些先验知识。有了上面提出的新工具(特征重要性、模块化结构和符号公式),用户可以利用这些新工具轻松地与KAN进行交互和协作,使符号回归更容易。
图9展示了与KAN进行交互协作进行符号回归的3个技巧:
1.发现并利用模块化结构
2.稀疏初始化
3.假设检验
5. KAN2.0 的应用
除了进行原理层面的说明,论文还讲解了多个具体案例,如何将KAN融入到现实的科学研究中,比如发现新的物理概念和定律。本文给出的案例包括守恒量、拉格朗日量、隐藏对称性和本构方程等。
案例1. 发现守恒量
守恒量是随时间保持恒定的物理量。守恒量至关重要,因为它们通常对应于物理系统中的对称性,并且可以通过降低系统的维数来简化计算。传统上,用纸和笔推导守恒量可能非常耗时,并且需要广泛的领域知识。
机器学习方法可以将守恒量参数化,转化为求解微分方程的问题。此处所用的方法基本类似于作者Ziming Liu等人2022年发表的论文,但将其中的MLP网络换成了KAN。
本案例的例程,详见:
Physics_2A_conservation_law
Physics_2B_conservation_law_2D
案例2. 发现拉格朗日方程
描述了如何从实验数据中推断出拉格朗日量。
本案例的例程,详见:
Physics_1_Lagrangian
案例3. 发现隐藏的对称性
发现Schwarzschild黑洞中的隐藏对称性。
本案例的例程,详见:
Physics_3_blackhole
案例4. 发现本构定律
本构定律通过模拟材料对外力或变形的响应,定义材料的行为和属性,比如描述弹簧的胡克定律。
本案例的例程,详见:
Physics_4A_constitutive_laws_P11
Physics_4B_constitutive_laws_P12_with_prior
/Physics_4C_constitutive_laws_P12_without_prior
6. KAN2.0 总结
Kolmogorov-Arnold Networks(KANs)与其他神经网络的关键区别在于它们具有更大的可解释性,这允许用户进行操作。KANs网络具有:(1)可学习性(好),使它们能够从数据中学习新事物,以及(2)可解释性降低(坏),随着网络规模的增加,它们变得不那么可解释和可控。
效率提高
最初的pykan包效率很低。我们采用了一些技术来提高它的效率。
- 高效的样条评估。
受Efficient KAN的启发,我们通过避免不必要的输入扩展来优化样条求值。对于具有L层、每层N个神经元和网格大小G的KAN,内存使用量已从O(LN2G)减少到O(LNG)。 - 仅在需要时启用符号分支。
KAN图层包含样条曲线分支和符号分支。符号分支比样条分支耗时得多,因为它不能并行化(需要灾难性的双环)。然而,在许多应用程序中,符号分支是不必要的,因此我们可以在可能的情况下跳过它,从而显著减少运行时间,特别是在网络较大的情况下。 - 仅在需要时保存中间激活。
要绘制KAN图,必须保存中间激活。最初,默认情况下会保存激活,导致运行时间变慢和内存使用过多。我们现在只在需要时保存中间激活(例如,用于在训练中绘制或应用正则化)。用户只需一行代码即可实现这些效率提升:model.speed()。 - GPU加速。
最初,由于问题的小规模性质,所有模型都在CPU上运行。我们现在已经使模型GPU兼容。例如,使用Adam训练 [4100100100,1] 网络100步,在CPU上花费了一整天的时间(在实现1,2,3之前),现在在CPU上需要20秒,而在GPU上不到一秒。然而,KAN在效率方面仍然落后于MLP,尤其是在大规模方面。
当面临1.0(交互式和通用性)和2.0(高效和特定性)之间的权衡时,我们优先考虑交互性和通用性而不是效率。例如,我们将缓存的数据存储在模型中(这会消耗额外的内存),因此用户可以简单地调用model.plot()来生成KAN图,而无需手动进行前向传递来收集数据。
可解释性
尽管KAN中的可学习单变量函数比MLP中的权重矩阵更具可解释性,但可扩展性仍然是一个挑战。
随着KAN模型的扩展,即使所有样条函数都可以单独解释,管理这些1D函数的组合输出也变得越来越困难。因此,只有当网络规模相对较小时,KAN才可能保持可解释性。
值得注意的是,可解释性取决于内在因素(与模型本身相关)和外在因素(与可解释性方法相关)。先进的可解释性方法应该能够在各个层面处理可解释性。例如,通过用符号回归、模块化发现和特征归因来解释KAN,可解释性与规模的帕累托前沿超出了KAN单独可以实现的范围。未来研究的一个有前景的方向是开发更先进的可解释性方法,进一步推动当前的帕累托前沿。
未来的工作
本文介绍了一个将KAN与科学知识相结合的框架,主要关注与物理相关的小规模示例。展望未来,两个有前景的方向包括将这一框架应用于更大规模的问题,并将其扩展到物理学以外的其他科学学科。
7. KAN2.0 的安装与使用
7.1 下载 KAN2.0
【论文下载地址】:KAN 2.0:Kolmogorov-Arnold Networks Meet Science
【GitHub地址】:Github-pyKAN
7.2 安装 KAN2.0
Pykan 可以通过 PyPI 安装,或直接从 GitHub 安装。
通过 Git 安装:
git clone https://github.com/KindXiaoming/pykan.git
cd pykan
pip install -e .
Installation via github
通过 PyPI 安装:
pip install git+https://github.com/KindXiaoming/pykan.git
从 GitHub 安装:
pip install pykan
依赖项的安装
# python==3.9.7
matplotlib==3.6.2
numpy==1.24.4
scikit_learn==1.1.3
setuptools==65.5.0
sympy==1.11.1
torch==2.2.2
tqdm==4.66.2
激活虚拟环境后,可以按如下方式安装所需的 Python包:
pip install -r requirements.txt
也可以通过 Conda环境来安装(可选的):
conda create --name pykan-env python=3.9.7
conda activate pykan-env
pip install git+https://github.com/KindXiaoming/pykan.git # For GitHub installation
# or
pip install pykan # For PyPI installation
7.3 新增的 MultKAN 模块
KAN2.0 增加了 MultKAN模块,通过程序 pykan/kan/MultKAN.py 实现。
# pykan/kan/MultKAN.py
class MultKAN(nn.Module):
'''
KAN class
Attributes:
-----------
grid : int
the number of grid intervals
k : int
spline order
act_fun : a list of KANLayers
symbolic_fun: a list of Symbolic_KANLayer
depth : int
depth of KAN
width : list
number of neurons in each layer.
Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2).
mult_arity : int, or list of int lists
multiplication arity for each multiplication node (the number of numbers to be multiplied)
grid : int
the number of grid intervals
k : int
the order of piecewise polynomial
base_fun : fun
residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
symbolic_fun : a list of Symbolic_KANLayer
Symbolic_KANLayers
symbolic_enabled : bool
If False, the symbolic front is not computed (to save time). Default: True.
width_in : list
The number of input neurons for each layer
width_out : list
The number of output neurons for each layer
base_fun_name : str
The base function b(x)
grip_eps : float
The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)
node_bias : a list of 1D torch.float
node_scale : a list of 1D torch.float
subnode_bias : a list of 1D torch.float
subnode_scale : a list of 1D torch.float
symbolic_enabled : bool
when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)
affine_trainable : bool
indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)
sp_trainable : bool
indicate whether the overall magnitude of splines is trainable
sb_trainable : bool
indicate whether the overall magnitude of base function is trainable
save_act : bool
indicate whether intermediate activations are saved in forward pass
node_scores : None or list of 1D torch.float
node attribution score
edge_scores : None or list of 2D torch.float
edge attribution score
subnode_scores : None or list of 1D torch.float
subnode attribution score
cache_data : None or 2D torch.float
cached input data
acts : None or a list of 2D torch.float
activations on nodes
auto_save : bool
indicate whether to automatically save a checkpoint once the model is modified
state_id : int
the state of the model (used to save checkpoint)
ckpt_path : str
the folder to store checkpoints
round : int
the number of times rewind() has been called
device : str
'''
def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):
'''
initalize a KAN model
Args:
-----
width : list of int
Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs)
grid : int
number of grid intervals. Default: 3.
k : int
order of piecewise polynomial. Default: 3.
mult_arity : int, or list of int lists
multiplication arity for each multiplication node (the number of numbers to be multiplied)
noise_scale : float
initial injected noise to spline.
base_fun : str
the residual function b(x). Default: 'silu'
symbolic_enabled : bool
compute (True) or skip (False) symbolic computations (for efficiency). By default: True.
affine_trainable : bool
affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias
grid_eps : float
When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
grid_range : list/np.array of shape (2,))
setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)
sp_trainable : bool
If true, scale_sp is trainable. Default: True.
sb_trainable : bool
If true, scale_base is trainable. Default: True.
device : str
device
seed : int
random seed
save_act : bool
indicate whether intermediate activations are saved in forward pass
sparse_init : bool
sparse initialization (True) or normal dense initialization. Default: False.
auto_save : bool
indicate whether to automatically save a checkpoint once the model is modified
state_id : int
the state of the model (used to save checkpoint)
ckpt_path : str
the folder to store checkpoints. Default: './model'
round : int
the number of times rewind() has been called
device : str
Returns:
--------
self
Example
-------
>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
checkpoint directory created: ./model
saving model version 0.0
'''
...
def to(self, device):
'''
move the model to device
Args:
-----
device : str or device
Returns:
--------
self
Example
-------
>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
checkpoint directory created: ./model
saving model version 0.0
Example
-------
>>> from kan import *
>>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
>>> model.to(device)
'''
...
7.4 KAN2.0 官方例程
教程中的示例可以在单个CPU上运行,通常不到10分钟。本文中的所有示例都可以在单个CPU上运行。如果任务规模较大,建议使用GPU。
Quickstart: Hello, KAN!
KANs in Action: API Demos, Examples
API (advanced): API
KAN2.0 论文中的相关案例,详见:
KAN2.0_Physics
下载相关案例后,就可以运行 ipynb 程序。
以【案例1:发现守恒量】为例,运行结果如下图/下式。
from kan import *
from kan.utils import batch_jacobian, create_dataset_from_data
import numpy as np
model = KAN(width=[2,1], seed=42)
# the model learns the Hamiltonian H = 1/2 * (x**2 + p**2)
x = torch.rand(1000,2) * 2 - 1
flow = torch.cat([x[:,[1]], -x[:,[0]]], dim=1)
def pred_fn(model, x):
grad = batch_jacobian(model, x, create_graph=True)
grad_normalized = grad/torch.linalg.norm(grad, dim=1, keepdim=True)
return grad_normalized
loss_fn = lambda grad_normalized, flow: torch.mean(torch.sum(flow * grad_normalized, dim=1)**2)
dataset = create_dataset_from_data(x, flow)
model.fit(dataset, steps=20, pred_fn=pred_fn, loss_fn=loss_fn);
'''
checkpoint directory created: ./model
saving model version 0.0
| train_loss: 1.07e-04 | test_loss: 1.17e-04 | reg: 4.12e+00 | : 100%|█| 20/20 [00:01<00:00, 16.52it
saving model version 0.1
'''
model.plot()
'''
运行结果参见下图
'''
model.auto_symbolic()
'''
fixing (0,0,0) with x^2, r2=1.0000003576278687, c=2
fixing (0,1,0) with x^2, r2=1.0000004768371582, c=2
saving model version 0.2
'''
from kan.utils import ex_round
ex_round(model.symbolic_formula()[0][0], 3)
'''
运行结果参见下式
'''
运行结果如下图/下式。
【本节完】
版权声明:
欢迎关注『youcans论文精读』系列
转发请注明原文链接:
【youcans论文精读】KAN 2.0:面向科学的KAN网络
Copyright 2023 youcans, XUPT
Crated:2024-08-23