传知代码-KAN卷积:医学图像分割新前沿

news2024/11/24 11:39:08

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

概述

在本文中深入探讨KAN卷积在医学图像分割领域的创新应用,特别是通过引入Tokenized KAN Block(Tok Kan)这一突破性设计,将深度学习中的图像分割技术推向了新的高度。KAN作为一种能够替代传统MLP(多层感知机)的网络结构,以其独特的优势在多个领域展现出强大的潜力。而在医学图像分割这一复杂且关键的领域,KAN卷积更是凭借其高效处理图像特征的能力,成为了研究的热点。本文将U-Net结构中的卷积部分替换成了KAN卷积,将MLP部分用KANLinear取代,同时融入了类似Vision Transformer(VIT)的移位思想,使得模型在捕捉图像全局信息的同时,也能精准定位局部细节。该创新型能够支撑起一篇论文

核心创新点

  1. 将KAN卷积引入分割网络中

Kolmogorov–Arnold Networks(KAN)通常不是直接指代一种具体的卷积神经网络架构,但在这里我们可以理解为一种特殊的卷积或特征提取机制,可能基于Kolmogorov-Arnold表示定理(也称为超位置定理),该定理提供了多变量函数可以通过一系列一元函数和固定二元函数的组合来表示的理论基础。因此,KAN卷积可能意味着一种高度非线性的、能够捕捉复杂依赖关系的卷积操作。将其引入U-Net中,可以显著提升模型对图像特征的提取能力,特别是那些需要高级抽象和复杂交互的特征。

  1. KANLinear替换MLP

将传统的MLP层替换为KANLinear层,可能意味着这一层结合了KAN的某些特性(如非线性处理能力或复杂的函数逼近能力)和线性变换的简洁性。这种替换可能使模型在保持高效计算的同时,能够更灵活地处理特征之间的复杂关系,进一步增强模型的特征表示能力,同时能够减少模型复杂度。KANLinear层可能通过其独特的机制,更好地整合和转换来自不同层级的特征,从而有助于模型在全局和局部信息之间做出更精准的权衡。

  1. 融入移位思想

虽然Vision Transformer(VIT)本身并不直接包含“移位”操作,但其自注意力机制能够捕捉图像中的全局依赖关系,这一点与移位操作在促进信息流动和增强全局感受野方面的作用相似。在U-Net中融入类似VIT的思想,可能意味着引入了一种能够跨越空间位置直接交互特征的机制(如自注意力模块),或者通过某种形式的特征重排(类似于移位但更灵活)来增强模型的全局理解能力。这种设计使得模型在保持对局部细节敏感的同时,也能够有效地整合全局信息,从而在处理复杂图像任务时展现出更高的性能。本文中设计了沿着两个方向的移位。

模块介绍

KAN

KAN模型由数学定理Kolmogorov–Arnold启发得出,该定理由前苏联的两位数学家Vladimir Arnold和Andrey Kolmogorov提出。定理表明,任何多元连续函数都可以表示为单变量连续函数的两层嵌套叠加(一个单一变量的连续函数和一系列连续的双变量函数的组合)。这为多维函数的分解提供了理论基础,也是KAN模型设计的核心思想。KAN模型具体解读可以参考这篇
在这里插入图片描述

UNext模块

UNext模块一种基于卷积多层感知器(MLP)的医学图像分割网络,旨在解决现有模型如UNet和Transformer版本在计算复杂度、参数量以及推理速度上的不足。本文是在此基础上将卷积用KAN卷积取代,将MLP用KAN模块代替。具体有关UNext模块的介绍可以看这篇解读。
在这里插入图片描述

本文主要结构

本文受启发与上述的两种模块,将UNext模块中的上下采样中的卷积提取特征阶段用KAN卷积取代,将Tok MLP阶段用Tok Kan取代,从而增加模型提取特征的能力和渐少模型的参数量,同时提高模型非线性表达的能力。
整体架构如下:
在这里插入图片描述

主要代码

KAN卷积

class KANConvs(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, enable_standalone_scale_spline=True):
        """
        定义KAN卷积层,类似于nn.Conv2d,但包括KAN样条插值权重
        :param in_channels: 输入通道数
        :param out_channels: 输出通道数
        :param kernel_size: 卷积核大小
        :param stride: 步长
        :param padding: 填充
        :param dilation: 扩张
        :param groups: 组卷积
        :param bias: 偏置项
        :param enable_standalone_scale_spline: 是否启用独立缩放样条插值
        """
        super(KANConvs, self).__init__()

        # 基本的卷积层初始化
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias

        # 标准卷积层参数
        self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        # 样条插值权重
        self.spline_weight = Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))

        # 是否启用独立缩放样条插值
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        if enable_standalone_scale_spline:
            self.spline_scaler = Parameter(torch.ones(out_channels, 1))
            

        # 初始化权重
        self.reset_parameters()
