DenseNet分类网络改进--亲测有效

news2025/2/25 18:17:26

最近,我在做多分类问题。在针对基模型的选择中,我使用了DenseNet作为基本模型。我在阅读论文时,遇到了一种改进方式:
在这里插入图片描述

如上图所示,在全连接层之前引入SPP模块:
在这里插入图片描述
代码如下:

SPP模块代码:

class SpatialPyramidPooling(nn.Module):
    def __init__(self, pool_sizes: List[int], in_channels: int):
        super(SpatialPyramidPooling, self).__init__()
        self.pool_sizes = pool_sizes
        self.in_channels = in_channels
        self.pool_layers = nn.ModuleList([
            nn.AdaptiveMaxPool2d(output_size=(size, size)) for size in pool_sizes
        ])

    def forward(self, x: Tensor) -> Tensor:
        pools = [pool_layer(x) for pool_layer in self.pool_layers]

        # Resize the output of each pool to have the same number of channels
        pools_resized = [F.adaptive_max_pool2d(pool, (1, 1)) for pool in pools]

        spp_out = torch.cat(pools_resized, dim=1)  # Concatenate the resized pools
        return spp_out

加入SPP代码后的DenseNet网络完整如下:

import re
from typing import List, Tuple, Any
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import Tensor

class _DenseLayer(nn.Module):
    def __init__(self, input_c: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False):
        super(_DenseLayer, self).__init__()

        self.add_module("norm1", nn.BatchNorm2d(input_c))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", nn.Conv2d(in_channels=input_c, out_channels=bn_size * growth_rate,
                                           kernel_size=1, stride=1, bias=False))
        self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False))
        self.drop_rate = drop_rate
        self.memory_efficient = memory_efficient

    def bn_function(self, inputs: List[Tensor]) -> Tensor:
        concat_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concat_features)))
        return bottleneck_output

    @staticmethod
    def any_requires_grad(inputs: List[Tensor]) -> bool:
        for tensor in inputs:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused
    def call_checkpoint_bottleneck(self, inputs: List[Tensor]) -> Tensor:
        def closure(*inp):
            return self.bn_function(inp)

        return cp.checkpoint(closure, *inputs)

    def forward(self, inputs: Tensor) -> Tensor:
        if isinstance(inputs, Tensor):
            prev_features = [inputs]
        else:
            prev_features = inputs

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("memory efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output_with_cbam)))

        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)

        return new_features

class _DenseBlock(nn.ModuleDict):
    _version = 2

    def __init__(self, num_layers: int, input_c: int, bn_size: int, growth_rate: int, drop_rate: float,
                 memory_efficient: bool = False):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(input_c + i * growth_rate,
                                growth_rate=growth_rate,
                                bn_size=bn_size,
                                drop_rate=drop_rate,
                                memory_efficient=memory_efficient)
            self.add_module("denselayer%d" % (i + 1), layer)

    def forward(self, init_features: Tensor) -> Tensor:
        features = [init_features]
        for _, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)



class _Transition(nn.Sequential):
    def __init__(self, input_c: int, output_c: int):
        super(_Transition, self).__init__()
        self.add_module("norm", nn.BatchNorm2d(input_c))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(input_c, output_c, kernel_size=1, stride=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))

class SpatialPyramidPooling(nn.Module):
    def __init__(self, pool_sizes: List[int], in_channels: int):
        super(SpatialPyramidPooling, self).__init__()
        self.pool_sizes = pool_sizes
        self.in_channels = in_channels
        self.pool_layers = nn.ModuleList([
            nn.AdaptiveMaxPool2d(output_size=(size, size)) for size in pool_sizes
        ])

    def forward(self, x: Tensor) -> Tensor:
        pools = [pool_layer(x) for pool_layer in self.pool_layers]

        # Resize the output of each pool to have the same number of channels
        pools_resized = [F.adaptive_max_pool2d(pool, (1, 1)) for pool in pools]

        spp_out = torch.cat(pools_resized, dim=1)  # Concatenate the resized pools
        return spp_out

