2023-简单点-yolox-pytorch代码解析(一)-nets/darknet.py

news2024/11/23 13:18:57

yolox-pytorch: nets/darknet.py

  • yolox网络结构
  • yolox-pytorch目录
  • 今天解析注释net/darknet.py
  • Focus
  • BaseConv
  • DWConv
  • SPPBottleneck
  • Darknet
  • 未完待续。。。

yolox网络结构

这里是引用

yolox-pytorch目录

在这里插入图片描述

今天解析注释net/darknet.py

#!/usr/bin/env python3  # 指定使用python3来执行此脚本  
# -*- coding:utf-8 -*-  # 声明脚本使用的编码是utf-8  
# Copyright (c) Megvii, Inc. and its affiliates.  # 版权信息,标注了归属公司为Megvii, Inc.及其附属公司  
  
# 导入torch库,torch是PyTorch的主体库,提供了张量计算和神经网络构建等基础功能  
import torch  
  
# 从torch库中导入nn模块,nn是PyTorch的神经网络模块,提供了各种神经网络层和损失函数等  
from torch import nn  
  ##########################################################################################
# 定义一个名为SiLU的类,继承自nn.Module,这个类代表Sigmoid线性单元,一种非线性激活函数  
class SiLU(nn.Module):  
    # 定义一个静态方法forward,此方法用于定义SiLU类的运算过程,参数x是输入数据  
    @staticmethod  
    def forward(x):   
        # 返回输入数据x乘以sigmoid函数的结果,实现SiLU运算  
        return x * torch.sigmoid(x)  
##########################################################################################  
# 定义一个函数get_activation,用于获取不同类型的激活函数模块  
def get_activation(name="silu", inplace=True):  
    # 根据传入的name参数判断要返回的激活函数类型,默认为SiLU类型  
    if name == "silu":   
        # 如果name为"silu",则创建一个SiLU类型的模块并赋值给module变量  
        module = SiLU()  
    elif name == "relu":   
        # 如果name为"relu",则创建一个ReLU类型的模块,inplace参数表示原地操作,是否会修改输入数据,默认为True  
        module = nn.ReLU(inplace=inplace)  
    elif name == "lrelu":   
        # 如果name为"lrelu",则创建一个LeakyReLU类型的模块,0.1表示负斜率,inplace参数表示原地操作,是否会修改输入数据,默认为True  
        module = nn.LeakyReLU(0.1, inplace=inplace)  
    else:   
        # 如果name不是上述任何一种类型,则抛出一个AttributeError异常,提示"Unsupported act type: {}".format(name)错误信息  
        raise AttributeError("Unsupported act type: {}".format(name))  
    # 返回获取到的激活函数模块  
    return module  
##########################################################################################  
# 定义一个名为Focus的类,继承自nn.Module,这个类代表焦点模块,用于对输入数据进行空间上的重新构造  
class Focus(nn.Module):  
    # 定义一个构造函数__init__,此方法用于初始化Focus类的实例对象  
    def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):  
        # 调用父类的构造函数,进行基本的初始化操作  
        super().__init__()  
        # 定义变量pad为(ksize - 1) // 2,表示卷积操作的填充大小  
        pad = (ksize - 1) // 2  
        # 创建一个BaseConv类型的模块,参数为in_channels * 4(输入通道数变为原来的四倍),out_channels(输出通道数),ksize(卷积核大小),stride(步长),pad(填充大小),act(激活函数类型默认为SiLU)  
        self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)  
    # 定义一个forward方法,此方法用于定义Focus类的运算过程,参数x是输入数据  
    def forward(self, x):   
        # 从输入数据x中获取四个位置的patch并进行拼接,形成一个新的数据x并返回给conv进行卷积操作
        #考虑2 x 2的四个方格,左上,左下,右上和右下为起始点,各一个元素采样  
        patch_top_left = x[..., ::2, ::2]  # 从左上角开始每隔一个像素取一个像素点形成左上角patch  
        patch_bot_left = x[..., 1::2, ::2]  # 从左下角开始每隔一个像素取一个像素点形成左下角patch  
        patch_top_right = x[..., ::2, 1::2]  # 从右上角开始每隔一个像素取一个像素点形成右上角patch  
        patch_bot_right = x[..., 1::2, 1::2]  # 从右下角开始每隔一个像素取一个像素点形成右下角patch  
        x = torch.cat((patch_top_left, patch_bot_left, patch_top_right,patch_bot_right), dim=1)  # 将四个patch在通道维度上进行拼接  
        # 返回经过conv卷积操作后的结果  
        return self.conv(x)  
