PyTorch基于注意力的目标检测模型DETR

news2024/11/14 20:08:32

【图书推荐】《PyTorch深度学习与计算机视觉实践》-CSDN博客

目标检测是计算机视觉领域的一个重要任务,它的目标是在图像或视频中识别并定位出特定的对象。在这个过程中,需要确定对象的位置和类别,以及可能存在的多个实例。

DETR模型通过端到端的方式进行目标检测,即从原始图像直接检测出目标的位置和类别,而不需要进行区域提议或特征金字塔等步骤。

DETR模型的核心思想是将目标检测任务转换为一个序列到序列的问题。它将输入图像视为一个序列,并使用Transformer编码器将其转换为一种可被解码器理解的形式。具体来说,DETR模型使用CNN来提取图像特征,然后将其输入Transformer编码器中进行处理。再使用一个Transformer解码器来逐步解码出目标的位置和类别。完整的DETR的架构如图13-11所示。

图13-11  完整的DETR模型架构

下面借用在13.2节中实现的DETR目标检测模型进行讲解。完整的DETR模型代码如下:

import torch
from torch import nn
from torchvision.models import resnet50

class DETR(nn.Module):
    def __init__(self,num_classes = 92,hidden_dim=256,nheads=8,num_encoder_layers=6,num_decoder_layers=6):
        super().__init__()
        #创建ResNet-50的骨干(backbone)网
        with torch.no_grad():
            self.backbone = resnet50()
            #清除ResNet-50骨干网最后的全连接层
            del self.backbone.fc
        #创建转换层,1×1的卷积,主要起到改变通道大小的作用
        self.conv = nn.Conv2d(2048,hidden_dim,1)
        #利用PyTorch内嵌的类创建Transformer实例
        self.transformer = nn.Transformer(hidden_dim,nheads,num_encoder_layers,num_decoder_layers)
        #预测头,多出的类别用于预测non-empty slots
        self.linear_class = nn.Linear(hidden_dim,num_classes)
        self.linear_bbox = nn.Linear(hidden_dim,4)
        # 输出检测槽编码(object queries)
        self.query_pos = nn.Parameter(torch.rand(100,hidden_dim))
        #可学习的位置编码,用于指导输入图形的坐标
        self.row_embed = nn.Parameter(torch.rand(50,hidden_dim//2))
        self.col_embed = nn.Parameter(torch.rand(50,hidden_dim//2))
        self._reset_parameters()

    def forward(self,inputs):
        #将ResNet-50网络作为backbone
        x = self.backbone.conv1(inputs)       
        x = self.backbone.bn1(x)                
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)      
        x = self.backbone.layer1(x)             
        x = self.backbone.layer2(x)             
        x = self.backbone.layer3(x)             
        x = self.backbone.layer4(x)     	#将ResNet-50网络作为backbone

        #从2048维度转换到可被Transformer接受的256维特征平面
        h = self.conv(x)                                        
        #(1,2048,25,34)->(1,hidden_dim,25,34)
        # 构建位置编码
        B,C,H,W = h.shape
        #创建一个可训练的与输入向量同样维度的位置向量,与原始的DETR的不同之处在于这里的位置向量是可训练的
        pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H,1,1),self. row_embed[:H].unsqueeze(1).repeat(1,W,1),],dim=-1).flatten(0,1).unsqueeze(1)
		
	   #将图像特征与位置信息进行合并
        src = pos+0.1*h.flatten(2).permute(2,0,1)
        #创建查询函数
        tgt = self.query_pos.unsqueeze(1).repeat(1,B,1)
        #通过Transformer继续前向传播
        #参数1:(h*w,batch_size,256),参数2:(100,batch_size,hidden_dim)
        #输出:(hidden_dim,100)-->(100,hidden_dim)
        h = self.transformer(src,tgt).transpose(0,1)
        #将Transformer的输出投影到分类标签及边界框
        return {'pred_logits':self.linear_class(h),'pred_boxes': self.linear_bbox(h).sigmoid()}

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                torch.nn.init.xavier_uniform_(p)

从上面模型架构的实现代码上来看,整体DETR设计较为简单,可以分为3个主要部分:backbone、Transfomer和FFN。

1. backbone组件

backbone是DETR模型的第一部分,主要用于在图像上提取特征,生成特征图。这些特征图将作为输入传递给Transformer Encoder。backbone通常使用类似于ResNet或CNN模型来提取特征。

