基于SuperPoint与SuperGlue实现图像配准

news2024/9/29 11:23:42

基于SuperPoint与SuperGlue实现图像配准,项目地址https://github.com/magicleap/SuperGluePretrainedNetwork,使用到了特殊算子grid_sample,在转onnx时要求opset_version为16及以上(即pytorch版本为1.9以上)。SuperPoint模型用于提取图像的特征点和特征点描述符(在进行图像配准时需要运行两个,实现对两个图片特征点的提取),SuperGlue模型用于对SuperPoint模型所提取的特征点和特征描述符进行匹配。

1、前置操作

为实现模型可以onnx部署,对项目中部分代码进行修改。主要是删除代码中对dict对象的使用,因为onnx不支持。

1.1 superpoint修改

代码在models/superpoint.py中,主要修改 SuperPoint模型的forward函数(代码在145行开始),不使用字典对象做参数(输入值和输出值),避免在onnx算子中不支持。同时对keypoints的实现函数进行多种尝试。其中,SuperPoint模型在训练时是只输出坐标点置信度(scores1)和坐标点的描述符(descriptors1),这里的坐标其实就是指特征点。但是,坐标信息仅体现在网格数据中且在进行点匹配时需要xy格式的坐标,为此将scores中置信度值大于阈值的点的坐标进行提取,故此得到了keypoints1(坐标点)。

    def forward(self, data):
        """ Compute keypoints, scores, descriptors for image """
        # Shared Encoder
        x = self.relu(self.conv1a(data))
        x = self.relu(self.conv1b(x))
        x = self.pool(x)
        x = self.relu(self.conv2a(x))
        x = self.relu(self.conv2b(x))
        x = self.pool(x)
        x = self.relu(self.conv3a(x))
        x = self.relu(self.conv3b(x))
        x = self.pool(x)
        x = self.relu(self.conv4a(x))
        x = self.relu(self.conv4b(x))

        # Compute the dense keypoint scores
        cPa = self.relu(self.convPa(x))
        scores = self.convPb(cPa)
        scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
        b, _, h, w = scores.shape
        scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
        scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
        scores = simple_nms(scores, self.config['nms_radius'])

        # Extract keypoints
        #keypoints = [ torch.nonzero(s > self.config['keypoint_threshold']) for s in scores]## nonzero->tensorRT not support
        #keypoints = [torch.vstack(torch.where(s > self.config['keypoint_threshold'])).T for s in scores]## vstack->onnx not support
        #keypoints = [torch.cat(torch.where(s > self.config['keypoint_threshold']),0).reshape(len(s.shape),-1).T for s in scores]# tensor.T->onnx not support
        #keypoints = [none_zero_index(s,self.config['keypoint_threshold']) for s in scores]# where->nonzero ->tensorRT not support
        keypoints = [torch.transpose(torch.cat(torch.where(s>self.config['keypoint_threshold']),0).reshape(len(s.shape),-1),1,0) for s in scores]# transpose->tensorRT not support
        scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]

        # Discard keypoints near the image borders
        keypoints, scores = list(zip(*[
            remove_borders(k, s, self.config['remove_borders'], h*8, w*8)
            for k, s in zip(keypoints, scores)]))

        # Keep the k keypoints with highest score
        if self.config['max_keypoints'] >= 0:
            keypoints, scores = list(zip(*[
                top_k_keypoints(k, s, self.config['max_keypoints'])
                for k, s in zip(keypoints, scores)]))

        # Convert (h, w) to (x, y)
        keypoints = [torch.flip(k, [1]).float() for k in keypoints]

        # Compute the dense descriptors
        cDa = self.relu(self.convDa(x))
        descriptors = self.convDb(cDa)
        descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)

        # Extract descriptors
        descriptors = [sample_descriptors(k[None], d[None], 8)[0]
                       for k, d in zip(keypoints, descriptors)]

        # return {
        #     'keypoints': keypoints,
        #     'scores': scores,
        #     'descriptors': descriptors,
        # }
        return  keypoints[0].unsqueeze(0),scores[0].unsqueeze(0),descriptors[0].unsqueeze(0)

1.2 SuperGlue修改

代码中models/superglue.py中,主要修正由于字典对象在superpoint中被删除后的的影像。

1.2.1 normalize_keypoints函数调整

将原函数的参数image_shape替换为height和width