##########################################################################################  
# 定义一个名为BaseConv的类,继承自nn.Module,这个类代表基础卷积模块,用于构建卷积神经网络的基础模块  
class BaseConv(nn.Module):  
    # 定义一个构造函数__init__,此方法用于初始化BaseConv类的实例对象  
    def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):  
        # 调用父类的构造函数,进行基本的初始化操作  
        super().__init__()  
        # 定义变量pad为(ksize - 1) // 2,表示卷积操作的填充大小  
        pad = (ksize - 1) // 2  
        # 创建一个Conv2d类型的卷积层,参数依次为输入通道数,输出通道数,卷积核大小,步长,填充大小,分组数,是否使用偏置  
        self.conv = nn.Conv2d(in_channels, out_channels, ksize, stride, pad, groups=groups, bias=bias)
        # 定义一个批归一化层,用于加速训练和提高模型稳定性。  
        self.bn   = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)   
        # 如果传入的act参数不为空字符串,则根据act参数创建一个激活函数模块并赋值给self.act变量  
        if act is not None:  
            self.act = get_activation(act, inplace=True)  
        else:  
            self.act = nn.Identity()  # 如果act参数为空字符串,则使用恒等映射作为激活函数  
    # 定义一个forward方法,此方法用于定义BaseConv类的运算过程,参数x是输入数据  
    def forward(self, x):   
        # 返回经过conv卷积操作和act激活函数处理后的结果  
        return self.act(self.conv(x))
#########################################################################################

# 定义一个名为DWConv的类,继承自nn.Module,这是一个深度可分离卷积模块。  
class DWConv(nn.Module):  
    # 初始化函数,用于设置该模块所需的各种参数和子模块。  
    def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):  
        # 调用父类的初始化函数。  
        super().__init__()  
        # 定义一个深度卷积层(逐通道卷积),其输出通道数与输入通道数相同,并且使用给定的核大小和步长。  
        # 注意这里的groups参数设置为in_channels,意味着每个输入通道都有独立的卷积核。  
        self.dconv = BaseConv(in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act,)  
        # 定义一个逐点卷积层(1x1卷积),用于改变通道数。这里的groups参数设置为1,表示是普通的卷积操作。  
        self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act)  
  
    # 定义前向传播函数,输入数据x会首先经过深度卷积层处理,然后再经过逐点卷积层处理。  
    def forward(self, x):  
        x = self.dconv(x)  # 数据经过深度卷积层处理。  
        return self.pconv(x)  # 经过逐点卷积层处理后输出结果。
##########################################################################################
# 定义一个名为SPPBottleneck的类,继承自nn.Module,这是一个包含空间金字塔池化(Spatial Pyramid Pooling, SPP)的瓶颈模块。  
class SPPBottleneck(nn.Module):  
    # 初始化函数,用于设置该模块所需的各种参数和子模块。  
    def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):  
        # 调用父类的初始化函数。  
        super().__init__()  
        # 计算隐藏层的通道数,为输入通道数的一半。  
        hidden_channels = in_channels // 2  
        # 定义第一个卷积层,用于降低通道数。这里使用1x1的卷积核。  
        self.conv1      = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)  
        # 定义一个包含多个最大池化层的模块列表,用于生成不同尺度的空间金字塔特征。  
        self.m          = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes])  
        # 计算第二个卷积层的输入通道数,为隐藏层通道数与空间金字塔层数的和。  
        conv2_channels  = hidden_channels * (len(kernel_sizes) + 1)  
        # 定义第二个卷积层,用于增加通道数并整合空间金字塔特征。这里使用1x1的卷积核。  
        self.conv2      = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)  
  
    # 定义前向传播函数,输入数据x会首先经过第一个卷积层处理,然后生成空间金字塔特征并拼接在一起,最后经过第二个卷积层处理并输出结果。  
    def forward(self, x):  
        x = self.conv1(x)  # 数据经过第一个卷积层处理。  
        x = torch.cat([x] + [m(x) for m in self.m], dim=1)  # 生成空间金字塔特征并将它们拼接在一起。  
        x = self.conv2(x)  # 经过第二个卷积层处理后输出结果。  
        return x