DETR将Resnet50作为backbone进行特征抽取,这样做的目的是可以直接使用PyTorch 2.0中提供的预训练模型和权重,从而节省了训练时间。

2. Transformer构成

Transformer是DETR模型的第二部分,它是由编码器和解码器构成,如图13-12所示。

编码器用于对backbone输出的特征图进行编码。这个编码过程主要是通过多头自注意力机制实现的。在DETR模型中,每个多头自注意力之前都使用了位置编码,这种位置编码方式可以帮助模型更好地理解图像中的空间信息。

图13-12  DETR中的Transformer组件

3. 分类器FFN

FFN一般使用两个全连接层作为分类器,其作用是对基于Transformer编码和查询后的特征向量进行分类计算,代码如下:

{'pred_logits':self.linear_class(h),'pred_boxes':self.linear_bbox(h).sigmoid()}

这里的self.linear_class和linear_bbox分别是对查询结果类别和位置的计算,分别用于预测分类和边界框回归。

以上就是对DETR模型的讲解。可以看到,DETR模型在架构设计上并没有太过于难懂的部分,可以认为是前面所学知识的集成。DETR在目标检测上的成功除了模型的设计外,还有一个重大创新就是开创性地提出了新的损失函数,目标检测中的损失函数通常由两部分组成:类别损失和边界框损失。对于类别损失,一般采用交叉熵损失函数,而在边界框损失方面,一般采用L1或L2损失函数。然而,DETR算法采用了不同的方式来计算类别损失和边界框损失。

DETR算法中的损失函数采用了基于二部图匹配的方式进行计算。具体来说,该算法首先将ground truth和预测的bounding box进行匹配,然后通过对比匹配结果和真实标签之间的差异来计算损失值。

