【youcans论文精读】KAN 2.0:面向科学的KAN网络

news2024/9/27 23:31:08

欢迎关注『youcans论文精读』系列
本专栏内容和资源同步到 GitHub/youcans


【youcans论文精读】KAN 2.0:面向科学的KAN网络

    • 1. KAN2.0 简介
      • 1.1 KAN 2.0 论文发布
      • 1.2 KAN2.0 的新特点:
      • 1.3 KAN 回顾
    • 2. MultiKAN:用乘法增强 KAN 网络的表达能力和可解释性
    • 3. Science to KANs
      • 3.1 在KANs中添加重要特征
      • 3.2 为KAN构建模块化结构
      • 3.3 KAN编译器:将符号公式编译成 KAN
    • 4. KANs to Science
      • 4.1 识别重要特征
    • 5. KAN2.0 的应用
    • 6. KAN2.0 总结
    • 7. KAN2.0 的安装与使用
      • 7.1 下载 KAN2.0
      • 7.2 安装 KAN2.0
      • 7.3 新增的 MultKAN 模块
      • 7.4 KAN2.0 官方例程


‌KAN 2.0‌ 引入了新的功能,如带有乘法节点的MultKAN、将符号公式编译成KAN的编译器kanpiler以及将KAN架构转换为树状图的树转化器,进一步提高了 KAN 架构的可解释性和通用性。


1. KAN2.0 简介

1.1 KAN 2.0 论文发布

2024年 4月,Ziming Liu 等提出了 KAN(Kolmogorov-Arnold Networks) 网络架构。2024年8月,KAN的作者Ziming Liu 等又推出了 KAN2.0:面向科学的KAN网络。

Ziming Liu, Pingchuan Ma, Yixuan Wang, Wojciech Matusik, Max Tegmark, KAN 2.0:Kolmogorov-Arnold Networks Meet Science, Aug 2024
【论文下载地址】:KAN 2.0:Kolmogorov-Arnold Networks Meet Science
【GitHub地址】:Github-pyKAN

在这里插入图片描述

AI+Science 的一个主要挑战在于它们固有的不相容性:AI 主要基于联接主义(connectionism),而科学则依赖于符号主义(symbolism)。
为了融合 AI 与科学这两个世界,我们提出了一个框架,使Kolmogorov-Arnold网络(KANs)和科学无缝协同。该框架强调了KAN在科学发现的三个方面的使用:识别相关特征、揭示模块结构和发现符号公式。这种协同作用是双向的:科学对KAN(将科学知识融入KAN),KAN对科学(通过KAN发掘科学规律)。
KAN 2.0版本重点介绍pykan中的主要新功能:

  1. MultKAN:具有乘法节点的KAN。
  2. kanpiler:一个将符号公式编译为KAN的KAN编译器。
  3. 树转换器:将KAN(或任何神经网络)转换为树图。

基于这些工具,我们展示了KAN发现各种物理定律的能力,包括守恒量、拉格朗日量、隐藏对称性和本构方程等重要概念,为经典物理学的研究提供了新的视角。


1.2 KAN2.0 的新特点:

  1. MultKAN:在KAN中引入乘法节点,增强模型的表达能力。
  2. kanpiler:将符号公式翻译成KAN的编译器,提高模型的实用性。
  3. Tree Transformer:将KAN2.0架构转换为树形图,增强模型的可解释性。

在这里插入图片描述

KAN2.0在科学发现中的作用主要体现在三个方面:识别重要特征、揭示模块结构、发现符号公式,这些功能在原有KAN的基础上都得到了增强。

KAN2.0的可解释性更加普适,适用于化学、生物等难以用符号方程式表示的领域。用户可以将模块化结构构建到KAN2.0中,并通过与MLP神经元的交换直观地观察模块化结构。

研究团队计划将KAN2.0应用于更大规模的问题,并将其扩展到物理学以外的其他科学学科。


1.3 KAN 回顾

‌‌KAN网络(‌Kolmogorov-Arnold Networks)‌是一种创新的神经网络架构,它是受Kolmogorov-Arnold表示定理的启发而设计的。

Kolmogorov-Arnold表示定理

  • Kolmogorov-Arnold表示定理指出,任何连续的多变量函数都可以表示为有限个单变量连续函数的组合。
  • KAN网络通过将激活函数置于网络的边(连接)上,而不是传统的节点上,来实现这一理论。
  • KAN中没有任何线性权重,网络中的每个权重都变成了B-spline型单变量函数的可学习参数。