##########################################################################################
# 定义一个名为Darknet的类,继承自nn.Module,这个类代表Darknet网络模型,用于构建Darknet系列的目标检测模型  
class Darknet(nn.Module):  
    # 定义一个构造函数__init__,此方法用于初始化Darknet类的实例对象  
    def __init__(self, in_channels, out_channels, **kwargs):  
        # 调用父类的构造函数,进行基本的初始化操作  
        super().__init__()  
        # 根据传入的参数创建多个BaseConv类型的模块并依次添加到self.layers列表中  
        self.layers = nn.ModuleList([BaseConv(in_channels, out_channels, **kwargs) for _ in range(len(kwargs["out_channels"]))])  
        # 将最后一个BaseConv模块的输出通道数赋值给self.out_channels变量  
        self.out_channels = kwargs["out_channels"][-1]  
  
    # 定义一个forward方法,此方法用于定义Darknet类的运算过程,参数x是输入数据  
    def forward(self, x):  
        # 依次将输入数据x传入self.layers列表中的每个BaseConv模块进行处理,并将结果赋值给x  
        for layer in self.layers:  
            x = layer(x)  
        # 返回经过所有BaseConv模块处理后的结果x  
        return x  
##########################################################################################  
# 定义一个名为CSPDarknet的类,继承自nn.Module,这个类代表CSPDarknet网络模型,用于构建CSPDarknet系列的目标检测模型  
class CSPDarknet(nn.Module):  
    # 定义一个构造函数__init__,此方法用于初始化CSPDarknet类的实例对象  
    def __init__(self, in_channels, out_channels, **kwargs):  
        # 调用父类的构造函数,进行基本的初始化操作  
        super().__init__()  
        # 根据传入的参数创建多个BaseConv类型的模块和多个Darknet类型的模块,并依次添加到self.conv和self.blocks列表中  
        self.conv = BaseConv(in_channels, out_channels, **kwargs)  
        self.blocks = nn.ModuleList([Darknet(out_channels, out_channels * 2, **kwargs) for _ in range(len(kwargs["out_channels"]))])  
        # 将最后一个Darknet模块的输出通道数赋值给self.out_channels变量  
        self.out_channels = kwargs["out_channels"][-1] * 2  
  
    # 定义一个forward方法,此方法用于定义CSPDarknet类的运算过程,参数x是输入数据  
    def forward(self, x):  
        # 将输入数据x传入self.conv模块进行处理,并将结果赋值给x  
        x = self.conv(x)  
        # 将处理后的结果x分别传入self.blocks列表中的每个Darknet模块进行处理,并将结果拼接在一起形成一个新的结果x  
        x = torch.cat([block(x) for block in self.blocks], dim=1)  
        # 返回经过所有Darknet模块处理后的结果x  
        return x

Focus

在这里插入图片描述

BaseConv

简而言之:卷积 + 激活【是否】

BaseConv是一个基础卷积类,它继承了PyTorch中的nn.Module类,并实现了卷积、批归一化(Batch Normalization)和激活函数等核心操作。BaseConv类的主要参数包括:

  • in_channels:输入通道数。
  • out_channels:输出通道数。
  • ksize:卷积核大小。
  • stride:步长。
  • groups:分组卷积中的组数,默认为1。
  • bias:是否使用bias(偏差),默认为False。
  • act:激活函数类型,默认为"silu"。