class ConvLayer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ConvLayer, self).__init__()
        self.conv = nn.Sequential(
            KANConvs(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            KANConvs(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)

KANLinear层

class KANBlock(nn.Module):
    def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, no_kan=False):
        super().__init__()

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim)

        self.layer = KANLayer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, no_kan=no_kan)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

数据集

本文使用的是BUSI(Breast Ultrasound Image)是一个包含乳腺超声图像的分类和分割数据集。该数据集包括了 2018 年收集的乳腺超声波图像,涵盖了 25 至 75 岁的 600 名女性患者。数据集由 780 张图像组成,每张图像的平均大小为 500*500 像素。这些图像被划分为三类:正常、良性和恶性。而在良性和恶性乳腺超声图像中,还包含了对应胸部肿瘤的详细分割标注,为深入研究和精准诊断提供了关键信息。这份数据集不仅为乳腺癌研究提供了丰富的图像资源和宝贵支持。
在这里插入图片描述

结果展示

在BUSI数据集中的结果展示如下:

MethodsIoUDice
U-Net57.2271.91
Att-Unet55.1870.22
U-Net++57.4172.11
U-NeXt59.0673.08
Rolling-UNet61.0074.67
U-Mamba61.8175.55
OURS63.4577.05

分割图
在这里插入图片描述

运行过程
附件下载文件,readme中有详细步骤。数据集在readme提供的链接中。
环境配置

  - pip:
    - addict==2.4.0
    - dataclasses==0.8
    - mmcv-full==1.2.7
    - numpy==1.19.5
    - opencv-python==4.5.1.48
    - perceptual==0.1
    - pillow==8.4.0
    - scikit-image==0.17.2
    - scipy==1.5.4
    - tifffile==2020.9.3
    - timm==0.3.2
    - torch==2.4.0
    - torchvision==0.8.2
    - typing-extensions==4.0.0
    - yapf==0.31.0

运行结果:
在这里插入图片描述

本文在附件的readme中提供了除了BUSI之外的数据集,可以自行尝试增加工作量。

参考文献

  1. UNeXt: MLP-based Rapid Medical Image Segmentation Network
  2. KAN: Kolmogorov–Arnold Networks

源码下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2142841.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代理导致的git错误

问题: 今天在clone时出现如下错误: fatal: unable to access https://github.com/NirDiamant/RAG_Techniques.git/: Failed to connect to 127.0.0.1 port 10089 after 2065 ms: Couldnt connect to server真是让人感到奇怪!就在前天&#…

Typora安装,使用,图片加载全流程

文章目录 前言:安装:破解:使用typora:关于CSDN加载不出图片:创建OSS:设置PicGo: 前言: ​ Typora是一款非常流行的Markdown编辑器,简单来说就是可以方便我们写博客。拿我…

Linux基础---07文件传输及解决yum安装失效的方法

Linux文件传输地图如下,先选取你所需的场景,若你是需要Linux和Linux之间传输文件就查看SCP工具即可。 一.下载网站文件 前提是有网: 检查网络是否畅通命令:ping www.baidu.com,若有持续的返回值就说明网络畅通。Ctr…

如何建立一个Webservice WSDL的简单例子(完整例子)

一:根据对方给的wsdl 的接口地址创建Web 的逻辑端口 1:例如这个用C#写的Web 2.我们需要在SAP里建立一个Service Consumers 的服务记得后缀要加?wsdl 2:然后就会生成对应方法的出参 入参 返回的消息根据接口方法来判断 二:如何通…

day21JS-axios数据通信

1. 什么是axios axios 是一个基于Promise 用于浏览器和 nodejs 的 HTTP 客户端,简单的理解就是ajax的封装,只不过它是Promise的实现版本。 特性: 从浏览器中创建 XMLHttpRequests从 node.js 创建 http 请求支持 Promise API拦截请求和响应转…

基于Java的固定资产管理系统

基于Java的固定资产管理系统是一个实用的应用程序,用于跟踪和管理公司的资产。这种系统可以包括资产的采购日期、位置、状态、折旧等信息。下面是一个简单的固定资产管理系统的设计概述,以及一些关键功能模块的实现思路。 系统设计概览 用户管理&…

2-97 基于matlab的小波变换模量最大值 (WTMM) 方法进行图像边缘检测

基于matlab的小波变换模量最大值 (WTMM) 方法进行图像边缘检测。利用小波基函数的局部化和振荡特性来检测图像中的边缘,沿每个像素的梯度方向搜索局部最大值,保留局部最大值,抑制其他系数,实现边缘检测。程…

