KAN 学习 Day4 —— MultKAN 正向传播代码解读及测试

news2024/11/13 8:48:01

在KAN学习Day1——模型框架解析及HelloKAN中,我对KAN模型的基本原理进行了简单说明,并将作者团队给出的入门教程hellokan跑了一遍;

在KAN 学习 Day2 —— utils.py及spline.py 代码解读及测试中,我对项目的基本模块代码进行了解释,并以单元测试的形式深入理解模块功能,其中还发现了一个细小的错误。

在KAN 学习 Day3 —— KANLayer.py 与 Symbolic_KANLayer.py 代码解读及测试中,我对两种KAN层的实现进行了解读,它们分别是 “基于B样条曲线的KAN层” 和 “基于 eq?c*f%28a*x+b%29+d 的KAN层” 。(在下文中就称 B样条KAN层 和 符号KAN层)

今天我们开始对完整的KAN网络进行剖析,根据之前的经验,MultKAN类应该包括网络初始化、层之间网格参数传递、反向传播参数更新、网络剪枝、画图等等操作。

目录

一、kan目录

二、MultKAN.py 

2.1 类注释

​​​​2.2 构造函数 __init__

2.3 节点数计算

2.4 前向传播 forward

 0. 方法定义及注释

1. 初始化阶段

2. 前向传播循环

2.5 训练方法 fit

三、总结


一、kan目录

kan目录结构如下,包括了模型源码、检查点、实验以及assets等

e12295be65d94b3381e647242dc51eba.pngcc0f7d4a5a5148c995fcd44ad9bbbab6.png

 先了解一下这些文件/文件夹的大致信息:

  • kan\__init__.py:用于初始化Python包,方便使用时导入模块
  • kan\compiler.py:用于编译模型
  • kan\experiment.py:实验代码
  • kan\feynman.py:费曼函数,根据传入“name”的值确定函数,暂时没找到这个在哪里用到
  • kan\hypothesis.py:将函数进行线性分离,还包含一些画图函数
  •  kan\KANLayer.py:KAN层的实现,使用B样条曲线作为激活函数 
  • kan\LBFGS.py:这个文件名似乎昨天见过,训练时的opt参数。L-BFGS是一种用于无约束优化问题的算法,它是一种拟牛顿方法,特别适用于大型稀疏问题。
  • kan\MLP.py:作者自己实现了一个MLP,应该使来与KAN做对比的
  • kan\MultKAN.py:在KANLayer的基础上实现的KAN类的定义,提供了关于构建和配置这种网络的详细信息。
  • kan\spline.py:样条函数的实现
  •  kan\Symbolic_KANLayer.py:符号化的KAN层,使用四参线性函数作为激活函数 
  • kan\utils.py:通用模块
  • kan\.ipynb_checkpoints:看目录名,这个文件夹下存放的应该是检查点文件,但是似乎和模型的实现代码区别不大,没遇到过,还不知道有什么用。
  • kan\assets:这个目录下存放了两张图片,一张加号一张乘号,应该是对函数进行线性分离后,可视化时用的
  • kan\experiments:这个目录下是experiment1.ipynb,和昨天跑的hellokan差不多,今天再跑一下

二、MultKAN.py 

import torch
import torch.nn as nn
import numpy as np
from .KANLayer import KANLayer
#from .Symbolic_MultKANLayer import *
from .Symbolic_KANLayer import Symbolic_KANLayer
from .LBFGS import *
import os
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import copy
#from .MultKANLayer import MultKANLayer
import pandas as pd
from sympy.printing import latex
from sympy import *
import sympy
import yaml
from .spline import curve2coef
from .utils import SYMBOLIC_LIB
from .hypothesis import plot_tree

导入的这些依赖中,只有 LBFGS 和 plot_tree 我们还没介绍,这两部分内容我也没打算深入研究

  • LBFGS(Limited-memory BFGS)是一种优化算法,它主要用于求解无约束优化问题。
  • plot_tree则是画出网络的树状图

2.1 类注释

class MultKAN(nn.Module):
    '''
    KAN class
    
    Attributes:
    -----------
        grid : int
            the number of grid intervals
        k : int
            spline order
        act_fun : a list of KANLayers
        symbolic_fun: a list of Symbolic_KANLayer
        depth : int
            depth of KAN
        width : list
            number of neurons in each layer.
            Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
            With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2). 
        mult_arity : int, or list of int lists
            multiplication arity for each multiplication node (the number of numbers to be multiplied)
        grid : int
            the number of grid intervals
        k : int
            the order of piecewise polynomial
        base_fun : fun
            residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
        symbolic_fun : a list of Symbolic_KANLayer
            Symbolic_KANLayers
        symbolic_enabled : bool
            If False, the symbolic front is not computed (to save time). Default: True.
        width_in : list
            The number of input neurons for each layer
        width_out : list
            The number of output neurons for each layer
        base_fun_name : str
            The base function b(x)
        grip_eps : float
            The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)
        node_bias : a list of 1D torch.float
        node_scale : a list of 1D torch.float
        subnode_bias : a list of 1D torch.float
        subnode_scale : a list of 1D torch.float
        symbolic_enabled : bool
            when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)
        affine_trainable : bool
            indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)
        sp_trainable : bool
            indicate whether the overall magnitude of splines is trainable
        sb_trainable : bool
            indicate whether the overall magnitude of base function is trainable
        save_act : bool
            indicate whether intermediate activations are saved in forward pass
        node_scores : None or list of 1D torch.float
            node attribution score
        edge_scores : None or list of 2D torch.float
            edge attribution score
        subnode_scores : None or list of 1D torch.float
            subnode attribution score
        cache_data : None or 2D torch.float
            cached input data
        acts : None or a list of 2D torch.float
            activations on nodes
        auto_save : bool
            indicate whether to automatically save a checkpoint once the model is modified
        state_id : int
            the state of the model (used to save checkpoint)
        ckpt_path : str
            the folder to store checkpoints
        round : int
            the number of times rewind() has been called
        device : str
    '''

这段代码定义了一个名为 MultKAN 的类,它是基于 nn.Module 构建的,这个类具有众多的属性,用于描述和控制其行为和特征:

  • grid:网格的间隔数(使用网格进行参数优化)
  • k:分段多项式的阶数,或者说B样条的控制点数
  • act_fun:B样条KAN层列表
  • symbolic_fun:符号KAN层列表。
  • depth:表示模型的深度。
  • width:描述了各层神经元的数量。
  • mult_arity:与乘法节点的乘法运算的元数有关。
  • base_fun:公式中的eq?b%28x%29
  •  symbolic_enabled:布尔值,是否使用符号KAN层 
  • width_in 和 width_out:分别表示各层的输入和输出神经元数量。
  • base_fun_name:基础函数的名称。
  • grip_eps:可能用于在均匀网格和自适应网格之间进行插值。
  • 各种与偏差、缩放、训练相关的属性,如 node_biasnode_scale 等,用于控制模型的训练和参数调整。
  • 各种与分数、缓存、自动保存、设备等相关的属性,用于模型的评估、数据存储、模型保存和硬件设置等方面。 

嘛,就是说这里好多注释又重复了......