在BaseConv类中,主要实现了以下三个方法:

  • __init__:构造函数,用于初始化BaseConv对象。在构造函数中,会创建一个nn.Conv2d对象(卷积层),一个nn.BatchNorm2d对象(批归一化层)和一个激活函数对象。
  • forward:前向传播函数。输入数据x首先经过卷积层和批归一化层,然后通过激活函数进行激活,最终输出结果。
  • get_activation:获取激活函数。该函数用于获取指定名称的激活函数对象。

总体来说,BaseConv是一个简单但实用的卷积类,可以作为构建其他复杂卷积网络的基础组件。

DWConv

在这里插入图片描述

SPPBottleneck

在这里插入图片描述

Darknet

未完待续。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1266210.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用 Nginx Ingress 快速实现 URL 重写

什么是URL重写 URL重写(URL rewriting)是一种在Web服务器上修改或转换请求URL的过程。它通常涉及使用服务器配置或规则来更改传入的URL,以便在不改变实际请求资源的情况下,实现不同的行为,如重定向、路径映射、参数处…

三大录屏软件推荐,让你轻松录制屏幕

录屏软件的应用变得越来越广泛,无论是记录屏幕上的内容以方便日后查阅,还是与他人分享操作过程,录屏软件都发挥着重要作用。然而,市面上的录屏软件种类繁多,质量参差不齐。那有没有好用的录屏软件推荐呢?在…

U4_2:图论之MST/Prim/Kruskal

文章目录 一、最小生成树-MST生成MST策略一些定义 思路彩蛋 二、普里姆算法(Prim算法)思路算法流程数据存储分析 伪代码时间复杂度分析 三、克鲁斯卡尔算法(Kruskal算法)分析算法流程并查集-Find-set 伪代码时间复杂度分析 一、最…

基于FactoryBean、实例工厂、静态工厂创建Spring中的复杂对象

😉😉 学习交流群: ✅✅1:这是孙哥suns给大家的福利! ✨✨2:我们免费分享Netty、Dubbo、k8s、Mybatis、Spring...应用和源码级别的视频资料 🥭🥭3:QQ群:583783…

同旺科技 USB 转 RS-485 适配器

内附链接 1、USB 转 RS-485 适配器 基础版主要特性有:(非隔离) ● 支持USB 2.0/3.0接口,并兼容USB 1.1接口; ● 支持USB总线供电; ● 支持Windows系统驱动,包含WIN10 / WIN11系统32 / 64位…

基于Java SSM框架+Vue实现汉服文化平台网站项目【项目源码+论文说明】计算机毕业设计

基于java的SSM框架Vue实现汉服文化平台系统演示 摘要 本论文主要论述了如何使用JAVA语言开发一个汉服文化平台网站 ,本系统将严格按照软件开发流程进行各个阶段的工作,采用B/S架构,面向对象编程思想进行项目开发。在引言中,作者将…

Redis高可用集群架构

高可用集群架构 哨兵模式缺点 主从切换阶段, redis服务不可用,高可用不太友好只有单个主节点对外服务,不能支持高并发单节点如果设置内存过大,导致持久化文件很大,影响数据恢复,主从同步性能 高可用集群…

Java第二十章

一.创建线程 1.继承Thread类 Thread类是java.lang 包中的一个类,从这个类中实例化的对象代表线程,程序员启动一个新线程需要建立Thread实例。Thread类中常用的两个构造方法如下: public Thread()://创建一个新的线程对象。 public Thread(String threa…

王者荣耀游戏制作

1.创建所需要的包 2.创建怪物类 bear package beast;import wangzherogyao.GameFrame;public class Bear extends Beast {public Bear(int x, int y, GameFrame gameFrame) {super(x, y, gameFrame);setImg("img/bear.jpg");width 85;height 112;setDis(65);}} b…