KAN网络的基本原理

  • Kolmogorov-Arnold表示定理‌:任何连续的多变量函数都可以表示为有限个单变量连续函数的组合。
  • 网络结构‌:KAN网络将激活函数置于网络的边(连接)上,而不是传统的节点上。
  • 参数化激活函数‌:KAN网络中的激活函数被参数化为B-spline型单变量函数,这使得每个权重都变成了可学习的参数。

KAN网络的特点

  • 高精度与少参数‌:KAN网络能够以较少的参数量实现比MLP更高的预测准确性,展现出更优的参数效率。
  • 可解释性与可视化‌:KAN网络提供了MLP难以实现的模型可解释性和交互性,有助于科学规律的发现和理解。
  • 应用场景多样‌:从拟合符号公式、特殊函数到偏微分方程求解,乃至避免灾难性遗忘,展示了KAN在多个领域的潜力。

KAN 将激活函数置于网络的边(连接)上,而不是传统的节点上。这种结构变化使得 KAN 在模型精度和可解释性方面显著优于传统的多层感知器(MLP),在数学和物理定律方面具有很大的潜力。


2. MultiKAN:用乘法增强 KAN 网络的表达能力和可解释性

KAN2.0 在最初的 KAN 网络架构的基础上,引入了如图2所示的 MultKAN 乘法节点,其核心改进是引入额外的乘法层进行增强。

在这里插入图片描述

Kolmogorov-Arnold表示定理提出,任何连续高维函数都可以分解为单变量连续函数和加法的有限组合。因此,最初的 KAN 网络仅包含加法运算:

在这里插入图片描述
然而,考虑到乘法在科学和日常生活中的普遍存在,MultKAN 网络加入了乘法模块,可以更清楚地揭示数据中的乘法结构,以期增强可解释性和表达能力。

如图2所示,MultKAN 和 KAN相似,都包含标准KAN层,但区别在于 MultKAN 插入了乘法节点,对输入的子节点进行乘法运算后再进行恒等变换。MultKAN 由标准 KANLayer Φ l \Phi_l Φl 和乘法层 M l M_l Ml 组成。
用Python代码可表示为( ⨀ \bigodot 表示逐元素乘法):
在这里插入图片描述

根据图2,MultKAN网络进行的运算就可以写作:

在这里插入图片描述

如下图所示,对于当乘法任务 f ( x , y ) = x y f(x,y)=xy f(x,y)=xy ,MultKAN 确实学会了使用一个乘法节点执行简单的乘法。
在这里插入图片描述

我们将 KAN 扩展到 MultiKAN,后续将 KAN 和MultKAN视为同义词,即默认 KAN 都将允许乘法层的存在。

GitHub 仓库中的 KAN 代码(Github-pyKAN)已经更新,可以通过 pip 快捷命令直接安装使用。


3. Science to KANs

在科学领域,领域知识至关重要,让我们可以在数据稀少或不存在的情况下,也能有效工作。因此,对KAN采用基于物理的方法会很有帮助:将可用的领域知识整合到KAN中,同时保持其从数据中发现新物理规律的灵活性。

文中作者探讨了三种可以整合到KAN中的领域知识,从最粗略(最简单/相关性)到最精细(最困难/因果关系):重要特征、模块化结构和符号公式。


3.1 在KANs中添加重要特征

在回归问题中,目标是找到一个函数f,使得y=f(x1, x2, ···, xn)。假设我们希望引入一个辅助输入变量a=a(x1, x2, …, xn),将函数转化为y=f(x1, ···, xn, xa)。

尽管辅助变量a不增加新的信息,但它可以提高神经网络的表达能力。这是因为网络无需消耗资源来计算辅助变量。此外,计算可能变得更简单,从而提升可解释性。

这里,用户可以使用augment_input方法向输入添加辅助特征:

图3显示了包含辅助变量和不包含这些辅助变量的KAN:(a)由符号公式编译而成的KAN,需要5条连接边;(b)(c)包含辅助变量的KAN,仅需2或3条连接边,损失分别为10⁻⁶和10⁻⁴。


3.2 为KAN构建模块化结构

模块化在自然界中非常普遍:比如,人类大脑皮层被划分为几个功能不同的模块,每个模块负责特定任务,如感知或决策。模块化简化了对神经网络的理解,因为它允许我们整体解释神经元群集,而不是单独分析每个神经元。

