【目标检测】理论篇(3)YOLOv5实现

news2025/1/31 11:30:19

Yolov5网络构架实现

import torch
import torch.nn as nn


class SiLU(nn.Module):
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)

def autopad(k, p=None):
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k] 
    return p

class Focus(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Focus, self).__init__()
        self.conv = Conv(c1 * 4, c2, k, s, p, g, act)

    def forward(self, x):
        # 320, 320, 12 => 320, 320, 64
        return self.conv(
            # 640, 640, 3 => 320, 320, 12
            torch.cat(
                [
                    x[..., ::2, ::2], 
                    x[..., 1::2, ::2], 
                    x[..., ::2, 1::2], 
                    x[..., 1::2, 1::2]
                ], 1
            )
        )

class Conv(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
        super(Conv, self).__init__()
        self.conv   = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn     = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
        self.act    = SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def fuseforward(self, x):
        return self.act(self.conv(x))

class Bottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super(Bottleneck, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

class C3(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(C3, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
        # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])

    def forward(self, x):
        return self.cv3(torch.cat(
            (
                self.m(self.cv1(x)), 
                self.cv2(x)
            )
            , dim=1))

class SPP(nn.Module):
    # Spatial pyramid pooling layer used in YOLOv3-SPP
    def __init__(self, c1, c2, k=(5, 9, 13)):
        super(SPP, self).__init__()
        c_ = c1 // 2  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
        self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])

    def forward(self, x):
        x = self.cv1(x)
        return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
        
class CSPDarknet(nn.Module):
    def __init__(self, base_channels, base_depth, phi, pretrained):
        super().__init__()
        #-----------------------------------------------#
        #   输入图片是640, 640, 3
        #   初始的基本通道base_channels是64
        #-----------------------------------------------#

        #-----------------------------------------------#
        #   利用focus网络结构进行特征提取
        #   640, 640, 3 -> 320, 320, 12 -> 320, 320, 64
        #-----------------------------------------------#
        self.stem       = Focus(3, base_channels, k=3)
        
        #-----------------------------------------------#
        #   完成卷积之后,320, 320, 64 -> 160, 160, 128
        #   完成CSPlayer之后,160, 160, 128 -> 160, 160, 128
        #-----------------------------------------------#
        self.dark2 = nn.Sequential(
            # 320, 320, 64 -> 160, 160, 128
            Conv(base_channels, base_channels * 2, 3, 2),
            # 160, 160, 128 -> 160, 160, 128
            C3(base_channels * 2, base_channels * 2, base_depth),
        )
        
        #-----------------------------------------------#
        #   完成卷积之后,160, 160, 128 -> 80, 80, 256
        #   完成CSPlayer之后,80, 80, 256 -> 80, 80, 256
        #                   在这里引出有效特征层80, 80, 256
        #                   进行加强特征提取网络FPN的构建
        #-----------------------------------------------#
        self.dark3 = nn.Sequential(
            Conv(base_channels * 2, base_channels * 4, 3, 2),
            C3(base_channels * 4, base_channels * 4, base_depth * 3),
        )

        #-----------------------------------------------#
        #   完成卷积之后,80, 80, 256 -> 40, 40, 512
        #   完成CSPlayer之后,40, 40, 512 -> 40, 40, 512
        #                   在这里引出有效特征层40, 40, 512
        #                   进行加强特征提取网络FPN的构建
        #-----------------------------------------------#
        self.dark4 = nn.Sequential(
            Conv(base_channels * 4, base_channels * 8, 3, 2),
            C3(base_channels * 8, base_channels * 8, base_depth * 3),
        )
        
        #-----------------------------------------------#
        #   完成卷积之后,40, 40, 512 -> 20, 20, 1024
        #   完成SPP之后,20, 20, 1024 -> 20, 20, 1024
        #   完成CSPlayer之后,20, 20, 1024 -> 20, 20, 1024
        #-----------------------------------------------#
        self.dark5 = nn.Sequential(
            Conv(base_channels * 8, base_channels * 16, 3, 2),
            SPP(base_channels * 16, base_channels * 16),
            C3(base_channels * 16, base_channels * 16, base_depth, shortcut=False),
        )
        if pretrained:
            url = {
                's' : 'https://github.com/bubbliiiing/yolov5-pytorch/releases/download/v1.0/cspdarknet_s_backbone.pth',
                'm' : 'https://github.com/bubbliiiing/yolov5-pytorch/releases/download/v1.0/cspdarknet_m_backbone.pth',
                'l' : 'https://github.com/bubbliiiing/yolov5-pytorch/releases/download/v1.0/cspdarknet_l_backbone.pth',
                'x' : 'https://github.com/bubbliiiing/yolov5-pytorch/releases/download/v1.0/cspdarknet_x_backbone.pth',
            }[phi]
            checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", model_dir="./model_data")
            self.load_state_dict(checkpoint, strict=False)
            print("Load weights from ", url.split('/')[-1])
            
    def forward(self, x):
        x = self.stem(x)
        x = self.dark2(x)
        #-----------------------------------------------#
        #   dark3的输出为80, 80, 256,是一个有效特征层
        #-----------------------------------------------#
        x = self.dark3(x)
        feat1 = x
        #-----------------------------------------------#
        #   dark4的输出为40, 40, 512,是一个有效特征层
        #-----------------------------------------------#
        x = self.dark4(x)
        feat2 = x
        #-----------------------------------------------#
        #   dark5的输出为20, 20, 1024,是一个有效特征层
        #-----------------------------------------------#
        x = self.dark5(x)
        feat3 = x
        return feat1, feat2, feat3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/964972.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实战项目 在线学院之集成springsecurity的配置以及执行流程

一 后端操作配置 1.0 工程结构 1.1 在common下创建spring_security模块 1.2 pom文件中依赖的注入 1.3 在service_acl模块服务中引入spring-security权限认证模块 1.3.1 service_acl引入spring-security 1.3.2 在service_acl编写查询数据库信息 定义userDetailServiceImpl 查…

网站搭建最简化的引导操作 | 云服务器的购买选用 | 域名的选用 | 网站的上线和备案。

本文章面向对象为网站搭建的初次操作者,主要是一些自主使用的网站,为小白做为引导的教程。 一, 网站搭建的流程 1,服务器的租赁 2,购买域名 3,对域名进行备案 4,网站内部的搭建,上线…

读研对抗孤独感——快乐的意义在于和别人泛泛的交往

【对抗孤独需要肤浅朋友,不要什么精神之交】 评论区的精华总结几点: 1 降低道德感 2 降低交友门槛 3 交往的时候把别人看成自己的镜子就可以了 一些思考 其实 共鸣不是一开始就有的,而是似乎从一点点小的话题开始,每个人懂你心…

Mybatis1.1 环境准备

1.1 环境准备 数据库表(tb_brand)及数据准备实体类 Brand编写测试用例安装 MyBatisX 插件 数据库表(tb_brand)及数据准备 -- 删除tb_brand表 drop table if exists tb_brand; -- 创建tb_brand表 create table tb_brand (-- id 主…

设计模式-4--原型模式(Prototype Pattern)

一、什么是原型模式 原型模式(Prototype Pattern)是一种创建型设计模式,它的主要目的是通过复制现有对象来创建新的对象,而无需显式地使用构造函数或工厂方法。这种模式允许我们创建一个可定制的原型对象,然后通过复制…

React原理 - React New Component Lifecycle

目录 扩展学习资料 React New Component Lifecycle【新生命周期】 React 组件新生命周期详解 React组件老生命周期 v15.x 为什么Fiber Reconciler要有新的生命周期函数呢? 新的组件生命周期 getDerivedStateFromProps 挂载阶段 更新阶段 卸载阶段 异常捕…

Backtrader中文文档专栏订阅须知

由于CSDN新出的扣费机制过于强盗,您通过平台支付的金额只有50%是我的。针对本专栏,如果您有付费意愿,建议通过我的淘宝小店完成支付。(我已调整本专栏在CSDN上的售价) 之所以如此安排,主要是考虑到文档凝聚…

什么是jvm

一、初识JVM(虚拟机) JVM是Java Virtual Machine(Java虚拟机)的缩写,JVM是一种用于计算设备的规范,它是一个虚构出来的计算机,是通过在实际的计算机上仿真模拟各种计算机功能来实现的。 引入Jav…

Java中的动态代理(JDK Proxy VS CGLib)

前言 动态代理可以说是Java基础中一个比较重要的内容,这块内容关系到Spring框架中的AOP实现原理,所以特别写了一篇作为个人对这块知识的总结。这部分内容主要包括:JDK Proxy和CGLib的基本介绍、二者的实现原理、代码示例等。 什么是动态代理…

【python爬虫】9.带着小饼干登录(cookies)

文章目录 前言项目:发表博客评论post请求 cookies及其用法session及其用法存储cookies读取cookies复习 前言 第1-8关我们学习的是爬虫最为基础的知识,从第9关开始,我们正式打开爬虫的进阶之门,学习爬虫更多的精进知识。 在前面几…

PY32F003F18点灯

延时函数学习完之后,可以学习PY32F003F18的GPIO输出功能。 1、Debug引脚默认被置于复用功能上拉或下拉模式:PA14默认为SWCLK: 置于下拉模式PA13默认为SWDIO: 置于上拉模式PF4默认为Boot:Boot引脚默认置于输入下拉模式 2、GPIO输出状态&#…

Prometheus+grafana安装配置

Prometheus安装配置 Prometheus下载地址 官方地址:Download | Prometheus 可根据系统版本下载想要的安装包,复制链接地址 wget https://github.com/prometheus/prometheus/releases/download/v2.33.3/prometheus-2.33.3.linux-amd64.tar.gzwg 解压pr…

protues仿真时有时候串口虚拟中端不弹窗的问题

在使用proteus的时候,有时候你会发现点击调试开始运行后,串口虚拟终端没有自动弹窗的问题,其实照成这种现象的原因是你在使用的过程中移动了器件位置或者是对整个视窗使用鼠标滚动进行缩放了,如果要重新弹窗则需要进行以下操作: …

手撕 视觉slam14讲 ch13 代码(2)基本类的抽象

在正式写系统之前,我们在上一篇分析了基本的3个类:帧、2D特征点、3D地图点,这次我们开始代码实现这些基本数据结构: 1.帧类 常见的SLAM系统中的帧(Frame)需要包含以下信息:id,位姿…

容器技术Linux Namespaces和Cgroups

对操作系统了解多少,仅仅敲个命令吗 操作系统虚拟化(容器技术)的发展历程 1979 年,UNIX 的第 7 个版本引入了 Chroot 特性。Chroot 现在被认为是第一个操作系统虚拟化(Operating system level virtualization&#x…

CSS使两个不同的div居中对齐的三种解决方案

在CSS中,有多种方法可以让两个不同的div居中对齐,包括相对定位和绝对定位。以下是两种常见的方法: 方法一:使用Flexbox Flexbox是一个用于创建灵活布局的CSS3模块。使用Flexbox,可以很容易地对元素进行居中对齐。 H…

C++基础语法——多态

1.什么是多态? 多态是面向对象编程中的一个概念,它允许不同的对象对同一个消息作出不同的响应。简单来说,多态是指同一种操作或方法可以在不同的对象上产生不同的行为。这种灵活性使得代码更加可扩展和可维护。在多态中,对象的类型…

弯道超车必做好题集锦三(C语言编程题)

目录 前言: 1.单词倒排 方法1:scanf匹配特定字符法 方法2: 双指针法 2.统计每个月兔子的总数 方法1:斐波那契数列 方法2:斐波那契的递归 3.珠玑妙算 方法:遍历 4.寻找奇数(单身狗&#…

【图解算法数据结构】分治算法篇 + Java代码实现

文章目录 一、重建二叉树二、数值的整数次方三、打印从 1 到最大的 n 位数四、二叉搜索树的后序遍历序列五、数组中的逆序对 一、重建二叉树 public class Solution {int[] preorder;HashMap<Integer, Integer> dic new HashMap<>();public TreeNode buildTree(in…

可视化流程设计平台有啥优势?

在流程化办公发展趋势逐渐明朗的今天&#xff0c;运用什么样的平台可以帮助广大用户朋友实现这一目标&#xff1f;可视化流程设计平台是轻量级、更灵活、易操作、效率高的平台&#xff0c;可以快速定制客户专属的框架平台&#xff0c;为每一位客户朋友做好数据管理&#xff0c;…