如何改进YOLOv5主干网络

news2025/2/4 0:57:35

D:\yolov5-master\models目录下新建mobilevit.py文件夹

代码内容:

import torch
import torch.nn as nn
from einops import rearrange
def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn # mg
    
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)# mg
        ) if project_out else nn.Identity()
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b p h n d -> b p n (h d)')
        return self.to_out(out)
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)
class UserDefined(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
            ]))
    
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x
class MobileViT(nn.Module):
    def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size
        self.mv01 = IRBlock(channel, channel) 
        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv2 = conv_1x1_bn(channel, dim)
        self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
    def forward(self, x):
        y = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        z = x.clone()
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        x = self.transformer(x)
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
        x = self.conv3(x)
        x = torch.cat((x, z), 1)
        x = self.conv4(x)
        x = x + y
        x = self.mv01(x)
        return x
class IRBlock(nn.Module):
    def __init__(self, inp, oup, stride=1, expansion=4):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]
        hidden_dim = int(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup
        if expansion == 1: # 构建没有扩展层的卷积块
            self.conv = nn.Sequential(
                # 深度可分离卷积(Depthwise Convolution)
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # “线性”逐点卷积 (Pointwise-Linear Convolution)
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:  # 构建包含扩展层的卷积块
            self.conv = nn.Sequential(
                # 逐点卷积 (Pointwise Convolution)
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # 深度可分离卷积 (Depthwise Convolution)
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # “线性”逐点卷积 (Pointwise-Linear Convolution)
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

代码解析见:

【DeepLearning-1】 注意力机制(Attention Mechanism)-CSDN博客 

【DeepLearning-2】预归一化(Pre-Normalization)策略_prenorm 代码-CSDN博客

【DeepLearning-3】前馈(feed-forward)神经网络层-CSDN博客

【DeepLearning-5】基于Transformer架构的自定义神经网络类-CSDN博客

【DeepLearning-6】实现倒置残差块(Inverted Residual Block)-CSDN博客

【DeepLearning-7】 CNN 和Transformer的混合神经网络结构-CSDN博客

D:\yolov5-master\models\yolo.py文件中加入MobileViT模型

【DeepLearning-10】yolo.py文件关键代码parse_model(d, ch)函数-CSDN博客

YOLOv5模型网络结构中加入MobileViT模块

 【DeepLearning-9】YOLOv5模型网络结构中加入MobileViT模块-CSDN博客

命令窗代码训练网络 

python train.py --cfg D:\yolov5-master\models\yolov5l.yaml

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1427386.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python库Bleach:保护应用免受XSS攻击

Python库Bleach:保护应用免受XSS攻击 在当今的网络环境中,跨站脚本攻击(XSS)是一种常见而严重的安全威胁。为了保护我们的应用程序免受XSS攻击,我们可以使用Python库Bleach。本文将介绍Bleach库的基本概念、功能和用法…

物联网浏览器(IoTBrowser)-Modbus协议集成和测试

Modbus协议在应用中一般用来与PLC或者其他硬件设备通讯,Modbus集成到IoTBrowser使用串口插件模式开发,不同的是采用命令函数,具体可以参考前面几篇文章。目前示例实现了Modbus-Rtu和Modbus-Tcp两种,通过js可以与Modbus进行通讯控制…

其他发现:开源数据可视化分析工具DataEase介绍文档

一、 简介 DataEase 是开源的数据可视化分析工具,帮助用户快速分析数据并洞察业务趋势,从而实现业务的改进与优化。DataEase 支持丰富的数据源连接,能够通过拖拉拽方式快速制作图表,并可以方便地与他人分享。 二、 优势 1、 开…

计算机视觉实战项目4(单目测距与测速+摔倒检测+目标检测+目标跟踪+姿态识别+车道线识别+车牌识别+无人机检测+A_路径规划+行人车辆计数+动物识别等)

基于YOLOv5的无人机视频检测与计数系统 摘要: 无人机技术的快速发展和广泛应用给社会带来了巨大的便利,但也带来了一系列的安全隐患。为了实现对无人机的有效管理和监控,本文提出了一种基于YOLOv5的无人机视频检测与计数系统。该系统通过使用…

UE4 C++ 数据表

//添加使用DataTable需要的头文件 #include "Engine/DataTable.h"//基于结构体变量类型,创建数据表DataTable类型 USTRUCT(BlueprintType) struct FMyDataTableStruct : public FTableRowBase //把结构体变量公开到数据表类型 {GENERATED_BODY() //必须添…

前后端分离,RSA加密传输方案

1.原理 RSA是一种非对称加密算法。通过生成密钥对,用公钥加密,用私钥解密。对于前后端分离的项目,让前端获取到公钥对敏感数据加密,发送到后端,后端用私钥对加密后的数据进行解密即可。 2.实现 RSA工具类&#xff1…

MQ回顾之rabbitmq速通

rabbitMQ相对来说功能比较完善,吞吐量会低一点。 持续更新…… 安装 docker 测试选择docker安装 官方安装操作 1、docker pull rabbitmq:latest 2、docker run -d --hostname my-rabbit --name some-rabbit -p 15672:15672 -p 5672:5672 rabbitmq 3、docker…

C/C++ C++入门

个人主页:仍有未知等待探索-CSDN博客 专题分栏:C_仍有未知等待探索的博客-CSDN博客 目录 一、C关键字 二、命名空间 1、区别 1. C语言 ​编辑 2. C 2、命名空间定义 3、命名空间的使用 三、C输入&输出 四、缺省参数 五、函数重载 六、引用 …

80.如何评估一台服务器能承受的最大TCP连接数

文章目录 一、一个服务端进程最多能支持多少条 TCP 连接?二、一台服务器最大最多能支持多少条 TCP 连接?三、总结 一个服务端进程最大能支持多少条 TCP 连接? 一台服务器最大能支持多少条 TCP 连接? 很多朋友可能第一反应就是端…

【RT-DETR有效改进】利用YOLO-MS的MSBlock模块改进ResNet中的Bottleneck(RT-DETR深度改进)

👑欢迎大家订阅本专栏,一起学习RT-DETR👑 一、本文介绍 本文给大家带来的改进机制是利用YOLO-MS提出的一种针对于实时目标检测的MSBlock模块(其其实不能算是Conv但是其应该是一整个模块),我们将其用于替换我们ResNet中Basic组合出一种新的结构,来替换我们网络中的…

Spring:JDBCTemplate 的源码分析

一:JdbcTemplate的简介 JdbcTemplate 是 Spring Template设置模式中的一员。类似的还有 TransactionTemplate、 MongoTemplate 等。通过 JdbcTemplate 我们可以使得 Spring 访问数据库的过程简单化。 二:执行SQL语句的方法 1:在JdbcTempla…

智能末世战争之机器人的反击

在遥远的未来,地球陷入了一场空前的战争。这场战争不同于以往的任何战争,因为这是由人工智能和机器人主导的战争。在战争爆发之前,人类一直依赖AI和机器人来提高生产效率和生活质量。然而,随着AI技术的飞速发展,机器人…

H5 简约四色新科技风引导页源码

H5 简约四色新科技风引导页源码 源码介绍:一款四色切换自适应现代科技风动态背景的引导页源码,源码有主站按钮,分站按钮2个,QQ联系站长按钮一个。 下载地址: https://www.changyouzuhao.cn/11990.html

flinkcdc 3.0 尝鲜

本文会将从环境搭建到demo来全流程体验flinkcdc 3.0 包含了如下内容 flink1.18 standalone搭建doris 1fe1be 搭建整库数据同步测试各同步场景从检查点重启同步任务 环境搭建 flink环境(Standalone模式) 下载flink 1.18.0 链接 : https://archive.apache.org/dist/flink/flink…

【大数据】专业融合型人才迎来发展良机-国家数据局正式揭牌

⭐简单说两句⭐ 作者:后端小知识 CSDN个人主页:后端小知识 🔎GZH:后端小知识 🎉欢迎关注🔎点赞👍收藏⭐️留言📝 摘要: 新华社北京10月26日电 《中国证券报》26日刊发文章…

shell - 正则表达式和grep命令和sed命令

一.正则表达式概述 1.正则表达式定义 1.1 定义 使用字符串描述、匹配一系列符合某个规则的字符串 1.2 了解 普通字符: 大小写字母、数字、标点符号及一些其它符号元字符: 在正则表达式中具有特殊意义的专用字符 1.3 层次分类 基础正则表达式扩展正…

git修改密码后mac使用sourceTree出现Authentication failed错误

1、退出sourceTree 2、在钥匙串中删除git对应站点Access Key 3、执行命令:git config --system --unset credential.helper 4、重新启动sourceTree,这时会弹出输入密码框,重新输入密码即可

react 之 UseMemo

useMemo 看个场景 下面我们的本来的用意是想基于count的变化计算斐波那契数列之和,但是当我们修改num状态的时候,斐波那契求和函数也会被执行,显然是一种浪费 // useMemo // 作用:在组件渲染时缓存计算的结果import { useState …

树——二叉搜索树

二叉搜索树 概述 随着计算机算力的提升和对数据结构的深入研究,二叉搜索树也不断被优化和扩展,例如AVL树、红黑树等。 特性 二叉搜索树(也称二叉排序树)是符合下面特征的二叉树: 树节点增加 key 属性,用来…

Git介绍与常用命令总结

Git介绍与其常用命令总结 1、Git介绍2、Git的使用3、Git常用命令3.1 初始化仓库3.2 克隆仓库3.3 配置用户信息3.4 提交代码(Commit)3.5 推送代码(Push)3.6 拉取代码(Pull)3.7 分支(Branch)3.8 远程仓库(Remote)3.9 撤销回退本地改动3.10 更新本地仓库与远程仓库 1、Git介绍 Gi…