DenseNet算法实战

news2024/9/24 13:20:40

DenseNet算法实战


文章目录

  • DenseNet算法实战
    • @[TOC](文章目录)
  • 前言
  • 一、设计理念
  • 二、网络结构
    • 1.DenseNet网络结构
    • 2. DenseBlock + Transition结构
    • 3. DenseBlock 非线性结构
  • 三、代码实现
    • 1. 导入相关的包
    • 2. DenseBlock 内部结构
    • 3. DenseBlock 模块
    • 4. Transition 层
    • 5. 最后实现DenseNet网络

前言

  • 主要介绍DenseNet模型,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接,通过特征图在channel上的连接来实现特征重用
  • 使用pytorch框架进行代码编写, 对应的tensorflow代码正在写中…

一、设计理念

  • 相比ResNet, DenseNet提出了一个更为激进的密集连接机制: 即互相连接所有的层,具体来说就是每个层都会接受其前面所有层作为其额外的输入。
  • 下图为ResNet网络的短路连接机制
    在这里插入图片描述
  • 下图为DenseNet网络的短路连接机制
    在这里插入图片描述
  • 而对于DenseNet,则是通过跨通道concat的形式来连接,会连接前面的所有层作为输入,这里要注意所有的层的输入都来源于前面所有层在channel维度concat,
    在这里插入图片描述

二、网络结构

1.DenseNet网络结构

在这里插入图片描述

2. DenseBlock + Transition结构

  • DenseNet网络中使用DenseBlock + Transition的结构,其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。而Transition层是两个相邻的DenseBlock,并且通过pooling使特征图大小降低。
  • 下图为DenseBlock + Transition结构
    在这里插入图片描述

3. DenseBlock 非线性结构

  • 在DenseBlock中, 各个层的特征图大小一致, 可以在channel维度上连接,DenseBlock基本结构是BN + ReLU +(33)Conv的结构,如下图所示
    在这里插入图片描述
    由于后面层的输入会非常大, DenseBlock内部可以采用bottleneck层来减少计算量, 主要是原有的结构增加1
    1的Conv, 即BN + ReLU + 11Conv + BN + ReLU +33Conv, 称为DenseBlock 结构
    在这里插入图片描述

三、代码实现

1. 导入相关的包

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings
from torchsummary import summary
import torch.nn.functional as F
from collections import OrderedDict

2. DenseBlock 内部结构

class _DenseLayer(nn.Sequential):

    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()

        self.add_module("norm1", nn.BatchNorm2d(num_input_features))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", nn.Conv2d(num_input_features, bn_size*growth_rate,
                                           kernel_size=1, stride=1, bias=False))

        self.add_module("norm2", nn.BatchNorm2d(bn_size*growth_rate))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False))

        self.drop_rate = drop_rate

    def forward(self, x):

        new_feartures = super(_DenseLayer, self).forward(x)

        if self.drop_rate > 0:
            new_feartures = F.dropout(new_feartures, p=self.drop_rate, training=self.training)

        return torch.cat([x, new_feartures], 1)

3. DenseBlock 模块

class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()

        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i*growth_rate, growth_rate,
                                bn_size, drop_rate)

            self.add_module("denselayer%d" %(i+1), layer)

4. Transition 层

  • 主要是一个卷积层和一个池化层
class _Transition(nn.Sequential):
    def __init__(self, num_input_feature, num_output_features):
        super(_Transition, self).__init__()

        self.add_module("norm", nn.BatchNorm2d(num_input_feature))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(num_input_feature, num_output_features,
                                          kernel_size=1, stride=1, bias=False))

        self.add_module("pool", nn.AvgPool2d(2, stride=2))

5. 最后实现DenseNet网络

