在KAN学习Day1——模型框架解析及HelloKAN中,我对KAN模型的基本原理进行了简单说明,并将作者团队给出的入门教程hellokan跑了一遍;
在KAN 学习 Day2 —— utils.py及spline.py 代码解读及测试中,我对项目的基本模块代码进行了解释,并以单元测试的形式深入理解模块功能,其中还发现了一个细小的错误。
在KAN 学习 Day3 —— KANLayer.py 与 Symbolic_KANLayer.py 代码解读及测试中,我对两种KAN层的实现进行了解读,它们分别是 “基于B样条曲线的KAN层” 和 “基于 的KAN层” 。(在下文中就称 B样条KAN层 和 符号KAN层)
今天我们开始对完整的KAN网络进行剖析,根据之前的经验,MultKAN类应该包括网络初始化、层之间网格参数传递、反向传播参数更新、网络剪枝、画图等等操作。
目录
一、kan目录
二、MultKAN.py
2.1 类注释
2.2 构造函数 __init__
2.3 节点数计算
2.4 前向传播 forward
0. 方法定义及注释
1. 初始化阶段
2. 前向传播循环
2.5 训练方法 fit
三、总结
一、kan目录
kan目录结构如下,包括了模型源码、检查点、实验以及assets等
先了解一下这些文件/文件夹的大致信息:
- kan\__init__.py:用于初始化Python包,方便使用时导入模块
- kan\compiler.py:用于编译模型
- kan\experiment.py:实验代码
- kan\feynman.py:费曼函数,根据传入“name”的值确定函数,暂时没找到这个在哪里用到
- kan\hypothesis.py:将函数进行线性分离,还包含一些画图函数
- kan\KANLayer.py:KAN层的实现,使用B样条曲线作为激活函数
- kan\LBFGS.py:这个文件名似乎昨天见过,训练时的opt参数。L-BFGS是一种用于无约束优化问题的算法,它是一种拟牛顿方法,特别适用于大型稀疏问题。
- kan\MLP.py:作者自己实现了一个MLP,应该使来与KAN做对比的
- kan\MultKAN.py:在KANLayer的基础上实现的KAN类的定义,提供了关于构建和配置这种网络的详细信息。
- kan\spline.py:样条函数的实现
- kan\Symbolic_KANLayer.py:符号化的KAN层,使用四参线性函数作为激活函数
- kan\utils.py:通用模块
- kan\.ipynb_checkpoints:看目录名,这个文件夹下存放的应该是检查点文件,但是似乎和模型的实现代码区别不大,没遇到过,还不知道有什么用。
- kan\assets:这个目录下存放了两张图片,一张加号一张乘号,应该是对函数进行线性分离后,可视化时用的
- kan\experiments:这个目录下是experiment1.ipynb,和昨天跑的hellokan差不多,今天再跑一下
二、MultKAN.py
import torch
import torch.nn as nn
import numpy as np
from .KANLayer import KANLayer
#from .Symbolic_MultKANLayer import *
from .Symbolic_KANLayer import Symbolic_KANLayer
from .LBFGS import *
import os
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import copy
#from .MultKANLayer import MultKANLayer
import pandas as pd
from sympy.printing import latex
from sympy import *
import sympy
import yaml
from .spline import curve2coef
from .utils import SYMBOLIC_LIB
from .hypothesis import plot_tree
导入的这些依赖中,只有 LBFGS 和 plot_tree 我们还没介绍,这两部分内容我也没打算深入研究
- LBFGS(Limited-memory BFGS)是一种优化算法,它主要用于求解无约束优化问题。
- plot_tree则是画出网络的树状图
2.1 类注释
class MultKAN(nn.Module):
'''
KAN class
Attributes:
-----------
grid : int
the number of grid intervals
k : int
spline order
act_fun : a list of KANLayers
symbolic_fun: a list of Symbolic_KANLayer
depth : int
depth of KAN
width : list
number of neurons in each layer.
Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2).
mult_arity : int, or list of int lists
multiplication arity for each multiplication node (the number of numbers to be multiplied)
grid : int
the number of grid intervals
k : int
the order of piecewise polynomial
base_fun : fun
residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
symbolic_fun : a list of Symbolic_KANLayer
Symbolic_KANLayers
symbolic_enabled : bool
If False, the symbolic front is not computed (to save time). Default: True.
width_in : list
The number of input neurons for each layer
width_out : list
The number of output neurons for each layer
base_fun_name : str
The base function b(x)
grip_eps : float
The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)
node_bias : a list of 1D torch.float
node_scale : a list of 1D torch.float
subnode_bias : a list of 1D torch.float
subnode_scale : a list of 1D torch.float
symbolic_enabled : bool
when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)
affine_trainable : bool
indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)
sp_trainable : bool
indicate whether the overall magnitude of splines is trainable
sb_trainable : bool
indicate whether the overall magnitude of base function is trainable
save_act : bool
indicate whether intermediate activations are saved in forward pass
node_scores : None or list of 1D torch.float
node attribution score
edge_scores : None or list of 2D torch.float
edge attribution score
subnode_scores : None or list of 1D torch.float
subnode attribution score
cache_data : None or 2D torch.float
cached input data
acts : None or a list of 2D torch.float
activations on nodes
auto_save : bool
indicate whether to automatically save a checkpoint once the model is modified
state_id : int
the state of the model (used to save checkpoint)
ckpt_path : str
the folder to store checkpoints
round : int
the number of times rewind() has been called
device : str
'''
这段代码定义了一个名为 MultKAN
的类,它是基于 nn.Module
构建的,这个类具有众多的属性,用于描述和控制其行为和特征:
grid:
网格的间隔数(使用网格进行参数优化)k
:分段多项式的阶数,或者说B样条的控制点数act_fun:B样条KAN层列表
symbolic_fun
:符号KAN层列表。depth
:表示模型的深度。width
:描述了各层神经元的数量。mult_arity
:与乘法节点的乘法运算的元数有关。base_fun
:公式中的。- symbolic_enabled:布尔值,是否使用符号KAN层
width_in
和width_out
:分别表示各层的输入和输出神经元数量。base_fun_name
:基础函数的名称。grip_eps
:可能用于在均匀网格和自适应网格之间进行插值。- 各种与偏差、缩放、训练相关的属性,如
node_bias
、node_scale
等,用于控制模型的训练和参数调整。 - 各种与分数、缓存、自动保存、设备等相关的属性,用于模型的评估、数据存储、模型保存和硬件设置等方面。
嘛,就是说这里好多注释又重复了......
2.2 构造函数 __init__
def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):
'''
initalize a KAN model
Args:
-----
width : list of int
Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs)
grid : int
number of grid intervals. Default: 3.
k : int
order of piecewise polynomial. Default: 3.
mult_arity : int, or list of int lists
multiplication arity for each multiplication node (the number of numbers to be multiplied)
noise_scale : float
initial injected noise to spline.
base_fun : str
the residual function b(x). Default: 'silu'
symbolic_enabled : bool
compute (True) or skip (False) symbolic computations (for efficiency). By default: True.
affine_trainable : bool
affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias
grid_eps : float
When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
grid_range : list/np.array of shape (2,))
setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)
sp_trainable : bool
If true, scale_sp is trainable. Default: True.
sb_trainable : bool
If true, scale_base is trainable. Default: True.
device : str
device
seed : int
random seed
save_act : bool
indicate whether intermediate activations are saved in forward pass
sparse_init : bool
sparse initialization (True) or normal dense initialization. Default: False.
auto_save : bool
indicate whether to automatically save a checkpoint once the model is modified
state_id : int
the state of the model (used to save checkpoint)
ckpt_path : str
the folder to store checkpoints. Default: './model'
round : int
the number of times rewind() has been called
device : str
Returns:
--------
self
Example
-------
>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
checkpoint directory created: ./model
saving model version 0.0
'''
这段代码是 MultKAN
类的构造函数 __init__
的定义。构造函数用于初始化一个 MultKAN
模型实例,并为其设置各种参数。
参数说明:
width
:一个整数列表,指定了每一层的神经元数量。如果没有乘法节点,列表中的每个元素代表相应层的神经元数量;如果有乘法节点,列表中的元素是一个包含神经元数量和乘法节点数量的元组。grid
:网格间隔的数量,默认为3。k
:分段多项式的阶数,默认为3。mult_arity
:每个乘法节点的乘法运算的元数,可以是单个整数或整数列表。noise_scale
:注入到样条函数中的初始噪声的缩放比例。scale_base_mu
和scale_base_sigma
:基础函数的缩放参数的均值和标准差。base_fun
:残差函数b(x)
的类型,默认为 'silu'。symbolic_enabled
:是否启用符号计算,默认为True。affine_trainable
:是否更新仿射参数,包括节点缩放、节点偏差、子节点缩放和子节点偏差。grid_eps
:用于在均匀网格和自适应网格之间进行插值的参数。grid_range
:设置网格范围的列表或NumPy数组。sp_trainable
:如果为真,则spline
的缩放是可训练的。sb_trainable
:如果为真,则基础函数的缩放是可训练的。device
:指定设备,如 'cpu' 或 'cuda'。seed
:随机种子,用于初始化权重。save_act
:指示是否在正向传递中保存中间激活。sparse_init
:是否进行稀疏初始化。auto_save
:指示是否在修改模型后自动保存检查点。state_id
:模型的当前状态,用于保存检查点。ckpt_path
:存储检查点的文件夹路径。round
:rewind()
被调用次数。device
:设备类型。
代码说明:
super(MultKAN, self).__init__()
- 调用父类
MultKAN
的初始化方法,用于设置一些基本的属性或执行一些初始化操作。
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
-
这三行代码设置了随机数种子,确保每次运行代码时生成的随机数序列相同,这对于测试和调试非常有用。据说将seed设置为3407会将模型的性能提升1%
### initializeing the numerical front ###
self.act_fun = []
self.depth = len(width) - 1
- 这里初始化了激活函数列表
self.act_fun
和模型的深度self.depth
,深度是通过宽度列表的长度减一得到的。
for i in range(len(width)):
if type(width[i]) == int:
width[i] = [width[i],0]
self.width = width
- 遍历宽度列表,如果宽度为整数,则将其转换为列表形式,形式为
[宽度, 0]。
- 将宽度列表赋值给
self.width
属性。 - 注意到,注释中的width属性是有两种形式的,这几行代码使其都转化为了第二种形式,即如果有乘法节点,列表中的元素是一个包含神经元数量和乘法节点数量的元组。
# if mult_arity is just a scalar, we extend it to a list of lists
# e.g, mult_arity = [[2,3],[4]] means that in the first hidden layer, 2 mult ops have arity 2 and 3, respectively;
# in the second hidden layer, 1 mult op has arity 4.
if isinstance(mult_arity, int):
self.mult_homo = True # when homo is True, parallelization is possible
else:
self.mult_homo = False # when home if False, for loop is required.
self.mult_arity = mult_arity
- 如果
mult_arity
是一个标量(即单个数字),代码将把它扩展为一个列表的列表。这样做通常是为了将单一的参数应用到多个乘法操作上。 - 例如,如果
mult_arity = [[2,3],[4]]
,这意味着在第一个隐藏层中有两个乘法操作,它们的参数分别是 2 和 3;在第二个隐藏层中有一个乘法操作,其参数是 4。 - 这里检查
mult_arity
是否是一个整数。如果是,那么所有乘法操作的参数都是相同的,这意味着它们是同质的。在这种情况下,可以将这些操作并行化,以提高计算效率。因此,将self.mult_homo
设置为True
。 - 如果
mult_arity
不是一个整数,那么它可能是一个列表的列表,其中包含不同层级的不同参数。在这种情况下,不能并行化乘法操作,因为每个操作的参数可能不同。因此,将self.mult_homo
设置为False
,这意味着可能需要使用循环来处理每个操作。 - 最后,将处理后的
mult_arity
参数赋值给self.mult_arity
,这样模型就可以使用这个参数来定义其乘法操作了。
width_in = self.width_in
width_out = self.width_out
调用了两个方法,获得了KAN层真正的输入输出节点数。
self.base_fun_name = base_fun
if base_fun == 'silu':
base_fun = torch.nn.SiLU()
elif base_fun == 'identity':
base_fun = torch.nn.Identity()
elif base_fun == 'zero':
base_fun = lambda x: x*0.
- 将传入的
base_fun
参数赋值给实例变量self.base_fun_name
。这意味着base_fun
是一个字符串,它表示想要使用的基础函数的名称。 - 如果
base_fun_name
是字符串'silu'
,那么代码将创建一个torch.nn.SiLU()
对象。SiLU
(Sigmoid-weighted Linear Unit)是一个激活函数,通常用于神经网络中。 - 如果
base_fun_name
是字符串'identity'
,那么代码将创建一个torch.nn.Identity()
对象。Identity
函数是一个恒等函数,它直接返回其输入值,通常用作默认激活函数或不改变输入的层。 - 如果
base_fun_name
是字符串'zero'
,那么代码将创建一个匿名函数(lambda 函数),这个函数将任何输入x
乘以 0,从而输出 0。这可能表示一个“关闭”激活状态的函数,不激活任何神经元。
self.grid_eps = grid_eps
self.grid_range = grid_range
- 将网格相关的参数赋值给
self.grid_eps
和self.grid_range
。 grid_eps
:控制网格细化策略的浮点数,默认为0.02。当grid_eps = 1
时,网格是均匀的;当grid_eps = 0
时,它使用样本的百分位数进行分区。0 < grid_eps < 1 插值在两种极端之间。
for l in range(self.depth):
# splines
sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid, k=k, noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
self.act_fun.append(sp_batch)
- 这段代码在循环中为每一层创建一个
KANLayer
实例,并把这些实例添加到一个列表中,以便后续可以用于KAN网络模型。
self.node_bias = []
self.node_scale = []
self.subnode_bias = []
self.subnode_scale = []
- 初始化用于节点和子节点的偏差和缩放参数的列表。
globals()['self.node_bias_0'] = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)
exec('self.node_bias_0' + " = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)")
globals()
返回当前全局命名空间中的所有全局变量。self.node_bias_0
是类的一个属性,这个属性在类的定义中尚未明确定义(即,它不是类的内部成员,而是通过全局命名空间访问的)。torch.nn.Parameter(torch.zeros(3,1))
创建了一个PyTorch的参数(Parameter
对象),该对象是张量,用于在神经网络中存储权重,并且支持梯度计算。.requires_grad_(False)
设置了该参数对象不进行梯度计算,即不会追踪其在计算图中的操作,这对于不需要计算梯度的参数(如偏置项)来说是合理的。exec()
是Python的内置函数,用于执行字符串形式的Python代码。在这里,它被用来动态地创建或更新类属性。'self.node_bias_0'
是一个字符串,表示类中要创建或更新的属性名。torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)
是一个字符串表达式,创建了一个新的PyTorch参数并设置了其梯度计算为False。
这种做法在某些情况下非常有用,比如在定义神经网络模型时,需要动态地为特定的参数创建属性,或者在模型中为某些不需要梯度计算的参数(如偏置项)创建独立的属性。
但是!我没找到这两行代码有啥用处。
for l in range(self.depth):
exec(f'self.node_bias_{l} = torch.nn.Parameter(torch.zeros(width_in[l+1])).requires_grad_(affine_trainable)')
exec(f'self.node_scale_{l} = torch.nn.Parameter(torch.ones(width_in[l+1])).requires_grad_(affine_trainable)')
exec(f'self.subnode_bias_{l} = torch.nn.Parameter(torch.zeros(width_out[l+1])).requires_grad_(affine_trainable)')
exec(f'self.subnode_scale_{l} = torch.nn.Parameter(torch.ones(width_out[l+1])).requires_grad_(affine_trainable)')
exec(f'self.node_bias.append(self.node_bias_{l})')
exec(f'self.node_scale.append(self.node_scale_{l})')
exec(f'self.subnode_bias.append(self.subnode_bias_{l})')
exec(f'self.subnode_scale.append(self.subnode_scale_{l})')
- 通过循环,它为模型的每一层创建了节点偏置、节点缩放、子节点偏置和子节点缩放参数,并将这些参数存储在类的属性中,以便后续使用。
- 通过
affine_trainable
参数来控制哪些参数是可训练的。
self.act_fun = nn.ModuleList(self.act_fun)
self.grid = grid
self.k = k
self.base_fun = base_fun
这几个基础的设置就不解释了。
### initializing the symbolic front ###
self.symbolic_fun = []
for l in range(self.depth):
sb_batch = Symbolic_KANLayer(in_dim=width_in[l], out_dim=width_out[l+1])
self.symbolic_fun.append(sb_batch)
刚刚创建了B样条KAN层,现在创建符号KAN层。
self.symbolic_fun = nn.ModuleList(self.symbolic_fun)
self.symbolic_enabled = symbolic_enabled
self.affine_trainable = affine_trainable
self.sp_trainable = sp_trainable
self.sb_trainable = sb_trainable
- 将符号层加入列表
- 设置符号层是否可用
- 设置符号层线性函数的四个参数是否可训练
- 设置激活函数中的参数 是否可训练,分为了sp和sb两种,sp为B样条KAN层的,sb为符号KAN层的
self.save_act = save_act
self.node_scores = None
self.edge_scores = None
self.subnode_scores = None
self.cache_data = None
self.acts = None
self.auto_save = auto_save
self.state_id = 0
self.ckpt_path = ckpt_path
self.round = round
一些中间结果的存储变量和保存操作设置,保存的具体操作如下:
if auto_save:
if first_init:
if not os.path.exists(ckpt_path):
# Create the directory
os.makedirs(ckpt_path)
print(f"checkpoint directory created: {ckpt_path}")
print('saving model version 0.0')
history_path = self.ckpt_path+'/history.txt'
with open(history_path, 'w') as file:
file.write(f'### Round {self.round} ###' + '\n')
file.write('init => 0.0' + '\n')
self.saveckpt(path=self.ckpt_path+'/'+'0.0')
else:
self.state_id = state_id
我们在hellokan中就见识过,模型在训练过程中会保存中间数据、状态和历史信息等内容
self.input_id = torch.arange(self.width_in[0],)
- 给输入节点编号 ,从0开始
self.device = device
self.to(device)
def to(self, device):
'''
move the model to device
Args:
-----
device : str or device
Returns:
--------
self
Example
-------
>>> from kan import *
>>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
>>> model.to(device)
'''
super(MultKAN, self).to(device)
self.device = device
for kanlayer in self.act_fun:
kanlayer.to(device)
for symbolic_kanlayer in self.symbolic_fun:
symbolic_kanlayer.to(device)
return self
- 选择计算设备
测试:
from kan import *
torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = KAN(width=[2,[5,3],[5,1],3], mult_arity=[0,[2,3,4],[2],0],grid=3, k=3, seed=42, device=device)
model.input_id
cuda
checkpoint directory created: ./model
saving model version 0.0tensor([0, 1])
2.3 节点数计算
2.3.1 width_in
@property
def width_in(self):
'''
The number of input nodes for each layer
'''
width = self.width
width_in = [width[l][0]+width[l][1] for l in range(len(width))]
return width_in
这段代码定义了一个属性 width_in
,它的作用是计算并返回模型每一层的输入节点数量。
- 首先,获取了模型的宽度信息
width
。 - 然后,通过列表推导式计算每一层输入节点的数量,计算方式是将每一层的总和维度
width[l][0]
和乘法操作维度width[l][1]
相加。 - 最后,返回计算得到的输入节点数量列表。
所以每层节点数=设置的节点数+乘法操作次数
测试:
width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]
print(model.width)
model.width_in
[[2, 0], [5, 3], [5, 1], [3, 0]]
[2, 8, 6, 3]
2.3.2 width_out
@property
def width_out(self):
'''
The number of output subnodes for each layer
'''
width = self.width
if self.mult_homo == True:
width_out = [width[l][0]+self.mult_arity*width[l][1] for l in range(len(width))]
else:
width_out = [width[l][0]+int(np.sum(self.mult_arity[l])) for l in range(len(width))]
return width_out
这段代码定义了一个属性 width_out
,其目的是计算并返回模型每一层的输出子节点数量。
- 首先,获取了模型的宽度信息
width
。然后根据self.mult_homo
的值来决定计算输出节点数量的方式。 - 如果
self.mult_homo
为True
,则使用列表推导式计算每一层的输出节点数量。计算方式是将每一层的总和维度width[l][0]
与乘法操作维度width[l][1]
的结果乘以mult_arity
的值相加。mult_arity
是一个数组,表示每一层的乘法操作的幅度。 - 如果
self.mult_homo
为False
,则使用列表推导式计算每一层的输出节点数量。计算方式是将每一层的总和维度width[l][0]
与mult_arity[l]
的元素之和相加。mult_arity[l]
是一个数组,表示每一层的乘法操作的幅度。 - 最后,返回计算得到的输出节点数量列表。
测试:
width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]
print(model.width)
model.width_out
[[2, 0], [5, 3], [5, 1], [3, 0]]
[2, 14, 7, 3]
2.3.3 n_sum
@property
def n_sum(self):
'''
The number of addition nodes for each layer
'''
width = self.width
n_sum = [width[l][0] for l in range(1,len(width)-1)]
return n_sum
这段代码定义了一个属性 n_sum
,用于计算并返回除了第一层和最后一层之外,每一层的总和维度 width[l][0]
所组成的列表。
首先,获取了模型的宽度信息 width
。然后通过列表推导式,从第二层到倒数第二层,提取出每一层的 width[l][0]
,并将这些值组成一个新的列表 n_sum
,最后返回这个列表。
测试:
width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]
print(model.width)
model.n_sum
[[2, 0], [5, 3], [5, 1], [3, 0]]
[5, 5]
2.3.4 n_mult
@property
def n_mult(self):
'''
The number of multiplication nodes for each layer
'''
width = self.width
n_mult = [width[l][1] for l in range(1,len(width)-1)]
return n_mult
这段代码定义了一个属性 n_mult
,用于计算并返回除了第一层和最后一层之外,每一层的乘法节点数量。这里 width
是一个包含多层宽度信息的数据结构,每一层的信息以列表的形式存储,其中 width[l][1]
表示第 l
层的乘法节点数量。
通过列表推导式,代码遍历从第二层到倒数第二层的所有层,提取每一层的乘法节点数量,并将这些数量组成一个新的列表 n_mult
。最后,这个列表被返回给调用者。
测试:
width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]
print(model.width)
model.n_mult
[[2, 0], [5, 3], [5, 1], [3, 0]]
[3, 1]
2.3.5 feature_score
@property
def feature_score(self):
'''
attribution scores for inputs
'''
self.attribute()
if self.node_scores == None:
return None
else:
return self.node_scores[0]
这段代码定义了一个名为 feature_score
的属性。其功能是计算输入的归因分数。
首先调用了 self.attribute()
方法。然后判断 self.node_scores
是否为 None
,如果是,则直接返回 None
;如果不是,则返回 self.node_scores
中的第一个元素。
这意味着只有在 self.node_scores
不为空的情况下,才会返回其第一个元素作为特征分数。
2.4 前向传播 forward
这个前向传播有点诡异,总感觉跟论文中的对不上,这次我们一边解释一边测试!
先来个简单的:width=[2,5,5,3],mult_arity = 2
from kan import *
torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = KAN(width=[2,5,5,3], mult_arity=2,grid=3, k=3, seed=42, device=device)
model.input_id
cuda
checkpoint directory created: ./model
saving model version 0.0
tensor([0, 1])
测试数据:
x = torch.tensor([[1,2],
[3,4],
[5,6],
[7,8],
[9,10]]).float()
x = x.to(device)
0. 方法定义及注释
def forward(self, x, singularity_avoiding=False, y_th=10.):
'''
forward pass
Args:
-----
x : 2D torch.tensor
inputs
singularity_avoiding : bool
whether to avoid singularity for the symbolic branch
y_th : float
the threshold for singularity
Returns:
--------
None
Example1
--------
>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
>>> x = torch.rand(100,2)
>>> model(x).shape
Example2
--------
>>> from kan import *
>>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
>>> x = torch.tensor([[1],[-0.01]])
>>> model.fix_symbolic(0,0,0,'log',fit_params_bool=False)
>>> print(model(x))
>>> print(model(x, singularity_avoiding=True))
>>> print(model(x, singularity_avoiding=True, y_th=1.))
'''
参数说明:
x
: 2Dtorch.tensor
,输入数据。singularity_avoiding
: bool,默认为False
。如果为True
,则在符号分支中避免奇异点。y_th
: float,默认为10.
。用于判断是否避免奇异点的阈值。
返回值:
None
:方法执行后不返回任何值
1. 初始化阶段
x = x[:,self.input_id.long()]
assert x.shape[1] == self.width_in[0]
# cache data
self.cache_data = x
self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
self.acts_premult = []
self.spline_preacts = []
self.spline_postsplines = []
self.spline_postacts = []
self.acts_scale = []
self.acts_scale_spline = []
self.subnode_actscale = []
self.edge_actscale = []
# self.neurons_scale = []
self.acts.append(x) # acts shape: (batch, width[l])
-
数据选择与验证:
- 选择输入数据
x
的特定列,并验证其形状是否符合模型的输入宽度要求。 - 缓存输入数据
x
。
- 选择输入数据
-
初始化变量:
- 初始化用于存储不同层激活、尺度因子等的列表。
2. 前向传播循环
for l in range(self.depth):
- 循环遍历模型中的每一层,其中
self.depth
是模型的层数。
x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
#print(preacts, postacts_numerical, postspline)
- 使用第
l
层的激活函数act_fun[l]
对输入x
进行处理。 - 这里的激活函数是B样条KAN层的激活函数,详情见KANLayer
- 处理结果包括数值分支的输出
x_numerical
、预激活输出preacts
、后激活输出postacts_numerical
和样条函数的输出postspline
。(对应的是y, preacts, postacts, postspline)
if self.symbolic_enabled == True:
x_symbolic, postacts_symbolic = self.symbolic_fun[l](x, singularity_avoiding=singularity_avoiding, y_th=y_th)
else:
x_symbolic = 0.
postacts_symbolic = 0.
- 可使用符号KAN层时,同样进行计算
x = x_numerical + x_symbolic
这里要注意了,作者将两种层的计算结果相加了!也就是把B样条和线性函数同时叠加使用!
# subnode affine transform
x = self.subnode_scale[l][None,:] * x + self.subnode_bias[l][None,:]
- 对激活函数的计算结果进行缩放,并增加偏置常数b
对以上这一部分内容做测试:
x = x[:,model.input_id.long()]
assert x.shape[1] == model.width_in[0]
for l in range(model.depth):
x_numerical, preacts, postacts_numerical, postspline = model.act_fun[l](x)
#print(preacts, postacts_numerical, postspline)
if model.symbolic_enabled == True:
x_symbolic, postacts_symbolic = model.symbolic_fun[l](x, singularity_avoiding=False, y_th=10)
else:
x_symbolic = 0.
postacts_symbolic = 0.
x = x_numerical + x_symbolic
x = model.subnode_scale[l][None,:] * x + model.subnode_bias[l][None,:]
print(x)
print(x.shape)
tensor([[ 1.2935, -0.7047, -1.1071, 0.1673, 0.7162],
[ 3.6752, -2.1692, -2.5181, 0.0475, 2.4007],
[ 5.9759, -3.6323, -3.8994, -0.1110, 4.0748],
[ 8.2045, -5.0443, -5.2469, -0.2554, 5.6880],
[10.4148, -6.4431, -6.5866, -0.3956, 7.2852]], device='cuda:0',
grad_fn=<AddBackward0>)
torch.Size([5, 5])
tensor([[-0.1832, 0.2447, 0.2546, 0.0981, -0.0997],
[-0.8389, 1.2821, 1.2113, 0.4335, -0.0337],
[-1.3981, 2.3204, 2.1626, 0.7829, -0.0287],
[-1.9293, 3.2605, 3.0434, 1.1072, -0.0466],
[-2.4603, 4.1674, 3.9076, 1.4242, -0.0795]], device='cuda:0',
grad_fn=<AddBackward0>)
torch.Size([5, 5])
tensor([[ 0.0064, 0.0481, -0.1441],
[ 0.4862, 0.2570, -0.7088],
[ 1.0812, 0.6035, -1.4443],
[ 1.6532, 0.8613, -2.0778],
[ 2.1898, 1.0638, -2.6628]], device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([5, 3])
所有中间结果的形状都没有问题。
if self.save_act:
# save subnode_scale
self.subnode_actscale.append(torch.std(x, dim=0).detach())
if self.save_act:
postacts = postacts_numerical + postacts_symbolic
# self.neurons_scale.append(torch.mean(torch.abs(x), dim=0))
#grid_reshape = self.act_fun[l].grid.reshape(self.width_out[l + 1], self.width_in[l], -1)
input_range = torch.std(preacts, dim=0) + 0.1
output_range_spline = torch.std(postacts_numerical, dim=0) # for training, only penalize the spline part
output_range = torch.std(postacts, dim=0) # for visualization, include the contribution from both spline + symbolic
# save edge_scale
self.edge_actscale.append(output_range)
self.acts_scale.append((output_range / input_range).detach())
self.acts_scale_spline.append(output_range_spline / input_range)
self.spline_preacts.append(preacts.detach())
self.spline_postacts.append(postacts.detach())
self.spline_postsplines.append(postspline.detach())
self.acts_premult.append(x.detach())
如果启用了保存激活函数尺度因子的选项,则计算并保存以下内容:
- 子节点尺度因子(标准差)。
- 边尺度因子(输出范围)。
- 激活函数输出的尺度因子(输出范围与输入范围的比例)。
- 样条部分的尺度因子。
- 预激活输出、后激活输出和样条函数输出的副本。
但是我很好奇,这不是self.spline_postacts嘛,但是存的是postacts = postacts_numerical + postacts_symbolic,保存样条的激活输出为什么不只保存postacts_numerical。
还有就是都是判断 save_act,为啥用两个if。有时候就挺不能理解的
接下来介绍的这个东西,非常重要!它在基础节点的基础上引入了乘法操作。并且分为同质和非同质两种。
# multiplication
dim_sum = self.width[l+1][0]
dim_mult = self.width[l+1][1]
- 获取下一次节点数以及乘法操作次数
- self.width[l+1][0]是下一层的节点数
- self.width[l+1][1]是乘法操作次数
对于上面的例子,有
x.shape: torch.Size([5, 5])
dim_sum: 5
dim_mult: 0
x.shape: torch.Size([5, 5])
dim_sum: 5
dim_mult: 0
x.shape: torch.Size([5, 3])
dim_sum: 3
dim_mult: 0
if self.mult_homo == True:
for i in range(self.mult_arity-1):
if i == 0:
x_mult = x[:,dim_sum::self.mult_arity] * x[:,dim_sum+1::self.mult_arity]
else:
x_mult = x_mult * x[:,dim_sum+i+1::self.mult_arity]
- 当本层乘法参数都相同,则进行矩阵运算,即处理同质(homogeneous)乘法操作:
- 在第一次循环(
i == 0
)中:x_mult = x[:,dim_sum::self.mult_arity] * x[:,dim_sum+1::self.mult_arity]
:对x
的特定部分进行逐元素乘法。这里x[:,dim_sum::self.mult_arity]
表示从dim_sum
开始,每隔self.mult_arity
个元素取一个元素,形成一个新的张量。同理x[:,dim_sum+1::self.mult_arity]
表示从dim_sum+1
开始取元素。这两个张量逐元素相乘得到x_mult
。
- 在后续的循环中(
i != 0
):x_mult = x_mult * x[:,dim_sum+i+1::self.mult_arity]
:将上一次乘法的结果与x
的另一部分相乘。
- 在第一次循环(
对于我们width=[2,5,5,3],mult_arity = 2这个例子,有model.mult_homo == True,但结果如下:
tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])
tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])
tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])
由于dim_mult = 0,所以不进行乘法运算,代码中表现为dim_sum超出index,所以dim_sum::model.mult_arity都为0,自然乘积也为0。
测试升级:
设置width=[2,[5,2],[5,3],3], mult_arity=3,这是一个同质运算,由第一层向第二层传递时,会做乘法运算,次数为mult_arity-1=2,而乘法运算结果维度为dim_mult,然后与原始的dim_sum维度拼接,参数设置 width=[2,[5,1],[5,3],3], mult_arity=3,拼接操作:
if self.width[l+1][1] > 0:
x = torch.cat([x[:,:dim_sum], x_mult], dim=1)
- 将x中未参与乘法计算的部分与乘法计算结果进行拼接,恢复原始张量形状
x.shape: torch.Size([5, 8])
x_mult.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x.shape: torch.Size([5, 6])
x.shape: torch.Size([5, 14])
x_mult.shape: torch.Size([5, 3])
x_mult.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 8])
x.shape: torch.Size([5, 3])
x_mult.shape: torch.Size([5, 0])
x_mult.shape: torch.Size([5, 0])
x.shape: torch.Size([5, 3])
测试再次升级:
我用数据展示第二层的计算:
x = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14]])
dim_sum = 5
dim_mult = 3
mult_arity = 3
if model.mult_homo == True:
for i in range(2):
print(f"第{i+1}次乘法:")
if i == 0:
print(x[:,dim_sum::mult_arity])
print(x[:,dim_sum+1::mult_arity])
x_mult = x[:,dim_sum::mult_arity] * x[:,dim_sum+1::mult_arity]
else:
print(x_mult)
print(x[:,dim_sum+i+1::mult_arity])
x_mult = x_mult * x[:,dim_sum+i+1::mult_arity]
print(x_mult)
if dim_mult > 0:
x = torch.cat([x[:,:dim_sum], x_mult], dim=1)
print(x)
print("x.shape:",x.shape)
print()
第1次乘法:
tensor([[ 6, 9, 12]])
tensor([[ 7, 10, 13]])
tensor([[ 42, 90, 156]])
第2次乘法:
tensor([[ 42, 90, 156]])
tensor([[ 8, 11, 14]])
tensor([[ 336, 990, 2184]])
tensor([[ 1, 2, 3, 4, 5, 336, 990, 2184]])
x.shape: torch.Size([1, 8])
这下就完全理解它的乘法是如何运算的了。同质运算使用了矩阵运算以加快运算速度,这建立在mult_arity为常数的情况下,而当mult_arity的元素为列表时,只能进行遍历运算,如下:
else:
for j in range(dim_mult):
acml_id = dim_sum + np.sum(self.mult_arity[l+1][:j])
for i in range(self.mult_arity[l+1][j]-1):
if i == 0:
x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]
else:
x_mult_j = x_mult_j * x[:,[acml_id+i+1]]
if j == 0:
x_mult = x_mult_j
else:
x_mult = torch.cat([x_mult, x_mult_j], dim=1)
- 当本层乘法参数不相同,则进行遍历运算,即处理非同质(non-homogeneous)乘法操作:
for j in range(dim_mult):
循环遍历dim_mult
次,dim_mult
表示乘法操作的次数。- 在每次循环中,计算
acml_id
:acml_id = dim_sum + np.sum(self.mult_arity[l+1][:j])
:计算当前乘法操作的起始索引。
- 然后对每个乘法操作:
for i in range(self.mult_arity[l+1][j]-1):
:循环遍历当前维度的乘法操作次数。- 在第一次循环(
i == 0
)中:x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]
:对x
的特定部分进行逐元素乘法。
- 在后续的循环中(
i != 0
):x_mult_j = x_mult_j * x[:,[acml_id+i+1]]
:将上一次乘法的结果与x
的另一部分相乘。
- 如果是第一个乘法操作(
j == 0
):x_mult = x_mult_j
:将第一个乘法操作的结果赋值给x_mult
。
- 如果不是第一个乘法操作:
x_mult = torch.cat([x_mult, x_mult_j], dim=1)
:将当前乘法操作的结果与之前的结果在最后一个维度上连接。
测试:
参数设置 width=[2,[5,1],[5,3],3], mult_arity=[[0],[2],[2,3,4],[0]]
x = torch.tensor([[1,2],
[3,4],
[5,6],
[7,8],
[9,10]]).float()
x = x.to(device)
x = x[:,model.input_id.long()]
assert x.shape[1] == model.width_in[0]
for l in range(model.depth):
x_numerical, preacts, postacts_numerical, postspline = model.act_fun[l](x)
#print(preacts, postacts_numerical, postspline)
if model.symbolic_enabled == True:
x_symbolic, postacts_symbolic = model.symbolic_fun[l](x, singularity_avoiding=False, y_th=10)
else:
x_symbolic = 0.
postacts_symbolic = 0.
x = x_numerical + x_symbolic
x = model.subnode_scale[l][None,:] * x + model.subnode_bias[l][None,:]
#print(x)
print("x.shape:",x.shape)
# multiplication
dim_sum = model.width[l+1][0]
dim_mult = model.width[l+1][1]
#print("dim_sum:",dim_sum)
#print("dim_mult:",dim_mult)
if model.mult_homo == False:
for j in range(dim_mult):
acml_id = dim_sum + np.sum(model.mult_arity[l+1][:j])
for i in range(model.mult_arity[l+1][j]-1):
if i == 0:
x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]
else:
x_mult_j = x_mult_j * x[:,[acml_id+i+1]]
print("x_mult_j.shape:",x_mult_j.shape )
if j == 0:
x_mult = x_mult_j
else:
x_mult = torch.cat([x_mult, x_mult_j], dim=1)
print("x_mult.shape:",x_mult.shape)
if model.width[l+1][1] > 0:
x = torch.cat([x[:,:dim_sum], x_mult], dim=1)
x.shape: torch.Size([5, 7])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x.shape: torch.Size([5, 6])
x.shape: torch.Size([5, 14])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 2])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 8])
x.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 3])
来逐一分析:
- 第一层到第二层计算,经过B样条KAN层和符号KAN层,x形状为[batch_size,dim_sum+sum(mult_arity[l+1])],其中sum(mult_arity[l+1])=dim_mult*mult_arity[l+1],因为mult_arity[l+1]只有一个元素。然后进行了一次乘法运算,并将结果拼接在x[:,:dim_sum]后面
- 第二层到第三次计算:经过B样条KAN层和符号KAN层,x形状为[batch_size,dim_sum+sum(mult_arity[l+1])],其中sum(mult_arity[l+1])=np.sum(model.mult_arity[l+1][:j]),对于mult_arity[l+1]列表中的每一个元素,都执行其数值减一的乘法运算,运算结果x_mult_j的形状为[batch_size,1],最终获得的x_mult都是由x_mult_j拼接来的,最后将x_mult拼接在x[:,:dim_sum]后面。
- 第三层到第四层同理。
使用数据展示第二层的计算:
x = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14]])
dim_sum = 5
dim_mult = 3
mult_arity = [2,3,4]
print(x)
if model.mult_homo == False:
for j in range(dim_mult):
print(f"第{j+1}次运算:")
acml_id = dim_sum + np.sum(mult_arity[:j])
for i in range(mult_arity[j]-1):
if i == 0:
print(x[:,[acml_id]])
print(x[:,[acml_id+1]])
x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]
else:
print(x_mult_j)
print(x[:,[acml_id+i+1]])
x_mult_j = x_mult_j * x[:,[acml_id+i+1]]
print("x_mult_j:",x_mult_j)
if j == 0:
x_mult = x_mult_j
else:
x_mult = torch.cat([x_mult, x_mult_j], dim=1)
print("x_mult:",x_mult)
if dim_mult > 0:
x = torch.cat([x[:,:dim_sum], x_mult], dim=1)
print(x)
print("x.shape:",x.shape)
print()
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])
第1次运算:
tensor([[6]])
tensor([[7]])
x_mult_j: tensor([[42]])
x_mult: tensor([[42]])
第2次运算:
tensor([[8]])
tensor([[9]])
x_mult_j: tensor([[72]])
tensor([[72]])
tensor([[10]])
x_mult_j: tensor([[720]])
x_mult: tensor([[ 42, 720]])
第3次运算:
tensor([[11]])
tensor([[12]])
x_mult_j: tensor([[132]])
tensor([[132]])
tensor([[13]])
x_mult_j: tensor([[1716]])
tensor([[1716]])
tensor([[14]])
x_mult_j: tensor([[24024]])
x_mult: tensor([[ 42, 720, 24024]])
tensor([[ 1, 2, 3, 4, 5, 42, 720, 24024]])
x.shape: torch.Size([1, 8])
# x = x + self.biases[l].weight
# node affine transform
x = self.node_scale[l][None,:] * x + self.node_bias[l][None,:]
self.acts.append(x.detach())
return x
- 对拼接后的x进行缩放并且加上偏置常数
- 返回计算结果
我们理一下整个计算思路:
- 传入x后,首先检查x的形状,遍历KAN层进行计算:
- 再分别使用B样条KAN层和符号KAN层计算出x = x_numerical + x_symbolic
- 对x进行缩放处理,并加入偏置常数
- 乘法运算
- 同质乘法运算:对于dim_sum之外的维度,使用矩阵运算计算出x_mult
- 非同质乘法运算:根据mult_arity[l+1]列表一次计算出x_mult_j,拼接成x_mult
- 如进行了乘法运算,则将x_mult与x[:,:dim_sum]拼接
- 对x进行缩放处理,并加入偏置常数
- 返回x
2.5 训练方法 fit
通过对前向传播进行剖析,KAN网络并不像论文中展示的那么简单
- KAN层包含了B样条层和符号层两种,我们可以设置是否使用符号层,如使用的话,中间x计算结果为两者之和。
- KAN层节点过渡时引入了乘法操作,包括同质乘法和非同质乘法,在定义的基础维度上进行了扩展,进一步加强了网络的学习能力。
现在我们对MultKAN的fit方法的使用进行详解。
def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1.,start_grid_update_step=-1, stop_grid_update_step=50, batch=-1,
metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):
'''
training
Args:
-----
dataset : dic
contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
opt : str
"LBFGS" or "Adam"
steps : int
training steps
log : int
logging frequency
lamb : float
overall penalty strength
lamb_l1 : float
l1 penalty strength
lamb_entropy : float
entropy penalty strength
lamb_coef : float
coefficient magnitude penalty strength
lamb_coefdiff : float
difference of nearby coefficits (smoothness) penalty strength
update_grid : bool
If True, update grid regularly before stop_grid_update_step
grid_update_num : int
the number of grid updates before stop_grid_update_step
start_grid_update_step : int
no grid updates before this training step
stop_grid_update_step : int
no grid updates after this training step
loss_fn : function
loss function
lr : float
learning rate
batch : int
batch size, if -1 then full.
save_fig_freq : int
save figure every (save_fig_freq) steps
singularity_avoiding : bool
indicate whether to avoid singularity for the symbolic part
y_th : float
singularity threshold (anything above the threshold is considered singular and is softened in some ways)
reg_metric : str
regularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'}
metrics : a list of metrics (as functions)
the metrics to be computed in training
display_metrics : a list of functions
the metric to be displayed in tqdm progress bar
Returns:
--------
results : dic
results['train_loss'], 1D array of training losses (RMSE)
results['test_loss'], 1D array of test losses (RMSE)
results['reg'], 1D array of regularization
other metrics specified in metrics
Example
-------
>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model.plot()
# Most examples in toturals involve the fit() method. Please check them for useness.
'''
参数说明:
-
dataset (
dic
): 包含训练集和测试集的数据字典,通常包括输入数据(train_input
,test_input
)和标签数据(train_label
,test_label
)。 -
opt (
str
): 选择的优化器,可以是 "LBFGS"(L-BFGS)或 "Adam"。 -
steps (
int
): 训练的总步骤数。 -
log (
int
): 日志输出的频率,即每多少步骤输出一次日志。 -
lamb (
float
): 总体正则化强度,用于控制模型复杂度。 -
lamb_l1 (
float
): L1 正则化强度,用于惩罚模型参数的绝对值。 -
lamb_entropy (
float
): 用于惩罚模型熵的强度,有助于防止过拟合。 -
lamb_coef (
float
): 模型系数的大小惩罚强度。 -
lamb_coefdiff (
float
): 邻近系数之间的差异惩罚强度,用于增加模型的平滑性。 -
update_grid (
bool
): 如果为True
,则在训练步骤达到stop_grid_update_step
之前定期更新网格。 -
grid_update_num (
int
): 在stop_grid_update_step
之前更新网格的次数。 -
start_grid_update_step (
int
): 在这个步骤之前不进行网格更新。 -
stop_grid_update_step (
int
): 这个步骤之后不进行网格更新。 -
loss_fn (
function
): 自定义损失函数,用于计算模型的损失。 -
lr (
float
): 学习率,决定每次更新参数时的步长。 -
batch (
int
): 批处理大小,如果为-1
,则使用完整数据集。 -
save_fig_freq (
int
): 每多少步骤保存一次训练结果的图形。 -
singularity_avoiding (
bool
): 如果为True
,则在符号部分避免奇异点。 -
y_th (
float
): 奇异点阈值,高于此值的任何值都将被视为奇异点。 -
reg_metric (
str
): 用于计算正则化的度量标准,可以选择不同的选项如edge_forward_spline_n
等。 -
metrics (
list of functions
): 计算并返回的自定义度量列表。 -
display_metrics (
list of functions
): 在训练进度条中显示的度量列表。
返回值:
- results (
dic
): 包含训练过程中的关键信息的字典,包括:train_loss
: 训练集上的损失(通常为 RMSE)。test_loss
: 测试集上的损失(通常为 RMSE)。reg
: 正则化项的值。- 其他用户指定的度量。
测试1:
from kan import *
model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
model.plot()
checkpoint directory created: ./model
saving model version 0.0
| train_loss: 1.91e-02 | test_loss: 1.97e-02 | reg: 1.38e+01 | : 100%|█| 20/20 [00:07<00:00, 2.66it
saving model version 0.1
测试2:
from kan import *
model = KAN(width=[2,[5,3],3], mult_arity=3, grid=5, k=3, noise_scale=0.3, seed=2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
model.plot()
checkpoint directory created: ./model
saving model version 0.0
| train_loss: 2.44e-02 | test_loss: 2.81e-02 | reg: 2.67e+01 | : 100%|█| 20/20 [00:09<00:00, 2.15it
saving model version 0.1
这个图包含了3个乘法节点。
三、总结
今天内容主要包括MultKAN网络的初始化、正向传播方法实现、训练方法参数说明。MultKAN网络正向传播有两个特点:
- 传播时可以同时使用KANLayer和Symbolic_KANLayer,以叠加的形式计算中间结果
- KAN节点的连接既有加法连接也有乘法连接,我们可以自定义乘法运算的方式(同质或非同质)
在上文中,我用数据直观展示了mult的计算过程,实际上只是连续的列相乘,因此在我看来,MultKAN的mult节点运算还有一定的优化空间,除了改善单一控制变量self.mult_homo,将其扩展为列表,还可以用numpy库实现连续列相乘的算法,这些尝试我打算放在实际应用中进行。