初始化KAN和创建数据集
from kan import *
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[2,5,1], grid=3, k=3, seed=1)
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape
(torch.Size([1000, 2]), torch.Size([1000, 1]))
画图初始化后的KAN
# plot KAN at initialization
model(dataset['train_input']);
model.plot(beta=100)
Train KAN with sparsity regularization
# train the model
model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01=);