结构模块化的特点是连接群集,其中特征是群集内的连接远强于群集间的连接。为此,作者引入了module方法:保留群集内的连接,同时去除群集间的连接。

模块由用户来指定,语法是:

model.module(start_layer_id,[nodes_id]->[subnodes_id]->[nodes_id]...)

具体而言,模块化有两种类型:可分性和对称性。

  • 可分性:如果说一个函数是可分的,那么它就可以表示为非重叠变量组的函数的和或积。
  • 对称性:如果f(x1, x2, x3, ···)=g(h(x1, x2), x3, ···),则这个函数在变量(x1, x2)上是对称的。因为只要h(x1, x2)保持不变,即使x1和x2发生变化,f的值仍然保持不变。

3.3 KAN编译器:将符号公式编译成 KAN

为了结合「符号方程」和「神经网络」的优势,提出了一个两步程序:(1)将符号方程编译成KAN,(2)使用数据微调这些KAN。第一步可以将已知的领域知识嵌入到KAN中,而第二步则专注于从数据中学习新的「物理」知识。

在这里插入图片描述

具体来说,首先提出了用于将符号公式编译成KAN的kanpiler(KAN编译器)。如图5(a)所示,该过程包括三个主要步骤:
(1) 符号公式被解析为树结构,其中节点表示表达式,边表示操作/函数。
(2) 然后修改此树以与KAN图的结构对齐。修改包括通过虚拟边将所有叶子节点移动到输入层,并添加虚拟子节点/节点以匹配KAN架构。这些虚拟边/节点/子节点仅执行身份转换。
(3) 变量在第一层中组合在一起,有效地将树转换为图。

为了便于表达,图中将 1D 曲线放置在边缘上以表示函数。我们在费曼数据集上对 kanpiler 进行了基准测试,它成功地处理了所有120个方程。

示例如图5(b)所示。kanpiler 接受输入变量(作为sympy符号)和输出表达式(作为sympty表达式),并返回一个如下的 KAN模型:

model = kanpiler(input_variables, output_expression)

返回的KAN模型处于符号模式,即符号函数被精确编码。如果我们使用三次样条来近似这些符号函数,我们得到MSE损失 ℓ ∝ N − 8 ℓ \propto N^{−8} N8,其中N是网格间隔的数量(与模型参数的数量成正比)。

宽度/深度扩展以提高表达能力:由kanpiler生成的KAN网络结构紧凑,没有多余的边缘,这可能会限制其表达能力并阻碍进一步的微调。为了解决这个问题,我们提出了expand_width和expand_depth方法来扩展网络,使其变得更宽、更深。扩展方法最初添加了零激活函数,这些函数在训练过程中会出现零梯度。因此,应使用扰动方法将这些零函数扰动为非零值,使其可以用非零梯度进行训练。


4. KANs to Science

今天的黑箱的深度神经网络功能强大,但解释这些模型仍然具有挑战性。科学家们不仅寻求高性能的模型,还寻求从模型中提取有意义知识的能力。在本节中,我们将重点介绍如何提高KAN科学目的的可解释性。
我们将探讨从KAN中提取知识的三个层次,从最基本到最复杂:重要特征(第4.1节)、模块结构(第4.2节)和符号公式(第4.3节)。

4.1 识别重要特征

给定一个回归模型f,有 y ≈ f ⁢ ( x 1 , x 2 , … , x n ) y≈f⁢(x1,x2,…,xn) yf(x1,x2,,xn),我们的目标是为输入变量分配重要性分数。

之前所使用的L1范数只考虑到了局部信息。基于 KAN网络,提出了一种更有效的归因分数,能更好反映变量的重要性,还可以根据这种归因分数对网络进行剪枝。

在真实数据集中,输入维度可能很大,但只有少数变量可能是相关的。为了解决这个问题,我们建议根据归因得分修剪掉不相关的特征,这样我们就可以专注于最相关的特征。用户可以应用prune_input来仅保留最相关的变量。


### 4.2 识别模块化结构

归因分数可以告诉我们哪些边或节点更有价值,但它没有揭示模块化结构,即重要的边和节点如何连接。

模块化结构可以分为两种:解剖模块化(anatomical modularity)和功能模块化(functional modularity)。

