目标检测tricks

news2025/2/26 15:54:18
A. Stochastic Weight Averaging (SWA)
1. 基本思想

SWA 的核心思想是通过对训练过程中不同时间点的模型参数进行加权平均,从而获得一个更好的模型。具体来说,SWA 在训练过程的后期阶段对多个不同的模型快照(snapshots)进行平均,而不是只使用最终的模型参数。

2. 为什么有效?

在深度学习训练中,尤其是在使用随机梯度下降(SGD)及其变体时,模型参数会在局部最优解附近波动。这些波动通常反映了损失函数的不同局部极小值或鞍点。通过平均这些波动中的参数,可以平滑这些波动,并找到一个更稳定的解决方案,从而提高模型的泛化能力。

SWA 的优点

提高泛化能力:通过平均多个模型的参数,可以减少单个模型可能存在的过拟合问题,从而提高模型的泛化能力。
增加稳定性:由于 SWA 平滑了参数的波动,使得最终模型更加稳定,减少了对初始条件和随机性的敏感性。
简单易用:相比于其他复杂的正则化方法(如 Dropout、DropConnect 等),SWA 实现起来非常简单,只需要在训练的后期阶段进行简单的参数平均即可。

SWA 的局限性

内存需求:如果需要存储多个模型快照,则可能会增加内存需求。
计算成本:在 SWA 阶段,虽然学习率较低,但仍然需要进行额外的前向和后向传播计算。
适用范围:SWA 主要适用于那些在训练后期参数波动较大的模型。对于一些已经非常稳定的模型,SWA 可能不会带来显著的改进。

SWA 的工作流程

1. 标准训练阶段
首先,模型按照标准的优化算法(如 SGD 或 Adam)进行训练。在这个阶段,模型参数会逐渐收敛到某个局部最优解。

for epoch in range(total_epochs):
    for batch_i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

2. SWA 启动阶段
当训练进入某一特定阶段(通常是训练的最后 25%),开始应用 SWA。此时,模型参数会被周期性地保存下来,并用于计算平均值。

swa_model = AveragedModel(model)
swa_start = int(0.75 * total_epochs)

for epoch in range(total_epochs):
    if epoch >= swa_start:
        swa_model.update_parameters(model)

3. SWA 调度器
为了更好地控制学习率,在 SWA 阶段通常使用一个固定的学习率调度器(如 SWALR)。这个调度器确保在 SWA 阶段学习率保持在一个较低且固定的值。

swa_scheduler = SWALR(optimizer, swa_lr=0.05)

4. BN 层更新
在 SWA 结束后,批归一化(Batch Normalization, BN)层的统计量需要更新。这是因为 BN 层依赖于训练数据的均值和方差统计量,而在 SWA 过程中这些统计量没有被更新。因此,需要通过重新遍历训练数据集来更新这些统计量。

torch.optim.swa_utils.update_bn(train_loader, swa_model, device='cuda')

运用在YoloV8中的代码

import torch
from torch.optim.swa_utils import AveragedModel, SWALR
from ultralytics import YOLO
from torch.utils.data import DataLoader
from torchvision import transforms
from yolov8_dataset import YOLOv8Dataset  # 假设你有一个自定义的数据集类