​​​​2.2 构造函数 __init__

    def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):
        '''
        initalize a KAN model
        
        Args:
        -----
            width : list of int
                Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
                With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs)
            grid : int
                number of grid intervals. Default: 3.
            k : int
                order of piecewise polynomial. Default: 3.
            mult_arity : int, or list of int lists
                multiplication arity for each multiplication node (the number of numbers to be multiplied)
            noise_scale : float
                initial injected noise to spline.
            base_fun : str
                the residual function b(x). Default: 'silu'
            symbolic_enabled : bool
                compute (True) or skip (False) symbolic computations (for efficiency). By default: True. 
            affine_trainable : bool
                affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias
            grid_eps : float
                When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
            grid_range : list/np.array of shape (2,))
                setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)
            sp_trainable : bool
                If true, scale_sp is trainable. Default: True.
            sb_trainable : bool
                If true, scale_base is trainable. Default: True.
            device : str
                device
            seed : int
                random seed
            save_act : bool
                indicate whether intermediate activations are saved in forward pass
            sparse_init : bool
                sparse initialization (True) or normal dense initialization. Default: False.
            auto_save : bool
                indicate whether to automatically save a checkpoint once the model is modified
            state_id : int
                the state of the model (used to save checkpoint)
            ckpt_path : str
                the folder to store checkpoints. Default: './model'
            round : int
                the number of times rewind() has been called
            device : str
            
        Returns:
        --------
            self
            
        Example
        -------
        >>> from kan import *
        >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
        checkpoint directory created: ./model
        saving model version 0.0
        '''

这段代码是 MultKAN 类的构造函数 __init__ 的定义。构造函数用于初始化一个 MultKAN 模型实例,并为其设置各种参数。

参数说明:

  • width:一个整数列表,指定了每一层的神经元数量。如果没有乘法节点,列表中的每个元素代表相应层的神经元数量;如果有乘法节点,列表中的元素是一个包含神经元数量和乘法节点数量的元组。
  • grid:网格间隔的数量,默认为3。
  • k:分段多项式的阶数,默认为3。
  • mult_arity:每个乘法节点的乘法运算的元数,可以是单个整数或整数列表。
  • noise_scale:注入到样条函数中的初始噪声的缩放比例。
  • scale_base_mu 和 scale_base_sigma:基础函数的缩放参数的均值和标准差。
  • base_fun:残差函数 b(x) 的类型,默认为 'silu'。
  • symbolic_enabled:是否启用符号计算,默认为True。
  • affine_trainable:是否更新仿射参数,包括节点缩放、节点偏差、子节点缩放和子节点偏差。
  • grid_eps:用于在均匀网格和自适应网格之间进行插值的参数。
  • grid_range:设置网格范围的列表或NumPy数组。
  • sp_trainable:如果为真,则spline的缩放是可训练的。
  • sb_trainable:如果为真,则基础函数的缩放是可训练的。
  • device:指定设备,如 'cpu' 或 'cuda'。
  • seed:随机种子,用于初始化权重。
  • save_act:指示是否在正向传递中保存中间激活。
  • sparse_init:是否进行稀疏初始化。
  • auto_save:指示是否在修改模型后自动保存检查点。
  • state_id:模型的当前状态,用于保存检查点。
  • ckpt_path:存储检查点的文件夹路径。
  • roundrewind() 被调用次数。
  • device:设备类型。

代码说明:

        super(MultKAN, self).__init__()
  • 调用父类 MultKAN 的初始化方法,用于设置一些基本的属性或执行一些初始化操作。
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
  • 这三行代码设置了随机数种子,确保每次运行代码时生成的随机数序列相同,这对于测试和调试非常有用。据说将seed设置为3407会将模型的性能提升1%

        ### initializeing the numerical front ###

        self.act_fun = []
        self.depth = len(width) - 1
  • 这里初始化了激活函数列表 self.act_fun 和模型的深度 self.depth,深度是通过宽度列表的长度减一得到的。
        for i in range(len(width)):
            if type(width[i]) == int:
                width[i] = [width[i],0]
            
        self.width = width
  • 遍历宽度列表,如果宽度为整数,则将其转换为列表形式,形式为 [宽度, 0]。
  • 将宽度列表赋值给 self.width 属性。
  • 注意到,注释中的width属性是有两种形式的,这几行代码使其都转化为了第二种形式,即如果有乘法节点,列表中的元素是一个包含神经元数量和乘法节点数量的元组。
        # if mult_arity is just a scalar, we extend it to a list of lists
        # e.g, mult_arity = [[2,3],[4]] means that in the first hidden layer, 2 mult ops have arity 2 and 3, respectively;
        # in the second hidden layer, 1 mult op has arity 4.
        if isinstance(mult_arity, int):
            self.mult_homo = True # when homo is True, parallelization is possible
        else:
            self.mult_homo = False # when home if False, for loop is required. 
        self.mult_arity = mult_arity
  • 如果 mult_arity 是一个标量(即单个数字),代码将把它扩展为一个列表的列表。这样做通常是为了将单一的参数应用到多个乘法操作上。
  • 例如,如果 mult_arity = [[2,3],[4]],这意味着在第一个隐藏层中有两个乘法操作,它们的参数分别是 2 和 3;在第二个隐藏层中有一个乘法操作,其参数是 4。
  • 这里检查 mult_arity 是否是一个整数。如果是,那么所有乘法操作的参数都是相同的,这意味着它们是同质的。在这种情况下,可以将这些操作并行化,以提高计算效率。因此,将 self.mult_homo 设置为 True
  • 如果 mult_arity 不是一个整数,那么它可能是一个列表的列表,其中包含不同层级的不同参数。在这种情况下,不能并行化乘法操作,因为每个操作的参数可能不同。因此,将 self.mult_homo 设置为 False,这意味着可能需要使用循环来处理每个操作。
  • 最后,将处理后的 mult_arity 参数赋值给 self.mult_arity,这样模型就可以使用这个参数来定义其乘法操作了。
        width_in = self.width_in
        width_out = self.width_out