解剖模块化是指,空间上彼此靠近的神经元相比距离较远的神经元具有更强的连接趋势。论文采用了「神经元交换」方法 auto_swap,可以在保留网络功能的同时缩短连接,有助于识别模块。图7展示了两个成功识别模块的auto_swap任务:多任务匹配和分层多数投票。其中,KAN的模块结构相比MLP更加简单且富有组织性。

解剖模块化不能分析网络全局的模块化结构和整体功能,功能模块化分析通过输入和输出的前向和后向传递来收集有关信息。
图8定义了三种类型的功能模块化:可分性、一般可分性和一般对称性。
在这里插入图片描述
图8:检测KANs中的功能模块化。
(a) 我们研究了三种类型的函数模块化:可分离性(加性或乘性)、一般可分离力和对称性。
(b) 递归应用这些测试可以将函数转换为树。这里的函数可以是符号函数(顶部)、KAN(中间)或MLP(底部)。KAN和MLP在训练结束时都会生成正确的树图,但显示出不同的训练动态。


### 4.3 识别符号公式

符号公式是最有信息量的,因为它们可以直接、清楚地揭示重要特征和模块结构。在Liu等人[57]中,作者展示了一系列示例,他们可以从中提取符号公式,并在需要时提供一些先验知识。有了上面提出的新工具(特征重要性、模块化结构和符号公式),用户可以利用这些新工具轻松地与KAN进行交互和协作,使符号回归更容易。

图9展示了与KAN进行交互协作进行符号回归的3个技巧:
1.发现并利用模块化结构
2.稀疏初始化
3.假设检验

在这里插入图片描述


5. KAN2.0 的应用

除了进行原理层面的说明,论文还讲解了多个具体案例,如何将KAN融入到现实的科学研究中,比如发现新的物理概念和定律。本文给出的案例包括守恒量、拉格朗日量、隐藏对称性和本构方程等。

案例1. 发现守恒量

守恒量是随时间保持恒定的物理量。守恒量至关重要,因为它们通常对应于物理系统中的对称性,并且可以通过降低系统的维数来简化计算。传统上,用纸和笔推导守恒量可能非常耗时,并且需要广泛的领域知识。

机器学习方法可以将守恒量参数化,转化为求解微分方程的问题。此处所用的方法基本类似于作者Ziming Liu等人2022年发表的论文,但将其中的MLP网络换成了KAN。

在这里插入图片描述

本案例的例程,详见:
Physics_2A_conservation_law
Physics_2B_conservation_law_2D


案例2. 发现拉格朗日方程
描述了如何从实验数据中推断出拉格朗日量。

本案例的例程,详见:
Physics_1_Lagrangian

案例3. 发现隐藏的对称性
发现Schwarzschild黑洞中的隐藏对称性。

本案例的例程,详见:
Physics_3_blackhole

案例4. 发现本构定律
本构定律通过模拟材料对外力或变形的响应,定义材料的行为和属性,比如描述弹簧的胡克定律。

本案例的例程,详见:
Physics_4A_constitutive_laws_P11
Physics_4B_constitutive_laws_P12_with_prior
/Physics_4C_constitutive_laws_P12_without_prior


6. KAN2.0 总结

Kolmogorov-Arnold Networks(KANs)与其他神经网络的关键区别在于它们具有更大的可解释性,这允许用户进行操作。KANs网络具有:(1)可学习性(好),使它们能够从数据中学习新事物,以及(2)可解释性降低(坏),随着网络规模的增加,它们变得不那么可解释和可控。

效率提高
最初的pykan包效率很低。我们采用了一些技术来提高它的效率。

  1. 高效的样条评估。
    受Efficient KAN的启发,我们通过避免不必要的输入扩展来优化样条求值。对于具有L层、每层N个神经元和网格大小G的KAN,内存使用量已从O(LN2G)减少到O(LNG)。
  2. 仅在需要时启用符号分支。
    KAN图层包含样条曲线分支和符号分支。符号分支比样条分支耗时得多,因为它不能并行化(需要灾难性的双环)。然而,在许多应用程序中,符号分支是不必要的,因此我们可以在可能的情况下跳过它,从而显著减少运行时间,特别是在网络较大的情况下。
  3. 仅在需要时保存中间激活。
    要绘制KAN图,必须保存中间激活。最初,默认情况下会保存激活,导致运行时间变慢和内存使用过多。我们现在只在需要时保存中间激活(例如,用于在训练中绘制或应用正则化)。用户只需一行代码即可实现这些效率提升:model.speed()。
  4. GPU加速。
    最初,由于问题的小规模性质,所有模型都在CPU上运行。我们现在已经使模型GPU兼容。例如,使用Adam训练 [4100100100,1] 网络100步,在CPU上花费了一整天的时间(在实现1,2,3之前),现在在CPU上需要20秒,而在GPU上不到一秒。然而,KAN在效率方面仍然落后于MLP,尤其是在大规模方面。
    当面临1.0(交互式和通用性)和2.0(高效和特定性)之间的权衡时,我们优先考虑交互性和通用性而不是效率。例如,我们将缓存的数据存储在模型中(这会消耗额外的内存),因此用户可以简单地调用model.plot()来生成KAN图,而无需手动进行前向传递来收集数据。