# 数据集和数据加载器
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = YOLOv8Dataset(root='path/to/dataset', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# 加载预训练的 YOLOv8 模型
model = YOLO('yolov8n.yaml')  # 根据需要选择合适的模型配置

# 初始化优化器和学习率调度器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# 初始化 SWA 模型和调度器
swa_model = AveragedModel(model.model)
swa_start = int(0.75 * 100)  # 在训练的最后 25% 开始 SWA
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

total_epochs = 100

for epoch in range(total_epochs):
    model.train()
    for batch_i, (imgs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        loss, outputs = model(imgs, targets)
        loss.backward()
        optimizer.step()

    if epoch >= swa_start:
        swa_model.update_parameters(model.model)
        swa_scheduler.step()
    else:
        lr_scheduler.step()

# 在训练结束后,应用 SWA 最终步骤
torch.optim.swa_utils.update_bn(train_loader, swa_model, device='cuda')

# 保存 SWA 模型
torch.save(swa_model.state_dict(), 'yolov8_swa.pth')

AveragedModel 类简化实现

import copy

class AveragedModel:
    def __init__(self, model):
        self.n_averaged = 0  # 记录已经累加的模型快照数量
        self.module = copy.deepcopy(model)  # 复制原始模型结构和参数

    def update_parameters(self, model):
        self.n_averaged += 1
        for p_swa, p_model in zip(self.module.parameters(), model.parameters()):
            device = p_swa.device
            p_model_ = p_model.detach().to(device)
            if self.n_averaged == 1:
                p_swa.detach().copy_(p_model_)
            else:
                p_swa.detach().mul_(1.0 - 1.0 / self.n_averaged).add_(p_model_, alpha=1.0 / self.n_averaged)
Stochastic Weight Averaging (SWA) 参数更新公式

在 Stochastic Weight Averaging (SWA) 中,参数更新的过程涉及到将当前模型的参数逐步累加到一个平均模型中。以下是具体的公式和解释。

第一次调用 update_parameters

当第一次调用 update_parameters 方法时,直接将 p_model 赋值给 p_swa

p swa = p model p_{\text{swa}} = p_{\text{model}} pswa=pmodel

这意味着 p_swa 直接被设置为当前模型的参数。

后续调用 update_parameters

从第二次调用开始,使用以下公式更新 p_swa

p swa = ( 1 − 1 n ) ⋅ p swa + 1 n ⋅ p model p_{\text{swa}} = \left(1 - \frac{1}{n}\right) \cdot p_{\text{swa}} + \frac{1}{n} \cdot p_{\text{model}} pswa=(1n1)pswa+n1pmodel

其中 $ n $ 是已经累加的模型快照数量。

第二次调用 update_parameters 的例子

n = 2 n = 2 n=2 时,公式变为:

p swa = 1 2 ⋅ p swa + 1 2 ⋅ p model p_{\text{swa}} = \frac{1}{2} \cdot p_{\text{swa}} + \frac{1}{2} \cdot p_{\text{model}} pswa=21pswa+21pmodel

这表示 p_swa 被更新为其当前值的一半加上 p_model 的一半。

更多调用的例子

随着训练的继续,每次调用 update_parameters 都会更新 p_swa,逐渐平滑模型参数的的波动。例如:

  • n = 5 n = 5 n=5 时:
    p swa = 4 5 ⋅ p swa + 1 5 ⋅ p model p_{\text{swa}} = \frac{4}{5} \cdot p_{\text{swa}} + \frac{1}{5} \cdot p_{\text{model}} pswa=54pswa+51pmodel

  • n = 6 n = 6 n=6 时:
    p swa = 5 6 ⋅ p swa + 1 6 ⋅ p model p_{\text{swa}} = \frac{5}{6} \cdot p_{\text{swa}} + \frac{1}{6} \cdot p_{\text{model}} pswa=65pswa+61pmodel

B. SAHI (Slice and Hyper Inference)
SAHI 的核心思想

图像切片:将大图像切分成多个小块(slices),以便更好地捕捉小目标。
超推理(Hyper Inference):对每个切片进行独立的推理,并合并结果以获得最终的检测结果。
自动切片选择:根据目标的大小和分布,自适应地选择合适的切片策略。

工作流程

以下是 SAHI 的典型工作流程:

加载图像:读取待检测的大图像。
图像切片:将图像切分成多个小块(slices),每个切片的大小和重叠度可以根据需求调整。
模型推理:对每个切片分别进行目标检测模型的推理。
结果合并:将所有切片的检测结果合并成一个完整的检测结果,并去除重复检测框。
后处理:对合并后的检测结果进行进一步处理,如非极大值抑制(NMS)等。

以Yolov5为例的插入sahi代码

from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.yolov5 import Yolov5TestConstants

# 加载预训练的 YOLOv5 模型
detection_model = AutoDetectionModel.from_pretrained(
    model_type='yolov5',
    model_path=Yolov5TestConstants.YOLOV5N_MODEL_LOCAL_PATH,  # 替换为你的模型路径
    confidence_threshold=0.3,
    device="cuda"  # 或 "cpu"
)

# 图像路径
image_path = "path/to/your/image.jpg"

# 使用 SAHI 进行小目标检测
result = get_sliced_prediction(
    image=image_path,
    detection_model=detection_model,
    # slice_height 和 slice_width:每个切片的高度和宽度。可以根据图像大小和目标的尺度进行调整
    slice_height=512,
    slice_width=512, 
    # 切片之间的重叠比例。增加重叠可以减少目标被切分的风险,但会增加计算量。
    overlap_height_ratio=0.2,
    overlap_width_ratio=0.2
)

# 打印检测结果
print(result)
C. SeNet (Squeeze-and-Excitation Networks) 通道注意力
SeNet 工作流程图解:

输入特征图

输入是一个三维张量,包含高度 H H H、宽度 W W W 和通道数 C C C

Input:  U ∈ R H × W × C \text{Input: } U \in \mathbb{R}^{H \times W \times C} Input: URH×W×C

Squeeze(挤压):全局平均池化 (Global Average Pooling, GAP)

对每个通道进行全局平均池化操作,将空间维度压缩为一个标量。输出是一个向量 z ∈ R C z \in \mathbb{R}^C zRC,表示每个通道的全局信息。

z c = 1 H × W ∑ i = 1 H ∑ j = 1 W U ( i , j , c ) z_c = \frac{1}{H \times W} \sum_{i=1}^{H} \sum_{j=1}^{W} U(i,j,c) zc=H×W1i=1Hj=1WU(i,j,c)

其中, z c z_c zc 是第 c c c 个通道的全局平均值。

Excitation(激励):两个全连接层

使用两个全连接层来生成每个通道的权重。第一个全连接层将特征维度从 C C C 降到 C r \frac{C}{r} rC,其中 r r r 是降维比率。第二个全连接层将特征维度从 C r \frac{C}{r} rC 升回到 C C C。中间使用 ReLU 激活函数和 Sigmoid 激活函数分别处理两个全连接层的输出。

z ^ = ReLU ( W 1 ⋅ z ) s = σ ( W 2 ⋅ z ^ ) \begin{aligned} \hat{z} &= \text{ReLU}(W_1 \cdot z) \\ s &= \sigma(W_2 \cdot \hat{z}) \end{aligned} z^s=ReLU(W1z)=σ(W2z^)

其中, W 1 ∈ R C r × C W_1 \in \mathbb{R}^{\frac{C}{r} \times C} W1RrC×C W 2 ∈ R C × C r W_2 \in \mathbb{R}^{C \times \frac{C}{r}} W2RC×rC 分别是第一和第二全连接层的权重矩阵, σ \sigma σ 表示Sigmoid激活函数。

Reweight(重加权):逐元素相乘

将生成的通道权重 s ∈ R C s \in \mathbb{R}^C sRC 应用到原始特征图 U U U 上,通过逐元素相乘的方式重新加权特征图。输出是一个与输入特征图形状相同的张量 U ′ ∈ R H × W × C U' \in \mathbb{R}^{H \times W \times C} URH×W×C

U ′ ( i , j , c ) = s c × U ( i , j , c ) U'(i,j,c) = s_c \times U(i,j,c) U(i,j,c)=sc×U(i,j,c)

以上就是SENet的工作流程描述,包括了公式和步骤。

Spatial Attention Mechanism 空间注意力机制

SAM工作流程

1. 输入特征图

输入是一个三维张量,包含高度 H H H、宽度 W W W 和通道数 C C C

Input:  F ∈ R H × W × C \text{Input: } F \in \mathbb{R}^{H \times W \times C} Input: FRH×W×C

2. 生成注意力图

通过卷积层和其他操作生成注意力图 A A A

(1). 第一卷积层
Z 1 = C o n v 1 × 1 ( F ) Z_1 = Conv_{1 \times 1}(F) Z1=Conv1×1(F)

(2). 第二卷积层及激活函数
A = σ ( C o n v 1 × 1 ( Z 1 ) ) A = \sigma(Conv_{1 \times 1}(Z_1)) A=σ(Conv1×1(Z1))
其中 σ \sigma σ 是 Sigmoid 函数:

σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1

3. 归一化处理

对生成的注意力图进行归一化处理,使其值在合理范围内。

  • Sigmoid 归一化

    A = σ ( Z 1 ) = 1 1 + e − Z 1 A = \sigma(Z_1) = \frac{1}{1 + e^{-Z_1}} A=σ(Z1)=1+eZ11

  • Softmax 归一化(可选):

    A = Softmax ( Z 1 ) = e Z 1 ∑ e Z 1 A = \text{Softmax}(Z_1) = \frac{e^{Z_1}}{\sum e^{Z_1}} A=Softmax(Z1)=eZ1eZ1

4. 加权融合

将生成的注意力图与原始输入特征图逐元素相乘,从而实现对重要区域的强调和不重要区域的抑制。

F ′ = A ⊙ F F' = A \odot F F=AF

其中 ⊙ \odot 表示逐元素相乘操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2306445.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JNA基础使用,调用C++返回结构体

C端 test.h文件 #pragma oncestruct RespInfo {char* path;char* content;int statusCode; };extern "C" { DLL_EXPORT void readInfo(char* path, RespInfo* respInfo); }test.cpp文件 #include "test.h"void readInfo(char* path, RespInfo* respInfo…

解锁养生密码,拥抱健康生活

在快节奏的现代生活中,养生不再是一种选择,而是我们保持活力、提升生活质量的关键。它不是什么高深莫测的学问,而是一系列融入日常的简单习惯,每一个习惯都在为我们的健康加分。 早晨,当第一缕阳光洒进窗户&#xff0c…

OpenCV(6):图像边缘检测

图像边缘检测是计算机视觉和图像处理中的一项基本任务,它用于识别图像中亮度变化明显的区域,这些区域通常对应于物体的边界。是 OpenCV 中常用的边缘检测函数及其说明: 函数算法说明适用场景cv2.Canny()Canny 边缘检测多阶段算法,检测效果较…

spark的一些指令

一,复制和移动 1、复制文件 格式:cp 源文件 目标文件 示例:把file1.txt 复制一份得到file2.txt 。那么对应的命令就是:cp file1.txt file2.txt 2、复制目录 格式:cp -r 源文件 目标文件夹 示例:把目…

OpenHarmony全球化子系统

OpenHarmony全球化子系统 简介系统架构目录相关仓 简介 当OpenHarmony系统/应用在全球不同区域使用时,系统/应用需要满足不同市场用户关于语言、文化习俗的需求。全球化子系统提供支持多语言、多文化的能力,包括: 资源管理能力 根据设备类…

创建私人阿里云docker镜像仓库

步骤1、登录阿里云 阿里云创建私人镜像仓库地址:容器镜像服务 步骤2、创建个人实例 步骤:【实例列表】 》【创建个人实例】 》【设置Registry登录密码】 步骤3、创建命名空间 步骤:【个人实例】》【命名空间】》【创建命名空间】 注意&am…

【LLM】本地部署LLM大语言模型+可视化交互聊天,附常见本地部署硬件要求(以Ollama+OpenWebUI部署DeepSeekR1为例)

【LLM】本地部署LLM大语言模型可视化交互聊天,附常见本地部署硬件要求(以OllamaOpenWebUI部署DeepSeekR1为例) 文章目录 1、本地部署LLM(以Ollama为例)2、本地LLM交互界面(以OpenWebUI为例)3、本…

LLM之论文阅读——Context Size对RAG的影响

前言 RAG 系统已经在多个行业中得到广泛应用,尤其是在企业内部文档查询等场景中。尽管 RAG 系统的应用日益广泛,关于其最佳配置的研究却相对缺乏,特别是在上下文大小、基础 LLM 选择以及检索方法等方面。 论文原文: On the Influence of Co…

2025-02-25 学习记录--C/C++-用C语言实现删除字符串中的子串

用C语言实现删除字符串中的子串 在C语言中&#xff0c;你可以使用strstr函数来查找子串&#xff0c;然后用memmove或strcpy来覆盖或删除找到的子串。 一、举例 &#x1f430; #include <stdio.h> // 包含标准输入输出库&#xff0c;用于使用 printf 函数 #include <s…

【Linux】Ubuntu服务器的安装和配置管理

ℹ️大家好&#xff0c;我是练小杰&#xff0c;今天周二了&#xff0c;哪吒的票房已经到了138亿了&#xff0c;饺子导演好样的&#xff01;&#xff01;每个人的成功都不是必然的&#xff0c;坚信自己现在做的事是可以的&#xff01;&#xff01;&#x1f606; 本文是有关Ubunt…

2.3做logstash实验

收集apache日志输出到es 在真实服务器安装logstash&#xff0c;httpd systemctl start httpd echo 666 > /var/www/html/index.html cat /usr/local/logstash/vendor/bundle/jruby/2.3.0/gems/logstash-patterns-core-4.1.2/patterns/httpd #系统内置变量 cd /usr/local/…

pandas读取数据

pandas读取数据 导入需要的包 import pandas as pd import numpy as np import warnings import oswarnings.filterwarnings(ignore)读取纯文本文件 pd.read_csv 使用默认的标题行、逗号分隔符 import pandas as pd fpath "./datas/ml-latest-small/ratings.csv" 使…

ReentrantLock 用法与源码剖析笔记

&#x1f4d2; ReentrantLock 用法与源码剖析笔记 &#x1f680; 一、ReentrantLock 核心特性 &#x1f504; 可重入性&#xff1a;同一线程可重复获取锁&#xff08;最大递归次数为 Integer.MAX_VALUE&#xff09;&#x1f527; 公平性&#xff1a;支持公平锁&#xff08;按等…

java进阶专栏的学习指南

学习指南 java类和对象java内部类和常用类javaIO流 java类和对象 类和对象 java内部类和常用类 java内部类精讲Object类包装类的认识String类、BigDecimal类初探Date类、Calendar类、SimpleDateFormat类的认识java Random类、File类、System类初识 javaIO流 java IO流【…

架构思维:架构的演进之路

文章目录 引言为什么架构思维如此重要架构师的特点软件架构的知识体系如何提升架构思维大型互联网系统架构的演进之路一、大型互联网系统的特点二、系统处理能力提升的两种途径三、大型互联网系统架构演化过程四、总结 引言 在软件开发行业中&#xff0c;有很多技术人可能会问…

vue3:vue3项目安装并引入Element-plus

一、安装Element-plus 1、安装语句位置 安装 | Element Plushttps://element-plus.org/zh-CN/guide/installation.html根据所需进行安装&#xff0c;这里使用npm包 2、找到项目位置 找到项目位置&#xff0c;在路径上输入cmd回车打开“运行”窗口 输入安装语句回车完成安装 …

java.2.25

1. 注释 ​ 注释是对代码的解释和说明文字。 Java中的注释分为三种&#xff1a; 单行注释&#xff1a; // 这是单行注释文字多行注释&#xff1a; /* 这是多行注释文字 这是多行注释文字 这是多行注释文字 */ 注意&#xff1a;多行注释不能嵌套使用。文档注释&#xff1a;…

VScode 开发

目录 安装 VS Code 创建一个 Python 代码文件 安装 VS Code VSCode&#xff08;全称&#xff1a;Visual Studio Code&#xff09;是一款由微软开发且跨平台的免费源代码编辑器&#xff0c;VSCode 开发环境非常简单易用。 VSCode 安装也很简单&#xff0c;打开官网 Visual S…

A Large Recurrent Action Model: xLSTM Enables Fast Inference for Robotics Tasks

奥地利林茨约翰开普勒大学机器学习研究所 ELLIS 小组&#xff0c;LIT 人工智能实验室奥地利林茨 NXAI 有限公司谷歌 DeepMind米拉 - 魁北克人工智能研究所 摘要 近年来&#xff0c;强化学习&#xff08;Reinforcement Learning, RL&#xff09;领域出现了一种趋势&#xff0c;…