【论文阅读】RS-Mamba for Large Remote Sensing Image Dense Prediction(附Code)

news2024/11/30 8:40:49

论文作者提出了RS-Mamba(RSM)用于高分辨率遥感图像遥感的密集预测任务。RSM设计用于模拟具有线性复杂性的遥感图像的全局特征,使其能够有效地处理大型VHR图像。它采用全向选择性扫描模块,从多个方向对图像进行全局建模,从多个方向捕捉大的空间特征。

论文链接:https://arxiv.org/abs/2404.02668

code链接:https://github.com/walking-shadow/Official_Remote_Sensing_Mamba

2D全向扫描机制是本研究的主要创新点。作者考虑到遥感影像地物多方向的特点,在VMamba2D双向扫描机制的基础上增加了斜向扫描机制。

 以下是作者针对该部分进行改进的代码:

def antidiagonal_gather(tensor):
    # 取出矩阵所有反斜向的元素并拼接
    B, C, H, W = tensor.size()
    shift = torch.arange(H, device=tensor.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (torch.arange(W, device=tensor.device) - shift) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 使用gather进行索引选择
    return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)

def diagonal_gather(tensor):
    # 取出矩阵所有反斜向的元素并拼接
    B, C, H, W = tensor.size()
    shift = torch.arange(H, device=tensor.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (shift + torch.arange(W, device=tensor.device)) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 使用gather进行索引选择
    return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)

def diagonal_scatter(tensor_flat, original_shape):
    # 把斜向元素拼接起来的一维向量还原为最初的矩阵形式
    B, C, H, W = original_shape
    shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (shift + torch.arange(W, device=tensor_flat.device)) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 创建一个空的张量来存储反向散布的结果
    result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
    # 将平铺的张量重新变形为[B, C, H, W],考虑到需要使用transpose将H和W调换
    tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
    # 使用scatter_根据expanded_index将元素放回原位
    result_tensor.scatter_(3, expanded_index, tensor_reshaped)
    return result_tensor

def antidiagonal_scatter(tensor_flat, original_shape):
    # 把反斜向元素拼接起来的一维向量还原为最初的矩阵形式
    B, C, H, W = original_shape
    shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (torch.arange(W, device=tensor_flat.device) - shift) % W  # 利用广播创建索引矩阵[H, W]
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 初始化一个与原始张量形状相同、元素全为0的张量
    result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
    # 将平铺的张量重新变形为[B, C, W, H],因为操作是沿最后一个维度收集的,需要调整形状并交换维度
    tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
    # 使用scatter_将元素根据索引放回原位
    result_tensor.scatter_(3, expanded_index, tensor_reshaped)
    return result_tensor

class CrossScan(torch.autograd.Function):
    # ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, H, W = x.shape
        ctx.shape = (B, C, H, W)
        # xs = x.new_empty((B, 4, C, H * W))
        xs = x.new_empty((B, 8, C, H * W))
        # 添加横向和竖向的扫描
        xs[:, 0] = x.flatten(2, 3)
        xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
    
        # 提供斜向和反斜向的扫描
        xs[:, 4] = diagonal_gather(x)
        xs[:, 5] = antidiagonal_gather(x)
        xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])

        return xs
    
    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        # out: (b, k, d, l)
        B, C, H, W = ctx.shape
        L = H * W
        # 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加
        # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
        # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        y_rb = y_rb.view(B, -1, H, W)

        # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
        y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L)
        # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
        y_da = diagonal_scatter(y_da[:, 0], (B,C,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,C,H,W))

        y_res = y_rb + y_da
        # return y.view(B, -1, H, W)
        return y_res