def normalize_keypoints(kpts,  height, width):
    """ Normalize keypoints locations based on image image_shape"""
    one = kpts.new_tensor(1)
    size = torch.stack([one*width, one*height])[None]
    center = size / 2
    scaling = size.max(1, keepdim=True).values * 0.7
    return (kpts - center[:, None, :]) / scaling[:, None, :]

1.2.2 forword函数修改

将代码中SuperGlue的forward函数使用以下代码替换。主要是修改了传入参数,将先前的字典进行了解包,让一个参数变成了6个;并对函数的返回值进行了修改,同时固定死了图像的size为640*640
SuperGlue模型是根据输入的两组keypoints、scores、descriptors数据,输出两组match_indices, match_mscores信息。第一组用于描述A->B的对应关系,第二组用于描述B->A的对应关系。

    def forward(self, data_descriptors0, data_descriptors1, data_keypoints0, data_keypoints1, data_scores0, data_scores1):
        """Run SuperGlue on a pair of keypoints and descriptors"""
        #, height:int, width:int
        height, width=640,640
        desc0, desc1 = data_descriptors0, data_descriptors1
        kpts0, kpts1 = data_keypoints0, data_keypoints1

        if kpts0.shape[1] == 0 or kpts1.shape[1] == 0:  # no keypoints
            shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
            return kpts0.new_full(shape0, -1, dtype=torch.int),kpts1.new_full(shape1, -1, dtype=torch.int),kpts0.new_zeros(shape0),kpts1.new_zeros(shape1)

        # Keypoint normalization.
        kpts0 = normalize_keypoints(kpts0, height, width)
        kpts1 = normalize_keypoints(kpts1, height, width)

        # Keypoint MLP encoder.
        desc0 = desc0 + self.kenc(kpts0, data_scores0)
        desc1 = desc1 + self.kenc(kpts1, data_scores1)

        # Multi-layer Transformer network.
        desc0, desc1 = self.gnn(desc0, desc1)

        # Final MLP projection.
        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)

        # Compute matching descriptor distance.
        scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
        scores = scores / self.config['descriptor_dim']**.5

        # Run the optimal transport.
        scores = log_optimal_transport(
            scores, self.bin_score,
            iters=self.config['sinkhorn_iterations'])

        # Get the matches with score above "match_threshold".
        max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
        indices0, indices1 = max0.indices, max1.indices
        mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
        mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
        zero = scores.new_tensor(0)
        mscores0 = torch.where(mutual0, max0.values.exp(), zero)
        mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
        valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
        valid1 = mutual1 & valid0.gather(1, indices1)
        indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
        indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
        # return {
        #     'matches0': indices0, # use -1 for invalid match
        #     'matches1': indices1, # use -1 for invalid match
        #     'matching_scores0': mscores0,
        #     'matching_scores1': mscores1,
        # }
        return indices0,  indices1,  mscores0,  mscores1

1.3 集成调用

在进行图像配准时,使用superpoint模型和superglue模型的数据处理流程都是固定,为简化代码,故此将其封装为一个模型,代码保存为SPSP.py。

import torch
from superglue import SuperGlue
from superpoint import SuperPoint
import torch
import torch.nn as nn
import torch.nn.functional as F
class SPSG(nn.Module):#
    def __init__(self):
        super(SPSG, self).__init__()
        self.sp_model = SuperPoint({'max_keypoints':700})
        self.sg_model = SuperGlue({'weights': 'outdoor'})
    def forward(self,x1,x2):
        keypoints1,scores1,descriptors1=self.sp_model(x1)
        keypoints2,scores2,descriptors2=self.sp_model(x2)
        #print(scores1.shape,keypoints1.shape,descriptors1.shape)
        #example=(descriptors1.unsqueeze(0),descriptors2.unsqueeze(0),keypoints1.unsqueeze(0),keypoints2.unsqueeze(0),scores1.unsqueeze(0),scores2.unsqueeze(0))
        example=(descriptors1,descriptors2,keypoints1,keypoints2,scores1,scores2)
        indices0,  indices1,  mscores0,  mscores1=self.sg_model(*example)
        #return indices0,  indices1,  mscores0,  mscores1
        matches = indices0[0]
        
        valid = torch.nonzero(matches > -1).squeeze().detach()
        mkpts0 = keypoints1[0].index_select(0, valid);
        mkpts1 = keypoints2[0].index_select(0, matches.index_select(0, valid));
        confidence = mscores0[0].index_select(0, valid);
        return mkpts0, mkpts1, confidence

