DenseNet分类网络改进(添加SPP)--亲测有效

news2025/1/21 11:29:33

最近,我在做多分类问题。在针对基模型的选择中,我使用了DenseNet作为基本模型。我在阅读论文时,遇到了一种改进方式:
在这里插入图片描述

如上图所示,在全连接层之前引入SPP模块:
在这里插入图片描述
代码如下:

SPP模块代码:

class SpatialPyramidPooling(nn.Module):
    def __init__(self, pool_sizes: List[int], in_channels: int):
        super(SpatialPyramidPooling, self).__init__()
        self.pool_sizes = pool_sizes
        self.in_channels = in_channels
        self.pool_layers = nn.ModuleList([
            nn.AdaptiveMaxPool2d(output_size=(size, size)) for size in pool_sizes
        ])

    def forward(self, x: Tensor) -> Tensor:
        pools = [pool_layer(x) for pool_layer in self.pool_layers]

        # Resize the output of each pool to have the same number of channels
        pools_resized = [F.adaptive_max_pool2d(pool, (1, 1)) for pool in pools]

        spp_out = torch.cat(pools_resized, dim=1)  # Concatenate the resized pools
        return spp_out

加入SPP代码后的DenseNet网络完整如下:

import re
from typing import List, Tuple, Any
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import Tensor

class _DenseLayer(nn.Module):
    def __init__(self, input_c: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False):
        super(_DenseLayer, self).__init__()

        self.add_module("norm1", nn.BatchNorm2d(input_c))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", nn.Conv2d(in_channels=input_c, out_channels=bn_size * growth_rate,
                                           kernel_size=1, stride=1, bias=False))
        self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False))
        self.drop_rate = drop_rate
        self.memory_efficient = memory_efficient

    def bn_function(self, inputs: List[Tensor]) -> Tensor:
        concat_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concat_features)))
        return bottleneck_output

    @staticmethod
    def any_requires_grad(inputs: List[Tensor]) -> bool:
        for tensor in inputs:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused
    def call_checkpoint_bottleneck(self, inputs: List[Tensor]) -> Tensor:
        def closure(*inp):
            return self.bn_function(inp)

        return cp.checkpoint(closure, *inputs)

    def forward(self, inputs: Tensor) -> Tensor:
        if isinstance(inputs, Tensor):
            prev_features = [inputs]
        else:
            prev_features = inputs

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("memory efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output_with_cbam)))

        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)

        return new_features

class _DenseBlock(nn.ModuleDict):
    _version = 2

    def __init__(self, num_layers: int, input_c: int, bn_size: int, growth_rate: int, drop_rate: float,
                 memory_efficient: bool = False):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(input_c + i * growth_rate,
                                growth_rate=growth_rate,
                                bn_size=bn_size,
                                drop_rate=drop_rate,
                                memory_efficient=memory_efficient)
            self.add_module("denselayer%d" % (i + 1), layer)

    def forward(self, init_features: Tensor) -> Tensor:
        features = [init_features]
        for _, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)



class _Transition(nn.Sequential):
    def __init__(self, input_c: int, output_c: int):
        super(_Transition, self).__init__()
        self.add_module("norm", nn.BatchNorm2d(input_c))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(input_c, output_c, kernel_size=1, stride=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))

class SpatialPyramidPooling(nn.Module):
    def __init__(self, pool_sizes: List[int], in_channels: int):
        super(SpatialPyramidPooling, self).__init__()
        self.pool_sizes = pool_sizes
        self.in_channels = in_channels
        self.pool_layers = nn.ModuleList([
            nn.AdaptiveMaxPool2d(output_size=(size, size)) for size in pool_sizes
        ])

    def forward(self, x: Tensor) -> Tensor:
        pools = [pool_layer(x) for pool_layer in self.pool_layers]

        # Resize the output of each pool to have the same number of channels
        pools_resized = [F.adaptive_max_pool2d(pool, (1, 1)) for pool in pools]

        spp_out = torch.cat(pools_resized, dim=1)  # Concatenate the resized pools
        return spp_out

