每日Attention学习22——Inverted Residual RWKV

news2025/2/11 8:59:34
模块出处

[arXiv 25] [link] [code] RWKV-UNet: Improving UNet with Long-Range Cooperation for Effective Medical Image Segmentation


模块名称

Inverted Residual RWKV (IR-RWKV)


模块作用

用于vision的RWKV结构


模块结构

在这里插入图片描述


模块代码

注:cpp扩展请参考作者原仓库

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from timm.layers.activations import *
from functools import partial
from timm.layers import DropPath, create_act_layer, LayerType
from typing import Callable, Dict, Optional, Type
from torch.utils.cpp_extension import load


T_MAX = 1024
inplace = True
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
                verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])


def get_norm(norm_layer='in_1d'):
	eps = 1e-6
	norm_dict = {
		'none': nn.Identity,
		'in_1d': partial(nn.InstanceNorm1d, eps=eps),
		'in_2d': partial(nn.InstanceNorm2d, eps=eps),
		'in_3d': partial(nn.InstanceNorm3d, eps=eps),
		'bn_1d': partial(nn.BatchNorm1d, eps=eps),
		'bn_2d': partial(nn.BatchNorm2d, eps=eps),
		# 'bn_2d': partial(nn.SyncBatchNorm, eps=eps),
		'bn_3d': partial(nn.BatchNorm3d, eps=eps),
		'gn': partial(nn.GroupNorm, eps=eps),
		'ln_1d': partial(nn.LayerNorm, eps=eps),
		# 'ln_2d': partial(LayerNorm2d, eps=eps),
	}
	return norm_dict[norm_layer]


def get_act(act_layer='relu'):
	act_dict = {
		'none': nn.Identity,
		'sigmoid': Sigmoid,
		'swish': Swish,
		'mish': Mish,
		'hsigmoid': HardSigmoid,
		'hswish': HardSwish,
		'hmish': HardMish,
		'tanh': Tanh,
		'relu': nn.ReLU,
		'relu6': nn.ReLU6,
		'prelu': PReLU,
		'gelu': GELU,
		'silu': nn.SiLU
	}
	return act_dict[act_layer]


class ConvNormAct(nn.Module):
	def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,
				 skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):
		super(ConvNormAct, self).__init__()
		self.has_skip = skip and dim_in == dim_out
		padding = math.ceil((kernel_size - stride) / 2)
		self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)
		self.norm = get_norm(norm_layer)(dim_out)
		self.act = get_act(act_layer)(inplace=inplace)
		self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
	
	def forward(self, x):
		shortcut = x
		x = self.conv(x)
		x = self.norm(x)
		x = self.act(x)
		if self.has_skip:
			x = self.drop_path(x) + shortcut
		return x
      

class SE(nn.Module):
    def __init__(
            self,
            in_chs: int,
            rd_ratio: float = 0.25,
            rd_channels: Optional[int] = None,
            act_layer: LayerType = nn.ReLU,
            gate_layer: LayerType = nn.Sigmoid,
            force_act_layer: Optional[LayerType] = None,
            rd_round_fn: Optional[Callable] = None,
    ):
        super(SE, self).__init__()
        if rd_channels is None:
            rd_round_fn = rd_round_fn or round
            rd_channels = rd_round_fn(in_chs * rd_ratio)
        act_layer = force_act_layer or act_layer
        self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)
        self.act1 = create_act_layer(act_layer, inplace=True)
        self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)
        self.gate = create_act_layer(gate_layer)

    def forward(self, x):
        x_se = x.mean((2, 3), keepdim=True)
        x_se = self.conv_reduce(x_se)
        x_se = self.act1(x_se)
        x_se = self.conv_expand(x_se)
        return x * self.gate(x_se)
    