1.4 图像处理库

进行图像读取、图像显示操作的代码被封装为imgutils库,具体可以查阅https://hpg123.blog.csdn.net/article/details/124824892

2、实现图像配准

2.1 获取匹配点

通过以下步骤,即可获取两个图像的特征点,及特征点匹配度

from imgutils import *
import torch
from SPSG import SPSG
model=SPSG()#.to('cuda')
tensor2a,img2a=read_img_as_tensor("b1.jpg",(640,640),device='cpu')
tensor2b,img2b=read_img_as_tensor("b4.jpg",(640,640),device='cpu')
mkpts0, mkpts1, confidence=model(tensor2a,tensor2b)
myimshows( [img2a,img2b],size=12)

代码执行输出如下所示:

2.2 匹配点绘图

以下代码可以将两个图像中匹配度高于阈值的点进行绘制连接

import cv2 as cv
pt_num = mkpts0.shape[0]
im_dst,im_res=img2a,img2b
img = np.zeros((max(im_dst.shape[0], im_res.shape[0]), im_dst.shape[1]+im_res.shape[1]+10,3))
img[:,:im_res.shape[0],]=im_dst
img[:,-im_res.shape[0]:]=im_res
img=img.astype(np.uint8)
match_threshold=0.6
for i in range(0, pt_num):
    if (confidence[i] > match_threshold):
        pt0 = mkpts0[i].to('cpu').numpy().astype(np.int)
        pt1 = mkpts1[i].to('cpu').numpy().astype(np.int)
        #cv.circle(img, (pt0[0], pt0[1]), 1, (0, 0, 255), 2)
        #cv.circle(img, (pt1[0], pt1[1]+650), (0, 0, 255), 2)
        cv.line(img, pt0, (pt1[0]+im_res.shape[0], pt1[1]), (0, 255, 0), 1)
myimshow( img,size=12)

2.3 截取重叠区

先调用getGoodMatchPoint函数根据阈值筛选匹配度高的特征点,然后计算和透视变化矩阵H,最后提取重叠区域

import cv2
def getGoodMatchPoint(mkpts0, mkpts1, confidence,  match_threshold:float=0.5):
    n = min(mkpts0.size(0), mkpts1.size(0))
    srcImage1_matchedKPs, srcImage2_matchedKPs=[],[]

    if (match_threshold > 1 or match_threshold < 0):
        print("match_threshold error!")

    for i in range(n):
        kp0 = mkpts0[i]
        kp1 = mkpts1[i]
    
        pt0=(kp0[0].item(),kp0[1].item());
        pt1=(kp1[0].item(),kp1[1].item());
        c = confidence[i].item();
        if (c > match_threshold):
            srcImage1_matchedKPs.append(pt0);
            srcImage2_matchedKPs.append(pt1);
    
    return np.array(srcImage1_matchedKPs),np.array(srcImage2_matchedKPs)
pts_src, pts_dst=getGoodMatchPoint(mkpts0, mkpts1, confidence)

h1, status = cv2.findHomography(pts_src, pts_dst, cv.RANSAC, 8)
im_out1 = cv2.warpPerspective(im_dst, h1, (im_dst.shape[1],im_dst.shape[0]))
im_out2 = cv2.warpPerspective(im_res, h1, (im_dst.shape[1],im_dst.shape[0]),16)
#这里 im_dst和im_out1是严格配准的状态
myimshowsCL([im_dst,im_out1,im_res,im_out2],rows=2,cols=2, size=6)

2.4 模型导出

使用以下代码即可实现模型导出

input_names = ["input1","input2"]
output_names = ['mkpts0', 'mkpts1', 'confidence']
dummy_input=(tensor2a,tensor2b)
example_outputs=model(tensor2a,tensor2b)
ONNX_name="model.onnx"
torch.onnx.export(model.eval(), dummy_input, ONNX_name, verbose=True, input_names=input_names,opset_version=16,
                  dynamic_axes={
                        'confidence': {0: 'point_num',},
                        'mkpts0': {0: 'batch_size',},
                        'mkpts1': {0: 'batch_size',}
                       },
                   output_names=output_names)#,example_outputs=example_outputs

3、单独使用superpoint

可以单独使用SuperPoint模型提取图像的特征点

from imgutils import *
import torch
from superpoint import SuperPoint
import cv2

config={'max_keypoints': 400,'keypoint_threshold':0.1}
sp_model=SuperPoint(config).to('cuda')
sp_model=sp_model.eval()