调用了两个方法,获得了KAN层真正的输入输出节点数。

        self.base_fun_name = base_fun
        if base_fun == 'silu':
            base_fun = torch.nn.SiLU()
        elif base_fun == 'identity':
            base_fun = torch.nn.Identity()
        elif base_fun == 'zero':
            base_fun = lambda x: x*0.
  • 将传入的 base_fun 参数赋值给实例变量 self.base_fun_name。这意味着 base_fun 是一个字符串,它表示想要使用的基础函数的名称。
  • 如果 base_fun_name 是字符串 'silu',那么代码将创建一个 torch.nn.SiLU() 对象。SiLU(Sigmoid-weighted Linear Unit)是一个激活函数,通常用于神经网络中。
  • 如果 base_fun_name 是字符串 'identity',那么代码将创建一个 torch.nn.Identity() 对象。Identity 函数是一个恒等函数,它直接返回其输入值,通常用作默认激活函数或不改变输入的层。
  • 如果 base_fun_name 是字符串 'zero',那么代码将创建一个匿名函数(lambda 函数),这个函数将任何输入 x 乘以 0,从而输出 0。这可能表示一个“关闭”激活状态的函数,不激活任何神经元。
        self.grid_eps = grid_eps
        self.grid_range = grid_range
  • 将网格相关的参数赋值给 self.grid_eps 和 self.grid_range
  • grid_eps:控制网格细化策略的浮点数,默认为0.02。当 grid_eps = 1 时,网格是均匀的;当 grid_eps = 0 时,它使用样本的百分位数进行分区。0 < grid_eps < 1 插值在两种极端之间。
        for l in range(self.depth):
            # splines
            sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid, k=k, noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
            self.act_fun.append(sp_batch)
  • 这段代码在循环中为每一层创建一个 KANLayer 实例,并把这些实例添加到一个列表中,以便后续可以用于KAN网络模型。
        self.node_bias = []
        self.node_scale = []
        self.subnode_bias = []
        self.subnode_scale = []
  •  初始化用于节点和子节点的偏差和缩放参数的列表。
        globals()['self.node_bias_0'] = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)
        exec('self.node_bias_0' + " = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)")
  • globals() 返回当前全局命名空间中的所有全局变量。
  • self.node_bias_0 是类的一个属性,这个属性在类的定义中尚未明确定义(即,它不是类的内部成员,而是通过全局命名空间访问的)。
  • torch.nn.Parameter(torch.zeros(3,1)) 创建了一个PyTorch的参数(Parameter对象),该对象是张量,用于在神经网络中存储权重,并且支持梯度计算。
  • .requires_grad_(False) 设置了该参数对象不进行梯度计算,即不会追踪其在计算图中的操作,这对于不需要计算梯度的参数(如偏置项)来说是合理的。
  • exec() 是Python的内置函数,用于执行字符串形式的Python代码。在这里,它被用来动态地创建或更新类属性。
  • 'self.node_bias_0' 是一个字符串,表示类中要创建或更新的属性名。
  • torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False) 是一个字符串表达式,创建了一个新的PyTorch参数并设置了其梯度计算为False。

这种做法在某些情况下非常有用,比如在定义神经网络模型时,需要动态地为特定的参数创建属性,或者在模型中为某些不需要梯度计算的参数(如偏置项)创建独立的属性。

但是!我没找到这两行代码有啥用处。

        for l in range(self.depth):
            exec(f'self.node_bias_{l} = torch.nn.Parameter(torch.zeros(width_in[l+1])).requires_grad_(affine_trainable)')
            exec(f'self.node_scale_{l} = torch.nn.Parameter(torch.ones(width_in[l+1])).requires_grad_(affine_trainable)')
            exec(f'self.subnode_bias_{l} = torch.nn.Parameter(torch.zeros(width_out[l+1])).requires_grad_(affine_trainable)')
            exec(f'self.subnode_scale_{l} = torch.nn.Parameter(torch.ones(width_out[l+1])).requires_grad_(affine_trainable)')
            exec(f'self.node_bias.append(self.node_bias_{l})')
            exec(f'self.node_scale.append(self.node_scale_{l})')
            exec(f'self.subnode_bias.append(self.subnode_bias_{l})')
            exec(f'self.subnode_scale.append(self.subnode_scale_{l})')
  • 通过循环,它为模型的每一层创建了节点偏置、节点缩放、子节点偏置和子节点缩放参数,并将这些参数存储在类的属性中,以便后续使用。
  • 通过 affine_trainable 参数来控制哪些参数是可训练的。
        self.act_fun = nn.ModuleList(self.act_fun)

        self.grid = grid
        self.k = k
        self.base_fun = base_fun

这几个基础的设置就不解释了。

        ### initializing the symbolic front ###
        self.symbolic_fun = []
        for l in range(self.depth):
            sb_batch = Symbolic_KANLayer(in_dim=width_in[l], out_dim=width_out[l+1])
            self.symbolic_fun.append(sb_batch)

刚刚创建了B样条KAN层,现在创建符号KAN层。

        self.symbolic_fun = nn.ModuleList(self.symbolic_fun)
        self.symbolic_enabled = symbolic_enabled
        self.affine_trainable = affine_trainable
        self.sp_trainable = sp_trainable
        self.sb_trainable = sb_trainable
  • 将符号层加入列表
  • 设置符号层是否可用
  • 设置符号层线性函数的四个参数是否可训练
  • 设置激活函数中的参数 eq?w_%7Bs%7D 是否可训练,分为了sp和sb两种,sp为B样条KAN层的,sb为符号KAN层的
        self.save_act = save_act
            
        self.node_scores = None
        self.edge_scores = None
        self.subnode_scores = None
        
        self.cache_data = None
        self.acts = None
        
        self.auto_save = auto_save
        self.state_id = 0
        self.ckpt_path = ckpt_path
        self.round = round

一些中间结果的存储变量和保存操作设置,保存的具体操作如下:

        if auto_save:
            if first_init:
                if not os.path.exists(ckpt_path):
                    # Create the directory
                    os.makedirs(ckpt_path)
                print(f"checkpoint directory created: {ckpt_path}")
                print('saving model version 0.0')

                history_path = self.ckpt_path+'/history.txt'
                with open(history_path, 'w') as file:
                    file.write(f'### Round {self.round} ###' + '\n')
                    file.write('init => 0.0' + '\n')
                self.saveckpt(path=self.ckpt_path+'/'+'0.0')
            else:
                self.state_id = state_id

我们在hellokan中就见识过,模型在训练过程中会保存中间数据、状态和历史信息等内容

        self.input_id = torch.arange(self.width_in[0],)
  • 给输入节点编号 ,从0开始
        self.device = device
        self.to(device)
    def to(self, device):
        '''
        move the model to device
        
        Args:
        -----
            device : str or device

        Returns:
        --------
            self
            
        Example
        -------
        >>> from kan import *
        >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
        >>> model.to(device)
        '''
        super(MultKAN, self).to(device)
        self.device = device
        
        for kanlayer in self.act_fun:
            kanlayer.to(device)
            
        for symbolic_kanlayer in self.symbolic_fun:
            symbolic_kanlayer.to(device)
            
        return self
  •  选择计算设备

测试:

from kan import *
torch.set_default_dtype(torch.float64)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model = KAN(width=[2,[5,3],[5,1],3], mult_arity=[0,[2,3,4],[2],0],grid=3, k=3, seed=42, device=device)
model.input_id

cuda
checkpoint directory created: ./model
saving model version 0.0

tensor([0, 1])

2.3 节点数计算

2.3.1 width_in 

    @property
    def width_in(self):
        '''
        The number of input nodes for each layer
        '''
        width = self.width
        width_in = [width[l][0]+width[l][1] for l in range(len(width))]
        return width_in

这段代码定义了一个属性 width_in ,它的作用是计算并返回模型每一层的输入节点数量。

  •  首先,获取了模型的宽度信息 width 。
  • 然后,通过列表推导式计算每一层输入节点的数量,计算方式是将每一层的总和维度 width[l][0] 和乘法操作维度 width[l][1] 相加。
  • 最后,返回计算得到的输入节点数量列表。

所以每层节点数=设置的节点数+乘法操作次数

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.width_in

[[2, 0], [5, 3], [5, 1], [3, 0]]

 [2, 8, 6, 3]

2.3.2 width_out

    @property
    def width_out(self):
        '''
        The number of output subnodes for each layer
        '''
        width = self.width
        if self.mult_homo == True:
            width_out = [width[l][0]+self.mult_arity*width[l][1] for l in range(len(width))]
        else:
            width_out = [width[l][0]+int(np.sum(self.mult_arity[l])) for l in range(len(width))]
        return width_out

