AOT源码解析4.4 -decoder生成预测mask并计算loss

news2024/11/16 10:26:25

3、生成ref_imgs的预测mask和loss

这一步在训练阶段调用

3.1 数据处理

在这里插入图片描述

图1,如图1所示,将enc_embs的最后一个比例的特征图和有ref_imgs相关的特征图得到的LSTT特征图相拼接作为输入

        curr_enc_embs = self.curr_enc_embs
        curr_lstt_embs = self.curr_lstt_output[0]

        pred_id_logits = self.AOT.decode_id_logits(curr_lstt_embs,
                                                   curr_enc_embs)

3.2 Decoder结构

在这里插入图片描述

图2, decoder的操作步骤如图,该解码器将enc_embs各个比例的特征图结合到一起

  • Decoder结构
class FPNSegmentationHead(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 decode_intermediate_input=True,
                 hidden_dim=256,
                 shortcut_dims=[24, 32, 96, 1280],
                 align_corners=True):
        super().__init__()
        self.align_corners = align_corners

        self.decode_intermediate_input = decode_intermediate_input

        self.conv_in = ConvGN(in_dim, hidden_dim, 1)

        self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3)
        self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3)
        self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3)

        self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1)
        self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1)
        self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1)

        self.conv_out = nn.Conv2d(hidden_dim // 2, out_dim, 1)

        self._init_weight()

    def forward(self, inputs, shortcuts):

        if self.decode_intermediate_input:
            x = torch.cat(inputs, dim=1)
        else:
            x = inputs[-1]

        x = F.relu_(self.conv_in(x))
        s1 = self.adapter_16x(shortcuts[-2])
        x = F.relu_(self.conv_16x(self.adapter_16x(shortcuts[-2]) + x))

        x = F.interpolate(x,
                          size=shortcuts[-3].size()[-2:],
                          mode="bilinear",
                          align_corners=self.align_corners)
        x = F.relu_(self.conv_8x(self.adapter_8x(shortcuts[-3]) + x))

        x = F.interpolate(x,
                          size=shortcuts[-4].size()[-2:],
                          mode="bilinear",
                          align_corners=self.align_corners)
        x = F.relu_(self.conv_4x(self.adapter_4x(shortcuts[-4]) + x))

        x = self.conv_out(x)

        return x

3.3 计算loss

在这里插入图片描述

  • 对Decoder输出的结果按照对象数量进行分隔
        pred_id_logits = self.pred_id_logits

        pred_id_logits = F.interpolate(pred_id_logits,
                                       size=gt_mask.size()[-2:],
                                       mode="bilinear",
                                       align_corners=self.align_corners)

        label_list = []
        logit_list = []
        for batch_idx, obj_num in enumerate(self.obj_nums):
            now_label = gt_mask[batch_idx].long()
            now_logit = pred_id_logits[batch_idx, :(obj_num + 1)].unsqueeze(0)
            label_list.append(now_label.long())
            logit_list.append(now_logit)
  • 计算loss

在深度学习中,尤其是在图像相关的任务(如图像分割)中,我们通常有大量的像素需要预测。在这种情况下,可能并不是所有的像素对最终的任务都同样重要。
例如,模型可能已经能够很好地预测图像的大部分区域,但是对于一些难以区分的区域(如物体边缘或小物体)预测得不够好。这些难以预测的区域可能正是模型需要关注的重点。

为了使模型更加关注这些难以预测的区域,可以采用一种称为“硬例挖掘”(hard example mining)的技术。这种方法的基本思想是,不是对所有的像素平均地计算损失,而是只关注那些损失最大的像素。

通过这种方式,模型的训练可以更加集中在那些难以正确预测的像素上,从而提高模型的整体性能。具体来说,“top k percent pixels” 指的是按照损失值从高到低排序后,选取前 k 百分比的像素。例如,如果 k 设置为 50%,那么在损失计算中,只会考虑损失最大的前 50% 的像素。

在代码中,这通常是通过以下步骤实现的:

  • 计算所有像素的损失。
  • 根据损失值对像素进行排序。
  • 选择损失值最高的前 k 百分比的像素。
  • 只计算这些选定像素的损失,并将它们加起来作为最终的损失。
class CrossEntropyLoss(nn.Module):
    def __init__(self,
                 top_k_percent_pixels=None,
                 hard_example_mining_step=100000):
        super(CrossEntropyLoss, self).__init__()
        self.top_k_percent_pixels = top_k_percent_pixels
        if top_k_percent_pixels is not None:
            assert (top_k_percent_pixels > 0 and top_k_percent_pixels < 1)
        self.hard_example_mining_step = hard_example_mining_step + 1e-5
        if self.top_k_percent_pixels is None:
            self.celoss = nn.CrossEntropyLoss(ignore_index=255,
                                              reduction='mean')
        else:
            self.celoss = nn.CrossEntropyLoss(ignore_index=255,
                                              reduction='none')


    def forward(self, dic_tmp, y, step):
        total_loss = []
        for i in range(len(dic_tmp)):
            pred_logits = dic_tmp[i]
            gts = y[i]
            if self.top_k_percent_pixels is None:
                final_loss = self.celoss(pred_logits, gts)
            else:
                # Only compute the loss for top k percent pixels.
                # First, compute the loss for all pixels. Note we do not put the loss
                # to loss_collection and set reduction = None to keep the shape.
                num_pixels = float(pred_logits.size(2) * pred_logits.size(3))
                pred_logits = pred_logits.view(
                    -1, pred_logits.size(1),
                    pred_logits.size(2) * pred_logits.size(3))
                gts = gts.view(-1, gts.size(1) * gts.size(2))
                pixel_losses = self.celoss(pred_logits, gts)
                if self.hard_example_mining_step == 0:
                    top_k_pixels = int(self.top_k_percent_pixels * num_pixels)
                else:
                    ratio = min(1.0,
                                step / float(self.hard_example_mining_step))
                    top_k_pixels = int((ratio * self.top_k_percent_pixels +
                                        (1.0 - ratio)) * num_pixels)
                top_k_loss, top_k_indices = torch.topk(pixel_losses,
                                                       k=top_k_pixels,
                                                       dim=1)

                final_loss = torch.mean(top_k_loss)
            final_loss = final_loss.unsqueeze(0)
            total_loss.append(final_loss)
        total_loss = torch.cat(total_loss, dim=0)
        return total_loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2169059.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

卷轴模式商城APP开发搭建全流程解析

卷轴模式商城APP的开发搭建是一个综合性强、涉及多个关键步骤和技术环节的过程。本文将详细介绍从需求分析到最终发布的各个阶段&#xff0c;旨在为开发者renxb001提供一个清晰的开发指导方案。 一、需求分析 目标用户群体&#xff1a;首先&#xff0c;明确APP的目标用户&…

openKylin--安装 .net6.0

编辑profile文件 cd .. //切换到根目录 cd /etc //切换到etc目录 vim profile //b编辑profile文件 1. 按→键移动到文件末尾 2. 按Insert键进入编辑模式 3. 按Enter另起一行开始编辑 export DOTNET_ROOT/home/dotnetexport PATH$PATH:/home/dotnet 可以通过右键--粘贴 的…

基于skopt的贝叶斯优化基础实例学习实践

贝叶斯方法是非常基础且重要的方法&#xff0c;在前文中断断续续也有所介绍&#xff0c;感兴趣的话可以自行移步阅读即可&#xff1a; 《数学之美番外篇&#xff1a;平凡而又神奇的贝叶斯方法》 《贝叶斯深度学习——基于PyMC3的变分推理》 《模型优化调参利器贝叶斯优化bay…

Brave编译指南2024 MacOS篇-引言与准备工作(一)

引言 随着互联网隐私和安全问题日益突出,用户对安全浏览器的需求不断增加。Brave浏览器作为一款注重隐私保护和性能优化的开源浏览器,吸引了越来越多开发者的关注。本系列文章将详细介绍如何在MacOS环境下编译Brave浏览器,为有兴趣深入了解和定制Brave的开发者提供指导。 1. …

【智能控制】16章 基于Hopfield网络的路径优化,TSP问题

目录 15.6 基于Hopfield网络的路径优化 15.6.1 TSP问题 15.6.2 求解TSP问题的Hopfield神经网络设计 15.6 基于Hopfield网络的路径优化 15.6.1 TSP问题 旅行商问题&#xff08;Traveling Salesman Problem&#xff0c;简称TSP&#xff09;可描述为&#xff1a;已知N个城市之…

CloudMusic:免费听歌

本文所涉及所有资源均在 传知代码平台可获取。 目录 概述 演示效果 视频演示 图片展示 核心逻辑 获取歌曲图片 提取搜索结果 使用方式 部署方式 Docker部署1 构建镜像 Web站点部署2 附件下载 概述 CloudMusic是一款全网歌曲免费听的web项目&#xff0c;无需任何数据库&#x…

如何隐藏Windows10「安全删除硬件」里的USB无线网卡

本方法参照了原文《如何隐藏Windows10「安全删除硬件」里的USB无线网卡》里面的方法&#xff0c;但是文章中的描述我的实际情况不太一样&#xff0c;于是我针对自己的实际情况进行了调整&#xff0c;经过测试可以成功隐藏Windows10「安全删除硬件」里的USB无线网卡。 先说一下…

QT学习笔记之文件操作

你千万不要跟任何人谈起任何事。你只要一谈起&#xff0c;就会想念起每一个人来。 在ui界面添加一个LineEdit(lEt)、QPushButton(btn)、QWidget widget.cpp #include "widget.h" #include "ui_widget.h" #include <QFile> #include <QFileDialo…

node.js从入门到快速开发一个简易的web服务器

浏览器中JavaScript学习路径: JavaScript基础语法浏览器内置API(DOMBOM)第三方库(jQuery,art-template等) Node.js的学习路径 JavaScript基础语法Node.js内置API模块(fs、path、http等)第三方API模块(express、mysql等) Node.js安装 通过Node.js 来运行Javascript 代码&am…

坝上草原与闪电湖多伦湖自驾行程记录与攻略

本文介绍河北坝上草原、内蒙古多伦湖2天2夜自驾自由行&#xff08;坝上草原1日、多伦湖1日&#xff09;的每日详细行程、游览心得、避坑经历等。 2024年09月中秋节期间&#xff0c;我们一行4人从北京出发&#xff0c;自驾前往河北省与内蒙古自治区等2地&#xff0c;进行了一共为…

几个可以给pdf加密的方法,pdf加密详细教程。

几个可以给pdf加密的方法&#xff0c;pdf加密详细教程。在信息快速传播的今天&#xff0c;PDF文件已经成为重要的文档格式&#xff0c;被广泛应用于工作、学习和个人事务中。然而&#xff0c;随着数字内容的增加&#xff0c;数据安全和隐私保护的问题愈发凸显。无论是商业机密、…

高级算法设计与分析 学习笔记9 跳表

单链表的样子我们很熟悉了&#xff1a; 怎么加快查找&#xff1f;&#xff1a; 查找的具体方法&#xff1a; 超过了就回头下去。 这条“快速路”最好是几个节点呢&#xff1f;&#xff1a; 假如我们弄好多层跳表呢&#xff1f;&#xff1a; 给弄成2叉树了&#xff01; 如何插入…

设计模式、系统设计 record part01

技术路线&#xff1a; 工程师》设计师》分析师》架构师 管理路线&#xff1a; 项目经理》技术经理 工程师&#xff1a; 编程技术、测试技术 设计师&#xff1a; 工程师设计技术 分析师&#xff1a; 设计师分析技术 架构师&#xff1a; 分析师架构技术 项目经理&#xff1a; 时间…

发掘3D文件格式的无限潜力:打造沉浸式虚拟世界

在当今数字化时代&#xff0c;3D技术的应用范围日益广泛&#xff0c;涵盖电影后期制作、产品原型设计、虚拟现实&#xff08;VR&#xff09;、增强现实&#xff08;AR&#xff09;、游戏等众多领域。而3D文件格式作为3D技术的核心组成部分&#xff0c;对于实现3D数据和模型的存…

【linux进程】进程状态僵尸进程孤儿进程

目录 一&#xff0c;linux下的特定进程状态1. R状态 vs S状态2. T状态 vs t 状态3. D状态 vs S状态 二&#xff0c;OS中的进程状态1. 运行状态2. 阻塞状态3. 挂起状态 三&#xff0c;僵尸进程和孤儿进程1. 僵尸状态和僵尸进程2. 孤儿进程 一&#xff0c;linux下的特定进程状态 …

kafka分区和副本的关系?

概念来一波 比如一个topic的消息存放在两个分区中&#xff0c;分区1和分区2.每个分区都有自己的一个副本。即比如分区1有副本1/副本2/副本3&#xff0c;分区2也有分区2的副本1/副本2/副本3。一个节点上的一个topic的可以由多个分区存放&#xff0c;但是每个分区的leader副本会尽…

丹摩智算平台部署 Llama 3.1:实践与体验

文章目录 前言部署前的准备创建实例 部署与配置 Llama 3.1使用心得总结 前言 在最近的开发工作中&#xff0c;我有机会体验了丹摩智算平台&#xff0c;部署并使用了 Llama 3.1 模型。在人工智能和大模型领域&#xff0c;Meta 推出的 Llama 3.1 已经成为了目前最受瞩目的开源模…

manim中文字和目标的对齐方法的使用

为什么要文字对齐 &#xff1f; 对齐原则在现实生活中无处不在&#xff0c;比如&#xff1a;书籍、货架、地铁座位等等&#xff1b;对齐的目的其实就是在规整文案信息&#xff0c;对齐有利于信息传达以及视觉规范&#xff0c;当我们做文字编排工作时&#xff0c;要根据构图形…

【计算机网络 - 基础问题】每日 3 题(二十六)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/fYaBd &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 C 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏&…

基于springboot+vue 旅游网站的设计与实现

基于springbootvue 旅游网站的设计与实现 摘 要 互联网发展至今&#xff0c;无论是其理论还是技术都已经成熟&#xff0c;而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播&#xff0c;搭配信息管理工具可以很好地为人们提供服务。针对信息管理混乱&#xff0c…