VMamba模型

news2025/4/17 2:43:14

VMamba模型

  • 摘要
  • Abstract
  • 1. VMamba模型
    • 1.1 文献摘要
    • 1.2 研究背景
    • 1.3 状态空间模型(SSM)
    • 1.4 VMamba架构
    • 1.5 实验
      • 1.5.1 ImageNet-1K 上的图像分类
      • 1.5.2 COCO 上的物体检测
  • 总结
  • 2. pytorch练习

摘要

本周阅读了 VMamba: Visual State Space ModelVMamba 这篇文献,VMamba是一种通用视觉主干,具有基于 SSM 的块,用于高效的视觉表示学习。 VMamba 在降低注意力计算复杂性方面的有效性很大程度上归功于 S6 模型中存在的选择性扫描机制,也称为选择性 SSM。与允许在上下文中进行密集信息路由的传统注意力计算方法不同,S6 要求一维数组(例如文本序列)中的每个元素仅通过压缩隐藏状态来获取上下文知识,从而将二次复杂度降低为线性复杂度,同时实验结果证明了 VMamba 在各种视觉感知任务中的良好性能,凸显了与现有基准模型相比,其在输入缩放效率方面的显着优势。本文将详细介绍 VMamba

Abstract

This week read the paper VMamba: Visual State Space ModelVMamba, a generalized visual backbone with SSM-based blocks for efficient visual representation learning. VMamba’s effectiveness in reducing the complexity of attentional computation is largely attributed to the selective scanning mechanism, also known as selective SSM, present in the S6 model.Unlike traditional attentional computation methods that allow for dense information routing in context, S6 requires that each element in a one-dimensional array (e.g., a sequence of text) acquires contextual knowledge by compressing the hidden state only, thereby reducing the quadratic complexity to linear complexity, while experimental results demonstrate the good performance of VMamba in a variety of visual perception tasks, highlighting its significant advantage in input scaling efficiency over existing benchmark models. In this paper, we present a detailed description of VMamba

1. VMamba模型

文献出处:VMamba: Visual State Space Model

1.1 文献摘要

CNN和VIT一直以来都是视觉领域的骨干网络,虽然 ViT 最近因其卓越的拟合能力而比 CNN 获得了突出地位,但其可扩展性在很大程度上受到注意力计算的二次复杂度的限制。

作者在本文提出了 VMamba,目的是为了将计算复杂度降低到线性,同时保留 ViT 的优势特征,同时也引入了交叉扫描模块(CSM),以实现具有全局感受野的 2D 图像空间中的 1D 选择性扫描。

实验结果证明了 VMamba 在各种视觉感知任务中的良好性能,凸显了与现有基准模型相比,其在输入缩放效率方面的显着优势。

1.2 研究背景

最近,状态空间模型(SSM)在自然语言处理(NLP)任务中展示了具有线性复杂性的长序列建模的巨大潜力。

作者提出了 VMamba,这是一种通用视觉主干,具有基于 SSM 的块,用于高效的视觉表示学习。 VMamba 在降低注意力计算复杂性方面的有效性很大程度上归功于 S6 模型中存在的选择性扫描机制,也称为选择性 SSM。与允许在上下文中进行密集信息路由的传统注意力计算方法不同,S6 要求一维数组(例如文本序列)中的每个元素仅通过压缩隐藏状态来获取上下文知识,从而将二次复杂度降低为线性复杂度。

然而,由于视觉数据的二维性质,单个扫描过程很难同时捕获不同方向上的依赖性信息,从而导致感受野受到限制。 我们将此问题称为“方向敏感”问题,并建议通过新引入的交叉扫描模块(CSM)来解决它。 CSM 不是以单向模式(列向或行向)遍历图像特征图的空间域,而是采用四向扫描策略,即从左上角和右下角开始遍历整个特征 映射到相反的位置(如下图)。 该策略确保特征图中的每个元素集成来自不同方向的所有其他位置的信息,从而在不增加计算复杂度的情况下实现全局感受野。
在这里插入图片描述

1.3 状态空间模型(SSM)

SSM 可以被视为线性时不变 (LTI) 系统,它通过隐藏状态 h(t) ε CN 将输入刺激 u(t) ε RL 映射到输出响应 y(t) ε RL。 它们通常被表述为线性常微分方程 (ODE)
在这里插入图片描述
离散化 状态空间模型(SSM)作为连续时间模型,在集成到深度学习算法中时面临着巨大的挑战。为了克服这个障碍,离散化过程势在必行。
在这里插入图片描述
作者首先使用 CSM(扫描扩展)扫描图像。然后通过 S6 块单独处理四个结果特征,并将四个输出特征合并(扫描合并)以构建最终的 2D 特征图。