这段代码定义了一个属性 width_out,其目的是计算并返回模型每一层的输出子节点数量。

  • 首先,获取了模型的宽度信息 width。然后根据 self.mult_homo 的值来决定计算输出节点数量的方式。
  • 如果 self.mult_homo 为 True,则使用列表推导式计算每一层的输出节点数量。计算方式是将每一层的总和维度 width[l][0] 与乘法操作维度 width[l][1] 的结果乘以 mult_arity 的值相加。mult_arity 是一个数组,表示每一层的乘法操作的幅度。
  • 如果 self.mult_homo 为 False,则使用列表推导式计算每一层的输出节点数量。计算方式是将每一层的总和维度 width[l][0] 与 mult_arity[l] 的元素之和相加。mult_arity[l] 是一个数组,表示每一层的乘法操作的幅度。
  • 最后,返回计算得到的输出节点数量列表。

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.width_out

[[2, 0], [5, 3], [5, 1], [3, 0]]

[2, 14, 7, 3]

2.3.3 n_sum

    @property
    def n_sum(self):
        '''
        The number of addition nodes for each layer
        '''
        width = self.width
        n_sum = [width[l][0] for l in range(1,len(width)-1)]
        return n_sum

这段代码定义了一个属性 n_sum ,用于计算并返回除了第一层和最后一层之外,每一层的总和维度 width[l][0] 所组成的列表。

首先,获取了模型的宽度信息 width 。然后通过列表推导式,从第二层到倒数第二层,提取出每一层的 width[l][0] ,并将这些值组成一个新的列表 n_sum ,最后返回这个列表。

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.n_sum

[[2, 0], [5, 3], [5, 1], [3, 0]]
[5, 5]

2.3.4 n_mult

    @property
    def n_mult(self):
        '''
        The number of multiplication nodes for each layer
        '''
        width = self.width
        n_mult = [width[l][1] for l in range(1,len(width)-1)]
        return n_mult

这段代码定义了一个属性 n_mult ,用于计算并返回除了第一层和最后一层之外,每一层的乘法节点数量。这里 width 是一个包含多层宽度信息的数据结构,每一层的信息以列表的形式存储,其中 width[l][1] 表示第 l 层的乘法节点数量。

通过列表推导式,代码遍历从第二层到倒数第二层的所有层,提取每一层的乘法节点数量,并将这些数量组成一个新的列表 n_mult 。最后,这个列表被返回给调用者。

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.n_mult

[[2, 0], [5, 3], [5, 1], [3, 0]]

[3, 1]

2.3.5 feature_score

    @property
    def feature_score(self):
        '''
        attribution scores for inputs
        '''
        self.attribute()
        if self.node_scores == None:
            return None
        else:
            return self.node_scores[0]

这段代码定义了一个名为 feature_score 的属性。其功能是计算输入的归因分数。

首先调用了 self.attribute() 方法。然后判断 self.node_scores 是否为 None ,如果是,则直接返回 None ;如果不是,则返回 self.node_scores 中的第一个元素。

这意味着只有在 self.node_scores 不为空的情况下,才会返回其第一个元素作为特征分数。

2.4 前向传播 forward

这个前向传播有点诡异,总感觉跟论文中的对不上,这次我们一边解释一边测试!

先来个简单的:width=[2,5,5,3],mult_arity = 2

from kan import *
torch.set_default_dtype(torch.float64)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model = KAN(width=[2,5,5,3], mult_arity=2,grid=3, k=3, seed=42, device=device)
model.input_id

cuda
checkpoint directory created: ./model
saving model version 0.0
tensor([0, 1])

 测试数据:

x = torch.tensor([[1,2],
                  [3,4],
                  [5,6],
                  [7,8],
                  [9,10]]).float()
x = x.to(device)

 0. 方法定义及注释

    def forward(self, x, singularity_avoiding=False, y_th=10.):
        '''
        forward pass
        
        Args:
        -----
            x : 2D torch.tensor
                inputs
            singularity_avoiding : bool
                whether to avoid singularity for the symbolic branch
            y_th : float
                the threshold for singularity

        Returns:
        --------
            None
            
        Example1
        --------
        >>> from kan import *
        >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
        >>> x = torch.rand(100,2)
        >>> model(x).shape
        
        Example2
        --------
        >>> from kan import *
        >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
        >>> x = torch.tensor([[1],[-0.01]])
        >>> model.fix_symbolic(0,0,0,'log',fit_params_bool=False)
        >>> print(model(x))
        >>> print(model(x, singularity_avoiding=True))
        >>> print(model(x, singularity_avoiding=True, y_th=1.))
        '''

参数说明:

  • x: 2D torch.tensor,输入数据。
  • singularity_avoiding: bool,默认为 False。如果为 True,则在符号分支中避免奇异点。
  • y_th: float,默认为 10.。用于判断是否避免奇异点的阈值。

返回值:

  • None:方法执行后不返回任何值

1. 初始化阶段

        x = x[:,self.input_id.long()]
        assert x.shape[1] == self.width_in[0]

        # cache data
        self.cache_data = x
        
        self.acts = []  # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
        self.acts_premult = []
        self.spline_preacts = []
        self.spline_postsplines = []
        self.spline_postacts = []
        self.acts_scale = []
        self.acts_scale_spline = []
        self.subnode_actscale = []
        self.edge_actscale = []
        # self.neurons_scale = []

        self.acts.append(x)  # acts shape: (batch, width[l])
  • 数据选择与验证

    • 选择输入数据 x 的特定列,并验证其形状是否符合模型的输入宽度要求。
    • 缓存输入数据 x
  • 初始化变量

    • 初始化用于存储不同层激活、尺度因子等的列表。

2. 前向传播循环

        for l in range(self.depth):
  • 循环遍历模型中的每一层,其中 self.depth 是模型的层数。
            x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
            #print(preacts, postacts_numerical, postspline)
  • 使用第 l 层的激活函数 act_fun[l] 对输入 x 进行处理。
  • 这里的激活函数是B样条KAN层的激活函数,详情见KANLayer
  • 处理结果包括数值分支的输出 x_numerical、预激活输出 preacts、后激活输出 postacts_numerical 和样条函数的输出 postspline。(对应的是y, preacts, postacts, postspline)
            if self.symbolic_enabled == True:
                x_symbolic, postacts_symbolic = self.symbolic_fun[l](x, singularity_avoiding=singularity_avoiding, y_th=y_th)
            else:
                x_symbolic = 0.
                postacts_symbolic = 0.
  • 可使用符号KAN层时,同样进行计算
            x = x_numerical + x_symbolic

这里要注意了,作者将两种层的计算结果相加了!也就是把B样条和线性函数同时叠加使用!

            # subnode affine transform
            x = self.subnode_scale[l][None,:] * x + self.subnode_bias[l][None,:]
  •  对激活函数的计算结果进行缩放,并增加偏置常数b

对以上这一部分内容做测试:

x = x[:,model.input_id.long()]
assert x.shape[1] == model.width_in[0]