class DenseNet(nn.Module):
    def __init__(self, growth_rate: int = 32, block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
                 num_init_features: int = 64, bn_size: int = 4, drop_rate: float = 0, num_classes: int = 1000,
                 memory_efficient: bool = False):
        super(DenseNet, self).__init__()

        # First conv+bn+relu+pool
        self.features = nn.Sequential(OrderedDict([
            ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ("norm0", nn.BatchNorm2d(num_init_features)),
            ("relu0", nn.ReLU(inplace=True)),
            ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each dense block
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers,
                                input_c=num_features,
                                bn_size=bn_size,
                                growth_rate=growth_rate,
                                drop_rate=drop_rate,
                                memory_efficient=memory_efficient)
            self.features.add_module("denseblock%d" % (i + 1), block)
            num_features = num_features + num_layers * growth_rate

            if i != len(block_config) - 1:
                trans = _Transition(input_c=num_features,
                                    output_c=num_features // 2)
                self.features.add_module("transition%d" % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))

        # Spatial Pyramid Pooling (SPP) layer
        spp_pool_sizes = [1, 4, 16]  # You can adjust pool sizes as needed
        self.spp = SpatialPyramidPooling(spp_pool_sizes, in_channels=num_features)

        # FC layer
        self.classifier = nn.Linear(num_features + len(spp_pool_sizes) * num_features, num_classes)

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)
    def forward(self, x: Tensor) -> Tensor:
        features = self.features(x)
        out = F.relu(features, inplace=True)
        # Apply Spatial Pyramid Pooling
        spp_out = self.spp(out)
        # Adjust the number of channels in out to match spp_out
        out = F.adaptive_avg_pool2d(out, (1, 1))
        # Concatenate the original feature map with the SPP output along the channel dimension
        out = torch.cat([spp_out, out], dim=1)
        # Flatten the spatial dimensions of out
        out = torch.flatten(out, 1)
        # FC layer
        out = self.classifier(out)
        return out


def densenet121(**kwargs: Any) -> DenseNet:
    # Top-1 error: 25.35%
    # 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth'
    return DenseNet(growth_rate=32,
                    block_config=(6, 12, 24, 16),
                    num_init_features=64,
                    **kwargs)
def load_state_dict(model: nn.Module, weights_path: str) -> None:
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')

    state_dict = torch.load(weights_path)

    num_classes = model.classifier.out_features
    load_fc = num_classes == 1000

    for key in list(state_dict.keys()):
        if load_fc is False:
            if "classifier" in key:
                del state_dict[key]

        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict, strict=load_fc)
    print("successfully load pretrain-weights.")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1290306.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VMware 虚拟机 电脑重启后 NAT 模式连不上网络问题修复

问题描述: 昨天 VMware 安装centos7虚拟机,网络模式配置的是NAT模式,配置好后,当时能连上外网,今天电脑重启后,发现连不上外网了 检查下各个配置,都没变动,突然就连不上了 网上查了…

Flink Flink数据写入Kafka

一、环境准备 flink 1.14写入Kafka&#xff0c;首先在pom.xml文件中导入相关依赖 <properties><project.build.sourceEncoding>UTF-8</project.build.sourceEncoding><flink.version>1.14.6</flink.version><spark.version>2.4.3</spa…

MySQL联合查询、最左匹配、范围查询导致失效

服务器版本 客户端&#xff1a;navicat premium16.0.11 联合索引 假设有如下表 联合索引就是同时把多列设成索引&#xff0c;如(empno&#xff0c;ename)在查询的时候就会先按照empno进行查询&#xff0c;再按照ename进行查询其中empno是全局有序&#xff0c;ename是局部有…

java数字千分位格式转换