可解释性
尽管KAN中的可学习单变量函数比MLP中的权重矩阵更具可解释性,但可扩展性仍然是一个挑战。
随着KAN模型的扩展,即使所有样条函数都可以单独解释,管理这些1D函数的组合输出也变得越来越困难。因此,只有当网络规模相对较小时,KAN才可能保持可解释性。
值得注意的是,可解释性取决于内在因素(与模型本身相关)和外在因素(与可解释性方法相关)。先进的可解释性方法应该能够在各个层面处理可解释性。例如,通过用符号回归、模块化发现和特征归因来解释KAN,可解释性与规模的帕累托前沿超出了KAN单独可以实现的范围。未来研究的一个有前景的方向是开发更先进的可解释性方法,进一步推动当前的帕累托前沿。

未来的工作
本文介绍了一个将KAN与科学知识相结合的框架,主要关注与物理相关的小规模示例。展望未来,两个有前景的方向包括将这一框架应用于更大规模的问题,并将其扩展到物理学以外的其他科学学科。


7. KAN2.0 的安装与使用

7.1 下载 KAN2.0

【论文下载地址】:KAN 2.0:Kolmogorov-Arnold Networks Meet Science
【GitHub地址】:Github-pyKAN


7.2 安装 KAN2.0

Pykan 可以通过 PyPI 安装,或直接从 GitHub 安装。

通过 Git 安装:

git clone https://github.com/KindXiaoming/pykan.git
cd pykan
pip install -e .
Installation via github

通过 PyPI 安装:

pip install git+https://github.com/KindXiaoming/pykan.git

从 GitHub 安装:

pip install pykan

依赖项的安装

# python==3.9.7
matplotlib==3.6.2
numpy==1.24.4
scikit_learn==1.1.3
setuptools==65.5.0
sympy==1.11.1
torch==2.2.2
tqdm==4.66.2

激活虚拟环境后,可以按如下方式安装所需的 Python包:

pip install -r requirements.txt

也可以通过 Conda环境来安装(可选的):

conda create --name pykan-env python=3.9.7
conda activate pykan-env
pip install git+https://github.com/KindXiaoming/pykan.git  # For GitHub installation
# or
pip install pykan  # For PyPI installation

7.3 新增的 MultKAN 模块

KAN2.0 增加了 MultKAN模块,通过程序 pykan/kan/MultKAN.py 实现。