for l in range(model.depth):
            
    x_numerical, preacts, postacts_numerical, postspline = model.act_fun[l](x)
    #print(preacts, postacts_numerical, postspline)
            
    if model.symbolic_enabled == True:
        x_symbolic, postacts_symbolic = model.symbolic_fun[l](x, singularity_avoiding=False, y_th=10)
    else:
        x_symbolic = 0.
        postacts_symbolic = 0.

    x = x_numerical + x_symbolic
    x = model.subnode_scale[l][None,:] * x + model.subnode_bias[l][None,:]

    print(x)
    print(x.shape)

 tensor([[ 1.2935, -0.7047, -1.1071,  0.1673,  0.7162],
        [ 3.6752, -2.1692, -2.5181,  0.0475,  2.4007],
        [ 5.9759, -3.6323, -3.8994, -0.1110,  4.0748],
        [ 8.2045, -5.0443, -5.2469, -0.2554,  5.6880],
        [10.4148, -6.4431, -6.5866, -0.3956,  7.2852]], device='cuda:0',
       grad_fn=<AddBackward0>)
torch.Size([5, 5])
tensor([[-0.1832,  0.2447,  0.2546,  0.0981, -0.0997],
        [-0.8389,  1.2821,  1.2113,  0.4335, -0.0337],
        [-1.3981,  2.3204,  2.1626,  0.7829, -0.0287],
        [-1.9293,  3.2605,  3.0434,  1.1072, -0.0466],
        [-2.4603,  4.1674,  3.9076,  1.4242, -0.0795]], device='cuda:0',
       grad_fn=<AddBackward0>)
torch.Size([5, 5])
tensor([[ 0.0064,  0.0481, -0.1441],
        [ 0.4862,  0.2570, -0.7088],
        [ 1.0812,  0.6035, -1.4443],
        [ 1.6532,  0.8613, -2.0778],
        [ 2.1898,  1.0638, -2.6628]], device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([5, 3])

 所有中间结果的形状都没有问题。

            if self.save_act:
                # save subnode_scale
                self.subnode_actscale.append(torch.std(x, dim=0).detach())

            if self.save_act:
                postacts = postacts_numerical + postacts_symbolic

                # self.neurons_scale.append(torch.mean(torch.abs(x), dim=0))
                #grid_reshape = self.act_fun[l].grid.reshape(self.width_out[l + 1], self.width_in[l], -1)
                input_range = torch.std(preacts, dim=0) + 0.1
                output_range_spline = torch.std(postacts_numerical, dim=0) # for training, only penalize the spline part
                output_range = torch.std(postacts, dim=0) # for visualization, include the contribution from both spline + symbolic
                # save edge_scale
                self.edge_actscale.append(output_range)
                
                self.acts_scale.append((output_range / input_range).detach())
                self.acts_scale_spline.append(output_range_spline / input_range)
                self.spline_preacts.append(preacts.detach())
                self.spline_postacts.append(postacts.detach())
                self.spline_postsplines.append(postspline.detach())

                self.acts_premult.append(x.detach())

 如果启用了保存激活函数尺度因子的选项,则计算并保存以下内容:

  • 子节点尺度因子(标准差)。
  • 边尺度因子(输出范围)。
  • 激活函数输出的尺度因子(输出范围与输入范围的比例)。
  • 样条部分的尺度因子。
  • 预激活输出、后激活输出和样条函数输出的副本。

但是我很好奇,这不是self.spline_postacts嘛,但是存的是postacts = postacts_numerical + postacts_symbolic,保存样条的激活输出为什么不只保存postacts_numerical。

还有就是都是判断 save_act,为啥用两个if。有时候就挺不能理解的

接下来介绍的这个东西,非常重要!它在基础节点的基础上引入了乘法操作。并且分为同质和非同质两种。

            # multiplication
            dim_sum = self.width[l+1][0]
            dim_mult = self.width[l+1][1]
  • 获取下一次节点数以及乘法操作次数
  • self.width[l+1][0]是下一层的节点数
  • self.width[l+1][1]是乘法操作次数

对于上面的例子,有

x.shape: torch.Size([5, 5])
dim_sum: 5
dim_mult: 0
x.shape: torch.Size([5, 5])
dim_sum: 5
dim_mult: 0
x.shape: torch.Size([5, 3])
dim_sum: 3
dim_mult: 0

            if self.mult_homo == True:
                for i in range(self.mult_arity-1):
                    if i == 0:
                        x_mult = x[:,dim_sum::self.mult_arity] * x[:,dim_sum+1::self.mult_arity]
                    else:
                        x_mult = x_mult * x[:,dim_sum+i+1::self.mult_arity]
  • 当本层乘法参数都相同,则进行矩阵运算,即处理同质(homogeneous)乘法操作:
    • 在第一次循环(i == 0)中:
      • x_mult = x[:,dim_sum::self.mult_arity] * x[:,dim_sum+1::self.mult_arity]:对 x 的特定部分进行逐元素乘法。这里 x[:,dim_sum::self.mult_arity] 表示从 dim_sum 开始,每隔 self.mult_arity 个元素取一个元素,形成一个新的张量。同理 x[:,dim_sum+1::self.mult_arity] 表示从 dim_sum+1 开始取元素。这两个张量逐元素相乘得到 x_mult
    • 在后续的循环中(i != 0):
      • x_mult = x_mult * x[:,dim_sum+i+1::self.mult_arity]:将上一次乘法的结果与 x 的另一部分相乘。

对于我们width=[2,5,5,3],mult_arity = 2这个例子,有model.mult_homo == True,但结果如下:

tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])
tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])
tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])

 由于dim_mult = 0,所以不进行乘法运算,代码中表现为dim_sum超出index,所以dim_sum::model.mult_arity都为0,自然乘积也为0。

测试升级:

设置width=[2,[5,2],[5,3],3], mult_arity=3,这是一个同质运算,由第一层向第二层传递时,会做乘法运算,次数为mult_arity-1=2,而乘法运算结果维度为dim_mult,然后与原始的dim_sum维度拼接,参数设置 width=[2,[5,1],[5,3],3], mult_arity=3,拼接操作:

            if self.width[l+1][1] > 0:
                x = torch.cat([x[:,:dim_sum], x_mult], dim=1)
  •  将x中未参与乘法计算的部分与乘法计算结果进行拼接,恢复原始张量形状 

x.shape: torch.Size([5, 8])
x_mult.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x.shape: torch.Size([5, 6])

 

x.shape: torch.Size([5, 14])
x_mult.shape: torch.Size([5, 3])
x_mult.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 8])

 

x.shape: torch.Size([5, 3])
x_mult.shape: torch.Size([5, 0])
x_mult.shape: torch.Size([5, 0])
x.shape: torch.Size([5, 3])

 测试再次升级:

我用数据展示第二层的计算:

x = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14]])
dim_sum = 5
dim_mult = 3
mult_arity = 3
if model.mult_homo == True:
    for i in range(2):
        print(f"第{i+1}次乘法:")
        if i == 0:
            print(x[:,dim_sum::mult_arity])
            print(x[:,dim_sum+1::mult_arity])
            x_mult = x[:,dim_sum::mult_arity] * x[:,dim_sum+1::mult_arity]
            
        else:
            print(x_mult)
            print(x[:,dim_sum+i+1::mult_arity])
            x_mult = x_mult * x[:,dim_sum+i+1::mult_arity]

        print(x_mult)

if dim_mult > 0:
    x = torch.cat([x[:,:dim_sum], x_mult], dim=1)

print(x)
print("x.shape:",x.shape)
print()

 第1次乘法:
tensor([[ 6,  9, 12]])
tensor([[ 7, 10, 13]])
tensor([[ 42,  90, 156]])


第2次乘法:
tensor([[ 42,  90, 156]])
tensor([[ 8, 11, 14]])
tensor([[ 336,  990, 2184]])


tensor([[   1,    2,    3,    4,    5,  336,  990, 2184]])
x.shape: torch.Size([1, 8])

 这下就完全理解它的乘法是如何运算的了。同质运算使用了矩阵运算以加快运算速度,这建立在mult_arity为常数的情况下,而当mult_arity的元素为列表时,只能进行遍历运算,如下:

            else:
                for j in range(dim_mult):
                    acml_id = dim_sum + np.sum(self.mult_arity[l+1][:j])
                    for i in range(self.mult_arity[l+1][j]-1):
                        if i == 0:
                            x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]
                        else:
                            x_mult_j = x_mult_j * x[:,[acml_id+i+1]]
                            
                    if j == 0:
                        x_mult = x_mult_j
                    else:
                        x_mult = torch.cat([x_mult, x_mult_j], dim=1)
  • 当本层乘法参数不相同,则进行遍历运算,即处理非同质(non-homogeneous)乘法操作:
    • for j in range(dim_mult):循环遍历 dim_mult 次,dim_mult 表示乘法操作的次数。
    • 在每次循环中,计算 acml_id
      • acml_id = dim_sum + np.sum(self.mult_arity[l+1][:j]):计算当前乘法操作的起始索引。
    • 然后对每个乘法操作:
      • for i in range(self.mult_arity[l+1][j]-1)::循环遍历当前维度的乘法操作次数。
      • 在第一次循环(i == 0)中:
        • x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]:对 x 的特定部分进行逐元素乘法。
      • 在后续的循环中(i != 0):
        • x_mult_j = x_mult_j * x[:,[acml_id+i+1]]:将上一次乘法的结果与 x 的另一部分相乘。
    • 如果是第一个乘法操作(j == 0):
      • x_mult = x_mult_j:将第一个乘法操作的结果赋值给 x_mult
    • 如果不是第一个乘法操作:
      • x_mult = torch.cat([x_mult, x_mult_j], dim=1):将当前乘法操作的结果与之前的结果在最后一个维度上连接。

测试:

参数设置 width=[2,[5,1],[5,3],3], mult_arity=[[0],[2],[2,3,4],[0]]

x = torch.tensor([[1,2],
                  [3,4],
                  [5,6],
                  [7,8],
                  [9,10]]).float()
x = x.to(device)

x = x[:,model.input_id.long()]
assert x.shape[1] == model.width_in[0]

for l in range(model.depth):
            
    x_numerical, preacts, postacts_numerical, postspline = model.act_fun[l](x)
    #print(preacts, postacts_numerical, postspline)
            
    if model.symbolic_enabled == True:
        x_symbolic, postacts_symbolic = model.symbolic_fun[l](x, singularity_avoiding=False, y_th=10)
    else:
        x_symbolic = 0.
        postacts_symbolic = 0.

    x = x_numerical + x_symbolic
    x = model.subnode_scale[l][None,:] * x + model.subnode_bias[l][None,:]

    #print(x)
    print("x.shape:",x.shape)

    # multiplication
    dim_sum = model.width[l+1][0]
    dim_mult = model.width[l+1][1]
    #print("dim_sum:",dim_sum)
    #print("dim_mult:",dim_mult)
    if model.mult_homo == False:
        for j in range(dim_mult):
            acml_id = dim_sum + np.sum(model.mult_arity[l+1][:j])
            for i in range(model.mult_arity[l+1][j]-1):
                if i == 0:
                    x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]
                else:
                    x_mult_j = x_mult_j * x[:,[acml_id+i+1]]
                print("x_mult_j.shape:",x_mult_j.shape )
                            
            if j == 0:
                x_mult = x_mult_j
            else:
                x_mult = torch.cat([x_mult, x_mult_j], dim=1)
            print("x_mult.shape:",x_mult.shape)

    if model.width[l+1][1] > 0:
        x = torch.cat([x[:,:dim_sum], x_mult], dim=1)

x.shape: torch.Size([5, 7])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x.shape: torch.Size([5, 6])

 

x.shape: torch.Size([5, 14])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 2])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 8])

 

x.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 3])

 来逐一分析:

  • 第一层到第二层计算,经过B样条KAN层和符号KAN层,x形状为[batch_size,dim_sum+sum(mult_arity[l+1])],其中sum(mult_arity[l+1])=dim_mult*mult_arity[l+1],因为mult_arity[l+1]只有一个元素。然后进行了一次乘法运算,并将结果拼接在x[:,:dim_sum]后面
  • 第二层到第三次计算:经过B样条KAN层和符号KAN层,x形状为[batch_size,dim_sum+sum(mult_arity[l+1])],其中sum(mult_arity[l+1])=np.sum(model.mult_arity[l+1][:j]),对于mult_arity[l+1]列表中的每一个元素,都执行其数值减一的乘法运算,运算结果x_mult_j的形状为[batch_size,1],最终获得的x_mult都是由x_mult_j拼接来的,最后将x_mult拼接在x[:,:dim_sum]后面。
  • 第三层到第四层同理。

使用数据展示第二层的计算:

x = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14]])
dim_sum = 5
dim_mult = 3
mult_arity = [2,3,4]

print(x)

if model.mult_homo == False:
    for j in range(dim_mult):
        print(f"第{j+1}次运算:")
        acml_id = dim_sum + np.sum(mult_arity[:j])
        for i in range(mult_arity[j]-1):
            if i == 0:
                print(x[:,[acml_id]])
                print(x[:,[acml_id+1]])
                x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]
            else:
                print(x_mult_j)
                print(x[:,[acml_id+i+1]])
                x_mult_j = x_mult_j * x[:,[acml_id+i+1]]
            print("x_mult_j:",x_mult_j)
        if j == 0:
                x_mult = x_mult_j
        else:
            x_mult = torch.cat([x_mult, x_mult_j], dim=1)
        print("x_mult:",x_mult)

if dim_mult > 0:
    x = torch.cat([x[:,:dim_sum], x_mult], dim=1)

print(x)
print("x.shape:",x.shape)
print()

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]])
第1次运算:
tensor([[6]])
tensor([[7]])
x_mult_j: tensor([[42]])
x_mult: tensor([[42]])


第2次运算:
tensor([[8]])
tensor([[9]])
x_mult_j: tensor([[72]])
tensor([[72]])
tensor([[10]])
x_mult_j: tensor([[720]])
x_mult: tensor([[ 42, 720]])


第3次运算:
tensor([[11]])
tensor([[12]])
x_mult_j: tensor([[132]])
tensor([[132]])
tensor([[13]])
x_mult_j: tensor([[1716]])
tensor([[1716]])
tensor([[14]])
x_mult_j: tensor([[24024]])
x_mult: tensor([[   42,   720, 24024]])


tensor([[    1,     2,     3,     4,     5,    42,   720, 24024]])
x.shape: torch.Size([1, 8])

            # x = x + self.biases[l].weight
            # node affine transform
            x = self.node_scale[l][None,:] * x + self.node_bias[l][None,:]
            
            self.acts.append(x.detach())
        
        return x
  • 对拼接后的x进行缩放并且加上偏置常数 
  • 返回计算结果