class DenseNet(nn.Module):

    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
                 bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000):
        super(DenseNet, self).__init__()

        self.features = nn.Sequential(OrderedDict([
                ("conv0", nn.Conv2d(3, num_init_features, 7, 2, 3, bias=False)),
                ("norm0", nn.BatchNorm2d(num_init_features)),
                ("relu0", nn.ReLU(inplace=True)),
                ("pool", nn.MaxPool2d(3, stride=2, padding=1))
                ]))

        # DenseBlock
        num_features =num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)

            self.features.add_module("denseblock%d" %(i+1), block)
            num_features += num_layers*growth_rate

            if i != len(block_config) - 1:
                transition = _Transition(num_features, int(num_features*compression_rate))
                self.features.add_module(("transition%d" %(i+1), transition))
                num_features = int(num_features * compression_rate)

        # final bn + relu
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))
        self.features.add_module("relu5", nn.ReLU(inplace=True))

        # classification layer
        self.classifier = nn.Linear(num_features, num_classes)

        # param initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal(m.weight)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1)

            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)
        out = self.classifier(out)

        return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/732050.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

23款奔驰S400商务型加装原厂前排座椅通风系统,夏天必备的功能

通风座椅的主动通风功能可以迅速将座椅表面温度降至适宜程度,从而确保最佳座椅舒适性。该功能启用后,车内空气透过打孔皮饰座套被吸入座椅内部,持续时间为 8 分钟。然后,风扇会自动改变旋转方向,将更凉爽的环境空气从座…

TIA博途_封装FB或FC块时将未分配参数的管脚隐藏的具体方法示例

TIA博途_封装FB或FC块时将未分配参数的管脚隐藏的具体方法示例 如下图所示,在某个项目中添加一个模拟量平均值滤波FB块,FB块的输入输出接口如图中所示, FB块编写完成后,在OB1中调用该FB块,可以看到需要配置的相关管脚…

Melon库运用——数组篇

头文件片段 // mln_array.hstruct mln_array_attr {void *pool; // 自定义内存池结构指针array_pool_alloc_handler pool_alloc; // 自定义内存池分配函数指针array_pool_free_handler pool_free; // 自定义内存池释放函数指针array_free …

Linux编译器--gcc/g++的使用

1.gcc/g的作用 gcc/g就是将写好的c/c的代码经过预编译/编译/汇编/链接生成可执行程序的过程,这个过程就是编译器的作用。 PS:由于c支持c语言的语法,gcc和g的操作差不多,在这里只讲gcc的使用方法。 2.gcc如何完成 格式 gcc [选项] 要编译的文…

函数指针数组:更高效的代码实现方式——指针进阶(二)

目录 前言 一、函数指针 什么是函数指针 函数指针的使用 二、函数指针数组 什么是函数指针数组 函数指针数组的使用 三、指向函数指针数组的指针 总结 前言 当谈到C语言的高级特性时,函数指针和函数指针数组通常是最常见的话题之一。虽然这些概念可能会让初…

java面试题(24)

1、重写equals()方法的原则 1、对称性: 如果x.equals(y)返回是“true”,那么y.equals(x)也应该返回是 “true”。 2、自反性: x.equals(x)必须…

【动态规划】第N个泰波那契数

📭从这里开始,我们要开始学习动态规划辣。之后的动态规划有关的文章都是按照这个逻辑来写,首先来介绍一下基本逻辑。 🧀(1)题目解析:就是分析题目,读懂题目想让我们实现的功能 🧀(2)算法原理&…

linux 创建一个线程的基础开销探讨

测试代码 测试方法比较笨,每修改一次线程数,就重新编译一次,再运行。在程序运行过程中,查看到进程 pid,然后通过以下命令查看进程的运行状态信息输出到以线程数为名字的日志文件中,最后用 vimdiff 对比文件…

LVS负载均衡集群之LVS-DR部署

