YOLO5的修改

news2024/12/23 9:34:14

在传统的yolov5网络中并不存在注意力机制,但是源代码中存在相关简略的代码:

    def __init__(self, c, num_heads):
        """
        Initializes a transformer layer, sans LayerNorm for performance, with multihead attention and linear layers.

        See  as described in https://arxiv.org/abs/2010.11929.
        """
        super().__init__()
        self.q = nn.Linear(c, c, bias=False)
        self.k = nn.Linear(c, c, bias=False)
        self.v = nn.Linear(c, c, bias=False)
        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
        self.fc1 = nn.Linear(c, c, bias=False)
        self.fc2 = nn.Linear(c, c, bias=False)

    def forward(self, x):
        """Performs forward pass using MultiheadAttention and two linear transformations with residual connections."""
        x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
        x = self.fc2(self.fc1(x)) + x
        return x

实现了一个带有多头注意力机制和线性变换的Transformer层,并使用残差连接来增强模型的学习能力。

我们将代码修改为:

class TransformerLayer(nn.Module):
    """Transformer layer with multihead attention and linear layers, optimized by removing LayerNorm.

    Args:
        c (int): The dimension of the input embeddings.
        num_heads (int): The number of heads in the multiheadattention models.

    Returns:
        torch.Tensor: The output tensor after transformation.
    """

    def __init__(self, c, num_heads):
        """
        Initializes a transformer layer, sans LayerNorm for performance, with multihead attention and linear layers.

        Args:
            c (int): The dimension of the input embeddings.
            num_heads (int): The number of heads in the multiheadattention models.

        Raises:
            ValueError: If `c` or `num_heads` is not a positive integer.
        """
        super().__init__()
        if not isinstance(c, int) or c <= 0:
            raise ValueError("c must be a positive integer")
        if not isinstance(num_heads, int) or num_heads <= 0:
            raise ValueError("num_heads must be a positive integer")

        self.q = nn.Linear(c, c, bias=False)
        self.k = nn.Linear(c, c, bias=False)
        self.v = nn.Linear(c, c, bias=False)
        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
        self.fc1 = nn.Linear(c, c, bias=False)
        self.fc2 = nn.Linear(c, c, bias=False)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

        # Initialize weights
        nn.init.xavier_uniform_(self.q.weight)
        nn.init.xavier_uniform_(self.k.weight)
        nn.init.xavier_uniform_(self.v.weight)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        """Performs forward pass using MultiheadAttention and two linear transformations with residual connections and activation functions.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after transformation.
        """
        try:
            attn_output, _ = self.ma(self.q(x), self.k(x), self.v(x))
            x = attn_output + x
            x = self.fc2(self.relu(self.fc1(x))) + x
            x = self.dropout(x)
            return x
            print("dropout=0.5")
        except Exception as e:
            print(f"Error during forward pass: {e}")
            raise

改进后的代码有以下优势:

  • 对参数进行有效化检查(显然不需要,为了从字数呗)。
  • 在两个线性层之间使用ReLU激活函数,引入非线性层,使得模型能够更好的拟合复杂的函数关系。
  • 引入dropout层,减少过拟合。(不过这个参数可以调少一点)。
  • 对线性层的各个权重进行Xavier均匀初始化(详细公式),加快收敛速度。

我在head的num方面取数量为8,即8头注意力机制。包含6个transformer层。

在不修改的时候运行的结果如下:

如果不对transformer层进行改进,而只是添加了该层进行训练,结果如下:

不巧的是,效果有显著的下降。

改进之后训练的结果如下:

调一下drop试试:(此时dropout=0.1)

应该是代码错了,修改下代码重来,代码修改如下:

import torch
import torch.nn as nn

class TransformerLayer(nn.Module):
    """Transformer layer with multihead attention and linear layers, optimized by removing LayerNorm."""

    def __init__(self, c, num_heads, dropout_rate=0.1):
        """
        Initializes a transformer layer, with multihead attention and linear layers.

        Args:
            c (int): The number of input/output channels.
            num_heads (int): The number of attention heads.
            dropout_rate (float): The dropout rate for regularization.
        """
        super().__init__()
        self.q = nn.Linear(c, c, bias=False)
        self.k = nn.Linear(c, c, bias=False)
        self.v = nn.Linear(c, c, bias=False)
        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
        self.fc1 = nn.Linear(c, c, bias=False)
        self.fc2 = nn.Linear(c, c, bias=False)
        self.dropout = nn.Dropout(dropout_rate)  # Dropout layer
        self.activation = nn.GELU()  # Activation function

    def forward(self, x):
        """Performs forward pass using MultiheadAttention and two linear transformations with residual connections."""
        # Multihead attention
        attn_output = self.ma(self.q(x), self.k(x), self.v(x))[0]
        x = attn_output + x  # Residual connection
        x = self.dropout(x)  # Apply dropout

        # Feedforward network
        x = self.fc1(x)
        x = self.activation(x)  # Activation function
        x = self.fc2(x) + attn_output  # Residual connection
        return x