class CrossMerge(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ys: torch.Tensor):
        B, K, D, H, W = ys.shape
        ctx.shape = (H, W)
        ys = ys.view(B, K, D, -1)
        # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)

        y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
        y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
        y_rb = y_rb.view(B, -1, H, W)

        # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
        y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1)
        # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
        y_da = diagonal_scatter(y_da[:, 0], (B,D,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,D,H,W))

        y_res = y_rb + y_da
        return y_res.view(B, D, -1)
        # return y
    
    @staticmethod
    def backward(ctx, x: torch.Tensor):
        # B, D, L = x.shape
        # out: (b, k, d, l)
        H, W = ctx.shape
        B, C, L = x.shape
        # xs = x.new_empty((B, 4, C, L))
        xs = x.new_empty((B, 8, C, L))

        # 横向和竖向扫描
        xs[:, 0] = x
        xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
        # xs = xs.view(B, 4, C, H, W)

        # 提供斜向和反斜向的扫描
        xs[:, 4] = diagonal_gather(x.view(B,C,H,W))
        xs[:, 5] = antidiagonal_gather(x.view(B,C,H,W))
        xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])

        # return xs
        return xs.view(B, 8, C, H, W)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1602898.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java -- (part12)

一.权限修饰符 1.属性:用private ->封装思想 2.成员方法public ->便于调用 3.构造public ->便于new对象 二.final关键字 1.修饰类 a.格式 -- public final class 类名 b.特点:不能被继承 2.修饰方法 a.格式:修饰符 final 返回值类型 方法名(形参){} b.特点…

推荐两个植物miRNA数据库(miRbase和PNRD)

前记 植物miRNA数据库是储存和整理植物微小RNA(miRNA)相关信息的数据库。miRNA是一类长度为21-24个核苷酸的非编码小分子RNA,能够通过与靶基因的mRNA结合,调控基因表达。植物miRNA数据库通常包含以下内容: miRNA序列信…

ROS2 仿真学习02 Gazebo导入官方示例模型

1.下载模型 git clone https://gitee.com/bingda-robot/gazebo_models.git将gazebo_models拖到到.gazebo当中(如果没看到.gazebo文件请按住CTRLh) 2.添加模型到gazebo的Insert 这就将官方示例的模型都导入到Gazebo 了 随便试试一个模型

每日OJ题_完全背包②_力扣322. 零钱兑换

目录 力扣322. 零钱兑换 问题解析 解析代码 优化代码(滚动数组) 力扣322. 零钱兑换 322. 零钱兑换 难度 中等 给你一个整数数组 coins ,表示不同面额的硬币;以及一个整数 amount ,表示总金额。 计算并返回可以…

密码学 | 椭圆曲线密码学 ECC 入门(一)

目录 正文 1 公共密钥密码学的兴起 2 玩具版 RSA 算法 2.1 RSA 基本原理 2.2 RSA 举例说明 1 加密 2 解密 3 不是完美的陷门函数 ⚠️ 原文地址:A (Relatively Easy To Understand) Primer on Elliptic Curve Cryptography ⚠️ 写在前面&#xff1…

【测试开发学习历程】python常用的模块(下)

目录 8、MySQL数据库的操作-pymysql 8.1 连接并操作数据库 9、ini文件的操作-configparser 9.1 模块-configparser 9.2 读取ini文件中的内容 9.3 获取指定建的值 10 json文件操作-json 10.1 json文件的格式或者json数据的格式 10.2 json.load/json.loads 10.3 json.du…

OpenHarmony南向开发案例【智慧中控面板(基于 Bearpi-Micro)】

1 开发环境搭建 【从0开始搭建开发环境】【快速搭建开发环境】 参考鸿蒙开发指导文档:gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或复制转到。 【注意】:快速上手教程第六步出拉取代码时需要修改代码仓库地址 在MobaXterm中输入…

ES-全文搜索

模糊查询: 写数据通过id路由到master分片 查询数据到一个节点,该节点会作为一个调度节点判断负载等情况将请求转发到真正节点(一般し轮询)

【学习笔记十九】EWM Yard Management概述及后台配置

一、EWM Yard堆场管理业务概述 1.Yard Management基本概念 YARD管理针对的是库房以外的区域,可以理解为入大门开始到库门之前的这部分的区域 堆场结构 像在仓库中一样,将相应仓位映射为堆场仓位,可将其分组到堆场分区。场地中可能具有以下结构: 停车位(Park):在堆场中存…

linux(ub)-redis环境部署