目录 一、lVS-DR集群概述 二、LVS-DR数据包流向分析 四、LVS-DR特性 五、DR模式 LVS负载均衡群集部 5.0配置虚拟 IP 地址(VIP 192.168.14.180) 5.1.配置负载调度器(192.168.14.101) 5.2部署共享存储(NFS服务器:192.168.14.10…

7-3打怪升级(25分)【Floyd、dijkstra】【2021 RoboCom 世界机器人开发者大赛-本科组(初赛)】

考点:Floyd,dijkstra变式(记录路径,多优先级) 7-3 打怪升级 (25分) 很多游戏都有打怪升级的环节,玩家需要打败一系列怪兽去赢取成就和徽章。这里我们考虑一种简单的打怪升级游戏,游戏规则是&am…

数据在计算机中的存储——【C语言】

在前面的博客中,我们已经学习了C语言的数据类型,先让我们回顾一下C语言中有哪些数据类型。 目录 C语言的基本内置类型 类型的基本归类 整型在内存中的存储 原码、反码、补码 存储中的大小端 练习 浮点型在内存中的存储 浮点数的存储规则 对引例问…

【算法与数据结构】20、LeetCode有效的括号

文章目录 一、题目二、解法三、完整代码 所有的LeetCode题解索引,可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、解法 思路分析:括号匹配是使用栈解决的经典问题。做这道题首先要分析什么时候括号不匹配。1、右括号多余 ( { [ ] } )…

动态规划之96 不同的二叉搜索树(第7道)

题目: 给你一个整数 n ,求恰由 n 个节点组成且节点值从 1 到 n 互不相同的 二叉搜索树 有多少种?返回满足题意的二叉搜索树的种数。 示例: 递推关系的推导: n3时,如上图所示。 当1为头结点的时候&#x…

C#学习之路-常量

C# 常量 常量是固定值,程序执行期间不会改变。常量可以是任何基本数据类型,比如整数常量、浮点常量、字符常量或者字符串常量,还有枚举常量。 常量可以被当作常规的变量,只是它们的值在定义后不能被修改。 整数常量 整数常量可…

Mybatis-Plus查询

Mybatis-Plus Mybatis-Plus条件查询的书写方法 1.条件查询 直接new QueryQuery<>创建对象&#xff0c;然后再wrappee.eq(“数据库列表”,“匹配值”)创建条件就可以。 其中&#xff0c;基本查询&#xff1a;eq表示相等&#xff1b;gt表示大于&#xff1b;lt表示小于&a…

[Vue3]学习笔记-provide 与 inject

作用&#xff1a;实现祖与后代组件间通信 套路&#xff1a;父组件有一个 provide 选项来提供数据&#xff0c;后代组件有一个 inject 选项来开始使用这些数据 具体写法&#xff1a; 祖组件中&#xff1a; setup(){......let car reactive({name:奔驰,price:40万})provide(…

Leetcode刷题(Week1)——宽(深)度优先遍历专题

刷题时间&#xff1a; 2019/04/04 – 2019/04/07 主播&#xff1a;yxc(闫雪灿) 视频链接&#xff1a; https://www.bilibili.com/video/av32546525?fromsearch&seid14001345623296049881 题号题目链接127Word Ladderhttps://leetcode.com/problems/word-ladder/131Palind…

Integration Objects OPC 所有产品Crack

OPC产品 OPC UA 升级到 OPC UA 以提高互操作性和安全性。 OPC 隧道 无需 DCOM 即可实现安全可靠的连接。 OPC 数据归档 将 OPC 数据存储到标准数据库或 CSV 文件中。 OPC 服务器 将任何通信协议转换为OPC标准。 OPC 客户端 读取、写入和传输您的 OPC 数据。 OPC 服务器工具…

四十五、时间/空间复杂度分析

算法主要内容 一、时间复杂度分析1、由数据范围反推算法复杂度以及算法内容2、如何分析代码复杂度&#xff08;1&#xff09;看循环&#xff08;2&#xff09;看递归&#xff08;3&#xff09;一些看似为O(n^2)&#xff0c;但实际为O(n)&#xff08;4&#xff09;数据结构&…

HPM6750系列--第五篇 使用Segger Embedded Studio for RISC-V开发环境

一、目的 之前的博文中《HPM6750系列--第四篇 搭建Visual Studio Code开发调试环境》我们介绍了如何使用visual studio code进行开发调试&#xff0c;但是用起来总缺少点感觉&#xff0c;那么有没有更加友好一些的IDE用来开发呢&#xff1f; 本篇主要介绍如何使用Embedded Stud…