我们理一下整个计算思路:

  1. 传入x后,首先检查x的形状,遍历KAN层进行计算:
    1. 再分别使用B样条KAN层和符号KAN层计算出x = x_numerical + x_symbolic
    2. 对x进行缩放处理,并加入偏置常数
    3. 乘法运算
      1. 同质乘法运算:对于dim_sum之外的维度,使用矩阵运算计算出x_mult
      2. 非同质乘法运算:根据mult_arity[l+1]列表一次计算出x_mult_j,拼接成x_mult
    4. 如进行了乘法运算,则将x_mult与x[:,:dim_sum]拼接
    5. 对x进行缩放处理,并加入偏置常数
    6. 返回x

2.5 训练方法 fit

通过对前向传播进行剖析,KAN网络并不像论文中展示的那么简单

  • KAN层包含了B样条层和符号层两种,我们可以设置是否使用符号层,如使用的话,中间x计算结果为两者之和。
  • KAN层节点过渡时引入了乘法操作,包括同质乘法和非同质乘法,在定义的基础维度上进行了扩展,进一步加强了网络的学习能力。

现在我们对MultKAN的fit方法的使用进行详解。

    def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1.,start_grid_update_step=-1, stop_grid_update_step=50, batch=-1,
              metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):
        '''
        training

        Args:
        -----
            dataset : dic
                contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
            opt : str
                "LBFGS" or "Adam"
            steps : int
                training steps
            log : int
                logging frequency
            lamb : float
                overall penalty strength
            lamb_l1 : float
                l1 penalty strength
            lamb_entropy : float
                entropy penalty strength
            lamb_coef : float
                coefficient magnitude penalty strength
            lamb_coefdiff : float
                difference of nearby coefficits (smoothness) penalty strength
            update_grid : bool
                If True, update grid regularly before stop_grid_update_step
            grid_update_num : int
                the number of grid updates before stop_grid_update_step
            start_grid_update_step : int
                no grid updates before this training step
            stop_grid_update_step : int
                no grid updates after this training step
            loss_fn : function
                loss function
            lr : float
                learning rate
            batch : int
                batch size, if -1 then full.
            save_fig_freq : int
                save figure every (save_fig_freq) steps
            singularity_avoiding : bool
                indicate whether to avoid singularity for the symbolic part
            y_th : float
                singularity threshold (anything above the threshold is considered singular and is softened in some ways)
            reg_metric : str
                regularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'}
            metrics : a list of metrics (as functions)
                the metrics to be computed in training
            display_metrics : a list of functions
                the metric to be displayed in tqdm progress bar

        Returns:
        --------
            results : dic
                results['train_loss'], 1D array of training losses (RMSE)
                results['test_loss'], 1D array of test losses (RMSE)
                results['reg'], 1D array of regularization
                other metrics specified in metrics

        Example
        -------
        >>> from kan import *
        >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
        >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
        >>> dataset = create_dataset(f, n_var=2)
        >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
        >>> model.plot()
        # Most examples in toturals involve the fit() method. Please check them for useness.
        '''

参数说明:

  1. dataset (dic): 包含训练集和测试集的数据字典,通常包括输入数据(train_inputtest_input)和标签数据(train_labeltest_label)。

  2. opt (str): 选择的优化器,可以是 "LBFGS"(L-BFGS)或 "Adam"。

  3. steps (int): 训练的总步骤数。

  4. log (int): 日志输出的频率,即每多少步骤输出一次日志。

  5. lamb (float): 总体正则化强度,用于控制模型复杂度。

  6. lamb_l1 (float): L1 正则化强度,用于惩罚模型参数的绝对值。

  7. lamb_entropy (float): 用于惩罚模型熵的强度,有助于防止过拟合。

  8. lamb_coef (float): 模型系数的大小惩罚强度。

  9. lamb_coefdiff (float): 邻近系数之间的差异惩罚强度,用于增加模型的平滑性。

  10. update_grid (bool): 如果为 True,则在训练步骤达到 stop_grid_update_step 之前定期更新网格。

  11. grid_update_num (int): 在 stop_grid_update_step 之前更新网格的次数。

  12. start_grid_update_step (int): 在这个步骤之前不进行网格更新。

  13. stop_grid_update_step (int): 这个步骤之后不进行网格更新。

  14. loss_fn (function): 自定义损失函数,用于计算模型的损失。

  15. lr (float): 学习率,决定每次更新参数时的步长。

  16. batch (int): 批处理大小,如果为 -1,则使用完整数据集。

  17. save_fig_freq (int): 每多少步骤保存一次训练结果的图形。

  18. singularity_avoiding (bool): 如果为 True,则在符号部分避免奇异点。

  19. y_th (float): 奇异点阈值,高于此值的任何值都将被视为奇异点。

  20. reg_metric (str): 用于计算正则化的度量标准,可以选择不同的选项如 edge_forward_spline_n 等。

  21. metrics (list of functions): 计算并返回的自定义度量列表。

  22. display_metrics (list of functions): 在训练进度条中显示的度量列表。

返回值:

  • results (dic): 包含训练过程中的关键信息的字典,包括:
    • train_loss: 训练集上的损失(通常为 RMSE)。
    • test_loss: 测试集上的损失(通常为 RMSE)。
    • reg: 正则化项的值。
    • 其他用户指定的度量。

测试1:

from kan import *
model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
model.plot()