# pykan/kan/MultKAN.py
class MultKAN(nn.Module):
    '''
    KAN class
        Attributes:
    -----------
        grid : int
            the number of grid intervals
        k : int
            spline order
        act_fun : a list of KANLayers
        symbolic_fun: a list of Symbolic_KANLayer
        depth : int
            depth of KAN
        width : list
            number of neurons in each layer.
            Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
            With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2). 
        mult_arity : int, or list of int lists
            multiplication arity for each multiplication node (the number of numbers to be multiplied)
        grid : int
            the number of grid intervals
        k : int
            the order of piecewise polynomial
        base_fun : fun
            residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
        symbolic_fun : a list of Symbolic_KANLayer
            Symbolic_KANLayers
        symbolic_enabled : bool
            If False, the symbolic front is not computed (to save time). Default: True.
        width_in : list
            The number of input neurons for each layer
        width_out : list
            The number of output neurons for each layer
        base_fun_name : str
            The base function b(x)
        grip_eps : float
            The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)
        node_bias : a list of 1D torch.float
        node_scale : a list of 1D torch.float
        subnode_bias : a list of 1D torch.float
        subnode_scale : a list of 1D torch.float
        symbolic_enabled : bool
            when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)
        affine_trainable : bool
            indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)
        sp_trainable : bool
            indicate whether the overall magnitude of splines is trainable
        sb_trainable : bool
            indicate whether the overall magnitude of base function is trainable
        save_act : bool
            indicate whether intermediate activations are saved in forward pass
        node_scores : None or list of 1D torch.float
            node attribution score
        edge_scores : None or list of 2D torch.float
            edge attribution score
        subnode_scores : None or list of 1D torch.float
            subnode attribution score
        cache_data : None or 2D torch.float
            cached input data
        acts : None or a list of 2D torch.float
            activations on nodes
        auto_save : bool
            indicate whether to automatically save a checkpoint once the model is modified
        state_id : int
            the state of the model (used to save checkpoint)
        ckpt_path : str
            the folder to store checkpoints
        round : int
            the number of times rewind() has been called
        device : str
    '''
    def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):
        '''
        initalize a KAN model
        
        Args:
        -----
            width : list of int
                Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
                With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs)
            grid : int
                number of grid intervals. Default: 3.
            k : int
                order of piecewise polynomial. Default: 3.
            mult_arity : int, or list of int lists
                multiplication arity for each multiplication node (the number of numbers to be multiplied)
            noise_scale : float
                initial injected noise to spline.
            base_fun : str
                the residual function b(x). Default: 'silu'
            symbolic_enabled : bool
                compute (True) or skip (False) symbolic computations (for efficiency). By default: True. 
            affine_trainable : bool
                affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias
            grid_eps : float
                When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
            grid_range : list/np.array of shape (2,))
                setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)
            sp_trainable : bool
                If true, scale_sp is trainable. Default: True.
            sb_trainable : bool
                If true, scale_base is trainable. Default: True.
            device : str
                device
            seed : int
                random seed
            save_act : bool
                indicate whether intermediate activations are saved in forward pass
            sparse_init : bool
                sparse initialization (True) or normal dense initialization. Default: False.
            auto_save : bool
                indicate whether to automatically save a checkpoint once the model is modified
            state_id : int
                the state of the model (used to save checkpoint)
            ckpt_path : str
                the folder to store checkpoints. Default: './model'
            round : int
                the number of times rewind() has been called
            device : str
            
        Returns:
        --------
            self
            
        Example
        -------
        >>> from kan import *
        >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
        checkpoint directory created: ./model
        saving model version 0.0
        '''
        ...
    def to(self, device):
        '''
        move the model to device
        
        Args:
        -----
            device : str or device

        Returns:
        --------
            self
        Example
        -------
        >>> from kan import *
        >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
        checkpoint directory created: ./model
        saving model version 0.0

        Example
        -------
        >>> from kan import *
        >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
        >>> model.to(device)
        '''
        ...

7.4 KAN2.0 官方例程

教程中的示例可以在单个CPU上运行,通常不到10分钟。本文中的所有示例都可以在单个CPU上运行。如果任务规模较大,建议使用GPU。

Quickstart: Hello, KAN!

KANs in Action: API Demos, Examples

API (advanced): API

KAN2.0 论文中的相关案例,详见:
KAN2.0_Physics

下载相关案例后,就可以运行 ipynb 程序。

以【案例1:发现守恒量】为例,运行结果如下图/下式。

from kan import *
from kan.utils import batch_jacobian, create_dataset_from_data
import numpy as np

model = KAN(width=[2,1], seed=42)

# the model learns the Hamiltonian H = 1/2 * (x**2 + p**2)
x = torch.rand(1000,2) * 2 - 1
flow = torch.cat([x[:,[1]], -x[:,[0]]], dim=1)

def pred_fn(model, x):
    grad = batch_jacobian(model, x, create_graph=True)
    grad_normalized = grad/torch.linalg.norm(grad, dim=1, keepdim=True)
    return grad_normalized

loss_fn = lambda grad_normalized, flow: torch.mean(torch.sum(flow * grad_normalized, dim=1)**2)

dataset = create_dataset_from_data(x, flow)
model.fit(dataset, steps=20, pred_fn=pred_fn, loss_fn=loss_fn);

'''
checkpoint directory created: ./model
saving model version 0.0
| train_loss: 1.07e-04 | test_loss: 1.17e-04 | reg: 4.12e+00 | : 100%|█| 20/20 [00:01<00:00, 16.52it
saving model version 0.1
'''

model.plot()
'''
运行结果参见下图
'''

model.auto_symbolic()