省去了xavier均匀初始化,在第一个线性层后面添加了GELU激活函数,增强非线性的表达能力。每个主要操作后面都添加了残差连接,有助于梯度流动。(希望能有用吧,求求了!!!)

改进的结果如下:(算了,还是在xavier均匀化的基础上改进吧)。

 在每一个layer之后加一个规范化之后帮助稳定训练过程,使用 kaiming_uniform_ 方法进行权重初始化,结果如下:(换一种思路)

激活函数换一下试试:

只有在召回率方面比较有优势。但是在其他方面有所下降,所以要换一种模块的更新。

 另一个方向的修改就是在原本的模块中加入一层注意力机制,但是不幸运的是,效果却是下降了,注意力机制层如下:

然而结果却是各个指标都有所弱化,效果明显不好,有可能是因为训练的轮数不够导致的,(可以弥补学习率的调整不周到),所以我放大了训练的轮数到1000,重新进行训练,训练后的结果相比100轮的yolov5s有所进步(稍微一点,但是在分类上看,yolov5s表现的很强劲):

相比100轮的SE有大的进步,说明在100轮时模型还未达到收敛,不过不能由此证明SE有明显的优势,我将yolov5s也训练了1000轮对比一下:

可见SE还是不太适合作为一个可行的修改选项的。(也可能是网络结果配置的不合适)。

由于yolo的显著优势就是能快速收敛,所以对于模型的速度要求很高,这里我们就固定取100轮,在使用SE的基础上,我们将所有C3层改为C2f层。

简要介绍一下C3层,C3层使用残差网络链接能解决网络中的梯度消失问题

也就是三个卷积加一个瓶颈层(包含卷积核残差块)。输出是将输入先放到第一个卷积中,然后将得到的结果放到瓶颈层中,同时与原始输入经过第二个卷积层得到的结果相拼接。

模型并不复杂,同时实现多尺度的特征融合。

C2f相对复杂,将输入层经过多次分割,更好的融合了不同层次的特征。(但是仅有理论不能确定其在yolov5中是否更加有效,在此我将源代码中的C3全部替换变成C2f(包含主干部分和头部分),结果效果仍然是变差的)),当然也有可能是我选择添加的位置不对。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2203636.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

prometheus client_java实现进程的CPU、内存、IO、流量的可观测

文章目录 1、获取进程信息的方法1.1、通过读取/proc目录获取进程相关信息1.2、通过Linux命令获取进程信息1.2.1、top&#xff08;CPU/内存&#xff09;命令1.2.2、iotop&#xff08;磁盘IO&#xff09;命令1.2.3、nethogs&#xff08;流量&#xff09;命令 2、使用prometheus c…

tableau除了图表好看,在业务中真有用吗?

tableau之前的市值接近150亿美金&#xff0c;被saleforce以157亿美金收购&#xff0c;这个市值和现在的蔚来汽车差不多。 如果tableau仅仅是个show的可视化工具&#xff0c;必然不会有这么高的市值&#xff0c;资本市场的眼睛是雪亮的。 很多人觉得tableau做图表好看&#xff…

分布式常见面试题总结

文章目录 1 什么是 UUID 算法&#xff1f;2 什么是雪花算法&#xff1f;&#x1f525;3 说说什么是幂等性&#xff1f;&#x1f525;4 怎么保证接口幂等性&#xff1f;&#x1f525;5 paxos算法6 Raft 算法7 CAP理论和 BASE 理论7.1 CAP 理论&#x1f525;7.2 为什么无法同时保…

Echarts合集更更更之树图

实现效果 写在最后&#x1f352; 源码&#xff0c;关注&#x1f365;苏苏的bug&#xff0c;&#x1f361;苏苏的github&#xff0c;&#x1f36a;苏苏的码云

DGL库之HGTConv的使用

DGL库之HGTConv的使用 论文地址和异构图构建教程HGTConv语法格式HGTConv的使用 论文地址和异构图构建教程 论文地址&#xff1a;https://arxiv.org/pdf/2003.01332 异构图构建教程&#xff1a;异构图构建 异构图转同构图&#xff1a;异构图转同构图 HGTConv语法格式 dgl.nn.…

极客兔兔Gee-Cache Day7

protobuf配置&#xff1a; 从 Protobuf Releases 下载最先版本的发布包安装。解压后将解压路径下的 bin 目录 加入到环境变量即可。 如果能正常显示版本&#xff0c;则表示安装成功。 $ protoc --version libprotoc 3.11.2在Golang中使用protobuf&#xff0c;还需要protoc-g…

【单链表的模拟实现Java】

【单链表的模拟实现Java】 1. 了解单链表的功能2. 模拟实现单链表的功能2.1 单链表的创建2.2 链表的头插2.3 链表的尾插2.3 链表的长度2.4 链表的打印2.5 在指定位置插入2.6 查找2.7 删除第一个出现的节点2.8 删除出现的所有节点2.9 清空链表 3. 正确使用模拟单链表 1. 了解单链…