checkpoint directory created: ./model
saving model version 0.0
| train_loss: 1.91e-02 | test_loss: 1.97e-02 | reg: 1.38e+01 | : 100%|█| 20/20 [00:07<00:00,  2.66it
saving model version 0.1

7bdaad0cace146f7b795d0de96bbae94.png

 测试2:

from kan import *
model = KAN(width=[2,[5,3],3], mult_arity=3, grid=5, k=3, noise_scale=0.3, seed=2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
model.plot()

checkpoint directory created: ./model
saving model version 0.0
| train_loss: 2.44e-02 | test_loss: 2.81e-02 | reg: 2.67e+01 | : 100%|█| 20/20 [00:09<00:00,  2.15it
saving model version 0.1

b03c0e119ad04e4eb44dc123370a436c.png

这个图包含了3个乘法节点。

三、总结

今天内容主要包括MultKAN网络的初始化、正向传播方法实现、训练方法参数说明。MultKAN网络正向传播有两个特点:

  1. 传播时可以同时使用KANLayer和Symbolic_KANLayer,以叠加的形式计算中间结果
  2. KAN节点的连接既有加法连接也有乘法连接,我们可以自定义乘法运算的方式(同质或非同质)

在上文中,我用数据直观展示了mult的计算过程,实际上只是连续的列相乘,因此在我看来,MultKAN的mult节点运算还有一定的优化空间,除了改善单一控制变量self.mult_homo,将其扩展为列表,还可以用numpy库实现连续列相乘的算法,这些尝试我打算放在实际应用中进行。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2118576.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ubuntu上安装libdc1394-22-dev出现无法定位安装包的解决办法

一、libdc1394-22-dev介绍 libdc1394-22-dev 是一个开发库&#xff0c;用于与IEEE 1394 (FireWire)摄像头进行交互。具体来说&#xff0c;它是 libdc1394 的开发版本&#xff0c;提供了开发者头文件和链接库&#xff0c;方便在应用程序中集成对基于 IEEE 1394 标准的数码相机的…

【开源大模型生态5】解放大脑

AI能力的进化&#xff0c;如上图&#xff0c;分为4个阶段。 第一阶段&#xff1a;逻辑推理为主 在人工智能的早期发展阶段&#xff0c;研究者们将重心放在了构建能够进行逻辑推理的系统上。这些系统通常基于规则&#xff0c;通过定义一系列“如果...那么...”的规则来模拟人类…

此mac无法连接Applie媒体服务,因为“”出现问题。

出现问题&#xff1a; 这是因为mac登陆过别人的appId下载过软件&#xff0c;但是没有完全退出登陆 解决 打开偏好设置&#xff0c;点击头像&#xff0c;点击媒体与已购项目&#xff0c;能看到弹框内AppleID登陆的应用&#xff0c;打开对应的那个应用&#xff0c;我这里是音…

Linux文本内容管理命令

head与tail head----显示文件前10行 tail----显示文件后10行 查看前n行命令&#xff1a; head -n 文件路径 查看后n行命令&#xff1a; tail -n 文件路径 管道符&#xff1a; | ----将前一输出的结果作为后一命令的输入 查看第三行内容&#xff1a;head -3 文件路…

损坏SD数据恢复的8种有效方法

SD卡被用于许多不同的产品来存储重要数据&#xff0c;如图片和重要的商业文件。如果您的SD卡坏了&#xff0c;您需要SD数据恢复来获取您的信息。通过从损坏的SD卡中取回数据&#xff0c;您可以确保重要文件不会永远丢失&#xff0c;这对于工作或个人原因是非常重要的。 有许多…

如何在产品创新中实践TRIZ方法?

TRIZ&#xff08;发明问题解决理论&#xff09;作为一种强大的创新方法论&#xff0c;自其诞生以来&#xff0c;便以其系统性、科学性和实用性&#xff0c;在全球范围内被广泛应用于产品创新、技术升级及难题解决等领域。本文&#xff0c;深圳天行健企业管理咨询公司旨在分享如…

合作文章|基于FFPE样本研究腹水微生物群与HCC继发腹水、PVTT之间的相互作用

文章题目&#xff1a;Ascitic microbiota alteration is associated with portal vein tumor thrombosis occurrence and prognosis in hepatocellular carcinoma 发表期刊&#xff1a;mBio 影响因子&#xff1a;6.4 研究背景 肝细胞癌(HCC)是最常见的恶性肿瘤之一&#xf…

D45XT160-ASEMI新能源专用D45XT160

编辑&#xff1a;ll D45XT160-ASEMI新能源专用D45XT160 型号&#xff1a;D45XT160 品牌&#xff1a;ASEMI 封装&#xff1a;DXT-5 安装方式&#xff1a;直插 批号&#xff1a;2024 现货&#xff1a;50000 正向电流&#xff08;Id&#xff09;&#xff1a;45A 反向耐压…

环球团队迅速崛起,把握最新市场趋势引领未来

近日&#xff0c;一个名为“环球团队”的股票投资团队在业内迅速崭露头角&#xff0c;备受瞩目。该团队由多位在金融证券领域经验丰富、见解独到的专家组成&#xff0c;很快赢得了投资者的信任和支持。他们凭借精准的市场分析和高效的投资策略&#xff0c;多次成功抓住市场机遇…

ceph-radosgw 手动安装教程以及安装问题解决办法

一、环境 操作系统版本&#xff1a;Ubuntu20.04 x86_64 ceph版本&#xff1a;ceph version 15.2.17 (8a82819d84cf884bd39c17e3236e0632ac146dc4) octopus (stable) radosgw版本&#xff1a;15.2.17 二、ceph-radosgw 安装步骤 ceph官方英文版教程&#xff0c;写了个大概步骤…

echarts 饼图中间文字颜色小写设置

想要实现的效果如下&#xff1a; 只要在formatter里这样写就可以啦&#xff0c;rich里面写你需要的样式即可 var option {color: [#3d6dfe, #27b3ff, #2fffc1, #ff892f, #fcff2f],tooltip: {trigger: item},legend: {type: scroll,itemWidth: 12,itemHeight: 10,itemGap: 25,…

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动&#xff0c;这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高…

交换机的这些接口,网工真得清楚

号主&#xff1a;老杨丨11年资深网络工程师&#xff0c;更多网工提升干货&#xff0c;请关注公众号&#xff1a;网络工程师俱乐部 下午好&#xff0c;我的网工朋友。 交换机作为网络的核心设备之一&#xff0c;在实现高效的数据传输和网络管理方面扮演着非常重要的角色。 然而&…

echarts 多个3D柱状图

图片样式&#xff1a; 代码实现&#xff1a; <template><div :class"className" :style"{height:height,width:width}" /> </template><script> require("echarts/theme/sakura"); // echarts themeexport default {pro…

【专题】2024飞行汽车技术全景报告合集PDF分享(附原数据表)

原文链接&#xff1a; https://tecdat.cn/?p37628 6月16日&#xff0c;小鹏汇天旅航者X2在北京大兴国际机场临空经济区完成首飞&#xff0c;这也是小鹏汇天的产品在京津冀地区进行的首次飞行。小鹏汇天方面还表示&#xff0c;公司准备量产&#xff0c;并计划今年四季度开启预…

Leetcode122. 买卖股票 状态机dp C++实现

Leetcode 122. 买卖股票的最佳时机 问题&#xff1a;给你一个整数数组 prices &#xff0c;其中 prices [ i ] 表示某支股票第 i 天的价格。 在每一天&#xff0c;你可以决定是否购买和/或出售股票。你在任何时候 最多 只能持有 一股 股票。你也可以先购买&#xff0c;然后在…

AOC商用显示器赋能绿色教育,助推教育信息化发展

摘要&#xff1a;助推教育发展&#xff0c;打造健康教学&#xff01; 作为提高国家创新能力及综合素养的基础&#xff0c;教育水平的高低往往决定着人才培养的数量和质量&#xff0c;决定着国家科技发展水平的高低&#xff0c;甚至于决定着民族国家的成败兴衰。从长远规划来看…

跨平台数据库管理软件SQLynx

什么是 SQLynx &#xff1f; SQLynx 是一个原生基于 Web 的 SQL 编辑器&#xff0c;支持企业的桌面和 Web 数据库管理。它最初被称为 SQL Studio&#xff0c;后来改名为 SQLynx。SQLynx 支持所有流行的数据库&#xff0c;如 MySQL、MariaDB、PostgreSQL、SQLite、Hive、Impala、…

Java File类与字节输入输出流详解

File类&#xff1a; 1.首先创建一下file的对象&#xff1a; 里面可以写相对路径或者绝对路径 File file new File("CCC.java"); 也可以使用其他构造方法 //String path "D:\\ch06"; //String fileName "1.txt"; File file new File(path…

超声波自动气象站

超声波自动气象站的功能优势可以包括以下几个方面&#xff1a; 高精度测量&#xff1a;超声波自动气象站采用超声波技术进行测量&#xff0c;可以实现高精度的测量结果&#xff0c;能够准确地测量气温、湿度、风速、风向等气象参数。 高可靠性&#xff1a;超声波自动气象站采用…