'''
fixing (0,0,0) with x^2, r2=1.0000003576278687, c=2
fixing (0,1,0) with x^2, r2=1.0000004768371582, c=2
saving model version 0.2
'''

from kan.utils import ex_round
ex_round(model.symbolic_formula()[0][0], 3)
'''
运行结果参见下式
'''

运行结果如下图/下式。

在这里插入图片描述
在这里插入图片描述


【本节完】


版权声明:
欢迎关注『youcans论文精读』系列
转发请注明原文链接:
【youcans论文精读】KAN 2.0:面向科学的KAN网络
Copyright 2023 youcans, XUPT
Crated:2024-08-23


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2067093.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

足底筋膜炎专用药

足底筋膜炎专用药“古顺*敷堂筋膜*贴”通过其独特的药效和用法&#xff0c;能够针对足底筋膜炎进行有效治疗&#xff0c;缓解患者疼痛和不适感&#xff0c;促进炎症消退和肌肉恢复。长时间站立、行走或进行高强度的跑步、跳跃等活动&#xff0c;会使足底筋膜受到持续的牵拉和压…

MEMS 传感器 4GDTU 说明书

本系统经过精心设计&#xff0c;可无缝对接三石峰的振动管理系统平台。通过该平台&#xff0c;用户可直观查看传感器数据、分析振动趋势、预警潜在故障&#xff0c;并依据分析结果制定针对性的维护策略&#xff0c;从而有效提升设备运行的可靠性与安全性。 本产品广泛应用于工…

日常开发规范

日常开发规范 一.git提交规范 开发代码之前&#xff0c;需有管理员通过系统新建功能分支&#xff0c;如feature/one&#xff0c; 此时开发人员方可拉取feature/one到本地进行开发&#xff0c; 开发人员在本地环境测试稳定后&#xff0c;方可由管理员通过系统发布到开发环境…

宠物空气净化器不是智商税!希喂、352宠物空气净化器真实测评

前端时间我出差了&#xff0c;把小猫寄养在朋友家里&#xff0c;回来后去接它们&#xff0c;结果到朋友家差点没认出来...碰上换毛季猫咪疯狂脱毛&#xff0c;朋友没有及时清理&#xff0c;就全堆在身上了&#xff0c;简直是胖若两猫。到家后&#xff0c;我连忙用梳子把它身上的…

Wi-Fi发射功率简介

目录 一、概念 1.1 射频发射与组合功率 1.2 天线增益 1.3 信道影响 二、常用单位及转换 2.1 dB 与 dBm 2.2 dBi 与 dBd 三、发射功率 3.1 发射功率调节 3.1.1 TPC 3.2 国家码与信道功率 一、概念 ① 和 ⑦ 表示射频发送端处的功率,单位是 dBm。其中 ① 表示AP端的…

《通义千问AI落地—下》:WebSocket详解

一、前言 文本源自 微博客 且已获授权,请尊重版权。 《通义千问AI落地——下篇》如约而至。Websocket在这一类引用中,起到前后端通信的作用。因此,本文将介绍websocket在这类应用场景下的配置、使用、注意事项以及ws连接升级为wss连接等;如下图,本站已经使用了wss连接…

ssrf,csrf漏洞复现

印象深刻的csrf利用&#xff1a; 在phpwind下&#xff1a;漏洞点&#xff08;但是都是在后台的漏洞&#xff09; 代码追&#xff1a; task到unserialize&#xff0c;然后重写PwDelayRun的构造函数&#xff0c;给callback和args赋值&#xff0c;然后当程序执行结束&#xff0c…

请问lammps怎么做两种金属连接的原子浓度分布图??

&#x1f3c6;本文收录于《CSDN问答解惑-专业版》专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收…

未知单播泛洪原因

未知单播&#xff1a;交换机是收到数据包后&#xff0c;读取数据包的目的MAC&#xff0c;并查找自已的MAC表&#xff0c;查找目的MAC对应的端口&#xff0c;从而判断从哪个口端口转发出此数据包&#xff0c;若MAC表里没有此目的MAC&#xff0c;那对于此交换机来说就是未知单播&…

Day46 | 101孤岛的总面积 102沉没孤岛 103水流问题 104建造最大岛屿

语言 Java 101.孤岛的总面积 101. 孤岛的总面积 题目 题目描述 给定一个由 1&#xff08;陆地&#xff09;和 0&#xff08;水&#xff09;组成的矩阵&#xff0c;岛屿指的是由水平或垂直方向上相邻的陆地单元格组成的区域&#xff0c;且完全被水域单元格包围。孤岛是那些…

