ICCV 2023 | 动态蛇形卷积(内含即插即用的代码及测试用例)

news2025/1/8 4:53:24

论文链接:

https://arxiv.org/abs/2307.08388

代码链接:

https://github.com/YaoleiQi/DSCNet

下面直接上代码,并且源码中也给了测试用例,是一个即插即用的模块

import os
import torch
import numpy as np
from torch import nn
import warnings

warnings.filterwarnings("ignore")

"""
This code is mainly the deformation process of our DSConv
"""


class DSConv(nn.Module):

    def __init__(self, in_ch, out_ch, kernel_size, extend_scope, morph,
                 if_offset, device):
        """
        The Dynamic Snake Convolution
        :param in_ch: input channel
        :param out_ch: output channel
        :param kernel_size: the size of kernel
        :param extend_scope: the range to expand (default 1 for this method)
        :param morph: the morphology of the convolution kernel is mainly divided into two types
                        along the x-axis (0) and the y-axis (1) (see the paper for details)
        :param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
        :param device: set on gpu
        """
        super(DSConv, self).__init__()
        # use the <offset_conv> to learn the deformable offset
        self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
        self.bn = nn.BatchNorm2d(2 * kernel_size)
        self.kernel_size = kernel_size

        # two types of the DSConv (along x-axis and y-axis)
        self.dsc_conv_x = nn.Conv2d(
            in_ch,
            out_ch,
            kernel_size=(kernel_size, 1),
            stride=(kernel_size, 1),
            padding=0,
        )
        self.dsc_conv_y = nn.Conv2d(
            in_ch,
            out_ch,
            kernel_size=(1, kernel_size),
            stride=(1, kernel_size),
            padding=0,
        )

        self.gn = nn.GroupNorm(out_ch // 4, out_ch)
        self.relu = nn.ReLU(inplace=True)

        self.extend_scope = extend_scope
        self.morph = morph
        self.if_offset = if_offset
        self.device = device

    def forward(self, f):
        offset = self.offset_conv(f)
        offset = self.bn(offset)
        # We need a range of deformation between -1 and 1 to mimic the snake's swing
        offset = torch.tanh(offset)
        input_shape = f.shape
        dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph,
                  self.device)
        deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
        if self.morph == 0:
            x = self.dsc_conv_x(deformed_feature)
            x = self.gn(x)
            x = self.relu(x)
            return x
        else:
            x = self.dsc_conv_y(deformed_feature)
            x = self.gn(x)
            x = self.relu(x)
            return x