tensor2a,img2a=read_img_as_tensor(r"D:\SuperGluePretrainedNetwork-master\assets\freiburg_sequence\1341847986.762616.png",(640,640),device='cuda')
tensor2b,img2b=read_img_as_tensor(r"D:\SuperGluePretrainedNetwork-master\assets\freiburg_sequence\1341847987.758741.png",(640,640),device='cuda')

keypoints1,scores1,descriptors1=sp_model(tensor2a)
keypoints2,scores2,descriptors2=sp_model(tensor2b)

yanse=(0,0,255)
points=keypoints1[0].int().cpu().numpy()
for i in range(len(points)):
    X,Y=points[i]
    cv2.circle(img2a,(X,Y),3,yanse,2)

points=keypoints2[0].int().cpu().numpy()
for i in range(len(points)):
    X,Y=points[i]
    cv2.circle(img2b,(X,Y),3,yanse,2)
myimshows([img2a,img2b])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/353647.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计讯物联污染源自动监控系统,坚守“绿水青山就是金山银山”

近年来&#xff0c;“绿水青山就是金山银山”的理念在全国各地落地生根&#xff0c;各大城市积极构建环境监测体系&#xff0c;旨在让生态文明成色更足&#xff0c;绿色发展底色更亮。计讯物联污染源自动监控系统作为生态环境部门监督企业排污的“火眼金睛”&#xff0c;充分运…

apifox持续集成+java+企微机器人+xxljob定时推送

总览&#xff1a; apifox做接口测试后&#xff0c;把用例合并组装成测试套件&#xff0c;然后apifox-cli通过终端命令实现把套件执行后&#xff0c;输出本地文件的测试报告html或json。本地解析后拿到有用的解决通过定时执行推送到企微群里。 然后把html一起推到群里。 这个…

【Spark分布式内存计算框架——Spark SQL】8. Shuffle 分区数目、Dataset(上)

4.4 Shuffle 分区数目 运行上述程序时&#xff0c;查看WEB UI监控页面发现&#xff0c;某个Stage中有200个Task任务&#xff0c;也就是说RDD有200分区Partition。 原因&#xff1a;在SparkSQL中当Job中产生Shuffle时&#xff0c;默认的分区数&#xff08;spark.sql.shuffle.p…

基于STM32采用CS创世 SD NAND(贴片SD卡)完成FATFS文件系统移植与测试

一、前言 在STM32项目开发中&#xff0c;经常会用到存储芯片存储数据。 比如&#xff1a;关机时保存机器运行过程中的状态数据&#xff0c;上电再从存储芯片里读取数据恢复&#xff1b;在存储芯片里也会存放很多资源文件。比如&#xff0c;开机音乐&#xff0c;界面上的菜单图…

Selenium + python自动化测试环境搭建

selenium 是一个web的自动化测试工具&#xff0c;不少学习功能自动化的同学开始首选selenium &#xff0c;相因为它相比QTP有诸多有点&#xff1a; 免费&#xff0c;也不用再为破解QTP而大伤脑筋 小巧&#xff0c;对于不同的语言它只是一个包而已&#xff0c;而QTP需要下载安…

JSON字符串解析