java数字千分位格式转换 public static void main(String[] args) {System.out.println(thousandsSeparator("123123131"));}public static String thousandsSeparator(String value) {if (isNotNull(value)) {String[] arr value.split("");for (int i …

Kotlin Flow 操作符

前言 Kotlin 拥有函数式编程的能力&#xff0c;使用Kotlin开发&#xff0c;可以简化开发代码&#xff0c;层次清晰&#xff0c;利于阅读。 然而Kotlin拥有操作符很多&#xff0c;其中就包括了flow。Kotlin Flow 如此受欢迎大部分归功于其丰富、简洁的操作符&#xff0c;巧妙使…

top K问题(C语言)

目录 前言 top K问题 模拟数据 建堆 验证&#xff08;简单了解即可&#xff09; 最终代码 调试部分 前言 在大小堆的实现&#xff08;C语言&#xff09;中我们讨论了堆的实际意义&#xff0c;在看了就会的堆排序&#xff08;C语言&#xff09;中我们完成了堆排序&#…

MYSQL全语法速查(含示例)

文章目录 1.从简单的查询开始查找所有记录(SELECT *)查找记录中的所有登录名(SELECT)查找登录名为admin的密码(WHERE)查找电话号码非空的记录(IS NOT NULL)查找所在城市为北京或者用户名字是李四的记录(OR)查找所在城市为北京并且用户名字是张三的记录(AND)查找用户名字是李四或…

6个实用又好用的交互原型工具!

在 UI/UX 设计中&#xff0c;原型设计是至关重要的一步。正如用户体验中的其它环节一样&#xff0c;有无数的交互原型工具可以帮助你完成原型设计。市场上有太多的交互原型工具&#xff0c;如果你不知道选择哪一种&#xff0c;那么我们将为你介绍 6 个实用又好用的交互原型工具…

C++ day56 两个字符串的删除操作 编辑距离

题目1&#xff1a;583 两个字符串的删除操作 题目链接&#xff1a;两个字符串的删除操作 对题目的理解 返回使两个单词word1和word2相同的最少删除多少个元素&#xff0c;两个单词至少包含一个字母&#xff0c;且仅包含小写字母 思路1&#xff1a;这道题与昨天的不同子序列…

Dagger2使用

在android引入Dagger2库 //引入Dagger2implementation("com.google.dagger:dagger:2.48.1")annotationProcessor ("com.google.dagger:dagger-compiler:2.48.1") 构造器注入 创建一个类 public class Car {//在构造器上面添加dagger的Inject即可Injectp…

现代雷达车载应用——第2章 汽车雷达系统原理 2.1节

经典著作&#xff0c;值得一读&#xff0c;英文原版下载链接【免费】ModernRadarforAutomotiveApplications资源-CSDN文库。 2.1 基本雷达功能 雷达系统通过天线或天线阵列向空间辐射电磁能量。辐射的电磁能量“照亮”周围的目标。“被照亮”的目标拦截一些辐射能量&#xff0…

二十章总结

20.1线程介绍 世间有很多工作都是可以同时完成的。例如&#xff0c;人体可以同时进行呼吸、血液循环、思考问题等活动&#xff1b;用户既可以使用计算机听歌&#xff0c;也可以使用它打印文件。同样&#xff0c;计算机完全可以将多种活动同时进行&#xff0c;这种思想放在 Jav…

配置主机与外网时间服务器同步时间

目录 NTP服务简介 时间管理命令 第一步&#xff1a;更改当前主机所在地的时间 方法一&#xff1a;使用tzselect命令查询需要的时区 1、使用tzselect命令查询需要的时区 2、添加变量到 ~/.bash_profile 文件中&#xff0c;即追加类似的内容&#xff1a; 3、重新连接一个…

Flink(九)【时间语义与水位线】

前言 2023-12-02-20:05&#xff0c;终于写完啦&#xff0c;最近状态不错。刚写完又收到了她的消息哈哈哈哈&#xff0c;开心。 再去全力打拼一次&#xff0c;奋战一场&#xff0c;就算最后打了败仗也无所谓&#xff0c;至少你留下了足迹。 《解忧杂货店》 1、时间语义 …

团队git操作流程

项目的开发要求&#xff1a;&#xff08;1&#xff09;项目组厉员代码提交不少于20次 &#xff08;2&#xff09;项目组厉员每天提交不少于20次 &#xff08;3&#xff09;企业项目开发代码的每天的提交一般提交3-5次 &#xff08;4&#xff09;代码仓库的管理 git的基础操作流…

【CSP】202305-1_重复局面Python实现

文章目录 [toc]试题编号试题名称时间限制内存限制题目背景问题描述输入格式输出格式样例输入样例输出样例说明子任务提示Python实现 试题编号 202305-1 试题名称 重复局面 时间限制 1.0s 内存限制 512.0MB 题目背景 国际象棋在对局时&#xff0c;同一局面连续或间断出现3次或3…

会声会影2024软件还包含了视频教学以及模板素材

会声会影2024中文版是一款加拿大公司Corel发布的视频编软件。会声会影2024官方版支持视频合并、剪辑、屏幕录制、光盘制作、添加特效、字幕和配音等功能&#xff0c;用户可以快速上手。会声会影2024软件还包含了视频教学以及模板素材&#xff0c;让用户剪辑视频更加的轻松。 会…

【Linux 进度条小程序】缓冲区+回车换行

文章目录 回车与换行缓冲区举个栗子fflush函数倒计时小程序进度条小程序 回车与换行 回车和换行是不同的两个概念 回车&#xff1a;\r 使光标回到本行行首。 换行&#xff1a;\n使光标下移一格。 一般我们的键盘上的Enter键是回加换行键 在c语言中 \n 表示回车换行 效果和Ent…

iNet Network Scanner for Mac:简洁高效的WiFi网络扫描软件

随着无线网络的普及&#xff0c;WiFi网络已经成为我们日常生活中必不可少的一部分。无线网络的稳定性和速度对我们的工作和娱乐体验至关重要。因此&#xff0c;一款功能强大、简洁高效的WiFi网络扫描软件非常重要。今天&#xff0c;我们向大家推荐一款优秀的Mac平台WiFi网络扫描…

AI生成工具助力业界,黎万强投资。/微信支付和支付宝境外银行卡全面支持 | 魔法半周报

我有魔法✨为你劈开信息大海❗ 高效获取AIGC的热门事件&#x1f525;&#xff0c;更新AIGC的最新动态&#xff0c;生成相应的魔法简报&#xff0c;节省阅读时间&#x1f47b; &#x1f525;资讯预览 AI应用一年半&#xff0c;IBM HR部省12000小时 华为沉默中&#xff0c;何时…