# Core code, for ease of understanding, we mark the dimensions of input and output next to the code
class DSC(object):

    def __init__(self, input_shape, kernel_size, extend_scope, morph, device):
        self.num_points = kernel_size
        self.width = input_shape[2]
        self.height = input_shape[3]
        self.morph = morph
        self.device = device
        self.extend_scope = extend_scope  # offset (-1 ~ 1) * extend_scope

        # define feature map shape
        """
        B: Batch size  C: Channel  W: Width  H: Height
        """
        self.num_batch = input_shape[0]
        self.num_channels = input_shape[1]

    """
    input: offset [B,2*K,W,H]  K: Kernel size (2*K: 2D image, deformation contains <x_offset> and <y_offset>)
    output_x: [B,1,W,K*H]   coordinate map
    output_y: [B,1,K*W,H]   coordinate map
    """

    def _coordinate_map_3D(self, offset, if_offset):
        # offset
        y_offset, x_offset = torch.split(offset, self.num_points, dim=1)

        y_center = torch.arange(0, self.width).repeat([self.height])
        y_center = y_center.reshape(self.height, self.width)
        y_center = y_center.permute(1, 0)
        y_center = y_center.reshape([-1, self.width, self.height])
        y_center = y_center.repeat([self.num_points, 1, 1]).float()
        y_center = y_center.unsqueeze(0)

        x_center = torch.arange(0, self.height).repeat([self.width])
        x_center = x_center.reshape(self.width, self.height)
        x_center = x_center.permute(0, 1)
        x_center = x_center.reshape([-1, self.width, self.height])
        x_center = x_center.repeat([self.num_points, 1, 1]).float()
        x_center = x_center.unsqueeze(0)

        if self.morph == 0:
            """
            Initialize the kernel and flatten the kernel
                y: only need 0
                x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
                !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
            """
            y = torch.linspace(0, 0, 1)
            x = torch.linspace(
                -int(self.num_points // 2),
                int(self.num_points // 2),
                int(self.num_points),
            )

            y, x = torch.meshgrid(y, x)
            y_spread = y.reshape(-1, 1)
            x_spread = x.reshape(-1, 1)

            y_grid = y_spread.repeat([1, self.width * self.height])
            y_grid = y_grid.reshape([self.num_points, self.width, self.height])
            y_grid = y_grid.unsqueeze(0)  # [B*K*K, W,H]

            x_grid = x_spread.repeat([1, self.width * self.height])
            x_grid = x_grid.reshape([self.num_points, self.width, self.height])
            x_grid = x_grid.unsqueeze(0)  # [B*K*K, W,H]

            y_new = y_center + y_grid
            x_new = x_center + x_grid

            y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(self.device)
            x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(self.device)

            y_offset_new = y_offset.detach().clone()

            if if_offset:
                y_offset = y_offset.permute(1, 0, 2, 3)
                y_offset_new = y_offset_new.permute(1, 0, 2, 3)
                center = int(self.num_points // 2)

                # The center position remains unchanged and the rest of the positions begin to swing
                # This part is quite simple. The main idea is that "offset is an iterative process"
                y_offset_new[center] = 0
                for index in range(1, center):
                    y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
                    y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
                y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(self.device)
                y_new = y_new.add(y_offset_new.mul(self.extend_scope))

            y_new = y_new.reshape(
                [self.num_batch, self.num_points, 1, self.width, self.height])
            y_new = y_new.permute(0, 3, 1, 4, 2)
            y_new = y_new.reshape([
                self.num_batch, self.num_points * self.width, 1 * self.height
            ])
            x_new = x_new.reshape(
                [self.num_batch, self.num_points, 1, self.width, self.height])
            x_new = x_new.permute(0, 3, 1, 4, 2)
            x_new = x_new.reshape([
                self.num_batch, self.num_points * self.width, 1 * self.height
            ])
            return y_new, x_new

        else:
            """
            Initialize the kernel and flatten the kernel
                y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
                x: only need 0
            """
            y = torch.linspace(
                -int(self.num_points // 2),
                int(self.num_points // 2),
                int(self.num_points),
            )
            x = torch.linspace(0, 0, 1)

            y, x = torch.meshgrid(y, x)
            y_spread = y.reshape(-1, 1)
            x_spread = x.reshape(-1, 1)

            y_grid = y_spread.repeat([1, self.width * self.height])
            y_grid = y_grid.reshape([self.num_points, self.width, self.height])
            y_grid = y_grid.unsqueeze(0)

            x_grid = x_spread.repeat([1, self.width * self.height])
            x_grid = x_grid.reshape([self.num_points, self.width, self.height])
            x_grid = x_grid.unsqueeze(0)

            y_new = y_center + y_grid
            x_new = x_center + x_grid

            y_new = y_new.repeat(self.num_batch, 1, 1, 1)
            x_new = x_new.repeat(self.num_batch, 1, 1, 1)

            y_new = y_new.to(self.device)
            x_new = x_new.to(self.device)
            x_offset_new = x_offset.detach().clone()

            if if_offset:
                x_offset = x_offset.permute(1, 0, 2, 3)
                x_offset_new = x_offset_new.permute(1, 0, 2, 3)
                center = int(self.num_points // 2)
                x_offset_new[center] = 0
                for index in range(1, center):
                    x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
                    x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
                x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(self.device)
                x_new = x_new.add(x_offset_new.mul(self.extend_scope))

            y_new = y_new.reshape(
                [self.num_batch, 1, self.num_points, self.width, self.height])
            y_new = y_new.permute(0, 3, 1, 4, 2)
            y_new = y_new.reshape([
                self.num_batch, 1 * self.width, self.num_points * self.height
            ])
            x_new = x_new.reshape(
                [self.num_batch, 1, self.num_points, self.width, self.height])
            x_new = x_new.permute(0, 3, 1, 4, 2)
            x_new = x_new.reshape([
                self.num_batch, 1 * self.width, self.num_points * self.height
            ])
            return y_new, x_new

    """
    input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H] 
    output: [N,1,K*D,K*W,K*H]  deformed feature map
    """

    def _bilinear_interpolate_3D(self, input_feature, y, x):
        y = y.reshape([-1]).float()
        x = x.reshape([-1]).float()

        zero = torch.zeros([]).int()
        max_y = self.width - 1
        max_x = self.height - 1

        # find 8 grid locations
        y0 = torch.floor(y).int()
        y1 = y0 + 1
        x0 = torch.floor(x).int()
        x1 = x0 + 1

        # clip out coordinates exceeding feature map volume
        y0 = torch.clamp(y0, zero, max_y)
        y1 = torch.clamp(y1, zero, max_y)
        x0 = torch.clamp(x0, zero, max_x)
        x1 = torch.clamp(x1, zero, max_x)

        input_feature_flat = input_feature.flatten()
        input_feature_flat = input_feature_flat.reshape(
            self.num_batch, self.num_channels, self.width, self.height)
        input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
        input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
        dimension = self.height * self.width

        base = torch.arange(self.num_batch) * dimension
        base = base.reshape([-1, 1]).float()

        repeat = torch.ones([self.num_points * self.width * self.height
                             ]).unsqueeze(0)
        repeat = repeat.float()

        base = torch.matmul(base, repeat)
        base = base.reshape([-1])

        base = base.to(self.device)

        base_y0 = base + y0 * self.height
        base_y1 = base + y1 * self.height

        # top rectangle of the neighbourhood volume
        index_a0 = base_y0 - base + x0
        index_c0 = base_y0 - base + x1

        # bottom rectangle of the neighbourhood volume
        index_a1 = base_y1 - base + x0
        index_c1 = base_y1 - base + x1

        # get 8 grid values
        value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(self.device)
        value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(self.device)
        value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(self.device)
        value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(self.device)

        # find 8 grid locations
        y0 = torch.floor(y).int()
        y1 = y0 + 1
        x0 = torch.floor(x).int()
        x1 = x0 + 1

        # clip out coordinates exceeding feature map volume
        y0 = torch.clamp(y0, zero, max_y + 1)
        y1 = torch.clamp(y1, zero, max_y + 1)
        x0 = torch.clamp(x0, zero, max_x + 1)
        x1 = torch.clamp(x1, zero, max_x + 1)

        x0_float = x0.float()
        x1_float = x1.float()
        y0_float = y0.float()
        y1_float = y1.float()

        vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(self.device)
        vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(self.device)
        vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(self.device)
        vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(self.device)

        outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
                   value_c1 * vol_c1)

        if self.morph == 0:
            outputs = outputs.reshape([
                self.num_batch,
                self.num_points * self.width,
                1 * self.height,
                self.num_channels,
            ])
            outputs = outputs.permute(0, 3, 1, 2)
        else:
            outputs = outputs.reshape([
                self.num_batch,
                1 * self.width,
                self.num_points * self.height,
                self.num_channels,
            ])
            outputs = outputs.permute(0, 3, 1, 2)
        return outputs

    def deform_conv(self, input, offset, if_offset):
        y, x = self._coordinate_map_3D(offset, if_offset)
        deformed_feature = self._bilinear_interpolate_3D(input, y, x)
        return deformed_feature


# Code for testing the DSConv
if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    A = np.random.rand(4, 5, 6, 7)
    # A = np.ones(shape=(3, 2, 2, 3), dtype=np.float32)
    # print(A)
    A = A.astype(dtype=np.float32)
    A = torch.from_numpy(A)
    # print(A.shape)
    conv0 = DSConv(
        in_ch=5,
        out_ch=10,
        kernel_size=15,
        extend_scope=1,
        morph=0,
        if_offset=True,
        device=device)
    if torch.cuda.is_available():
        A = A.to(device)
        conv0 = conv0.to(device)
    out = conv0(A)
    print(out.shape)

谢谢小伙伴们多多支持,对了,如果希望出一个将其改进到不同的模型中的教程可以在评论区留言哦

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1258542.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

怎么做excel表格的二维码?文件快速做二维码的教程

Excel表格怎么做成二维码来扫码插看呢&#xff1f;Excel是工作中常用的一种文件格式&#xff0c;想要将表格内容分享给其他人查看&#xff0c;那么将表格生成二维码的方法会更加的方便快捷&#xff0c;其他人只需要扫描二维码就可以查看或者下载文件。表格excel二维码可以通过文…

Rust语言入门教程(六) - 字符串类型

在Rust中&#xff0c; 字符串类型其实是一个比较复杂的话题。在Rust的标准库中&#xff0c;至少都提供了6种字符串类型&#xff0c;我们平常使用的最多的是其中的两种。这两种类型互相之间也有所关联&#xff1a; str&#xff1a; 字符串切片String 字符串 其中&#xff0c; 字…

细说数据仓库上篇

在谈数仓之前&#xff0c;先来看下面几个问题&#xff1a; 数仓为什么要分层&#xff1f; 用空间换时间&#xff0c;通过大量的预处理来提升应用系统的用户体验&#xff08;效率&#xff09;&#xff0c;因此数据仓库会存在大量冗余的数据&#xff1b;不分层的话&#xff0c;…

高效运维管理的7个要点

T管理和运维工作涵盖了各行业的各岗位中&#xff0c;如何提高工作效率&#xff0c;规避风险&#xff0c;更好的做好IT管理和运维工作&#xff0c;已经成为一个不断探索和研究的新兴课题。因此&#xff0c;应从两个层面加强和完善IT管理和运维工作&#xff0c;可以改善IT运维工作…

为什么要编写测试用例,自己知道不就行了吗

“为什么要编写测试用例&#xff0c;测试用例写给谁看”&#xff0c;这个问题看似简单&#xff0c;但却涵盖了一系列复杂的考虑因素&#xff0c;并不太好回答。 为了向各位学测试的同学们解释清楚“为什么编写测试用例是至关重要的”&#xff0c;我将通过以下5个方面进行展开&…

MySQL-05-MySQL的日志系统

1-redo log(重做日志) 在MySQL里也有这个问题&#xff0c;如果每一次的更新操作都需要写进磁盘&#xff0c;然后磁盘也要找到对应的那条记录&#xff0c;然后再更新&#xff0c;整个过程IO成本、查找成本都很高。MySQL里经常说到的WAL技术&#xff0c;WAL的全称是Write-Ahead L…

pmos防反保护电路的设计,pmos烧毁原因分析

概述 汽车电源系统常在极为恶劣的环境下运行&#xff0c;数以百计的负载挂在汽车电池上&#xff0c;需要同时确定负载状态的汽车电池可能面临极大的挑战。当负载处于不同工作条件和潜在故障状态时&#xff0c;设计人员需要考虑电源线产生的各种脉冲可能带来的影响。 本系列的上…

VSD Viewer for Mac(Visio绘图文件阅读器)

VSD Viewer for Mac版是mac上一款非常强大的Visio绘图文件阅读器&#xff0c;它为打开和打印Visio文件提供了简单的解决方案。可以显示隐藏的图层&#xff0c;查看对象的形状数据&#xff0c;预览超链接。还可以将Visio转换为包含图层&#xff0c;形状数据和超链接的PDF文档。 …

RHEL开发者授权注册

$ sudo subscription-manager register --usernameusername --passwordpassword$ sudo subscription-manager attach --auto查看是否注册 Red Hat 订阅管理&#xff0c;请运行以下命令&#xff1a; $ sudo subscription-manager list --installed

mysql账户密码获取

数据库安装目录 MySQL\data\mysql 里面的user.MYD文件&#xff0c;需要编译查看 数据库里的user表 库下面的user表拿到后&#xff0c;直接解密密码即可 网站配置文件 conn、config、data、sql、common 、inc这些文件 比如pikachu\inc目录下的config.inc.php文件的内容会显示…

详解如何使用VSCode搭建TypeScript环境(适合小白)

搭建Javascript环境 因为TypeScript不能直接在浏览器上运行。它需要编译器来编译并生成JavaScript文件。所以需要首先安装好javascript环境&#xff0c;可以参考文章&#xff1a; 详解如何使用VS code搭建JavaScript环境&#xff08;适合小白&#xff09;_vscode配置javascri…

从 15000 家参赛企业脱颖而出,涛思数据荣获中国创新创业大赛“优秀企业”

近年来&#xff0c;以大数据、人工智能、物联网、新型显示、高性能集成电路、5G通信、云计算等为代表的创新技术加速突破应用&#xff0c;在传统行业的数字化转型进程中发挥着重要作用&#xff0c;催生出一系列新产品、新技术、新业态&#xff0c;形成了强劲的数字经济发展新动…

单细胞featureplot美化修改-自定义修改图片样式-umap密度图画等高线

大家好&#xff0c;欢迎来的单细胞图片美化专辑 1.如何修改seruat对象的行名 2.FeaturePlot如何把所有阳性表达的spot放到图的前面 在单细胞实践中&#xff0c;我发现不同的客户对画图需求并不一致&#xff0c;这可能和个人审美有关吧。本专辑着重于各种各样的单细胞个性化绘…

冯·诺依曼体系结构和操作系统

目录 一、冯诺依曼体系结构 1、初见结构 2、对体系结构的理解 3、总结 二、操作系统 1、概念 2、作用 一、冯诺依曼体系结构 1、初见结构 数学家冯诺依曼提出了计算机制造的三个基本原则&#xff0c;即采用二进制逻辑、程序存储执行以及计算机由五个部分组成&#xff08…

【C语言】优化通讯录管理系统2

本篇博客是基于上一篇博客写出来的&#xff0c;了解上一篇博客 大家好&#xff0c;我是苏貝&#xff0c;本篇博客带大家再次优化上一篇的通讯录&#xff0c;实现将录入的数据在程序退出后存储到文件中&#xff0c;在下一次程序开始时打开文件获取数据&#xff0c;如果你觉得我写…

耶鲁博弈论笔记

编辑记录&#xff1a; 1126&#xff1a;开个新坑&#xff0c;耶鲁大学的博弈论课程&#xff0c; 和专业相关不大&#xff0c;纯兴趣&#xff0c;尽量写好一点吧 1. 首先指出博弈论是一种研究策略形式的方法&#xff0c;对于经济学中&#xff0c;完全竞争市场只能被动接受均衡…

浏览器中实现可视化的方式有哪几种?带你盘点一下

前言 &#x1f4eb; 大家好&#xff0c;我是南木元元&#xff0c;热衷分享有趣实用的文章&#xff0c;希望大家多多支持&#xff0c;一起进步&#xff01; &#x1f345; 个人主页&#xff1a;南木元元 目录 可视化的含义 浏览器中实现可视化的4种方式 1. HTMLCSS 2. SVG …

鱼哥赠书活动第④期:从0到1Python进阶《利用Python进行数据分析》让你学完成为Python大神!!!

鱼哥赠书活动第④期&#xff1a; 《利用Python进行数据分析》、作译者简介&#xff1a;主要变动&#xff1a;购书链接&#xff1a; 适合阅读对象&#xff1a;赠书抽奖规则:往期赠书福利&#xff1a; 《利用Python进行数据分析》、 Python 语言极具吸引力。自从 1991 年诞生以来…

Linux中Netstat命令最常用的五个用法

当涉及到网络故障排除和网络连接监控时&#xff0c;netstat命令是Linux系统中一个非常常用的工具。netstat命令用于显示与网络相关的统计数据以及网络连接信息。无论是检查网络连接状态、监视网络接口还是查找网络服务的使用情况&#xff0c;netstat命令都能提供有用的信息。在…

视频后期效果制作工具Mocha Pro 2022 Plugins mac中文版软件介绍

Mocha Pro 2022 mac是一款专业的三维摄像机反求摩卡跟踪插件&#xff0c;同时也是一款视频后期效果制作工具&#xff0c;Mocha Pro 2022下载能够给数字媒体艺术家提供强大的、直观的和创新的追踪解决方案用简化的界面、加速的工作流程以及轻松追踪和操作镜头的强大性&#xff0…