yolov5增加AFPN-全新特征融合模块AFPN,效果完胜PAFPN

news2025/1/11 13:04:38

论文学习:AFPN: Asymptotic Feature Pyramid Network for Object Detection-全新特征融合模块AFPN,完胜PAFPN_athrunsunny的博客-CSDN博客

先上配置文件yolov5s-AFPN.yaml

# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[4, 1, Conv, [64, 1, 1]], 
   [6, 1, Conv, [128, 1, 1]], 

   [[10, 11], 1, ASFF_2, [64, 0]], 
   [[10, 11], 1, ASFF_2, [128, 1]], 

   [-2, 1, C3, [64, False]], 
   [-2, 1, C3, [128, False]], 

   [9, 1, Conv, [256, 1, 1]],

   [[14, 15, 16], 1, ASFF_3, [64, 0]],
   [[14, 15, 16], 1, ASFF_3, [128, 1]],
   [[14, 15, 16], 1, ASFF_3, [256, 2]],

   [17, 1, C3, [64, False]],
   [18, 1, C3, [128, False]],
   [19, 1, C3, [256, False]],
   [[20, 21, 22], 1, Detect, [nc, anchors]]
]

在models/common.py增加

class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super(Upsample, self).__init__()

        self.upsample = nn.Sequential(
            Conv(in_channels, out_channels, 1),
            nn.Upsample(scale_factor=scale_factor, mode='bilinear')
        )

        # carafe
        # from mmcv.ops import CARAFEPack
        # self.upsample = nn.Sequential(
        #     BasicConv(in_channels, out_channels, 1),
        #     CARAFEPack(out_channels, scale_factor=scale_factor)
        # )

    def forward(self, x):
        x = self.upsample(x)

        return x

class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels,scale_factor=2):
        super(Downsample, self).__init__()

        self.downsample = nn.Sequential(
            Conv(in_channels, out_channels, scale_factor, scale_factor, 0)
        )

    def forward(self, x):
        x = self.downsample(x)

        return x

class ASFF_2(nn.Module):
    def __init__(self, inter_dim=512,level=0,channel=[64,128]):
        super(ASFF_2, self).__init__()

        self.inter_dim = inter_dim
        compress_c = 8

        self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = Conv(self.inter_dim, compress_c, 1, 1)

        self.weight_levels = nn.Conv2d(compress_c * 2, 2, kernel_size=1, stride=1, padding=0)

        self.conv = Conv(self.inter_dim, self.inter_dim, 3, 1)
        self.upsample = Upsample(channel[1],channel[0])
        self.downsample = Downsample(channel[0],channel[1])
        self.level = level


    def forward(self, x):
        input1, input2 = x
        if self.level == 0:
            input2 = self.upsample(input2)
        elif self.level == 1:
            input1 = self.downsample(input1)

        level_1_weight_v = self.weight_level_1(input1)
        level_2_weight_v = self.weight_level_2(input2)

        levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v), 1)
        levels_weight = self.weight_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
                            input2 * levels_weight[:, 1:2, :, :]

        out = self.conv(fused_out_reduced)

        return out


class ASFF_3(nn.Module):
    def __init__(self, inter_dim=512,level=0,channel=[64,128,256]):
        super(ASFF_3, self).__init__()

        self.inter_dim = inter_dim
        compress_c = 8

        self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = Conv(self.inter_dim, compress_c, 1, 1)
        self.weight_level_3 = Conv(self.inter_dim, compress_c, 1, 1)

        self.weight_levels = nn.Conv2d(compress_c * 3, 3, kernel_size=1, stride=1, padding=0)

        self.conv = Conv(self.inter_dim, self.inter_dim, 3, 1)


        self.level = level
        if self.level == 0:
            self.upsample4x = Upsample(channel[2],channel[0], scale_factor=4)
            self.upsample2x = Upsample(channel[1], channel[0], scale_factor=2)
        elif self.level == 1:
            self.upsample2x1 = Upsample(channel[2], channel[1], scale_factor=2)
            self.downsample2x1 = Downsample(channel[0],channel[1], scale_factor=2)
        elif self.level == 2:
            self.downsample2x = Downsample(channel[1], channel[2], scale_factor=2)
            self.downsample4x = Downsample(channel[0], channel[2], scale_factor=4)

    def forward(self, x):
        input1, input2, input3 = x
        if self.level == 0:
            input2 = self.upsample2x(input2)
            input3= self.upsample4x(input3)
        elif self.level == 1:
            input3 = self.upsample2x1(input3)
            input1 = self.downsample2x1(input1)
        elif self.level == 2:
            input1 = self.downsample4x(input1)
            input2 = self.downsample2x(input2)
        level_1_weight_v = self.weight_level_1(input1)
        level_2_weight_v = self.weight_level_2(input2)
        level_3_weight_v = self.weight_level_3(input3)

        levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v, level_3_weight_v), 1)
        levels_weight = self.weight_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
                            input2 * levels_weight[:, 1:2, :, :] + \
                            input3 * levels_weight[:, 2:, :, :]

        out = self.conv(fused_out_reduced)

        return out

 在models/yolo.py中修改:

def parse_model(d, ch):  # model_dict, input_channels(3)
    # Parse a YOLOv5 model.yaml dictionary
    LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")
    anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        LOGGER.info(f"{colorstr('activation:')} {act}")  # print
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)

    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
        m = eval(m) if isinstance(m, str) else m  # eval strings
        for j, a in enumerate(args):
            with contextlib.suppress(NameError):
                args[j] = eval(a) if isinstance(a, str) else a  # eval strings

        n = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gain
        if m in {
                Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
            c1, c2 = ch[f], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)

            args = [c1, c2, *args[1:]]
            if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x}:
                args.insert(2, n)  # number of repeats
                n = 1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        # TODO: channel, gw, gd
        elif m in {Detect, Segment}:
            args.append([ch[x] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
            if m is Segment:
                args[3] = make_divisible(args[3] * gw, 8)
        elif m is Contract:
            c2 = ch[f] * args[0] ** 2
        elif m is Expand:
            c2 = ch[f] // args[0] ** 2
        elif m in {ASFF_2, ASFF_3}:
            c2 = args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)
            args[0] = c2
            args.append([ch[x] for x in f])
        else:
            c2 = ch[f]
        a = [*args]
        m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace('__main__.', '')  # module type
        np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
        LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

 在yolo.py中配置--cfg为yolov5s-AFPN.yaml,点击运行,可见下图:

        论文中提到使用AFPN的效果要比PAN的好,暂时还没有验证,先肝代码,这是最初版,后续会优化。可以看最上面的图,参数确实是少了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/722698.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nginx安装与介绍

Nginx概述 Nginx (“engine x”) 是一个高性能的 HTTP 和反向代理服务器&#xff0c;特点是占有内存少&#xff0c;并发能力强&#xff0c;事实上 Nginx 的并发能力确实在同类型的网页服务器中表现较好&#xff0c;中国大陆使用 Nginx 网站用户有&#xff1a;百度、京东、新浪…

Windows11最新版官网制作系统盘

在百度里搜索"windows11官网下载"&#xff0c;然后选择微软官网的链接&#xff1a; 下載 Windows 11https://www.microsoft.com/zh-tw/software-download/windows11 然后就可以制作U盘windows11官网系统的安装U盘了。

进行APP广告变现之前,媒体需要关注哪些APP运营的信息指标

在进行广告变现之前&#xff0c;媒体商务或运营人员首先要知道自家 APP 的一些基本体量信息及基本用户使用情况信息。唯有充分而全面的掌握并罗列出这些基础 APP 运营指标&#xff0c;才能便于媒体通过自家真实流量规模、实力等来预估广告位价值&#xff0c;或更好的像广告需求…

设计模式--------创建型模式

创建型模式 用于描述“怎样创建对象”&#xff0c;它的主要特点是“将对象的创建与使用分离”。GoF&#xff08;四人组&#xff09;书中提供了单例、原型、工厂方法、抽象工厂、建造者等 5 种创建型模式。 1.单例设计模式 单例模式&#xff08;Singleton Pattern&#xff09…

如何截取视频中的一段视频?简单的截取方法分享

如果我们只需要处理视频中的某一部分&#xff0c;就可以将这一部分的内容截取下来&#xff0c;可以省去处理整个视频文件的时间和精力。此外&#xff0c;截取视频也可以让更加方便地分享和保存视频内容。此外&#xff0c;如果我们只需要分享视频中的一部分给他人观看&#xff0…

二、MongoDB 安装集

一、MongoDB—Docker mongoNoSQL Manager for MongoDB&#xff1a; L1、L2 1. 创建容器 docker search mongo docker pull mongodocker run -d --namemongo_1 -p 27017:27017 \-v /root/mongo/configdb:/data/configdb/ \-v /root/mongo/db/:/data/db/ \[镜像ID] --auth2. 登…

【软件测试】盘一盘工作中遇到的 Redis 异常测试

目录 前言&#xff1a; 一、更新 Key 异常 二、Key的删除和丢失 三、KEY 过期策略不当造成内存泄漏 四、查询Redis异常时处理 五、redis 穿透、击穿、雪崩 六、Redis死锁 七、Redis持久化 八、缓存与数据库双写时的数据一致性 前言&#xff1a; 在软件测试过程中&…

第八章、【Linux】文件与文件系统的压缩,打包与备份

8.1 压缩文件的用途与技术 8.2 Linux 系统常见的压缩指令 列几个常见的压缩文件扩展名&#xff1a; 8.2.1 gzip, zcat/zmore/zless/zgrep gzip 可以说是应用度最广的压缩指令了&#xff01;目前 gzip 可以解开 compress, zip 与 gzip 等软件所压缩的文件。 当你使用 gzip 进…

Error: 系统错误,错误码:80051,source size 2649KB exceed max limit 2MB [202306

小程序主包体积过大&#xff0c;预览到手机失败 把这个勾选一下&#xff0c;上限调整到4MB

Android性能优化问题方案的总结~

虽然总说“英雄不问出处”&#xff0c;但大厂卡学历是默认的“潜规则”。不过最近一个老弟&#xff0c;让我挺振奋的&#xff01;人家完全靠实力上岸。他就属于死磕型的&#xff0c;是我近2年见过的少有的Android性能优化高手。 要说他也挺聪明&#xff0c;贼会选领域。你出去随…

力扣 -- 740. 删除并获得点数

题目链接&#xff1a;740. 删除并获得点数 - 力扣&#xff08;LeetCode&#xff09; 下面是用动态规划的思想解决这道题的过程&#xff0c;相信各位小伙伴都能看懂并且掌握这道经典的动规题目滴。 参考代码&#xff1a; class Solution { public:int deleteAndEarn(vector<…

【MySQL】 MySQL安装

文章目录 1. MySQL安装配置内置环境关闭MySQL卸载MySQL确认环境是否干净 配置MySQL yum源yum源 的安装注意事项检测是否安装成功 MySQL的启动MySQL的登录登录方案一 获取临时密码登录方案二 免密码登录 MySQL的配置文件 1. MySQL安装 配置内置环境 输入 ps axj | grep mysql …

SpringBoot教学补充资料3-Maven安装

Maven下载地址&#xff1a;https://maven.apache.org/download.cgi 下载后进行解压&#xff0c;记住解压路径。 mvn -v

大文件下载

背景 google chrome下载大文件的时候&#xff0c;没有断点续传的功能&#xff0c;会因为网络不稳定多次下载失败。 google drive大文件 安装google drive客户端 点开别人的goole drive链接&#xff0c;保存在自己的文件夹下 本人的goole drive获取file_id # 链接格式 h…

mapbox icon-allow-overlap 和 icon-ignore-placement的区别

icon-allow-overlap 和 icon-ignore-placement的区别 官网解释&#xff1a;If true, the icon will be visible even if it collides with other previously drawn symbols. 翻译&#xff1a;如果该属性为true那么他会显示即便会冲突和在它之前已经添加的图层。 官网解释&am…

NSS [鹏城杯 2022]简单包含

NSS [鹏城杯 2022]简单包含 看代码觉得不就直接post flagdata://text/plain,<?php system(ls);?>行了。 但是事实上没有什么软用。 用了php://伪协议之后发现有WAF。 可以读源码 解码得到 <?php$path $_POST["flag"];if (strlen(file_get_contents(ph…

UE5 MetaHuman SDK插件的使用【一、编辑器创建音波与蓝图创建获取音波,语音与嘴唇口型的同步】

目录 打开插件 创建音频 编辑器这直接创建音频&#xff1a; 蓝图中创建和获取音频&#xff1a; 唇语&#xff1a; 声音与嘴唇同步&#xff1a; 方法一【效果不是很好】&#xff1a; 方法二【效果很好&#xff0c;但有一段时间延迟在处理】&#xff1a; 逻辑&#xff1…

Mysql数据库,Navicat上给表创建索引一直等待

问题背景&#xff1a; 对查询语句进行索引优化&#xff0c;针对以下表添加联合索引&#xff0c;语句如下&#xff1a; ALTER TABLE hzz_patrol_period_config add index IDX_PERIOD_CONFIG_YEAR_TYPE_VAL (EFFECT_YEAR,CHECK_CYCLE_TYPE,CHECK_CYCLE_VAL); Navicat上执行一直…

linux_driver_day02

作业1 题目&#xff1a; 映射 GPIOE、GPIOF、RCC 虚拟地址&#xff0c;实现控制三盏LED的驱动 效果&#xff1a; 代码&#xff1a; mycdev.h #ifndef _MYCDEV_H_ #define _MYCDEV_H_#include <linux/fs.h> #include <linux/init.h> #include <linux/io.h&g…

如何使用ChatGPT写好简历?如何使用ChatGPT优化简历?21个写简历的ChatGPT的Prompts!

你是一位求职者&#xff0c;即将要参加一场面试&#xff0c;你的工作经历是[2年国企会计经验]&#xff0c;教育背景是[国内211本科毕业&#xff0c;会计学专业]&#xff0c;请基于上述内容生成一份简历&#xff0c;要求加上自我评价。 根据这份工作描述写一份[TITLE]的简历。[…