通过 SS2D 模块传递数据涉及三个步骤:交叉扫描、使用 S6 块进行选择性扫描以及交叉合并。 给定输入数据,SS2D 首先沿着四个不同的遍历路径(即交叉扫描)将图像块展开为序列,使用单独的 S6 块并行处理每个块序列,然后重塑并合并结果序列以形成输出图 (即交叉合并)。 通过采用互补的遍历路径,SS2D使图像中的每个像素能够有效地整合来自不同方向的所有其他像素的信息,从而促进全局感受野的建立。

1.4 VMamba架构

VMamba-Tiny 的架构概述如下图所示。 VMamba 首先使用 Stem 模块将输入图像划分为图块,从而生成空间维度为 H 4 × W 4 \frac{H}{4} \times \frac{W}{4} 4H×4W 的 2D 特征图。
在这里插入图片描述
随后,多个网络阶段,每个阶段由 VSS 块组成,前面是下采样层(第一阶段除外),用于创建分辨率为 H 8 × W 8 \frac{H}{8} \times \frac{W}{8} 8H×8W H 16 × W 16 \frac{H}{16} \times \frac{W}{16} 16H×16W H 32 × W 32 \frac{H}{32} \times \frac{W}{32} 32H×32W。 下采样操作是通过补丁合并进行的,VSS块的详细结构如下图所示:
在这里插入图片描述
普通 VSS 块的结构如下图所示,这两个块都可以看作具有跳跃连接的残差网络。 残差网络包含两个分支:一个用于使用 3 × 3 深度卷积层进行特征提取,另一个由线性映射和激活层组成,激活层计算乘性门控信号。 Mamba 和普通 VSS 模块之间的主要区别在于用 SS2D 模块替换了 S6 模块,这使得选择性扫描能够适应 2D 视觉数据。
在这里插入图片描述
尽管在长序列建模方面效率很高,但基于 SSM 的架构 [14] 在处理较小规模的输入时经常会遇到计算速度降低的情况,这可能会限制 VMamba 的实际用途。

如下图所示,普通 VMamba-Tiny 模型实现了 426 个图像/秒的吞吐量,包含 22.9M 个参数和 5.6G FLOP(如果选择性扫描操作可以实现,FLOP 将降至 4.5G) 由单个 for 循环实现)。 低吞吐量和高内存开销给VMamba的实际部署带来了挑战。 因此,为了提高其推理速度,人们付出了巨大的努力,主要集中在实现细节和架构设计方面的进步。
在这里插入图片描述
从VMamba V0到V2,我们先后在torch.autograd.Function中实现了CSM,然后在Triton中重新实现了它。 这些修改有助于将吞吐量从 426 增加到 467。然后,在 V3 中,我们调整了与选择性扫描操作相关的 CUDA 实现,以适应 float16 输入张量并生成具有 float32 数据类型的输出张量。 与处理 float32 数据类型张量的实现相比,此调整提高了性能,特别是在训练期间,同时与对输入和输出张量使用 float16 相比,还实现了更高的数值稳定性。 此外,在 V4 和 V5 中,我们用线性变换(即 torch.nn.function.linear)替代了选择性扫描中相对较慢的 einsum 操作。 我们还采用了(B,C,H,W)的张量布局来消除不必要的数据排列。 这些变化导致吞吐量增加了 49.5%(从 426 增加到 637),并且不影响其他指标,例如参数数量、FLOP 和 ImageNet-1K 上的分类性能。

1.5 实验

1.5.1 ImageNet-1K 上的图像分类

我们使用 ImageNet-1K 数据集评估 VMamba 在图像分类方面的性能。 遵循[31]中概述的评估协议,VMamba-T/S/B模型从头开始训练300个epoch,前20个epoch专门用于预热,批量大小为1024。训练过程使用AdamW 优化器[34],贝塔设置为(0.9,0.999),动量为0.9,余弦衰减学习率调度器,初始学习率为1×10−3,权重衰减为0.05。 还应用了标签平滑 (0.1) 和指数移动平均 (EMA) 等其他技术。 除此之外,没有采用进一步的培训技术。

下表总结了 VMamba 与 ImageNet-1K 上基准骨干模型的比较结果。很明显,在相似的 FLOP 下,VMamba-T 的性能达到 82.5%,超过 RegNetY-4G 2.5%,超过 DeiT-S 2.5%。 2.7%,Swin-T 1.2%。 值得注意的是,VMamba 的这些性能优势在小型和基本规模模型中始终存在。 具体来说,VMamba-S 的 top-1 准确率达到 83.6%,比 RegNetY-8G 提高 1.9%,比 Swin-S 提高 0.6%。 同时,VMamba-B 的 top-1 准确率达到 83.9%,超过 RegNetY-16G 1.0%,超过 DeiT-B 0.6%。 在计算效率方面,虽然现有的基于 SSM 的视觉模型通常仅在大规模输入 [68](例如 1024 × 1024)下才表现出明显更好的吞吐量,但 VMamba-T 即使在输入分辨率为 224 × 224。这种性能更好,或者至少与最先进的方法相当,并且这种优势在 VMamba-S 和 VMamba-B 中仍然存在。 值得注意的是,随着输入大小从 224 × 224 扩展到 1024 × 1024,VMamba 相对于现有方法的优势变得更加明显,如表 4 所示。后续章节将对此主题进行进一步讨论。
在这里插入图片描述