倒计时 5 天,您有一份 2023 IoTDB 用户大会参会指南请注意查收!

叮叮!距离 2023 IoTDB 用户大会在北京与大家见面还有 5 天! 这场筹备已久的盛会,汇集了超 20 位大咖嘉宾带来的精彩议题,届时来自美国国家工程院、清华大学软件学院的产业大拿,与能源电力、钢铁冶炼、城轨运输、智能制…

如何使用ArcGIS Pro制作一张北极俯视地图

地图的表现形式有很多种,经常我们看到的地图是以大西洋为中心的地图,还有以太平洋为中心的地图,今天要给大家介绍的地图是从北极上方俯视看的地图,这里给大家讲解一下制作方法,希望能对你有所帮助。 修改坐标系 制作…

AntDB数据库:从海量数据处理,到5G计费商用核心

AntDB数据库自2008年研发面世以来,首先被应用于运营商的核心系统,满足运营商海量数据处理的需求。随着数字科技的不断发展,AntDB也在不断地更新迭代,逐渐地为更多行业与客户提供更全面的服务。5G时代来临,AntDB抓住发展…

webGL开发虚拟实验室

开发虚拟实验室是一个具有挑战性但也非常有趣和有价值的任务。通过 WebGL,你可以创建交互式、沉浸式的虚拟实验室,使用户能够进行实验和学习。以下是一些步骤和关键考虑因素,帮助你开始开发虚拟实验室,希望对大家有所帮助。北京木…

你知道显卡型号上的数字是什么意思吗?数字越大就越好吗?

大家好,欢迎来到我们的显卡探秘之旅!今天,我们将一探究竟——显卡型号上的数字到底是啥意思?是不是数字越大,显卡就越NB?别急,跟着小编一起揭开这个神秘的数字面纱! Q1 显卡的基本概…

技巧-PyCharm中Debug和Run对训练的影响和实验测试

简介 在训练深度学习模型时,使用PyCharm的Debug模式和Run模式对训练模型的耗时会有一些区别。 Debug模式:Debug模式在训练模型时,会对每一行代码进行监视,这使得CPU的利用率相对较高。由于需要逐步执行、断点调试、查看变量值等操…

链接共享平台LinkStack

什么是 LinkStack ? LinkStack 是一个独特的平台,为在线管理和共享链接提供了高效的解决方案。平台提供了一个类似于 Linktree 的网站,它可以让用户克服社交媒体平台上只能添加一个链接的限制。借助 LinkStack,用户可以轻松链接到…

无需提前更新数据源,一键形态选股直接出票——股票量化分析工具QTYX-V2.7.3...

功能概述 我们的股票量化系统QTYX在实战中不断迭代升级!!! 星球学员中的大佬们给QTYX提供了很多实战应用方面的建议,志同道合的一群人一起来优化完善这个系统,日益强大的QTYX同时也能更好地帮助各位在市场中提高战绩! 这个需求是来自于星球学…

“华为不造车 只帮车企造好车“ 那么华为到底造不造车

大家好,我是极智视界,欢迎关注我的公众号,获取我的更多前沿科技分享 邀您加入我的知识星球「极智视界」,星球内有超多好玩的项目实战源码和资源下载,链接:https://t.zsxq.com/0aiNxERDq "华为不造车&a…

W2311283-可燃气体监测仪怎么监测燃气管道

可燃气体监测仪怎么有效监测燃气管道 燃气管道遍布于城市地下各处,作为城市生命线的一部分,一旦燃气管网出现泄露问题便是牵一发而动全身,城市的整体安全也会受到威胁。但是如何才能科学管理和监测燃气管网呢? 燃气管网监测系统便…

Vue3-ElementPlus按需导入

1.安装 pnpm add element-plus 2.配置按需导入: 官方文档:快速开始 | Element Plus 按照官网按需导入中的自动导入步骤来进行 pnpm add -D unplugin-vue-components unplugin-auto-import 观察Vite代码与原vite文件的差别,将原vite文件中没…