【论文笔记】RS-Mamba for Large Remote Sensing Image Dense Prediction(附Code)

news2024/10/8 14:31:33

论文作者提出了RS-Mamba(RSM)用于高分辨率遥感图像遥感的密集预测任务。RSM设计用于模拟具有线性复杂性的遥感图像的全局特征,使其能够有效地处理大型VHR图像。它采用全向选择性扫描模块,从多个方向对图像进行全局建模,从多个方向捕捉大的空间特征。

论文链接:https://arxiv.org/abs/2404.02668

code链接:https://github.com/walking-shadow/Official_Remote_Sensing_Mamba

2D全向扫描机制是本研究的主要创新点。作者考虑到遥感影像地物多方向的特点,在VMamba2D双向扫描机制的基础上增加了斜向扫描机制。

 以下是作者针对该部分进行改进的代码:

def antidiagonal_gather(tensor):
    # 取出矩阵所有反斜向的元素并拼接
    B, C, H, W = tensor.size()
    shift = torch.arange(H, device=tensor.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (torch.arange(W, device=tensor.device) - shift) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 使用gather进行索引选择
    return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)

def diagonal_gather(tensor):
    # 取出矩阵所有反斜向的元素并拼接
    B, C, H, W = tensor.size()
    shift = torch.arange(H, device=tensor.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (shift + torch.arange(W, device=tensor.device)) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 使用gather进行索引选择
    return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)

def diagonal_scatter(tensor_flat, original_shape):
    # 把斜向元素拼接起来的一维向量还原为最初的矩阵形式
    B, C, H, W = original_shape
    shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (shift + torch.arange(W, device=tensor_flat.device)) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 创建一个空的张量来存储反向散布的结果
    result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
    # 将平铺的张量重新变形为[B, C, H, W],考虑到需要使用transpose将H和W调换
    tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
    # 使用scatter_根据expanded_index将元素放回原位
    result_tensor.scatter_(3, expanded_index, tensor_reshaped)
    return result_tensor

def antidiagonal_scatter(tensor_flat, original_shape):
    # 把反斜向元素拼接起来的一维向量还原为最初的矩阵形式
    B, C, H, W = original_shape
    shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (torch.arange(W, device=tensor_flat.device) - shift) % W  # 利用广播创建索引矩阵[H, W]
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 初始化一个与原始张量形状相同、元素全为0的张量
    result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
    # 将平铺的张量重新变形为[B, C, W, H],因为操作是沿最后一个维度收集的,需要调整形状并交换维度
    tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
    # 使用scatter_将元素根据索引放回原位
    result_tensor.scatter_(3, expanded_index, tensor_reshaped)
    return result_tensor

class CrossScan(torch.autograd.Function):
    # ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, H, W = x.shape
        ctx.shape = (B, C, H, W)
        # xs = x.new_empty((B, 4, C, H * W))
        xs = x.new_empty((B, 8, C, H * W))
        # 添加横向和竖向的扫描
        xs[:, 0] = x.flatten(2, 3)
        xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
    
        # 提供斜向和反斜向的扫描
        xs[:, 4] = diagonal_gather(x)
        xs[:, 5] = antidiagonal_gather(x)
        xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])

        return xs
    
    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        # out: (b, k, d, l)
        B, C, H, W = ctx.shape
        L = H * W
        # 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加
        # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
        # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        y_rb = y_rb.view(B, -1, H, W)

        # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
        y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L)
        # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
        y_da = diagonal_scatter(y_da[:, 0], (B,C,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,C,H,W))

        y_res = y_rb + y_da
        # return y.view(B, -1, H, W)
        return y_res