1.5.2 COCO 上的物体检测

我们使用 MSCOCO 2017 数据集评估 VMamba 在对象检测方面的性能。 我们的训练框架是使用 MMDetection 库构建的,并且我们遵循 Swin中使用的超参数和 Mask-RCNN 检测器。 具体来说,我们采用 AdamW 优化器并对 12 和 36 epoch 的预训练分类模型(在 ImageNet-1K 上)进行微调。 VMamba-T/S/B 的丢弃路径率分别设置为 0.2%/0.3%/0.5%。 学习率初始化为 1×10−4,并在第 9 和 11 epoch 减少 10×。 我们实现了批量大小为 16 的多尺度训练和随机翻转,这与目标检测评估的既定实践一致。

VMamba 在 COCO 上的框/掩模平均精度 (AP) 方面保持优势,无论采用何种训练计划(12 或 36 epoch)。 具体来说,通过 12 epoch 的微调计划,VMamba-T/S/B 模型实现了 47.4%/48.7%/49.2% 的目标检测 mAP,超过了 Swin-T/S/B 4.7%/3.9%/2.3 % mAP 和 ConvNeXt-T/S/B 分别提高 3.2%/3.3%/2.2% mAP。 在相同配置下,VMambaT/S/B 的实例分割 mIoU 为 42.7%/43.7%/43.9%,比 Swin-T/S/B 高出 3.4%/2.8%/1.6% mIoU,而 ConvNeXt-T/S/ B 分别为 2.6%/1.9%/1.3% mIoU。 此外,VMamba 在多尺度训练的 36 epoch 微调方案下仍然具有优势,如表 2 所示。与 Swin [32]、ConvNeXt [33]、PVTv2 [55] 和 ViT 等同行相比 [12](使用适配器),VMamba-T/S 表现出卓越的性能,在对象检测上分别实现了 48.9%/49.9% mAP,在实例分割上分别实现了 43.7%/44.2% mIoU。 这些结果强调了 VMamba 在具有密集预测的下游任务中实现有希望的性能的潜力。
在这里插入图片描述

总结

本文介绍了 VMamba,这是一种多功能主干网络,专为使用状态空间模型 (SSM) 进行高效视觉表示学习而设计。 VMamba 的主要目标是将选择性 SSM 的优点(包括全局感受野、输入相关的加权参数和线性计算复杂性)融入视觉数据处理中。 具体来说,我们提出交叉扫描模块(CSM)来弥合一维选择性扫描和二维视觉数据之间的差距,并通过数学推导和定性可视化说明其与注意力机制的关系及其在实现全局感受野方面的有效性 。 此外,我们通过改进技术实现和架构设计,显着提高了 VMamba 的推理速度。 VMamba 系列(包括 VMamba-T/S/B 模型)的有效性已通过大量实验和消融研究得到证明,超越了流行的 CNN 和视觉 Transformer 的性能。 此外,VMamba 随着输入分辨率的提高而表现出卓越的可扩展性,在保持线性计算复杂性的同时表现出最小的性能下降。

下周我将具体通过pytorch实现这个网络架构,加油~

2. pytorch练习

数据集处理

import os
from shutil import copy, rmtree
import random


def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)