植物大战僵尸杂交版v2.3.7最新版本(附下载链接)

新版本更新啦&#xff01; B站游戏作者潜艇伟伟迷于8月19日更新了植物大战僵尸杂交版2.3.7版本&#xff01;&#xff01;&#xff01; v2.3.7版本更新内容&#xff1a; 游戏分辨率扩充&#xff0c;UI界面翻新&#xff0c;卡槽数量提升至16个&#xff0c;修复大量BUG&#xff0c…

网络协议与IO模型

1、说一说网络模型&#xff08;OSI、TCP/IP模型&#xff09; OSI采用了分层的结构化技术&#xff0c;共分七层&#xff0c; 物理层、数据链路层、网络层、传输层、会话层、表示层、应用层 。 Open System Interconnect 简称OSI&#xff0c;是国际标准化组织(ISO)和国际电报电…

【Windows脚本】如何测试远程主机某个端口是否开放?

概要 如何测试远程主机某个端口是否开放&#xff1f; 1、PowerShell脚本 使用Test-NetConnection 指令&#xff0c;命令如下。 Test-NetConnection RemoteIP -Port 80 -InformationLevel Detailed 2、tcping工具 下载地址&#xff1a;https://download.csdn.net/download/…

工具(1)查看YUV 图

#灵感# 没啥灵感&#xff0c;就是脑子越来越健忘&#xff0c;就啥都记一笔。 工具名字&#xff1a;YUVPlayer 操作流程&#xff1a; 1、打开YUVPlayer, 把YUV文件拖进来。 2、如果拖进来失败&#xff0c;需要先设置属性, 尤其是YUV类型。 3、成功打开图片后&#xff0c;如…

Linux批量验证代理IP的实用方法

在网络管理和优化过程中&#xff0c;批量验证代理IP的有效性是一个常见需求。无论是为了提高网络访问速度&#xff0c;还是为了确保代理IP的可用性&#xff0c;批量验证代理IP都是一项重要的任务。本文将详细介绍如何在Linux环境下批量验证代理IP&#xff0c;帮助你高效地管理和…

短剧小程序源码2023 短剧影视付费查看小视频会员收益系统全开源

本文来自&#xff1a;短剧小程序源码2023 短剧影视付费查看小视频会员收益系统全开源 - 源码1688 应用介绍 演示后台&#xff1a;http://duan.hengchuang.top/VwmRIfEYDH.php 后台账号&#xff1a;admin 后台密码&#xff1a;123456 功能介绍&#xff1a; 1&#xff0c;内容…

《白蛇:浮生》后劲不足,国漫败走2024暑期档

截止到8月19日中午&#xff0c;上映10天的动画电影《白蛇&#xff1a;浮生》票房终于突破3亿。 客观来说&#xff0c;3亿票房在今年暑期档不算差&#xff0c;但对于上映首日就拿到1.29亿票房的《白蛇&#xff1a;浮生》而言&#xff0c;后期票房走势确实没有达到预期&#xff…

4 nesjs IOC控制反转 DI依赖注入

在 NestJS 中&#xff0c;IOC&#xff08;控制反转&#xff09;和 DI&#xff08;依赖注入&#xff09;是核心概念&#xff0c;它们使得应用程序的模块化和解耦变得更加容易。 控制反转&#xff08;IOC&#xff0c;Inversion of Control&#xff09; 控制反转是一个设计原则&…

clickhouse中使用ReplicatedMergeTree表引擎数据去重问题

问题&#xff1a;使用ReplicatedMergeTree表引擎&#xff0c;该引擎逻辑上是不会对于主键相同的数据&#xff0c;进行去重合并操作。如果想要去重&#xff0c;可以使用ReplacingReplicatedMergeTree表引擎。然后使用ReplicatedMergeTree表引擎进行数据insert 插入数据&#xff…

数据防泄密之源代码防泄密的七大要则!

在数字化时代&#xff0c;源代码的安全保护对企业至关重要。它是企业创新和竞争力的核心&#xff0c;一旦泄露&#xff0c;可能会带来不可估量的损失。因此&#xff0c;选择一款合适的源代码加密软件成为了企业信息安全的关键。SDC沙盒防泄密软件以其独特的技术优势和全面的功能…