修改YOLOv5的模型结构第二弹

news2024/11/24 7:50:09
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

上节说到了通过修改YOLOv5的common.py来修改模型的结构,修改的是模块的内部结构,具体某些模块组织顺序等,如插入一个新的模型,则需要在yolo.py文件中修改

yolo.py文件

yolo.py中有几个主要的组成部分

parse_model函数

主要负责读取传入的–cfg配置指定的模型配置文件。例如经常使用的yolov5s.yaml,通过配置文件创建实际的模块对象,并将模块拼接起来。

Detect类

主要用来构建Detect层,将输入的feature map通过一个卷积操作和公式计算得到想要的shape,为后面计算损失或者NMS做准备

Model类

这个类实现的是整个模型的搭建。YOLOv5的作者在其中还加入了很多功能,例如:特征可视化、打印模型信息、TTA推理增强、融合Conv+Bn加速推理、模型搭载NMS功能、autoshape函数等等。

修改模型

任务

参考C3模块,创建一个C2模块,并插入到模型的第二、三层之间
C2模型
总模型结构
可以发现C2模块就是上一篇文章中,对模型C3模块的修改结果。

步骤

由于要插入一个新的模块,首先就是要在commm.py里仿造C3模块,新增一个可以创建C2模块的方法

class C3(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
    
class C2(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_1 = c2 // 2
        c_2 = c2 - c1  # hidden channels
        self.cv1 = Conv(c1, c_2, 1, 1)
        self.cv2 = Conv(c1, c_1, 1, 1)
        self.m = nn.Sequential(*(Bottleneck(c_2, c_2, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        return torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)

然后修改yolo.py中构建模型的部分,增加C2模块

def parse_model(d, ch):  # model_dict, input_channels(3)
    # Parse a YOLOv5 model.yaml dictionary
    LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")
    anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        LOGGER.info(f"{colorstr('activation:')} {act}")  # print
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)

    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
        m = eval(m) if isinstance(m, str) else m  # eval strings
        for j, a in enumerate(args):
            with contextlib.suppress(NameError):
                args[j] = eval(a) if isinstance(a, str) else a  # eval strings

        n = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gain
        if m in {
                Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                BottleneckCSP, C3, C2, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
            c1, c2 = ch[f], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)

            args = [c1, c2, *args[1:]]
            if m in {BottleneckCSP, C3,C2, C3TR, C3Ghost, C3x}:
                args.insert(2, n)  # number of repeats
                n = 1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        # TODO: channel, gw, gd
        elif m in {Detect, Segment}:
            args.append([ch[x] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
            if m is Segment:
                args[3] = make_divisible(args[3] * gw, 8)
        elif m is Contract:
            c2 = ch[f] * args[0] ** 2
        elif m is Expand:
            c2 = ch[f] // args[0] ** 2
        else:
            c2 = ch[f]

        m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace('__main__.', '')  # module type
        np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
        LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

最后修改模型的配置文件,将yolov5s.yaml另存为yolov5s_aug.yaml并修改

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 3, C2, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

如此改造就完成了,可以使用训练集训练一下当前的yolov5s_aug模型
训练的结果如下:
修改后的模型

对比没有修改前的模型训练结果

修改前模型
对比发现增加这层模块没有对整个模型带来好的结果,每种分类的检测结果都变得更差了,上篇文章也分析过,C2模块由于缺少了最后的卷积重排,会导致特征的表达变差。

根据经验,cat操作后面最好是经过全连接层或者卷积再提取一下特征,直接使用的效果比较差。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1220611.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023年【陕西省安全员B证】考试题库及陕西省安全员B证找解析

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 陕西省安全员B证考试题库是安全生产模拟考试一点通生成的&#xff0c;陕西省安全员B证证模拟考试题库是根据陕西省安全员B证最新版教材汇编出陕西省安全员B证仿真模拟考试。2023年【陕西省安全员B证】考试题库及陕西省…

分布式事务seata的使用

分布式事务介绍 在微服务架构中&#xff0c;完成某一个业务功能可能需要横跨多个服务&#xff0c;操作多个数据库。这就涉及到到了分布式事务&#xff0c;需要操作的资源位于多个资源服务器上&#xff0c;而应用需要保证对于多个资源服务器的数据操作&#xff0c;要么全部成功&…

深度学习_14_单层|多层感知机及代码实现

单层感知机&#xff1a; 功能&#xff1a; 能完成二分类问题 问题&#xff1a; 模型训练采用X*W b训练出模型&#xff0c;对数据只进行了一层处理&#xff0c;也就是说训练出来的模型仅是个线性模型&#xff0c;它无法解决XOR问题&#xff0c;所以模型在训练效果上&#xf…

cookie机制 + java 案例

目录 为什么会有cookie?? cookie从哪里来的&#xff1f;&#xff1f; cookie到哪里去&#xff1f;&#xff1f; cookie有啥用&#xff1f;&#xff1f; session HttpServletRequest类中的相关方法 简单的实现cookie登录功能 实现登录页面 实现servlet逻辑 实现生成主…

【Spring】依赖注入方式,DI的方式

这里写目录标题 1. setter注入在一个类中注入引用类型在一个类中注入简单类型 2. 构造器注入在一个类中注入引用类型在一个类中注入简单类型 3. 依赖注入方式选择4. 依赖自动装配按类型注入按名称注入 5. 集合注入 1. setter注入 在一个类中注入引用类型 回顾一下之前setter注…

Python基础:输入输出详解-输出字符串格式化

Python中的输入和输出是编程中非常重要的方面。 1. 输入输出简单介绍 1.1 输入方式 Python中的输入可以通过input()函数从键盘键入&#xff0c;也可以通过命令行参数或读取文件的方式获得数据来源。 1&#xff09;input()示例 基本的input()函数&#xff0c;会将用户在终端&…

力扣栈与队列--总结篇

前言 八道题&#xff0c;没想到用了五天。当然需要时间的沉淀&#xff0c;但是一天不能啥也不干啊&#xff01; 内容 首先得熟悉特点和基本操作。 栈与队列在计算机底层中非常重要&#xff0c;这就是为什么要学好数据结构。 可视化的软件例如APP、网站之类的&#xff0c;都…

Kotlin原理+协程基本使用

协程概念 协程是Coroutine的中文简称&#xff0c;co表示协同、协作&#xff0c;routine表示程序。协程可以理解为多个互相协作的程序。协程是轻量级的线程&#xff0c;它的轻量体现在启动和切换&#xff0c;协程的启动不需要申请额外的堆栈空间&#xff1b;协程的切换发生在用…

AVL树的底层实现

文章目录 什么是AVL树&#xff1f;平衡因子Node节点插入新节点插入较高左子树的左侧新节点插入较高左子树的右侧新节点插入较高右子树的左侧新节点插入较高右子树的右侧 验证是否为平衡树二叉树的高度AVL的性能 什么是AVL树&#xff1f; AVL树又称平衡二叉搜索树&#xff0c;相…

腾讯云服务器便宜吗?腾讯云服务器怎么买便宜?附优惠链接

首先&#xff0c;咱们来看一下大家最关心的一个问题&#xff1a;“腾讯云服务器便宜吗&#xff1f;”我的答案是&#xff1a;“YES&#xff01;它真的很便宜&#xff01;”比如&#xff0c;轻量2核2G3M服务器&#xff0c;1年只需要88元&#xff0c;是不是很划算&#xff1f;再比…

实例解释遇到前端报错时如何排查问题

前端页面报错&#xff1a; 1、页面报错500&#xff0c;首先我们可以知道是服务端的问题&#xff0c;需要去看下服务端的报错信息&#xff1a; 2、首先我们查看下前端是否给后端传了id: 我们可以看到接口是把ID返回了&#xff0c;就需要再看下p_id是什么情况了。 3、我们再次请…

基于ssm+vue的程序设计课程可视化教学系统设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…

双11背后的中国云厂商:新“标准化”,和调整后的新韧性

降价并不代表一味的压缩自身利润空间&#xff0c;云厂商已经开始向具有更高利润空间的PaaS、SaaS产品腾挪&#xff0c;核心产品在总包占比越来越高。 作者|斗斗 编辑|皮爷 出品|产业家 今年云厂商&#xff0c;全面拥抱双11。 作为中国最大的云计算服务提供商&#xff0…

关于苏州立讯公司国产替代案例(使用我公司H82409S网络变压器和E1152E01A-YG网口连接器产品)

关于苏州立讯公司国产替代案例&#xff08;使用我们公司的H82409S网络变压器和E1152E01A-YG网口连接器产品&#xff09; 苏州立讯公司是一家专注于通信设备制造的企业&#xff0c;他们在其产品中选择了我们公司的H82409S网络变压器和E1152E01A-YG网口连接器&#xff0c;以实现…

在线协作工具都有哪些?推荐这10款

如今&#xff0c;互联网的快速发展不仅改变了我们的生活方式&#xff0c;也改变了我们的工作方式。 特别是对于一些与产品设计相关的公司或团体&#xff0c;网络不仅为其设计提供了稳定的资源和灵感&#xff0c;而且为成员之间的沟通和合作提供了更大的便利。 如果您也需要为…

ke11介绍本地,会话存储

代码顺序: 1.设置input,捕获input如果有多个用属性选择符例如 input[typefile]点击事件.向我们的本地存储设置键值对 2.在点击事件外面设置本地存储表示初始化的值.点击上面的事件才能修改我们想修改的值 会话(session)浏览a数据可以写到本地硬盘,关闭页面数据就没了 本地(…

JPA整合Sqlite解决Dialect报错问题, 最新版Hibernate6

前言 我个人项目中&#xff0c;不想使用太重的数据库&#xff0c;而内嵌数据库中SQLite又是最受欢迎的&#xff0c; 因此决定采用这个数据库。 可是JPA并不支持Sqlite&#xff0c;这篇文章就是记录如何解决这个问题的。 原因 JPA屏蔽了底层的各个数据库差异&#xff0c; 但是…

使用Python的time库来格式化时间

目录 一、引言 二、time库简介 三、使用time库来格式化时间 四、使用time库进行时间戳转换 五、使用time库获取当前时间 六、使用time库进行延时操作 七、使用time库计算时间差 八、使用time库获取系统时间 九、使用time库的其他功能 总结 一、引言 在Python中&…

鸿蒙OS应用开发初体验

什么是HarmonyOS&#xff1f; HarmonyOS&#xff08;鸿蒙操作系统&#xff09;是华为公司开发的一款基于微内核的分布式操作系统。它是一个面向物联网&#xff08;IoT&#xff09;时代的全场景操作系统&#xff0c;旨在为各种类型的设备提供统一的操作系统平台和开发框架。Har…

自媒体剪辑必备,6个音效素材网站,你值得拥有。

这6个剪辑必备的音效素材网站一定要收藏好了&#xff0c;有了这几个网站能让你在剪辑的时候事半功倍&#xff0c;还不用担心版权问题。话不多说&#xff0c;直接上干货。 1、菜鸟图库 https://www.sucai999.com/audio.html?vNTYwNDUx 菜鸟图库是一个综合性素材网站&#xff…