每日Attention学习23——KAN-Block

news2025/2/20 21:33:07
模块出处

[SPL 25] [link] [code] KAN See In the Dark


模块名称

Kolmogorov-Arnold Network Block (KAN-Block)


模块作用

用于vision的KAN结构


模块结构

在这里插入图片描述


模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = nn.Parameter(torch.Tensor(out_features))
        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)
        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )
        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features
        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class DW_bn_relu(nn.Module):
    def __init__(self, dim=768):
        super(DW_bn_relu, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
        self.bn = nn.BatchNorm2d(dim)
        self.relu = nn.ReLU()

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class KANBlock(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dim = in_features
        grid_size=5
        spline_order=3
        scale_noise=0.1
        scale_base=1.0
        scale_spline=1.0
        base_activation=torch.nn.SiLU
        grid_eps=0.02
        grid_range=[-1, 1]
        self.fc1 = KANLinear(
                    in_features,
                    hidden_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
        self.fc2 = KANLinear(
                    hidden_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
        self.fc3 = KANLinear(
                    hidden_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )   
        self.dwconv_1 = DW_bn_relu(hidden_features)
        self.dwconv_2 = DW_bn_relu(hidden_features)
        self.dwconv_3 = DW_bn_relu(hidden_features)
        self.drop = nn.Dropout(drop)
        self.shift_size = shift_size
        self.pad = shift_size // 2

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = self.fc1(x.reshape(B*N,C))
        x = x.reshape(B,N,C).contiguous()
        x = self.dwconv_1(x, H, W)
        x = self.fc2(x.reshape(B*N,C))
        x = x.reshape(B,N,C).contiguous()
        x = self.dwconv_2(x, H, W)
        x = self.fc3(x.reshape(B*N,C))
        x = x.reshape(B,N,C).contiguous()
        x = self.dwconv_3(x, H, W)
        return x
    

if __name__ == '__main__':
    x = torch.randn([1, 22*22, 128])
    kan = KANBlock(in_features=128)
    out = kan(x, H=22, W=22)
    print(out.shape)  # [1, 22*22, 128]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2299364.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Python的Optimal Interpolation (OI) 方法实现

前言 Optimal Interpolation (OI) 方法概述与实现 Optimal Interpolation (OI) 是一种广泛应用于气象学、海洋学等领域的空间数据插值方法。该方法通过结合观测数据与模型预测数据&#xff0c;最小化误差方差&#xff0c;从而实现对空间数据的最优插值。以下是OI方法的一般步骤…

学习数据结构(10)栈和队列下+二叉树(堆)上

1.关于栈和队列的算法题 &#xff08;1&#xff09;用队列实现栈 解法一&#xff1a;&#xff08;参考代码&#xff09; 题目要求实现六个函数&#xff0c;分别是栈初始化&#xff0c;入栈&#xff0c;移除并返回栈顶元素&#xff0c;返回栈顶元素&#xff0c;判空&#xff0…

.NET版Word处理控件Aspose.Words教程:使用 C# 删除 Word 中的空白页

Word 文档中的空白页会使其看起来不专业并扰乱流程。用户会遇到需要删除 Word 中的空白页的情况&#xff0c;但手动删除它们需要时间和精力。在这篇博文中&#xff0c;我们将探讨如何使用 C# 删除 Word 中的空白页。 本文涵盖以下主题&#xff1a; C# 库用于删除 Word 中的空…

《代码随想录》刷题笔记——回溯篇【java实现】

文章目录 组合组合总和 III电话号码的字母组合组合总和组合总和II思路代码实现 分割回文串※思路字符串分割回文串判断效率优化※ 复原 IP 地址优化版本 子集子集 II使用usedArr辅助去重不使用usedArr辅助去重 递增子序列※全排列全排列 II重新安排行程题意代码 N 皇后解数独直…

【JavaEE进阶】验证码案例

目 &#x1f332;实现说明 &#x1f384;Hutool介绍 &#x1f333;准备工作 &#x1f334;约定前后端交互接口 &#x1f6a9;接口定义 &#x1f6a9;实现服务器后端代码 &#x1f6a9;前端代码 &#x1f6a9;整体测试 &#x1f332;实现说明 随着安全性的要求越来越⾼…

Qt Creator 5.0.2 (Community)用久了突然变得很卡

目录 1.现象 2.解决方案 1.现象 很久没有用Qt Creator开发项目了&#xff0c;刚刚结束的项目又是用VS2019开发的&#xff1b;这两天刚好有时间去学习一下Qt&#xff0c;刚好要用Qt Creator&#xff0c;结果一打开就没反应&#xff0c;主界面显示出来要好几分钟&#xff0c;最…

阅读论文笔记《Efficient Estimation of Word Representations in Vector Space》

这篇文章写于2013年&#xff0c;对理解 word2vec 的发展历程挺有帮助。 本文仅适用于 Word2Vect 的复盘 引言 这篇论文致力于探索从海量数据中学习高质量单词向量的技术。当时已发现词向量能保留语义特征&#xff0c;例如 “国王 - 男人 女人≈女王”。论文打算借助该特性&am…

初学PADS使用技巧笔记(也许会继续更新)

操作意图&#xff1a;网上找某个芯片封装又不想自己画&#xff0c;再加上没经验&#xff0c;怎么办&#xff1f; 就以AC-DC芯片PN8036为例&#xff0c;打开嘉立创的的DFM&#xff0c;打开立创商城&#xff0c;输入PN8036&#xff0c;点击数据手册&#xff0c;然后点击直接打开…

从无序到有序:上北智信通过深度数据分析改善会议室资源配置

当前企业普遍面临会议室资源管理难题&#xff0c;预约机制不完善和临时会议多导致资源调度不合理&#xff0c;既有空置又有过度拥挤现象。 针对上述问题&#xff0c;上北智信采用了专业数据分析手段&#xff0c;巧妙融合楼层平面图、环形图、折线图和柱形图等多种可视化工具&a…

CAS单点登录(第7版)4.管理

如有疑问&#xff0c;请看视频&#xff1a;CAS单点登录&#xff08;第7版&#xff09; 管理 概述 Admin Console & 仪表板 CAS 提供了许多可用于管理 CAS 服务器部署的工具和控制板。此类选项通常不是互斥的&#xff0c;旨在协同工作并呈现 CAS 配置和构建的各个方面&am…

Baklib一站式云平台:全场景赋能企业知识资产激活

内容概要 在数字化浪潮推动下&#xff0c;企业知识资产的高效管理与价值释放成为核心议题。Baklib作为一站式云平台&#xff0c;以全场景赋能为核心定位&#xff0c;通过构建知识中台架构&#xff0c;为企业提供从资源整合到应用落地的闭环解决方案。该平台不仅支持文本、图像…

登录弹窗效果

1&#xff0c;要求 点击登录按钮&#xff0c;弹出登录窗口 提示1&#xff1a;登录窗口 display:none 隐藏状态&#xff1b; 提示2&#xff1a;登录按钮点击后&#xff0c;触发事件&#xff0c;修改 display:block 显示状态 提示3&#xff1a;登录窗口中点击关闭按钮&#xff0…

小小小病毒(3)(~_~|)

一分耕耘一分收获 声明&#xff1a; 仅供损害电脑&#xff0c;不得用于非法。损坏电脑&#xff0c;作者一律不负责。此作为作者原创&#xff0c;转载请经过同意。 欢迎来到小小小病毒&#xff08;3&#xff09; 感谢大家的支持 还是那句话&#xff1a;上代码&#xff01; …

【Java 面试 八股文】Spring Cloud 篇

Spring Cloud 篇 1. Spring Cloud 5大组件有哪些&#xff1f;2. 服务注册和发现是什么意思&#xff1f;Spring Cloud 如何实现服务注册发现&#xff1f;3. 我看你之前也用过nacos&#xff0c;你能说下nacos与eureka的区别&#xff1f;4. 你们项目负载均衡如何实现的&#xff1f…

Esxi8.0设置nvidia显卡直通安装最新驱动

ESXI8.0设置显卡直通 在某些情况下&#xff0c;我们需要多次切换操作系统&#xff0c;以测试软件是否适用于特定系统和环境&#xff0c;减少多次重装系统的麻烦 ESXI8.0安装包 通过网盘分享的文件&#xff1a;ESXi-8.0U2-22380479-USB-NVME-集成网卡镜像.iso 链接: https://…

LabVIEW袜品压力测试系统

开发了一种基于LabVIEW开发的袜品压力测试系统。该系统利用LabVIEW并结合灵敏的传感器和高精度的处理模块&#xff0c;实现了对袜品压力的精确测量和分析。系统不同于传统的服装压力测试方法&#xff0c;为研究和评价袜子的舒适性提供了新的测试手段。 ​ 项目背景 该系统的…

TestHubo基础教程-创建项目

TestHubo是一款国产开源一站式测试工具&#xff0c;涵盖功能测试、接口测试、性能测试&#xff0c;以及 Web 和 App 测试&#xff0c;可以满足不同类型项目的测试需求。本文将介绍如何快速创建第一个项目&#xff0c;以快速入门上手。 1、创建项目 在 TestHubo 中&#xff0c;…

深度求索—DeepSeek API的简单调用(Java)

DeepSeek简介 DeepSeek&#xff08;深度求索&#xff09;是由中国人工智能公司深度求索&#xff08;DeepSeek Inc.&#xff09;研发的大规模语言模型&#xff08;LLM&#xff09;&#xff0c;专注于提供高效、智能的自然语言处理能力&#xff0c;支持多种场景下的文本生成、对…

企业使用统一终端管理(UEM)工具提高端点安全性

什么是统一终端管理(UEM) 统一终端管理(UEM)是一种从单个控制台管理和保护企业中所有端点的方法&#xff0c;包括智能手机、平板电脑、笔记本电脑、台式机和 IoT设备。UEM 解决方案为 IT 管理员提供了一个集中式平台&#xff0c;用于跨所有作系统和设备类型部署、配置、管理和…

爱彼(Audemars Piguet):瑞士制表艺术的巅峰之作(中英双语)

爱彼&#xff08;Audemars Piguet&#xff09;&#xff1a;瑞士制表艺术的巅峰之作 在瑞士高级制表界&#xff0c;Audemars Piguet&#xff08;爱彼&#xff09; 以其大胆创新、卓越工艺和独立家族经营的传统&#xff0c;成为世界顶级腕表品牌之一。作为瑞士“三大制表品牌”之…