def q_shift(input, shift_pixel=1, gamma=1/4, patch_resolution=None):
    assert gamma <= 1/4
    B, N, C = input.shape
    input = input.transpose(1, 2).reshape(B, C, patch_resolution[0], patch_resolution[1])
    B, C, H, W = input.shape
    output = torch.zeros_like(input)
    output[:, 0:int(C*gamma), :, shift_pixel:W] = input[:, 0:int(C*gamma), :, 0:W-shift_pixel]
    output[:, int(C*gamma):int(C*gamma*2), :, 0:W-shift_pixel] = input[:, int(C*gamma):int(C*gamma*2), :, shift_pixel:W]
    output[:, int(C*gamma*2):int(C*gamma*3), shift_pixel:H, :] = input[:, int(C*gamma*2):int(C*gamma*3), 0:H-shift_pixel, :]
    output[:, int(C*gamma*3):int(C*gamma*4), 0:H-shift_pixel, :] = input[:, int(C*gamma*3):int(C*gamma*4), shift_pixel:H, :]
    output[:, int(C*gamma*4):, ...] = input[:, int(C*gamma*4):, ...]
    return output.flatten(2).transpose(1, 2)


def RUN_CUDA(B, T, C, w, u, k, v):
    return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())


class WKV(torch.autograd.Function):
    @staticmethod
    def forward(ctx, B, T, C, w, u, k, v):
        ctx.B = B
        ctx.T = T
        ctx.C = C
        assert T <= T_MAX
        assert B * C % min(C, 1024) == 0

        half_mode = (w.dtype == torch.half)
        bf_mode = (w.dtype == torch.bfloat16)
        ctx.save_for_backward(w, u, k, v)
        w = w.float().contiguous()
        u = u.float().contiguous()
        k = k.float().contiguous()
        v = v.float().contiguous()
        y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
        wkv_cuda.forward(B, T, C, w, u, k, v, y)
        if half_mode:
            y = y.half()
        elif bf_mode:
            y = y.bfloat16()
        return y

    @staticmethod
    def backward(ctx, gy):
        B = ctx.B
        T = ctx.T
        C = ctx.C
        assert T <= T_MAX
        assert B * C % min(C, 1024) == 0
        w, u, k, v = ctx.saved_tensors
        gw = torch.zeros((B, C), device='cuda').contiguous()
        gu = torch.zeros((B, C), device='cuda').contiguous()
        gk = torch.zeros((B, T, C), device='cuda').contiguous()
        gv = torch.zeros((B, T, C), device='cuda').contiguous()
        half_mode = (w.dtype == torch.half)
        bf_mode = (w.dtype == torch.bfloat16)
        wkv_cuda.backward(B, T, C,
                          w.float().contiguous(),
                          u.float().contiguous(),
                          k.float().contiguous(),
                          v.float().contiguous(),
                          gy.float().contiguous(),
                          gw, gu, gk, gv)
        if half_mode:
            gw = torch.sum(gw.half(), dim=0)
            gu = torch.sum(gu.half(), dim=0)
            return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
        elif bf_mode:
            gw = torch.sum(gw.bfloat16(), dim=0)
            gu = torch.sum(gu.bfloat16(), dim=0)
            return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
        else:
            gw = torch.sum(gw, dim=0)
            gu = torch.sum(gu, dim=0)
            return (None, None, None, gw, gu, gk, gv)
        