class CrossMerge(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ys: torch.Tensor):
        B, K, D, H, W = ys.shape
        ctx.shape = (H, W)
        ys = ys.view(B, K, D, -1)
        # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)

        y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
        y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
        y_rb = y_rb.view(B, -1, H, W)

        # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
        y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1)
        # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
        y_da = diagonal_scatter(y_da[:, 0], (B,D,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,D,H,W))

        y_res = y_rb + y_da
        return y_res.view(B, D, -1)
        # return y
    
    @staticmethod
    def backward(ctx, x: torch.Tensor):
        # B, D, L = x.shape
        # out: (b, k, d, l)
        H, W = ctx.shape
        B, C, L = x.shape
        # xs = x.new_empty((B, 4, C, L))
        xs = x.new_empty((B, 8, C, L))

        # 横向和竖向扫描
        xs[:, 0] = x
        xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
        # xs = xs.view(B, 4, C, H, W)

        # 提供斜向和反斜向的扫描
        xs[:, 4] = diagonal_gather(x.view(B,C,H,W))
        xs[:, 5] = antidiagonal_gather(x.view(B,C,H,W))
        xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])

        # return xs
        return xs.view(B, 8, C, H, W)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1613558.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PyTorch深度学习入门到精通指南AI写作一键生成

首先,这篇文章是基于笔尖AI写作进行文章创作的,喜欢的宝子,也可以去体验下,解放双手,上班直接摸鱼~ 按照惯例,先介绍下这款笔尖AI写作,宝子也可以直接下滑跳过看正文~ 笔尖Ai写作:…

虚拟机VMware安装与Ubuntu

1.虚拟机安装 链接:百度网盘 请输入提取码 提取码:2fr6 CG54H-D8D0H-H8DHY-C6X7X-N2KG6 2.Ubuntu下载 Download Ubuntu Desktop | Ubuntu 3.设置 如后续要下一些软件越大越好

关于Jetson空间不足的解决问题(sd卡挂载和conda更改环境安装路径)

文章目录 问题描述挂载sd卡到指定目录查看conda路径更改环境路径指定路径安装conda虚拟环境 问题描述 因为在做毕设的时候,用到了Jetson,发现这个空间太小了,如果下conda的包根本不够用,所以就想挂载sd卡,然后把环境安…

cesium sampleHeightMostDetailed 取高度

//通过经纬度异步拾取模型的高度,当模型还没下载,并不在屏幕范围内时,先下载模型,再拾取高度let c3 Cesium.Cartesian3.fromDegrees(120.134766, 30.188376, 0);let position Cesium.Cartographic.fromCartesian(c3);let promis…

spring版本介绍

Spring Framework 是一个广泛使用的 Java 平台,用于构建企业级应用程序。它提供了一个全面的编程和配置模型,支持现代 Java 应用程序的最佳实践,如依赖注入、面向切面编程以及基于注解的编程模型。自从 Spring 1.0 发布以来,已经经…

构建代理IP池并自动测试可用性的爬虫实现

目录 前言 一、认识代理IP 1. 隐藏真实IP地址 2. 提高爬虫效率 二、爬取代理IP 三、测试代理IP可用性 1. 发起HTTP请求 2. 超时检测 3. 循环请求 四、构建代理IP池 五、总结 前言 随着互联网的发展,网络爬虫在数据采集、搜索引擎、信息监控等领域发挥着…

Electron+Vue3整合-开发时整合-全部ts开发 + 一条命令启动vue3和electron两个服务

说明 本文介绍一下 Electron Vue3 的整合的中级操作。实现的效果是 : 1、一个正常的Vue3项目; 2、整合加入 Electron 框架 :开发时只执行一条命令,启动 vue 项目 后 再启动 electron;electron 的开发使用 typescript…

经典机器学习算法——决策树

优质博文:IT-BLOG-CN 树模型是机器学习中最常用的一类模型,包括随机森林、AdaBoost、GBDT(XGBoost和Lightgbm)等,基本原理都是通过集成弱学习器的即式来进一步提升准确度。这里的弱学习器包括线性模型和决策树模型&…

DC30V36V60V100V转9V、12V/1.5A方案 车灯驱动芯片IC H5028L ,高性价比,皮实耐抗