《PyTorch深度学习与计算机视觉实践(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书 (jd.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1944497.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2.3 大模型硬件基础:AI芯片(上篇) —— 《带你自学大语言模型》系列

本系列目录 《带你自学大语言模型》系列部分目录及计划,完整版目录见:带你自学大语言模型系列 —— 前言 第一部分 走进大语言模型(科普向) 第一章 走进大语言模型 1.1 从图灵机到GPT,人工智能经历了什么&#xff1…

vue3 学习笔记17 -- 基于el-menu封装菜单

vue3 学习笔记17 – 基于el-menu封装菜单 前提条件:组件创建完成 配置路由 // src/router/index.ts import { createRouter, createWebHashHistory } from vue-router import type { RouteRecordRaw } from vue-router export const Layout () > import(/lay…

FlutterFlame游戏实践#16 | 生命游戏 - 编辑与交互

theme: cyanosis 本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究! Flutter\&Flame 游戏开发系列前言: 该系列是 [张风捷特烈] 的 Flame 游戏开发教程。Flutter 作为 全平台 的 原生级 渲…

ios 15-16手机绕过ssl验证(抓取app上的https包)

绕过ssl验证的基本流程 前提概要:为什么你的charles抓不了https包 ios 越狱ios rootful安装ios 越狱商店sileo安装substitute越狱商店安装SSL Kill Switch3 全流程坑点巨多,博主亲身踩坑,务必按着步骤来 准备工作 type b to c 的数据线苹果…

读论文《Hi-Net: Hybrid-fusion Network for Multi-modalMR Image Synthesis》

论文题目:Hi-Net:用于多模态磁共振图像合成的混合融合网络 论文地址:arxiv 项目地址:github 原项目可能在训练的时候汇报version的错,这是因为生成器和辨别器的优化有些逻辑错误,会改的话多加一个生成操作可以解决&…

数字信号处理基础知识(二)

在介绍完“离散时间序列”基本概念和性质后,实际上就已经踏入了“数字信号处理”这门学科的学习征程,这篇文章里主要去说明“线性时不变系统”的定义概念和探讨“周期采样”的注意细节,相信更加理解这些概念定义和底层逻辑,对于大…

python+vue3+onlyoffice在线文档系统实战20240723笔记,项目界面设计和初步开发

经过之前的学习,已经能够正常打开文档了。 目前为止,我们的代码能够实现: 打开文档编辑文档手动保存自动保存虽然功能依然比较少,但是我们已经基本实现了文档管理最核心的功能,而且我们有个非常大的优势,就是支持多人同时在线协同编辑。 现在我们要开发项目,我们得做基…

Golang | Leetcode Golang题解之第279题完全平方数

题目: 题解: // 判断是否为完全平方数 func isPerfectSquare(x int) bool {y : int(math.Sqrt(float64(x)))return y*y x }// 判断是否能表示为 4^k*(8m7) func checkAnswer4(x int) bool {for x%4 0 {x / 4}return x%8 7 }func numSquares(n int) i…

Python的注释怎么写

今天我们讲一下Python的注释怎么写,Python的注释的写法主要就是用""" (注释)"""和 #(注释(多半就是一行)) 来写 第一种: 使用""" &…

【linux】Shell脚本三剑客之sed命令的详细用法攻略

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全…

从零开始:神经网络(1)——什么是人工神经网络

声明:本文章是根据网上资料,加上自己整理和理解而成,仅为记录自己学习的点点滴滴。可能有错误,欢迎大家指正。 人工神经网络(Artificial Neural Network,简称ANN)是一种模仿生物神经网络结构和功…

【vue教程】三. 组件复用和通信(7 种方式)

目录 本章涵盖知识点回顾 组件开发与复用组件的创建和注册全局定义局部定义单文件组件(.vue 文件)组件的注册方式在实例中注册在 Vue 中注册 组件的 props定义 props传递 props 组件事件自定义事件的创建和触发父组件监听子组件事件父组件处理事件 Vue 实…

网格布局 HTML CSS grid layout demo

文章目录 页面效果代码 (HTML CSS)参考 页面效果 代码 (HTML CSS) <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"…

Golang | Leetcode Golang题解之第275题H指数II

题目&#xff1a; 题解&#xff1a; func hIndex(citations []int) int {n : len(citations)return n - sort.Search(n, func(x int) bool { return citations[x] > n-x }) }

你了解GD32 MCU上下电要求吗

你了解GD32 MCU的上下电要求吗&#xff1f;MCU的上下电对于系统的稳定运行非常重要。 以GD32F30X为例&#xff0c;上电/掉电复位波形如如下图所示。 上电过程中&#xff0c;VDD/VDDA电压上电爬坡&#xff0c;当电压高于VPOR&#xff08;上电复位电压&#xff09;MCU开始启动&a…

设计测试用例的具体方法

一.等价类 等价类分为: 1.有效等价类 [6~15] 2.无效等价类 :小于6位,大于15位(不在数据范围内) 组合规则: 有效等价类组合的时候,尽可能一条测试用例尽可能多的覆盖有效等价类 无效等价类组合的时候,一条测试点,之恶能覆盖一个无效等价类 二.边界值 1.上点,离点,内点 上…

科技引领水资源管理新篇章:深入剖析智慧水利解决方案,展现其在提升水资源利用效率、优化水环境管理方面的创新实践

本文关键词&#xff1a;智慧水利、智慧水利工程、智慧水利发展前景、智慧水利技术、智慧水利信息化系统、智慧水利解决方案、数字水利和智慧水利、数字水利工程、数字水利建设、数字水利概念、人水和协、智慧水库、智慧水库管理平台、智慧水库建设方案、智慧水库解决方案、智慧…

git clone超时的解决方法

问题描述&#xff1a;在克隆一个仓库的时候&#xff0c;报错如下 git clone https://github.com/TeamWiseFlow/wiseflow.git Cloning into wiseflow... fatal: unable to access https://github.com/TeamWiseFlow/wiseflow.git/: Failed to connect to github.com port 443 aft…

【PyTorch】图像二分类项目

【PyTorch】图像二分类项目 【PyTorch】图像二分类项目-部署 【PyTorch】图像多分类项目 【PyTorch】图像多分类项目部署 图像分类是计算机视觉中的一项重要任务。在此任务中&#xff0c;我们假设每张图像只包含一个主对象。在这里&#xff0c;我们的目标是对主要对象进行分类。…

C#开源、简单易用的Dapper扩展类库 - Dommel

项目特性 Dommel 使用 IDbConnection 接口上的扩展方法为 CRUD 操作提供了便捷的 API。 Dommel 能够根据你的 POCO 实体自动生成相应的 SQL 查询语句。这大大减少了手动编写 SQL 代码的工作量&#xff0c;并提高了代码的可读性和可维护性。 Dommel 支持 LINQ 表达式&#xff…