class VRWKV_SpatialMix(nn.Module):
    def __init__(self, n_embd, channel_gamma=1/4, shift_pixel=1):
        super().__init__()
        self.n_embd = n_embd
        attn_sz = n_embd
        self._init_weights()
        self.shift_pixel = shift_pixel
        if shift_pixel > 0:
            self.channel_gamma = channel_gamma
        else:
            self.spatial_mix_k = None
            self.spatial_mix_v = None
            self.spatial_mix_r = None

        self.key = nn.Linear(n_embd, attn_sz, bias=False)
        self.value = nn.Linear(n_embd, attn_sz, bias=False)
        self.receptance = nn.Linear(n_embd, attn_sz, bias=False)
        self.key_norm = nn.LayerNorm(n_embd)
        self.output = nn.Linear(attn_sz, n_embd, bias=False)

        self.key.scale_init = 0
        self.receptance.scale_init = 0
        self.output.scale_init = 0

    def _init_weights(self):
        self.spatial_decay = nn.Parameter(torch.zeros(self.n_embd))
        self.spatial_first = nn.Parameter(torch.zeros(self.n_embd))
        self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
        self.spatial_mix_v = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
        self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
    def jit_func(self, x, patch_resolution):
        # Mix x with the previous timestep to produce xk, xv, xr
        B, T, C = x.size()
        # Use xk, xv, xr to produce k, v, r
        if self.shift_pixel > 0:
            xx = q_shift(x, self.shift_pixel, self.channel_gamma, patch_resolution)
            xk = x * self.spatial_mix_k + xx * (1 - self.spatial_mix_k)
            xv = x * self.spatial_mix_v + xx * (1 - self.spatial_mix_v)
            xr = x * self.spatial_mix_r + xx * (1 - self.spatial_mix_r)
        else:
            xk = x
            xv = x
            xr = x
        k = self.key(xk)
        v = self.value(xv)
        r = self.receptance(xr)
        sr = torch.sigmoid(r)
        return sr, k, v

    def forward(self, x, patch_resolution=None):
        B, T, C = x.size()
        sr, k, v = self.jit_func(x, patch_resolution)
        x = RUN_CUDA(B, T, C, self.spatial_decay / T, self.spatial_first / T, k, v)
        x = self.key_norm(x)
        x = sr * x
        x = self.output(x)
        return x
    

class iR_RWKV(nn.Module):
    def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0, norm_layer='bn_2d',
                 act_layer='relu', dw_ks=3, stride=1, dilation=1, se_ratio=0.0,
                 attn_s=True, drop_path=0., drop=0.,img_size=224, channel_gamma=1/4, shift_pixel=1):
        super().__init__()
        self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity()
        dim_mid = int(dim_in * exp_ratio)
        self.ln1 = nn.LayerNorm(dim_mid)
        self.conv = ConvNormAct(dim_in, dim_mid, kernel_size=1)
        self.has_skip = (dim_in == dim_out and stride == 1) and has_skip
        if attn_s==True:
                self.att = VRWKV_SpatialMix(dim_mid, channel_gamma, shift_pixel)
        self.se = SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else nn.Identity()
        self.proj_drop = nn.Dropout(drop)
        self.proj = ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace)
        self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
        self.attn_s=attn_s
        self.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride, dilation=dilation, groups=dim_mid, norm_layer='bn_2d', act_layer='silu', inplace=inplace)
        
    def forward(self, x):
        shortcut = x
        x = self.norm(x)
        x = self.conv(x)
        if self.attn_s:
            B, hidden, H, W = x.size()
            patch_resolution = (H,  W)
            x = x.view(B, hidden, -1)  # (B, hidden, H*W) = (B, C, N)
            x = x.permute(0, 2, 1)
            x = x + self.drop_path(self.ln1(self.att(x, patch_resolution)))
            B, n_patch, hidden = x.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hidde
            h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
            x = x.permute(0, 2, 1)
            x = x.contiguous().view(B, hidden, h, w)
        x = x + self.se(self.conv_local(x)) if self.has_skip else self.se(self.conv_local(x))
        x = self.proj_drop(x)
        x = self.proj(x)
        x = (shortcut + self.drop_path(x)) if self.has_skip else x
        return x


if __name__ == '__main__':
    x = torch.randn([1, 64, 11, 11]).cuda()
    ir_rwkv = iR_RWKV(dim_in=64, dim_out=64).cuda()
    out = ir_rwkv(x)
    print(out.shape)  # [1, 64, 11, 11]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2295523.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习之数学基础:线性代数、微积分、概率论 | PyTorch 深度学习实战

前一篇文章&#xff0c;使用线性回归模型逼近目标模型 | PyTorch 深度学习实战 本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started 本篇文章内容来自于 强化学习必修课&#xff1a;引领人工智能新时代【梗直哥瞿炜】 线性代数、微积分、概率论 …