def main():
    # 保证随机可复现
    random.seed(0)

    # 将数据集中10%的数据划分到验证集中
    split_rate = 0.1

    # 指向你解压后的flower_photos文件夹
    cwd = os.getcwd()
    data_root = os.path.join(cwd, "CUB_200_2011")
    origin_CUB_path = os.path.join(data_root, "images")
    assert os.path.exists(origin_CUB_path), "path '{}' does not exist.".format(origin_CUB_path)

    CUB_class = [cla for cla in os.listdir(origin_CUB_path)
                    if os.path.isdir(os.path.join(origin_CUB_path, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in CUB_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in CUB_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in CUB_class:
        cla_path = os.path.join(origin_CUB_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")


if __name__ == '__main__':
    main()

参数设置

import argparse

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def get_args():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('data', metavar='DIR', nargs='?', default='imagenet',
                        help='path to dataset (default: imagenet)')
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
                            help='models architecture: default: resnet18)') # arch是需要加载的预训练模型名
    parser.add_argument("--optimizer", default="SGD", type=str, help='["SGD", "Adam", "AdamW"]')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--epochs', default=120, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=16, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')
    # optimizer
    parser.add_argument('--lr', '--learning-rate', default=0.005, type=float,
                        metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')

    # center loss
    parser.add_argument('--parts', default=32, type=int,
                        metavar='N', help='number of parts (default: 32)')
    parser.add_argument('--alpha', default=0.95, type=float,
                        metavar='N', help='weight for BAP loss')

    # scheduler
    parser.add_argument('--decay-step', default=20, type=int, metavar='N',
                        help='learning rate decay step')
    parser.add_argument('--gamma', default=0.5, type=float, metavar='M',
                        help='gamma')
    parser.add_argument('-p', '--print-freq', default=10, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate models on validation set')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained models')
    # parser.add_argument('--world-size', default=-1, type=int,
    #                     help='number of nodes for distributed training')
    # parser.add_argument('--rank', default=-1, type=int,
    #                     help='node rank for distributed training')
    # parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
    #                     help='url used to set up distributed training')
    # parser.add_argument('--dist-backend', default='nccl', type=str,
    #                     help='distributed backend')
    parser.add_argument('--seed', default=1, type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--gpu', default=1, type=int,
                        help='GPU id to use.')
    # parser.add_argument('--multiprocessing-distributed', action='store_true',
    #                     help='Use multi-processing distributed training to launch '
    #                          'N processes per node, which has N GPUs. This is the '
    #                          'fastest way to use PyTorch for either single node or '
    #                          'multi node data parallel training')
    # parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

    # training
    parser.add_argument('--dataset', type=str, default='CUB',choices=['CUB','Cars','Aircraft'],
                        help='dataset for FGVC')
    parser.add_argument('--name', type=str, default='test_case')
    parser.add_argument('--lr_step', type=int, default=30)  # lr_step
    parser.add_argument('--resize-size', type=int, default=512, help='validation resize size')
    parser.add_argument('--crop-size', type=int, default=448, help='validation crop size')
    parser.add_argument('--VAL-CROP', type=str2bool, nargs='?', const=True, default=True,
                        help='Evaluation method'
                             'If True, Evaluate on 256x256 resized and center cropped 224x224 map'
                             'If False, Evaluate on directly 224x224 resized map')
    # CAM
    parser.add_argument('--cam-thr', type=float, default=0.2, help='cam threshold value(default=0.15)')

    # Random Erasing
    parser.add_argument('--p', default=0.5, type=float, help='Random Erasing probability')
    parser.add_argument('--sh', default=0.4, type=float, help='max erasing area')
    parser.add_argument('--r1', default=0.3, type=float, help='aspect of erasing area')


    args = parser.parse_args()
    return args


Res2Net模型


import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
import torch.nn.functional as F
__all__ = ['Res2Net', 'res2net50']


model_urls = {
    'res2net50_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_4s-06e79181.pth',
    'res2net50_48w_2s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_48w_2s-afed724a.pth',
    'res2net50_14w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_14w_8s-6527dddc.pth',
    'res2net50_26w_6s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_6s-19041792.pth',
    'res2net50_26w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_8s-2c7c9f12.pth',
    'res2net101_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_26w_4s-02a759a1.pth',
}


class Bottle2neck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'):
        """ 构造函数
        参数:
            inplanes: 输入通道维度
            planes: 输出通道维度
            stride: 卷积步长。替代池化层。
            downsample: 当stride = 1时为None
            baseWidth: conv3x3的基本宽度
            scale: 尺度数量。
            type: 'normal': 正常设置。 'stage': 新阶段的第一个块。
        """
        super(Bottle2neck, self).__init__()

        # 计算卷积核的宽度
        width = int(math.floor(planes * (baseWidth / 64.0)))
        # 第一个1x1卷积层
        self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width * scale)

        # 计算重复次数
        if scale == 1:
            self.nums = 1
        else:
            self.nums = scale - 1

        # 如果是新阶段的第一个块,则使用平均池化层进行下采样
        if stype == 'stage':
            self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)

        # 定义重复的卷积层和BN层
        convs = []
        bns = []
        for i in range(self.nums):
            convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False))
            bns.append(nn.BatchNorm2d(width))
        # 创建了两个 nn.ModuleList 对象 self.convs 和 self.bns,用于存储多个卷积层和批量归一化层。
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)

        # 最后一个1x1卷积层
        self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        # 激活函数
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stype = stype
        self.scale = scale
        self.width = width

    def forward(self, x):
        residual = x

        # 第一个1x1卷积层的计算
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        # 将输出按照宽度进行分割
        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
            # 如果是第一个块或者是新阶段的第一个块,则直接取分割后的部分
            if i == 0 or self.stype == 'stage':
                sp = spx[i]
            else:
                # 否则,累加之前的部分
                sp = sp + spx[i]
            # 对部分进行卷积、BN和ReLU操作
            sp = self.convs[i](sp)
            sp = self.relu(self.bns[i](sp))
            if i == 0:
                out = sp
            else:
                # 将处理后的部分拼接起来
                out = torch.cat((out, sp), 1)
        # 如果尺度不为1且为正常设置,将最后一个部分拼接到一起
        if self.scale != 1 and self.stype == 'normal':
            out = torch.cat((out, spx[self.nums]), 1)
        # 如果尺度不为1且为新阶段的第一个块,则对最后一个部分进行平均池化并拼接
        elif self.scale != 1 and self.stype == 'stage':
            out = torch.cat((out, self.pool(spx[self.nums])), 1)

        # 最后一个1x1卷积层的计算
        out = self.conv3(out)
        out = self.bn3(out)

        # 如果存在下采样,则对输入进行下采样
        if self.downsample is not None:
            residual = self.downsample(x)

        # 残差连接并进行ReLU激活
        out += residual
        out = self.relu(out)

        return out