DC24V、30V、36V、60V、100V转9V、12V/1.5A方案,以及车灯驱动芯片IC,这通常涉及到电源转换和驱动电路的设计。这些方案的目标是将一个较高的直流电压(如24V、30V、36V、60V或100V)转换为较低但稳定的直流电压(如9V或12…

EigenLayer生态全解析:再质押与AVS崛起的序章

基于以太坊网络的再质押协议EigenLayer提出了利用为以太坊网络验证而质押的ETH来与其他协议共享安全性和资本效率,同时为协议参与者提供额外利息。在AVS、再质押、积分系统等概念的推动下,逐渐形成一个庞大的生态系统,从2024年初到现在EigenL…

[前端]NVM管理器安装、nodejs、npm、yarn配置

NVM管理器安装、nodejs、npm、yarn配置 NVM管理器安装 nvm(Node.js version manager) 是一个命令行应用,可以协助您快速地 更新、安装、使用、卸载 本机的全局 node.js 版本。 nvm下载地址:https://github.com/coreybutler/nvm-windows/releases 1.全部…

分类预测 | Matlab实现CNN-LSTM-SAM-Attention卷积长短期记忆神经网络融合空间注意力机制的数据分类预测

分类预测 | Matlab实现CNN-LSTM-SAM-Attention卷积长短期记忆神经网络融合空间注意力机制的数据分类预测 目录 分类预测 | Matlab实现CNN-LSTM-SAM-Attention卷积长短期记忆神经网络融合空间注意力机制的数据分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述 1.Mat…

模电期末复习(四)功率放大电路

功率放大电路 4.1 功率放大电路的主要特点4.1.1 对放大电路的要求4.1.2 放大电路中三极管的工作状态4.1.3 放大电路的分析方法 4.2 互补对称式功率放大电路4.2.1 电路的组成和工作原理4.2.2 互补对称电路主要参数的估算 4.3 采用复合管的互补对称式放大电路4.3.1 复合管的接法及…

【计算机毕业设计】理发店管理系统产品功能说明——后附源码

🎉**欢迎来到我的技术世界!**🎉 📘 博主小档案: 一名来自世界500强的资深程序媛,毕业于国内知名985高校。 🔧 技术专长: 在深度学习任务中展现出卓越的能力,包括但不限于…

绝地求生:PUBG巅峰在线人数再次突破70W:荣都、杜卡迪功不可没!

根据黑盒游戏人数显示,进入2024年后,PUBG在线人数稳定在60W左右。 绝地求生自去年世界赛结束以来,一直处于不愠不火的状态,外挂横行加上没有新游戏元素加入,日活人数仅剩余30~40W。 荣都、杜卡迪上线 而20…

JavaSE基础篇-2

一、数组操作 【先写几个练习】 public class Demo01Array {public static void main(String[] args) {//1.创建Random对象以及数组Random rd new Random();int[] arr new int[10];//2.定义一个变量,统计个数 countint count 0;//3.循环随机循环存for (int i 0; i < ar…

【LInux学习】Linux项目自动化构建工具-make/Makefile

文章目录 &#x1f302;背景&#x1f302;make/Makefile的使用&#x1f302;make/Makefile原理&#x1f302;项目清理&#x1f302;make/Makefile的语法补充 &#x1f302;背景 会不会写makefile&#xff0c;从一个侧面说明了一个人是否具备完成大型工程的能力一个工程中的源文…

C++模版初阶----函数模版、类模版

C模版初阶 1. 泛型编程2. 函数模板2.1 函数模板概念2.2函数模板格式2.3 函数模板的原理2.4 函数模板的实例化2.5 函数模版的匹配原则 3. 类模板3.1 类模板的定义格式3.2 类模板的实例化 总结 1. 泛型编程 泛型编程 : 编写与类型无关的通用代码&#xff0c;是代码复用的一种手段…

模电期末复习(三)放大电路的频率响应

放大电路的频率响应 3.1 频率响应的一般概念3.1.1 幅频特性和相频特性3.1.2 下限频率、上限频率和通频带3.1.3 频率失真3.1.4 波特图3.1.5高通电路和低通电路 3.2 三极管的频率参数3.2.1 共射截止频率3.2.2 特征频率3.2.3 共基截止频率 3.3 单管共射放大电路的频率响应3.3.1 三…

Chrome 侧边栏开发示例

前言 最近做项目&#xff0c;需要开发浏览器扩展&#xff0c;但是考虑页面布局兼容性问题&#xff0c;使用了Chrome114开始的侧边栏&#xff0c;浏览器自带的能力毕竟不会出现兼容性问题&#xff0c;不过Chrome123开始&#xff0c;侧边栏居然又可以选择固定右侧扩展栏了&#…