重头开始嵌入式第四十八天(Linux内核驱动 linux启动流程)

目录 什么是操作系统&#xff1f; 一、管理硬件资源 二、提供用户接口 三、管理软件资源 什么是操作系统内核&#xff1f; 一、主要功能 1. 进程管理&#xff1a; 2. 内存管理&#xff1a; 3. 设备管理&#xff1a; 4. 文件系统管理&#xff1a; 二、特点 什么是驱动…

WebGoat JAVA反序列化漏洞源码分析

目录 InsecureDeserializationTask.java 代码分析 反序列化漏洞知识补充 VulnerableTaskHolder类分析 poc 编写 WebGoat 靶场地址&#xff1a;GitHub - WebGoat/WebGoat: WebGoat is a deliberately insecure application 这里就不介绍怎么搭建了&#xff0c;可以参考其他…

yq 工具

文章目录 yq命令快速 Recipes查找数组中的项目查找并更新数组中的项目深度修剪一棵树对数组中的项目进行多次或复杂的更新按字段对数组进行排序 OperatorsOmitOmit keys from mapOmit indices from array DeleteDelete entry in mapDelete nested entry in mapDelete entry in …

【重学 MySQL】六十三、唯一约束的使用

【重学 MySQL】六十三、唯一约束的使用 创建表时定义唯一约束示例 在已存在的表上添加唯一约束示例 删除唯一约束示例 复合唯一约束案例背景创建表并添加复合唯一约束插入数据测试总结 特点注意事项 在 MySQL 中&#xff0c;唯一约束&#xff08;UNIQUE Constraint&#xff09;…

butterfly主题留言板 报错记录 未解决

新建留言板&#xff0c;在博客根目录执行下面的命令 hexo new page messageboard 在博客/source/messageboard的文件夹下找到index.md文件并修改 --- title: 留言板 date: 2018-01-05 00:00:00 type: messageboard ---找到butterfly主题下的_config.yml文件 把留言板的注释…

基于springboot+小程序的智慧物流管理系统(物流1)

&#x1f449;文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1、项目介绍 基于springboot小程序的智慧物流管理系统实现了管理员、司机及用户。 1、管理员实现了司机管理、用户管理、车辆管理、商品管理、物流信息管理、基础数据管理、论坛管理、公告信息管理等。…

帮助自闭症孩子融入社会,寄宿学校是明智选择

在广州这座充满活力与温情的城市&#xff0c;有一群特殊的孩子&#xff0c;他们被称为“星星的孩子”——自闭症儿童。自闭症&#xff0c;这个让人既陌生又熟悉的名词&#xff0c;背后承载的是无数家庭的辛酸与希望。对于自闭症儿童来说&#xff0c;融入社会、与人交流、理解世…

【Linux第一弹】- 基本指令

&#x1f308; 个人主页&#xff1a;白子寰 &#x1f525; 分类专栏&#xff1a;重生之我在学Linux&#xff0c;C打怪之路&#xff0c;python从入门到精通&#xff0c;数据结构&#xff0c;C语言&#xff0c;C语言题集&#x1f448; 希望得到您的订阅和支持~ &#x1f4a1; 坚持…

blender 记一下lattice

这个工具能够辅助你捏形状 这里演示如何操作BOX shift A分别创建俩对象一个BOX 一个就是lattice对象 然后在BOX的修改器内 创建一个叫做lattice的修改器 然后指定object为刚刚创建的lattice对象 这样就算绑定好了 接下来 进入lattice的编辑模式下 你选取一个点进行运动&#…

量化交易与基础投资工具介绍

&#x1f31f;作者简介&#xff1a;热爱数据分析&#xff0c;学习Python、Stata、SPSS等统计语言的小高同学~&#x1f34a;个人主页&#xff1a;小高要坚强的博客&#x1f353;当前专栏&#xff1a;Python之机器学习《Python之量化交易》Python之机器学习&#x1f34e;本文内容…

谈谈留学生毕业论文如何分析问卷采访数据

留学生毕业论文在设计好采访问题并且顺利进行了采访之后&#xff0c;我们便需要将得到的采访答案进行必要的分析&#xff0c;从而得出一些结论。我们可以通过这些结论回答研究问题&#xff0c;或者提出进一步的思考等等。那么我们应当如何分析采访数据呢&#xff1f;以下有若干…

python3开头如何设置utf-8

编码格式1&#xff1a; 在源文件第一行或者第二行定义&#xff1a; # coding<encoding name> 例如&#xff1a; # codingutf-8 编码格式2&#xff1a;&#xff08;这种最流行&#xff09; 格式如下&#xff1a; #!/usr/bin/python # -*- coding: <encoding name>…

信息安全工程师(43)入侵检测概述

一、定义与目的 入侵检测&#xff08;Intrusion Detection&#xff09;是指通过对行为、安全日志、审计数据或其他网络上可以获得的信息进行操作&#xff0c;检测到对系统的闯入或闯入的企图。其主要目的是确保网络安全和信息安全&#xff0c;保护个人和机构的敏感数据免受未经…