目录 依赖 方法 示例 判断JSON是否合格 依赖 方法 JSON.parseObject() JSON.parseArray() 示例 Data public class OrderVo {public String name;public Integer price;public Integer count; } JSON数据 { "name": "苹果手机", "pric…

BIT.8_Linux 多线程

目录Linux线程概念什么是线程线程的优点线程的缺点线程异常线程用途Linux进程VS线程进程和线程总结Linux线程控制POSIX线程库创建线程线程ID及进程地址空间布局进程和线程ID区别内核层面&#xff1a;pid & tgid线程终止线程等待__thread 和 pthread_self()分离线程Linux线程…

《爆肝整理》保姆级系列教程python接口自动化(十七)--Json 数据处理---一次爬坑记(详解)

简介 有些 post 的请求参数是 json 格式的&#xff0c;这个前面发送post 请求里面提到过&#xff0c;需要导入 json模块处理。现在企业公司一般常见的接口因为json数据容易处理&#xff0c;所以绝大多数返回数据也是 json 格式的&#xff0c;我们在做判断时候&#xff0c;往往只…

Guava常用工具类总结

-“Null的含糊语义让人很不舒服。Null很少可以明确地表示某种语义&#xff0c;例如&#xff0c;Map.get(key)返回Null时&#xff0c;可能表示map中的值是null&#xff0c;亦或map中没有key对应的值。Null可以表示失败、成功或几乎任何情况。使用Null以外的特定值&#xff0c;会…

每日学术速递2.17

CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 Subjects: cs.LG 1.Decoupled Model Schedule for Deep Learning Training 标题&#xff1a;深度学习训练的解耦模型时间表 作者&#xff1a;Hongzheng Chen, Cody Hao Yu, Shuai Zheng, Zhen Zhang,…

快速识别台式机的内存条

拿上一根内存条&#xff0c;让一个喜欢IT的识别一下&#xff0c;很多人不一定能说出点内容。 这很正常&#xff0c;IT细分领域太多了&#xff0c;很多搞IT的包括写代码的人可能都没有接触内存条。 硬件的集成度随着硬件技术的提升越来越高&#xff0c;成本也下来了&#xff0c;…

支付宝支付详细流程

1、二维码的生成二维码生成坐标 <!-- zxing生成二维码 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.3.3</version></dependency><dependency><groupId>co…

nvm 控制 node版本

nvm 官网 https://nvm.uihtm.com/ 1、卸掉nodejs&#xff0c;根据官网操作 2、如果之前安装过的nodejs,且安装的目录改变了&#xff0c;需重新配置系统环境 第一步&#xff1a;打开此电脑 > 右键属性 > 高级系统设置 > 环境变量 第二步&#xff1a; 在系统变量中选中…

新手健身准备哪些物品,健身必备蓝牙运动耳机分享

第一次运动健身应该准备什么&#xff1f;运动耳机是一款必备的装备&#xff0c;可以让我们坚持运动的动力源泉&#xff0c;在健身当中远离枯燥乏味&#xff0c;有音乐的加持下健身能够让我们更具动力&#xff0c;有哪些值得入手的蓝牙运动耳机分享呢&#xff1f;看看下面这写分…

Java反射概述

2 反射 2.1 反射概述 Java反射机制&#xff1a;是指在运行时去获取一个类的变量和方法信息。然后通过获取到的信息来创建对象,调用方法的一种机制。由于这种动态性,可以极大的增强程序的灵活性,程序不用在编译期就完成确定,在运行期仍然可以扩展 2.2 反射获取Class类的对象 …

企业的知识文档管理系统需要注重什么?安全和共享能力很重要!

编者按&#xff1a;本文指出了企业的文档管理系统比较注重的能力&#xff0c;并从知识共享和文档安全两方面介绍了老厂商天翎是如何在这块实践的。关键词&#xff1a;知识共享&#xff0c;知识安全&#xff0c;标签分类&#xff0c;智能检索&#xff0c;资料分享&#xff0c;在…

element ui 下拉菜单组件 结合springboot 实现省市区简易三级联动 动态查询 并修改地点的省市区

目录 前言&#xff1a; 一.数据库表结构&#xff1a; 二.下拉菜单组件 2.1 效果展示 2.2下拉菜单的组件代码&#xff1a; 前言&#xff1a; 本篇博客&#xff0c;通过官网的学习&#xff0c;实现下拉菜单动态数据的传递与点击事件&#xff0c;如果只是按部就班的按照官网来…

29岁从事功能测试被辞,面试2个月都找不到工作吗?

最近一个28岁老同学联系我&#xff0c;因为被公司辞退&#xff0c;找我倾诉&#xff0c;于是写下此文。 他是14年二本毕业&#xff0c;在我的印象里人特别懒&#xff0c;不爱学习&#xff0c;专业不好&#xff0c;毕业前因为都没找到合适工作&#xff0c;直接去创业了&#xf…

03:入门篇 - CTK Plugin Framework 基本原理

作者: 一去、二三里 个人微信号: iwaleon 微信公众号: 高效程序员 CTK Plugin Framework 技术是面向 C++ 的动态模型系统。该系统允许插件之间的松散耦合,并且提供了设计良好的方式来进行功能和数据的交互。此外,它没有预先对插件施加限制,这样就可以很容易地将插件的相关…

研报精选230217

目录 【行业230217毕马威】奢侈品行业新气象【行业230217国信证券】医药生物行业2023年2月投资策略&#xff1a;持续关注疫后复苏和创新两大主线【行业230217国金证券】航空锻造&#xff1a;稳定格局筑专业化壁垒&#xff0c;顺势而为拓产业链深度【个股230217西南证券_招商轮船…