【扒代码】regression_head.py

news2024/9/30 15:25:24
import torch
from torch import nn

class UpsamplingLayer(nn.Module):
    # 初始化 UpsamplingLayer 类
    def __init__(self, in_channels, out_channels, leaky=True):
        super(UpsamplingLayer, self).__init__()  # 调用基类的初始化方法

        # 初始化一个序列模型,包含卷积层、激活函数和上采样操作
        self.layer = nn.Sequential(
            # 卷积层,用于特征图的卷积操作
            # in_channels 表示输入通道数,out_channels 表示输出通道数
            # kernel_size=3 表示卷积核大小为 3x3
            # padding=1 表示边缘填充,保持特征图尺寸不变
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            # 根据参数 leaky 决定使用 LeakyReLU 激活函数还是 ReLU 激活函数
            # LeakyReLU 在正输入值上与 ReLU 相同,但在负输入值上允许一个小的梯度(由 leaky 参数控制)
            nn.LeakyReLU() if leaky else nn.ReLU(),
            # 上采样层,使用双线性插值方法放大特征图
            # scale_factor=2 表示将特征图的尺寸放大两倍
            nn.UpsamplingBilinear2d(scale_factor=2)
        )

    # 前向传播方法,将输入 x 通过定义好的层进行处理
    def forward(self, x):
        return self.layer(x)

功能解释

  • UpsamplingLayer 类接收三个参数:in_channels(输入通道数),out_channels(输出通道数),和 leaky(一个布尔值,决定是否使用 LeakyReLU 激活函数)。
  • 类初始化方法 __init__ 中,使用 nn.Sequential 创建了一个序列模型,它将按照顺序应用里面的层。
  • nn.Conv2d 是一个二维卷积层,用于在卷积神经网络中进行卷积操作。
  • nn.LeakyReLU 是一种激活函数,当输入为正时,它的行为与 nn.ReLU 相同,当输入为负时,它允许一个非零的梯度(由 leaky 参数控制)。
  • nn.UpsamplingBilinear2d 是一个上采样层,使用双线性插值方法来放大特征图的尺寸。
  • forward 方法定义了模型的前向传播,它接收输入 x,并通过 self.layer 中定义的层进行处理,然后返回处理后的结果。

整体而言,UpsamplingLayer 类实现了一个简单的上采样模块,它首先通过卷积层提取特征,然后应用激活函数,最后通过上采样层放大特征图,这在图像分割、特征细化等任务中非常有用。(这个上采样)

import torch
from torch import nn
from .upsamplinglayer import UpsamplingLayer  # 假设 UpsamplingLayer 在当前包中定义

class DensityMapRegressor(nn.Module):
    # 初始化 DensityMapRegressor 类
    def __init__(self, in_channels, reduction):
        super(DensityMapRegressor, self).__init__()  # 调用基类的初始化方法

        # 根据 reduction 参数的不同,构建不同的回归器结构
        if reduction == 8:
            self.regressor = nn.Sequential(
                # 上采样层,将输入通道数 in_channels 上采样到 128
                UpsamplingLayer(in_channels, 128),
                # 继续上采样到 64
                UpsamplingLayer(128, 64),
                # 再上采样到 32
                UpsamplingLayer(64, 32),
                # 最后通过一个 1x1 卷积层将通道数减少到 1,生成密度图
                nn.Conv2d(32, 1, kernel_size=1),
                # 使用 LeakyReLU 激活函数
                nn.LeakyReLU()
            )
        elif reduction == 16:
            self.regressor = nn.Sequential(
                # 与 reduction == 8 类似,但是最后多一个上采样步骤到 16
                UpsamplingLayer(in_channels, 128),
                UpsamplingLayer(128, 64),
                UpsamplingLayer(64, 32),
                UpsamplingLayer(32, 16),
                nn.Conv2d(16, 1, kernel_size=1),
                nn.LeakyReLU()
            )

        # 初始化模型参数
        self.reset_parameters()

    # 前向传播方法,将输入 x 通过回归器处理
    def forward(self, x):
        return self.regressor(x)

    # 参数重置方法,使用特定的初始化方法初始化模型的权重和偏置
    def reset_parameters(self):
        for module in self.modules():  # 遍历模型中所有的模块
            if isinstance(module, nn.Conv2d):  # 如果模块是二维卷积层
                # 初始化权重为标准正态分布
                nn.init.normal_(module.weight, std=0.01)
                # 如果存在偏置项,则初始化为常数 0
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