UNI-MOL: A UNIVERSAL 3D MOLECULAR REPRESENTATION LEARNING FRAMEWORK

UNI-MOL: A UNIVERSAL 3D MOLECULAR REPRESENTATION LEARNING FRAMEWORK Neurips23 推荐指数&#xff1a;#paper/⭐⭐⭐#​&#xff08;工作量不小) 动机 在大多数分子表征学习方法中&#xff0c;分子被视为 1D 顺序标记或2D 拓扑图&#xff0c;这限制了它们为下游任务整合…

SQL Server查询计划操作符(7.3)——查询计划相关操作符(6)

7.3. 查询计划相关操作符 48)Key Lookup:该操作符对一个有簇索引的表进行书签查找。参数列包含簇索引的名字和用于查找簇索引中数据行的簇键。该操作符总是伴随一个Nested Loops操作符。如果其参数列中出现WITH PREFETCH子句,则查询处理器已决定使用异步预取(预读,read-ah…

C语言【基础篇】之数组——解锁多维与动态数组的编程奥秘

数组 &#x1f680;前言&#x1f99c;数组的由来与用途&#x1f31f;一维数组详解&#x1f58a;️二维数组进阶&#x1f4af;动态数组原理&#x1f914;常见误区扫盲&#x1f4bb;学习路径建议✍️总结 &#x1f680;前言 大家好&#xff01;我是 EnigmaCoder。本文收录于我的专…

掌握API和控制点(从Java到JNI接口)_38 JNI从C调用Java函数 01

1. Why? 将控制点下移到下C/C层 对古典视角的反思 App接近User&#xff0c;所以App在整体架构里&#xff0c;是主导者&#xff0c;拥有控制权。所以&#xff0c; App是架构的控制点所在。Java函数调用C/C层函数&#xff0c;是合理的。 但是EIT造形告诉我们&#xff1a; App…

windows蓝牙驱动开发-蓝牙 LE 邻近感应配置文件

邻近感应检测是蓝牙低功耗 (LE) 的常见用途。 本部分提供了创建可用于开发 UWP 设备应用的邻近感应配置文件的设备实现的指南。 在开发此应用之前&#xff0c;应熟悉蓝牙 LE 函数和蓝牙 LE 邻近感应配置文件规范。 示例服务声明 蓝牙低功耗引入了一个新的物理层&#xff0c;…

免费windows pdf编辑工具Epdf

Epdf&#xff08;完全免费&#xff09; 作者&#xff1a;不染心 时间&#xff1a;2025/2/6 Github: https://github.com/dog-tired/Epdf Epdf Epdf 是一款使用 Rust 编写的 PDF 编辑器&#xff0c;目前仍在开发中。它提供了一系列实用的命令行选项&#xff0c;方便用户对 PDF …

C++:类和对象初识

C&#xff1a;类和对象初识 前言类的引入与定义引入定义类的两种定义方法1. 声明和定义全部放在类体中2. 声明和定义分离式 类的成员变量命名规则 类的访问限定符及封装访问限定符封装 类的作用域与实例化类的作用域类实例化实例化方式&#xff1a; 类对象模型类对象的大小存储…

伪分布式Spark3.4.4安装

参考&#xff1a;Spark2.1.0入门&#xff1a;Spark的安装和使用_厦大数据库实验室博客 我的版本&#xff1a; hadoop 3.1.3 hbase 2.2.2 java openjdk version "1.8.0_432" 问了chatgpt,建议下载Spark3.4.4&#xff0c;不适合下载Spark 2.1.0: step1 Spark下载…

kafka服务端之控制器

文章目录 概述控制器的选举与故障恢复控制器的选举故障恢复 优雅关闭分区leader的选举 概述 在Kafka集群中会有一个或多个broker&#xff0c;其中有一个broker会被选举为控制器&#xff08;Kafka Controler&#xff09;&#xff0c;它负责管理整个集群中所有分区和副本的状态。…

【R语言】数据分析

一、描述性统计量 借助R语言内置的airquality数据集进行简单地演示&#xff1a; 1、集中趋势&#xff1a;均值和中位数 head(airquality) # 求集中趋势 mean(airquality$Ozone, na.rmT) # 求均值 median(airquality$Ozone, na.rmT) # 求中位数 2、众数 众数&#xff08;mod…

传输层协议 UDP 与 TCP

&#x1f308; 个人主页&#xff1a;Zfox_ &#x1f525; 系列专栏&#xff1a;Linux 目录 一&#xff1a;&#x1f525; 前置复盘&#x1f98b; 传输层&#x1f98b; 再谈端口号&#x1f98b; 端口号范围划分&#x1f98b; 认识知名端口号 (Well-Know Port Number) 二&#xf…

Java/Kotlin双语革命性ORM框架Jimmer(一)——介绍与简单使用

概览 Jimmer是一个Java/Kotlin双语框架 包含一个革命性的ORM 以此ORM为基础打造了一套综合性方案解决方案&#xff0c;包括 DTO语言 更全面更强大的缓存机制&#xff0c;以及高度自动化的缓存一致性 更强大客户端文档和代码生成能力&#xff0c;包括Jimmer独创的远程异常 …

剪辑学习整理

文章目录 1. 剪辑介绍 1. 剪辑介绍 剪辑可以干什么&#xff1f;剪辑分为哪些种类&#xff1f; https://www.bilibili.com/video/BV15r421p7aF/?spm_id_from333.337.search-card.all.click&vd_source5534adbd427e3b01c725714cd93961af 学完剪辑之后如何找工作or兼职&#…

IDEA查看项目依赖包及其版本

一.IDEA将现有项目转换为Maven项目 在IntelliJ IDEA中,将现有项目转换为Maven项目是一个常见的需求,可以通过几种不同的方法来实现。Maven是一个强大的构建工具,它可以帮助自动化项目的构建过程,管理依赖关系,以及其他许多方面。 添加Maven支持 如果你的项目还没有pom.xm…

centos虚拟机迁移没有ip的问题

故事背景&#xff0c;我们的centos虚拟机本来是好好的&#xff0c;但是拷贝到其他电脑上就不能分配ip&#xff0c;我个人觉得这个vmware他们软件应该搞定这个啊&#xff0c;因为这个问题是每次都会出现的。 网络选桥接 网络启动失败 service network restart Restarting netw…

Java 大视界 -- Java 大数据在智能供应链中的应用与优化(76)

&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎来到 青云交的博客&#xff01;能与诸位在此相逢&#xff0c;我倍感荣幸。在这飞速更迭的时代&#xff0c;我们都渴望一方心灵净土&#xff0c;而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识&#xff0c;也…

赛博算命之 ”梅花易数“ 的 “JAVA“ 实现 ——从玄学到科学的探索

hello~朋友们&#xff01;好久不见&#xff01; 今天给大家带来赛博算命第三期——梅花易数的java实现 赛博算命系列文章&#xff1a; 周易六十四卦 掐指一算——小六壬 更多优质文章&#xff1a;个人主页 JAVA系列&#xff1a;JAVA 大佬们互三哦~互三必回&#xff01;&#xf…

即梦(Dreamina)技术浅析(六):多模态生成模型

多模态生成模型是即梦(Dreamina)的核心技术之一,旨在结合文本和图像信息,生成更符合用户需求的视觉内容。多模态生成模型通过整合不同类型的数据(如文本和图像),能够实现更丰富、更精准的生成效果。 1. 基本原理 1.1 多模态生成模型概述 多模态生成模型的目标是结合不…

递增三元组(蓝桥杯18F)

暴力求解&#xff1a; #include<iostream> using namespace std; int main() {int N;cin >> N;int* A new int[N];int* B new int[N];int* C new int[N];for (int i 0; i < N;i) {cin >> A[i];}for (int i 0; i < N; i) {cin >> B[i];}for…