1.下载redis包 wget http://download.redis.io/releases/redis-7.0.5.tar.gz 2.解压缩: tar -zxvf redis-7.0.5.tar.gz 3.安装gcc:sudo apt-get install gcc 4. 编译:cd redis-7.0.5 make make make install 5. cd /usr/local/bin/ 6. mkdir …

JAVA基础08- 继承,重写,super以及this

目录 继承(extends) 定义 说明 作用 方法的重写 定义 重写关键点 方法重写与重载的区别 练习 练习1(方法继承与重写的简单练习) 练习2(方法继承与重写的进阶练习) This的使用 定义 作用以及注…

动态代理,XML,Dom4j

文章目录 动态代理概述特点代码实现实现的关键步骤优点 XML概述作用编写第一个XML文件组成声明元素(标签、标记)属性注释转义字符[实体字符字符区(了解) 约束DTD约束Schema约束名称空间 Dom4jXML解析解析方式和解析器解析方式解析器Snipaste_2024-04-17_21-22-44.png<br /&g…

企业linux-堡垒机与跳板机测试案例-6140字详谈

在开始今天内容前&#xff0c;小编先把专栏前面学的Linux命令&#xff08;部分&#xff09;做了思维导图帮助各位平时的学习&#xff1a; 场景&#xff1a; 运维人员管理三台机器&#xff0c;通过远程连接工具连接上三台机器&#xff0c;也知道这三台机器root密码&#xff0c…

Xshell无法输入命令输入命令卡顿

Xshell是一款功能强大的终端模拟软件&#xff0c;可以让用户通过SSH、Telnet、Rlogin、SFTP等协议远程连接到Linux、Unix、Windows等服务器。然而&#xff0c;在使用Xshell的过程中&#xff0c;我们可能会遇到一些问题。比如输入不了命令&#xff0c;或者输入命令很卡。这些问题…

C++笔记:异常

文章目录 C 运行时错误处理机制及其不足之处C 异常概念异常的使用异常的抛出和匹配原则在函数调用链中异常栈展开匹配原则异常的重新抛出举例演示说明例子一&#xff1a;串联举例演示大部分原则例子二&#xff1a;模拟服务器开发中常用的异常继承体系例子三&#xff1a;异常的重…

千锤百炼之每日算法(一)

目录 题外话 正题 第一题 第一题思路 第一题代码详解 第二题 第二题思路 动态规划 解法一 解法一代码详解 解法二 第三题 第三题思路 第三题代码详解 小结 题外话 从今天开始,每天都会更新算法题,一天就三道题 大家最好采用码形结合的方式,也就是代码和图形结合…

最新最全的Jmeter接口测试必会技能:jmeter对图片验证码的处理

jmeter对图片验证码的处理 在web端的登录接口经常会有图片验证码的输入&#xff0c;而且每次登录时图片验证码都是随机的&#xff1b;当通过jmeter做接口登录的时候要对图片验证码进行识别出图片中的字段&#xff0c;然后再登录接口中使用&#xff1b; 通过jmeter对图片验证码…

git出现错误 fail to push some refs to “xxx“

问题产生原因&#xff1a;根据测试猜测造成这一错误的原因是在码云的远程仓库上删除了一个文件,本地没有pull下来,直接进行了commit,commit到本地仓库后,如果在pull下来,也是无法提交的 问题解决办法: 使用 git pull --rebase,拉取远程仓库,并将本地仓库新的提交作为最顶层的提…

C++ 并发编程指南(11)原子操作 | 11.5、内存模型

文章目录 一、C 内存模型1、为什么需要内存模型&#xff1f; 前言 C 11标准中最重要的特性之一&#xff0c;是大多数程序员都不会关注的东西。它并不是新的语法特性&#xff0c;也不是新的类库功能&#xff0c;而是新的多线程感知内存模型。本文介绍的内存模型是指多线程编程方…

(算法版)基于二值图像数字矩阵的距离变换算法

Hi&#xff0c;大家好&#xff0c;我是半亩花海。本项目展示了欧氏距离、城市街区距离和棋盘距离变换的实现方法。通过定义一个距离变换类&#xff0c;对输入图像进行距离变换操作&#xff0c;并生成对应的距离矩阵。在示例中&#xff0c;展示了在一个480x480的全黑背景图像上设…