class DenseNet(nn.Module):
    def __init__(self, growth_rate: int = 32, block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
                 num_init_features: int = 64, bn_size: int = 4, drop_rate: float = 0, num_classes: int = 1000,
                 memory_efficient: bool = False):
        super(DenseNet, self).__init__()

        # First conv+bn+relu+pool
        self.features = nn.Sequential(OrderedDict([
            ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ("norm0", nn.BatchNorm2d(num_init_features)),
            ("relu0", nn.ReLU(inplace=True)),
            ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each dense block
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers,
                                input_c=num_features,
                                bn_size=bn_size,
                                growth_rate=growth_rate,
                                drop_rate=drop_rate,
                                memory_efficient=memory_efficient)
            self.features.add_module("denseblock%d" % (i + 1), block)
            num_features = num_features + num_layers * growth_rate

            if i != len(block_config) - 1:
                trans = _Transition(input_c=num_features,
                                    output_c=num_features // 2)
                self.features.add_module("transition%d" % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))

        # Spatial Pyramid Pooling (SPP) layer
        spp_pool_sizes = [1, 4, 16]  # You can adjust pool sizes as needed
        self.spp = SpatialPyramidPooling(spp_pool_sizes, in_channels=num_features)

        # FC layer
        self.classifier = nn.Linear(num_features + len(spp_pool_sizes) * num_features, num_classes)

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)
    def forward(self, x: Tensor) -> Tensor:
        features = self.features(x)
        out = F.relu(features, inplace=True)
        # Apply Spatial Pyramid Pooling
        spp_out = self.spp(out)
        # Adjust the number of channels in out to match spp_out
        out = F.adaptive_avg_pool2d(out, (1, 1))
        # Concatenate the original feature map with the SPP output along the channel dimension
        out = torch.cat([spp_out, out], dim=1)
        # Flatten the spatial dimensions of out
        out = torch.flatten(out, 1)
        # FC layer
        out = self.classifier(out)
        return out


def densenet121(**kwargs: Any) -> DenseNet:
    # Top-1 error: 25.35%
    # 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth'
    return DenseNet(growth_rate=32,
                    block_config=(6, 12, 24, 16),
                    num_init_features=64,
                    **kwargs)
def load_state_dict(model: nn.Module, weights_path: str) -> None:
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')

    state_dict = torch.load(weights_path)

    num_classes = model.classifier.out_features
    load_fc = num_classes == 1000

    for key in list(state_dict.keys()):
        if load_fc is False:
            if "classifier" in key:
                del state_dict[key]

        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict, strict=load_fc)
    print("successfully load pretrain-weights.")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1285617.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

系统运维工具KSysAK——让运维回归简单

系统运维工具KSysAK——让运维回归简单 1.基本信息 1.1概述 系统异常定位分析工具KSysAK是云峦操作系统研发及运维人员总结开发及运维经验,设计和研发的多个运维工具的集合,可以覆盖系统的日常监控、线上问题诊断和系统故障修复等常见运维场景。 工具…

Java---异常

文章目录 1. 异常概述2. try...catch3. Throwable成员方法4. 编译时异常和运行时异常区别5. 异常处理之throws6. 自定义异常7. throws和throw的区别 1. 异常概述 1. 异常:就是程序中出现了不正常的情况。 2. Error:严重问题,不需要处理。Exce…

揭秘MQTT:为何它是物联网的首选协议?

文章目录 MQTT 协议简介概览MQTT 与其他协议对比MQTT vs HTTPMQTT vs XMPP 为什么 MQTT 是适用于物联网的最佳协议?轻量高效,节省带宽可靠的消息传递海量连接支持安全的双向通信在线状态感知 MQTT 5.0 与 3.1.1MQTT 服务器MQTT 客户端 MQTT 协议简介 概…

亿胜盈科AT8236 刷式直流电机驱动器

AT8236是一款刷式直流电机驱动器,能够以高达6A的峰值电流双向控制电机。利用电流衰减模式,可通过对输入信号进行脉宽调制(PWM) 来控制电机转速,同时具备低功耗休眠模式。 AT8236集成同步整流功能,可显著降低系统功耗要求。内部保…

【鸿蒙应用开发】开发环境搭建及IDE安装使用

1.下载安装包 安装包下载地址: 点击跳转下载页面 可以根据自己的操作系统选择对应版本下载。 本文以Windows安装为例,Mac安装方式相同 2. 安装 下载好后,打开安装包,进入安装界面: 点击Next,进入安…

傻傻分不清楚的分区、分库、分表

一、分区 MySQL 在 5.1 时添加了对 分区(即水平分区) 的支持。MySQL 的物理数据存储在表空间文件(.ibdata1和.ibd)中,分区 的意思是指将同一表中不同行的记录分配到不同的物理文件中。有几个分区就有几个 .idb 文件。…

CPP-SCNUOJ-Problem P23. 计数排序(使用C/C++)

Problem P23. 计数排序(使用C/C) 计下标从 1 开始。有n 个取值范围在 [1,m] 的整数ai 。请将它们升序排序,设排序后数组为b 。为避免输出过长,请输出: 输入 输出 输出一个整数代表计算结果 样例 标准输入 10 3 1 …

智能液压传动综合实验台比例阀放大器