class Res2Net(nn.Module):

    def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000):
        # 初始化Res2Net模型
        self.inplanes = 64  # 设置输入通道数为64
        self.baseWidth = baseWidth
        self.scale = scale
        super(Res2Net, self).__init__()  # 调用父类的构造函数

        # 定义网络的第一层:7x7的卷积层,输入通道数为3,输出通道数为64,步长为2,填充为3
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Batch Normalization层,对每个channel的数据进行标准化
        self.bn1 = nn.BatchNorm2d(64)
        # 激活函数ReLU
        self.relu = nn.ReLU(inplace=True)
        # 最大池化层,窗口大小为3x3,步长为2,填充为1
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 定义4个Res2Net的阶段(stage)
        self.layer1 = self._make_layer(block, 64, layers[0])  # 第一个阶段,输出通道数为64
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)  # 第二个阶段,输出通道数为128,步长为2
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)  # 第三个阶段,输出通道数为256,步长为2
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)  # 第四个阶段,输出通道数为512,步长为2

        # 全局平均池化层,将每个通道的特征图变成一个数
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        # 全连接层,将512维的特征向量映射到num_classes维的向量,用于分类
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # 初始化网络参数
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # 使用kaiming正态分布初始化卷积层参数
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                # 将Batch Normalization层的权重初始化为1,偏置初始化为0
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        # 构建Res2Net的一个阶段(stage),包含多个block
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            # 如果输入输出通道数不一致,或者步长不为1,需要添加下采样层
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        # 构建阶段的每个block
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample=downsample,
                            stype='stage', baseWidth=self.baseWidth, scale=self.scale))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale))

        return nn.Sequential(*layers)

    def forward(self, x):
        # 定义前向传播过程
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def res2net50(pretrained=False, **kwargs):
    """Constructs a Res2Net-50 model.
    Res2Net-50 refers to the Res2Net-50_26w_4s.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s']))
    return model

def res2net50_26w_4s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s']))
    return model

def res2net101_26w_4s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net101_26w_4s']))
    return model

def res2net50_26w_6s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 6, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_6s']))
    return model

def res2net50_26w_8s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 8, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_8s']))
    return model

def res2net50_48w_2s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_48w_2s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 48, scale = 2, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_48w_2s']))
    return model

def res2net50_14w_8s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_14w_8s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 14, scale = 8, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_14w_8s']))
    return model



if __name__ == '__main__':
    images = torch.rand(1, 3, 224, 224).cuda(0)
    model = res2net50_48w_2s(pretrained=False)
    model = model.cuda(0)
    print(model(images).size())
    print(model)

训练代码

# coding:utf-8 允许中文注释
import numpy as np
import os

import torchvision

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from option import get_args
from model import resnet50
from util import AverageMeter, accuracy, save_checkpoint, load_model_checkpoint
from res2net import res2net50_48w_2s

def init_seeds(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if seed == 0:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


best_acc1 = 0.
def repeat_channels(x):
    # 这个函数将输入的 PIL 图像 x 复制到三个通道,模拟 RGB 图像
    return x.repeat(3, 1, 1)

def main():
    print("Start...")
    global best_acc1
    args = get_args()
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    init_seeds(seed=0) # set random seed

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))  # 训练所用的GUP ID

    # directory for save
    args.log_folder = os.path.join('log', 'res2net50_48w_2s')
    if not os.path.exists(args.log_folder):
        os.makedirs(args.log_folder)



    if args.dataset == "CUB" and args.arch == "resnet50":
        channels = 2048
        num_classes = 200
        data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
        image_path = os.path.join(data_root, '/data/tgf/resnet/Data/')
        # train_dir = '/data/tgf/resnet/Data/trian'
        # valid_dir = '/data/tgf/resnet/Data/test'
    elif args.dataset == 'Cars' and args.arch == "resnet50":
        channels = 2048
        num_classes = 196
        data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
        image_path = os.path.join(data_root, '/tgf/resnet/CUB_200_2011/dataset')
        # train_dir = '/learn_pytorch/resnet/Data/trian'
        # valid_dir = '/learn_pytorch/resnet/Data/test'
    elif args.dataset == "Aircraft" and args.arch == "resnet50":
        channels = 2048
        num_classes = 100
        data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
        image_path = os.path.join(data_root, '/data/tgf/resnet/Data')
        # train_dir = '/learn_pytorch/resnet/Data/trian'
        # valid_dir = '/learn_pytorch/resnet/Data/test'
    else:
        raise Exception("No dataset named {}".format(args.dataset))

    # Model
    print("=> creating model '{}'".format(args.arch))
    print("num_classes ", num_classes)
    model = res2net50_48w_2s(pretrained=True)
    # model_weight_path = "./resnet50_pre.pth"
    # assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    # model.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # change fc layer structure
    in_channel = model.fc.in_features
    model.fc = nn.Linear(in_channel, num_classes)
    model = model.cuda()

    cudnn.benchmark = True

    # Loading training/validation dataset
    train_transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.RandomCrop((448, 448)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Lambda(repeat_channels),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])
    test_transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.CenterCrop((448, 448)),  # RandomCrop for train and CenterCrop for test
        transforms.ToTensor(),
        # transforms.Lambda(repeat_channels),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])



    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=train_transform)
    print("train_dataset为:",train_dataset)
    valid_dataset = datasets.ImageFolder(root=os.path.join(image_path, "test"), transform=test_transform)
    train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size,
                              shuffle=True, num_workers=args.workers, pin_memory=True)
    # print("train_loader为:",train_loader)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=args.batch_size,
                              shuffle=False, num_workers=args.workers, pin_memory=True)

    print("using {} images for training, {} images for validation.".format(len(train_dataset), len(valid_dataset)))

    # define loss function (criterion), optimizer, and learning rate scheduler
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, nesterov=True, momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=args.gamma)

    # optionally resume from a checkpoint
    if args.resume:
        model, optimizer = load_model_checkpoint(model, optimizer, args)

    def train(train_loader, model, criterion, optimizer, epoch, args):
        # AverageMeter for Performance
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Switch to train mode
        model.train()

        # lr = next(iter(optimizer.param_groups))['lr']
        train_bar = tqdm(train_loader)  # 训练集进度条
        for batch_idx, (inputs, targets) in enumerate(train_bar):
            idx = batch_idx
            inputs, targets = Variable(inputs).cuda(), Variable(targets).cuda()
            # inputs, targets = Variable(inputs), Variable(targets)

            # compute output
            outputs = model(inputs)  # 前向传播
            loss = criterion(outputs, targets)

            # # measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1[0], inputs.size(0))
            top5.update(acc5[0], inputs.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()  # !!
            optimizer.step()

            # print info
            description = "[Train:{0:3d}/{1:3d}] Top1-cls: {2:6.2f}, Top5-cls: {3:6.2f}, Loss: {4:7.4f},". \
                format(epoch + 1, args.epochs, top1.avg, top5.avg, losses.avg)
            train_bar.set_description(desc=description)

        return top1.avg, losses.avg

    best_acc_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        lr = next(iter(optimizer.param_groups))['lr']
        # ————————————————Train————————————————#
        train_acc1, train_losses = train(train_loader, model, criterion, optimizer, epoch, args)
        scheduler.step()  # 放到每个epoch训练完之后

        # tensorboard
        with SummaryWriter(log_dir=os.path.join(args.log_folder, 'no_seed/train'), comment='train') as writer:
            writer.add_scalar('Train/learning_rate', lr, epoch)
            writer.add_scalar('Train/train_acc1', train_acc1, epoch)
            writer.add_scalar('Train/train_loss', train_losses, epoch)
            writer.flush()
            writer.close()

        # ————————————————Test————————————————#
        val_acc1, val_losses = validate(valid_loader, model, criterion, epoch, args)  # Test!!!

        # tensorboard
        with SummaryWriter(log_dir=os.path.join(args.log_folder, 'no_seed/val'), comment='test') as writer:
            writer.add_scalar('Test/val_acc1', val_acc1, epoch)
            writer.add_scalar('Test/val_loss', val_losses, epoch)
            writer.flush()
            writer.close()

        is_best = val_acc1 > best_acc1  # True / False
        best_acc1 = max(val_acc1, best_acc1)
        # save_checkpoint({
        #     'epoch': epoch + 1,
        #     'arch': args.arch,
        #     'state_dict': model.state_dict(),
        #     'best_acc1': best_acc1,
        #     'optimizer': optimizer.state_dict(),
        #     # 'scheduler': scheduler.state_dict()
        # }, is_best, args.log_folder)

        if is_best:
            best_acc_epoch = epoch + 1
            savepath = "/data/tgf/resnet/log/resnet50_in_CUB/best.pth"
            torch.save(model, savepath)

        print("Until %d epochs, Best Acc@1 %.3f in the %d-th epoch" % (epoch + 1, best_acc1, best_acc_epoch))

    with open(os.path.join(args.log_folder, 'result.txt'), 'w') as file:
        file.write("best_acc1 {}".format(best_acc1))
    file.close()





def validate(val_loader, model, criterion, epoch, args):
    # AverageMeter for Performance
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    # DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        val_bar = tqdm(val_loader)
        for batch_idx, (inputs, targets) in enumerate(val_bar):
            idx = batch_idx
            inputs, targets = Variable(inputs).cuda(), Variable(targets).cuda()

            # Compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1[0], inputs.size(0))
            top5.update(acc5[0], inputs.size(0))

            # print info
            description = "[Valid:{0:3d}/{1:3d}] Top1-cls: {2:6.2f}, Top5-cls: {3:6.2f}, Loss: {4:7.4f}, ". \
                format(epoch + 1, args.epochs, top1.avg, top5.avg, losses.avg)
            val_bar.set_description(desc=description)

    return top1.avg, losses.avg


if __name__ == '__main__':
    main()

实验结果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1670912.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux内核下RAS(Reliability, Availability and Serviceability)功能分析记录

1 简介 Reliability, Availability and Serviceability (RAS) — The Linux Kernel documentation 在服务器 和 卫星等领域,对设备的稳定性要求很高,需要及时的发现并处理软/硬件上的错误。RAS功能可以用来及时的发现硬件上的错误。 RAS功能需要硬件的…

网页版五子棋的自动化测试

目录 前言 一、主要技术 二、测试环境的准备部署 三、测试用例 四、执行测试 4.1、公共类设计 创建浏览器驱动对象 测试套件 释放驱动类 4.2、功能测试 登录页面 注册页面 游戏大厅页面 游戏房间页面 测试套件结果 4.3、界面测试 登录页面 注册页面 游戏大…

5.13网络编程

只要在一个电脑中的两个进程之间可以通过网络进行通信那么拥有公网ip的两个计算机的通信是一样的。但是一个局域网中的两台电脑上的虚拟机是不能进行通信的,因为这两个虚拟机在电脑中又有各自的局域网所以通信很难实现。 socket套接字是一种用于网络间进行通信的方…

基于微信小程序+JAVA Springboot 实现的【网上商城小程序】app+后台管理系统 (内附设计LW + PPT+ 源码+ 演示视频 下载)

项目名称 项目名称: 基于微信小程序的网上商城 项目技术栈 该项目采用了以下核心技术栈: 后端框架/库: Java, SSM框架数据库: MySQL前端技术: 微信开发者工具,微信小程序框架 项目展示 5.1 管理员服务…

链表+环-链表是否有环的判断

链表是否有环的判断 在数据结构中,链表是一种常见的数据结构,它允许我们在不需要预先知道数据总量的情况下进行数据的动态存储。然而,由于链表的特性,有时我们可能会遇到链表中出现环的情况,即链表的某个节点指向了链…

每日互动(个推)与您相约2024 AI+研发数字峰会(AiDD)上海站

伴随着人工智能在众多行业领域的广泛应用及其带来的颠覆性变革,软件的开发模式、方式和实践也将发生巨大的变化。 5月17-18日,2024 AI研发数字峰会(AiDD)上海站即将重磅开幕。峰会设置了15个主题论坛,策划60精彩议题内…

道格拉斯普克算法(DP)的点云轮廓线简化

1、背景介绍 由于点云无法精确刻画目标对象边缘信息,因此常规提取的边缘点直接相连所生成的轮廓线,锯齿现象显著,与真实情况相差甚远(图b所示)。 道格拉斯-普克(Douglas-Peuker)抽稀算法是用来对…

java 使用hh或者HH异常

故障描述 使用了HH或者hh使用时间format、DatetimeFormat注解时序列化失败 故障原因 当使用hh的时候,小时只能是1-24 使用KK的时候,小时只能是0-23 比如:凌晨0:30,使用hh就是0:30 am, kk就是12:30 24小时制的话,使…

01-02-4

1、中级阶段-day1作业 使用的代码 #include<stdio.h> typedef struct Student {int num;char name[20];char sex; }Stu; int main() {Stu s;scanf("%d%s %c", &s.num, s.name, &s.sex);//读取字符串时&#xff0c;scanf()的占位符用%s即可&#xff0c…

重大升级 | OpenSCA SaaS全面接入供应链安全情报!

结合社区用户反馈及研发小伙伴的积极探索&#xff0c; OpenSCA 项目组再次发力&#xff0c;SaaS版本重大升级啦&#xff01; 用户的需求是OpenSCA前进的动力&#xff0c;欢迎更多感兴趣的朋友们积极试用和反馈~ 更 新 内 容 1.全面接入云脉XSBOM供应链安全情报 2.强大的资产…

【异常】SpringBoot整合RabbitMQ-发送消息报错

错误信息 reply-code406, reply-textPRECONDITION_FAILED - inequivalent arg ‘x-message-ttl’ for queue ‘hello-queue’ in vhost ‘/lq’: received none but current is the value ‘10000’ of type ‘signedint’, class-id50, method-id10 错误原因 hello-queue这…

省公派访学|社科老师赴世界名校牛津大学开展研究

F老师已获某省公派出国访学半年的资助&#xff0c;希望落实的学校尽量知名。但因为F老师只是硕士毕业而无博士学位&#xff0c;专业方向又是社科类&#xff0c;所以申请到世界知名高校有一定难度。经过努力&#xff0c;最终我们获得了世界顶尖高校-英国牛津大学的访问学者邀请函…

C++常见十种排序方式

目录 前言 1、选择排序 介绍 参考代码 2、冒泡排序 介绍 参考代码 3、插入排序 介绍 参考代码 4、希尔排序 介绍 参考代码 5、快速排序 介绍 参考代码 6、并归排序 介绍 参考代码 7、堆排序 介绍 参考代码 8、基数排序 介绍 参考代码 9、计数排序 介绍 参考代…

用户需求甄别和筛选的6大标准

产品经理日常经常接收到大量的需求&#xff0c;并不是所有的需求都需要开发&#xff0c;需要进行甄别和筛选&#xff0c;这样有利于确保项目的成功、优化资源利用以及提高产品质量。 那么针对这些用户需求进行甄别或筛选的评判标准是什么&#xff1f;需求筛选可以说是初步的需求…

设计模式-工厂模式设计与详解

一、设计模式介绍 设计模式是我们开发中常常需要面对的核心概念&#xff0c;它们是解决特定问题的模板或者说是经验的总结。这些模式被设计出来是为了让软件设计更加清晰、代码更加可维护且能应对未来的变化。良好的设计模式不仅能解决重复代码的问题&#xff0c;还能使团队中…

关于修改ant-design-vue的table组件背景色遇到闪动的问题

项目中需要修改表格的背景色为以下样式 修改完之后发现表格行还有个hover的背景色&#xff0c;于是再次重置样式如下 .ant-table-tbody > tr {&:hover {td {// background: red !important;background: transparent !important;}}}这样重置之后&#xff0c;hover的样式…

【中级软件设计师】上午题16-算法(应试考试简略版)

上午题16-算法 1 回溯法1.1 n皇后问题 2 分治法3 动态规划3.1 0-1背包问题3.2 最长公共子序列3.3 矩阵连乘 4 贪心算法5 分支限界法总结 1 回溯法 深度优先方法搜索 1.1 n皇后问题 2 分治法 一般来说&#xff0c;分治算法在每一层递归上都有3个步骤 &#xff08;1&#xff…

【C++】详解STL的适配器容器之一:优先级队列 priority_queue

目录 堆算法 概述 向下调整建堆 向上调整建堆 建堆算法 仿函数 概述 使用介绍 emtpy size top push pop 模拟实现 仿函数 框架 向下调整算法 向上调整算法 pop push empty top 要理解优先级队列&#xff0c;需要有如下知识 STL容器之一的vector&#xf…

嵌入式:基于STM32的RFID访问控制系统

在商业和住宅建筑中&#xff0c;访问控制系统是确保安全的关键组件。使用射频识别&#xff08;RFID&#xff09;技术&#xff0c;我们可以创建一个安全、方便的门禁系统。本教程将详细说明如何使用STM32微控制器实现RFID基础的门禁系统&#xff0c;该系统能够控制电子锁并记录访…

品鉴中的品鉴笔记:如何记录和分享自己的品鉴心得

品鉴云仓酒庄雷盛红酒的过程&#xff0c;不仅是品尝美酒&#xff0c;更是一次与葡萄酒深度对话的旅程。为了更好地记录和分享自己的品鉴心得&#xff0c;养成写品鉴笔记的习惯是十分必要的。 首先&#xff0c;选择一个适合的记录工具。可以是传统的笔记本&#xff0c;也可以是…