功能解释

  • DensityMapRegressor 类用于生成对象计数的密度图,它根据输入的特征图和指定的 reduction 参数来构建一个回归器网络。
  • in_channels 参数指定了输入特征图的通道数。
  • reduction 参数控制了网络中上采样层的数量和最终生成的密度图的分辨率。
  • self.regressor 是一个序列模型,根据 reduction 参数的值,它将构建不同数量的上采样层,最后通过一个 1x1 卷积层输出通道数为 1 的密度图。
  • forward 方法定义了模型的前向传播逻辑,它接收输入 x,并通过 self.regressor 进行处理,返回处理后的密度图。
  • reset_parameters 方法用于初始化模型的参数,这里使用正态分布初始化权重,偏置初始化为常数 0。这是为了在训练开始前给模型一个合理的初始状态。

整体而言,DensityMapRegressor 类实现了一个用于生成密度图的回归网络,它通过一系列上采样层逐步放大特征图的尺寸,并最终生成一个通道数为 1 的密度图,这个密度图可以用于表示图像中对象的分布密度。

什么是上采样?

上采样(Upsampling)是深度学习和计算机视觉中常用的一种技术,用于增加数据的空间分辨率,即增加图像的高度和宽度。上采样通常在特征图(feature maps)经过一系列卷积层后应用,以便恢复图像的空间尺寸或为后续的网络层提供合适尺寸的输入。

上采样的常见方法包括:

  1. 最近邻插值(Nearest Neighbor Interpolation)

    • 这是最简单的上采样方法,通过选择距离最近的像素点的值来填充新像素点。
  2. 双线性插值(Bilinear Interpolation)

    • 这种方法考虑了新像素点周围四个最近像素点的值,并通过线性方式进行插值。
  3. 双三次插值(Bicubic Interpolation)

    • 类似于双线性插值,但使用了更高阶的多项式来提供平滑的插值效果。
  4. 转置卷积(Transposed Convolution)

    • 也称为反卷积,通过卷积操作来增加图像的空间尺寸,同时学习如何填充新像素点的值。
  5. 像素 Shuffle(Pixel Shuffle)

    • 通过重新排列像素来增加图像的分辨率,通常与子像素卷积一起使用。

在这个例子中,新像素点的值是通过考虑周围现有像素点的值并进行加权平均得到的。

上采样在许多深度学习任务中都非常有用,例如在语义分割任务中恢复图像分辨率,或者在生成对抗网络(GANs)中生成高分辨率的图像。通过上采样,模型能够生成更精细的特征表示,有助于提高任务的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1991484.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode 7, 703, 287

文章目录 7. 整数反转题目链接标签思路反转操作反转的数的范围 代码 703. 数据流中的第 K 大元素题目链接标签思路代码 287. 寻找重复数题目链接标签思路代码 7. 整数反转 题目链接 7. 整数反转 标签 数学 思路 反转操作 反转实际上很简单,假设要反转数字 n…

数据结构之Map与Set(上)

找往期文章包括但不限于本期文章中不懂的知识点: 个人主页:我要学编程(ಥ_ಥ)-CSDN博客 所属专栏:数据结构(Java版) 目录 二叉搜索树 Map和Set的介绍与使用 Map的常用方法及其示例 Set的常用方法及其示例 哈希表…

客户管理系统平台(CRM系统)是什么?它的核心主要解决哪些问题?

客户管理系统平台CRM是什么?客户关系管理系统CRM的核心主要解决哪些问题? CRM系统不仅仅是一套软件,更是一种策略,一种管理理念和一种企业发展方向。它通过整合客户数据、优化业务流程、提升客户体验,帮助企业在激烈的…

K8s第三节:k8s1.23.1升级为k8s1.30.0

上回书说到我们使用了kubeadm安装了k8s1.23.1,但是在k8s1.24之前还是使用docker作为容器运行时,所以这一节我打算将我安装的k8s集群升级为1.30.0版本; 1、修改containerd 配置 因为我们安装的docker自带containerd,所以我们不需要重新安装con…

蓝凌EKP二次开发资料大全 完整蓝凌二次开发资料 蓝凌 EKP开发实战教程 蓝凌OA二次开发资料大全 蓝凌OA java开发快速入门

蓝凌EKP二次开发资料大全 完整蓝凌二次开发资料 蓝凌 EKP开发实战教程 蓝凌OA二次开发资料大全 记得两年前花了非常贵的费用去现场学习的资料,把这些开发技术文档分享出来,希望通过这个资料, 为大家学习开发大大减少时间。期待大家能快速上…

《UE5_C++多人TPS完整教程》学习笔记32 ——《P33 动画蓝图(Animation Blueprint)》