一文入门生成式AI(理解ChatGPT的原理)

一、什么是生成式AI? 以ChatGPT为代表的生成式AI,是对已有的数据和知识进行向量化的归纳,总结出数据的联合概率。从而在生成内容时,根据用户需求,结合关联字词的概率,生成新的内容。 可以这么联想&#x…

(9) protobuf 与cmake

文章目录 概要整体架构流程代码优化 概要 protobuf的安装与用cmake编译protobuf程序 整体架构流程 安装protobuf3 21.12,最新版本好像要用到一个新的库,有点麻烦。 https://github.com/protocolbuffers/protobuf/releases/tag/v21.12 all和cpp后缀的包…

python植物大战僵尸项目源码【免费】

植物大战僵尸是一款经典的塔防游戏,玩家通过种植各种植物来抵御僵尸的进攻。 源码下载地址: 植物大战僵尸项目源码 提取码: 8muq

【前端】prop传值的用法

prop配置项的作用是让组件接收外部传过来的值。 组件标签上传值给vue组件对象 <script> export default {name: HelloWorld,data(){return{ }},props:[name,age] #简单接收 } </script>方式2:利用对象方式设置数据类型进行类型限制 props:{name:String…

kubernetes中pause容器的作用与源码详解

概述 摘要&#xff1a;上一篇文章我们介绍了kubernetes是如何通过pause容器来构建一个pod。本文我们对pause容器做一个总结&#xff0c;并再此基础上次深入浅出&#xff0c;从pause容器的源码详细了解pause容器的实现原理。 正文 pause容器是什么 在 Kubernetes 中&#xff…

echarts 实现中国geo地图自定义贴图实例

文章目录 1. 实现效果2. 设置自定义图片&#xff0c;做好准备3. echarts 实现贴图 1. 实现效果 实现自定义背景图&#xff0c;给echarts地图贴背景图效果&#xff0c; 先准备两张背景图片&#xff0c;一张是默认的&#xff0c;一张是鼠标悬浮替换的&#xff0c;如下两张图 2.…

基于EchoMimic加速版,可编辑标志点控制实现逼真音频驱动的肖像动画

EchoMimic 是蚂蚁集团终端技术部门开发的一项技术,旨在通过音频驱动生成逼真的肖像动画。对于那些初次接触这项技术的用户,本教程将带你逐步了解如何设置开发环境、获取项目代码、安装依赖,并最终成功运行示例生成自己的肖像动画。 文章目录 项目代码安装依赖业务拓展参数调…

webpack的热更新原理

Webpack热更新&#xff08; Hot Module Replacement&#xff0c;简称 HMR&#xff09;&#xff0c;无需完全刷新整个页面的同时&#xff0c;更新所有类型的模块&#xff0c;是 Webpack 提供的最有用的功能之一。 保留在完全重新加载页面期间丢失的应用程序状态。只更新变更内容…

【bug】通过lora方式微调sdxl inpainting踩坑

报错内容 ValueError: Attempting to unscale FP16 gradients. 报错位置 if accelerator.sync_gradients:params_to_clip (itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)if args.train_text_encoderelse unet_lora_parameters…

温湿度传感器SHT20的功能介绍和时序分析

目录 概述 1 认识SHT20 1.1 SHT20介绍 1.2 SHT20属性 1.3 接口介绍 1.4 SHT20的相关命令 1.5 转换时间 2 寄存器操作 2.1 复位操作 2.2 User Register 2.3 CRC Checksum 3 温湿度计算 3.1 相对湿度转换 3.2 温度换算 3.3 转换公式的C语言实现 概述 本文主要介绍…

ChatGLM-6B部署到本地电脑

引言 ChatGLM-6B是由清华大学开源的双语对话大模型&#xff0c;该模型有62亿参数&#xff0c;但在经过量化后模型体积大幅下降&#xff0c;因此不同于其他需要部署到服务器上的大模型&#xff0c;该模型可以部署到本地电脑&#xff0c;那么接下来我们来看看如何部署该模型。 …

OpenAI API key not working in my React App

题意&#xff1a;OpenAI API 密钥在我的 React 应用中不起作用 问题背景&#xff1a; I am trying to create a chatbot in my react app, and Im not able to generate an LLM powered response. Ive been studying documentation and checking out tutorials but am unable …

【CMake】使用CMake在Visual Studio中配置glad和glfw

下载glad和glfw g l a d glad glad下载&#xff1a;glad下载 这个是 g i t h u b github github上的资源&#xff0c;进不去的话就开开魔法。 g l f w glfw glfw下载&#xff1a;glfw下载 下载CMake C M a k e CMake CMake下载&#xff1a; CMake下载 根据自己的平台选择&…