智能型液压传动实验台具有开发测试分析系统,通过对流量、压力、功率、转速、扭矩、位移、时间、温度--计算机人机画面 -- 计算机智能数据采集、分析、处理、--自动生产报表、曲线等一系列智能化动作后,完成各类常规的液压回路、马达、各类阀泵的动静态测试等实验.通…

C语言枚举详解,typedef简介(能看懂文字就能明白系列)

系列文章目录 C语言基础专栏 笔记详解 🌟 个人主页:古德猫宁- 🌈 信念如阳光,照亮前行的每一步 文章目录 系列文章目录🌈 *信念如阳光,照亮前行的每一步* 前言一、枚举类型的声明枚举常量三、枚举类型的优…

情怀零食店溢价严重,网友:情怀就是智商税,贵可以不买!

小时候的零食,是每个人心中无法抹去的甜蜜记忆。在广东,那些5毛钱的零食更是让无数人回味无穷。但近年来,这些情怀零食店的价格乱象却让不少人大呼“离谱”。 有市民反映,在一家主打怀旧主题的零食店内,三样商品竟然要…

ros2+UBUNTU读取STM32发送过来的数据(C++)

ATTENTION:一般ros2上位机访问STM32不是使用串口,即使树莓派有串口,我也不会用的,因为那还要去学习其他的语言,一般就是ros2---------ubs转串口-------STM32串口。 这个USB转串口,我们已经安装了CH340驱动了&#xff…

mall电商项目(学习记录2)

运行mall-admin Java项目 需要安装Redis,需要安装mysql,同时需要运行其项目提供的mall.sql 运行mall-admin后端程序 安装完Redis、mysql、HeidiSQL(用于执行mall.sql,界面化操作高效直观)、IntelliJ IDEA 运行mall-…

写给初学者的 HarmonyOS 教程 -- 页面路由(router)

页面路由(router)是指在应用程序中实现不同页面之间的跳转和数据传递。 HarmonyOS 提供了 Router 模块,通过不同的 url 地址,可以方便地进行页面路由,轻松地访问不同的页面。 类似这样的效果: 页面跳转是…

MISRA C 2012 标准浅析

MISRA(The Motor Industry Software Reliability Association),汽车工业软件可靠性联会; 1994年,英国成立。致力于协助汽车厂商开发安全可靠的软件的跨国协会,其成员包括:AB汽车电子、罗孚汽车、宾利汽车、福特汽车、捷…

从0到1的跨境电商创业经验分享!个人如何做跨境电商创业?

近年来,跨境电商成为了一种非常流行的创业方式,都知道国内贸易不好做,许多卖家都想通过跨境电商创业,但他们不知道具体的过程,今天龙哥我就分享一下我自己在跨境电商创业总结出来的经验,帮助你在跨境电商领…

Apollo新版本Beta自动驾驶技术沙龙参会体验有感—百度自动驾驶开源框架

在繁忙的都市生活中,我们时常对未来的科技发展充满了好奇和期待。而近日,我有幸参加了一场引领科技潮流的线下技术沙龙,主题便是探索自动驾驶的魅力——一个让我们身临其境感受创新、了解技术巨擘的机会。 在12月2日我有幸参加了Apollo新版本…

基于Linux的网络防火墙设计方法

摘要 随着Internet的迅速发展,网络越来越成为了人们日常生活不可或缺的一部分,而随之引出的网络安全问题也越来越突出,成为人们不得不关注的问题。 为了在一个不安全的网际环境中构造出一个相对安全的环境,保证子网环境下的计算机…

LeetCode | 110. 平衡二叉树

LeetCode | 110. 平衡二叉树 OJ链接 首先计算出二叉树的高度然后计算当前节点的左右子树的高度,然后判断当前节点的左右子树高度差是否超过 1,最后递归地检查左右子树是否也是平衡的。 //计算二叉树的高度 int height(struct TreeNode* root) {if(root…

国标GB28181视频监控EasyCVR内网环境部署无法启动怎么办?

安防视频监控系统EasyCVR平台可拓展性强、视频能力灵活、部署轻快,可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等,以及支持厂家私有协议与SDK接入,包括海康Ehome、海大宇等设备的SDK等,能对外分发RTMP、RTSP、HTTP-FLV、…

【Spring Cloud Alibaba】1.4 Nacos服务注册流程和原理解析

文章目录 1.前言2. 服务注册的基本流程3. 服务注册的核心代码分析3.1. NacosNamingServiceNamingProxy 服务端通信的核心类NamingClientProxy nacos 2.x 版本服务端通信核心接口 3.2 NamingGrpcClientProxy 详解RpcClient类RpcClient类核心方法 start 3.3 NamingHttpClientProx…