本文为B站系列教学视频 《UE5_C多人TPS完整教程》 —— 《P33 动画蓝图(Animation Blueprint)》 的学习笔记,该系列教学视频为 Udemy 课程 《Unreal Engine 5 C Multiplayer Shooter》 的中文字幕翻译版,UP主(也是译者…

Python实战:类

一、圆的面积、周长 class Circle:# 初始化一个类参数:rdef __init__(self,r):self.r r# 计算面积的方法def get_area(self):return 3.14*pow(self.r,2)# 计算周长的方法def get_perimeter(self):return 2*3.14*self.r#创建对象 r eval(input(请输入圆的半径&…

Vue 2 和 Vue 3 生命周期钩子

Vue 2 和 Vue 3 生命周期钩子 在 Vue.js 开发中,了解生命周期钩子对于编写有效的组件至关重要。Vue 2 和 Vue 3 在生命周期钩子上大致相同,但 Vue 3 的 Composition API 引入了一种新的方式来处理它们。这里我会概述两者的生命周期钩子,并指…

2024年8月7日(mysql主从 )

回顾 主服务器 [rootmaster_mysql ~]# yum -y install rsync [rootmaster_mysql ~]# tar -xf mysql-8.0.33-linux-glibc2.12-x86_64.tar [rootmaster_mysql ~]# tar -xf mysql-8.0.33-linux-glibc2.12-x86_64.tar.xz [rootmaster_mysql ~]# cp -r mysql-8.0.33-linux-glibc2.…

QT找不到编辑框

问题展示: 解决办法:ALT0 然后我的变成了这种: 解决办法:文件系统改变成项目:

DNTR——F

文章目录 AbstractIntroductionContribution Related WorkAdvancements in Feature Pyramid Networks (FPNs)Coarse-to-Fine Image Partitioning in Drone Imagery DetectionDevelopments in Loss Function Approaches for Tiny Object DetectionR-CNN for Small Object Detect…

大炼模型进入尾声,“失眠”的欧洲和日本能否扳回一局?

大数据产业创新服务媒体 ——聚焦数据 改变商业 2022年末,ChatGPT-3.5的惊艳亮相,瞬间引爆了全球范围内的生成式AI(GenAI)热潮。 这场现代版的"淘金热"迅速在科技领域蔓延,尤其是在全球两大科技强国——中国…

简单分享下python打包手机app的apk

Python 把python程序打包成apk的完整步骤 1. 引言 在移动应用市场蓬勃发展的今天,开发人员常常需要将自己的Python程序打包成APK文件,以便在Android设备上运行。本文将详细介绍将Python程序打包成APK的完整步骤。 2. 准备工作 在开始打包前&#xff0c…

全网最详解LVS(Linux virual server)

目录 一、LVS(Linux virual server)是什么? 二、集群和分布式简介 2.1、集群Cluster 2.2、分布式 2.3、集群和分布式 三、LVS运行原理 3.1、LVS基本概念 3.2、LVS集群的类型 3.2.1 nat模式 3.2.2 DR模式 3.2.3、LVS工作模式总结 …

RSYSLOG收到华为防火墙日志差8小时的解决方法

RSYSLOG收到华为防火墙日志差8小时 这个问题其实不关Rsyslog配置的事,只要修改华为墙的配置就好 处理方法: info-center loghost 172.18.6.91 language Chinese local-time 在华为web界面添加ip是不会添加local-time这个参数的, 需要在命令…

sqli-labs第二关详解

首先让id1,正常显示,接着尝试and 11和and 12 and 11正常,and 12不正常 所以可以判断是数字型注入,使用order by 判断列数,发现有三个字段 使用union语句,找出能显示信息的地方 接下来就是找出数据库名称和版…

Leetcode75-7 除自身以外数组的乘积

没做出来 本来的思路是遍历一遍得到所有乘积和然后除就行 但是题目不能用除法 答案的思路 for(int i0;i<n;i) //最终每个元素其左右乘积进行相乘得出结果{res[i]*left; //乘以其左边的乘积left*nums[i];res[n-1-i]*right; //乘以其右边的乘积right*nums[n-1-i]…

搭建 Web 群集Haproxy

案例概述 Haproxy 是目前比较流行的一种群集调度工具&#xff0c;同类群集调度工具有很多&#xff0c;如 LVS 和Nginx。相比较而言&#xff0c;LVS 性能最好&#xff0c;但是搭建相对复杂;Nginx 的upstream模块支持群集功能&#xff0c;但是对群集节点健康检查功能不强&#xf…

海量数据处理商用短链接生成器平台 - 10

第二十一章 短链服务冗余双写-链路测试和异常消息处理实战 第1集 冗余双写MQ架构-消费者配置自动创建队列和集群测试 简介&#xff1a; 冗余双写MQ架构-MQ消费者配置自动创建队列 controller-service层开发配置文件配置MQ ##----------rabbit配置-------------- spring.rab…

古彝文——唯一存活的世界六大古文字

关注我们 - 数字罗塞塔计划 - 早在五千年前&#xff0c;彝族的先祖就发明了十月太阳历&#xff0c;成为中华文明的重要创造者之一&#xff1b;同时&#xff0c;彝族的先祖也创制了古彝文&#xff0c;开创了独具特色的彝族文化。古彝文也被称为古夷文、传统彝文&#xff0c;是相…