DenseNet的基本思想

news2024/11/23 4:06:48

之前的文章介绍过残差网络的基本思想:残差网络的思想就是将网络学习的映射从X到Y转为学习从X到Y-X的差,然后把学习到的残差信息加到原来的输出上即可。即便在某些极端情况下,这个残差为0,那么网络就是一个X到Y的恒等映射。其示意图如下:

ResNet
ResNet

后来,就有学者想到,既然输入一个残差块的X和该残差块的输出 可以相加,那么为什么不能一起作为特征继续向后传递呢?所以,就有了DenseNet的基本思想,其示意图如下:

DenseNet【引用自参考1】
DenseNet【引用自参考1】

也就是说在网络前向传播的过程中,不仅每一层提取的特征图用做后面一层的输入,其自身也会被当做特征图输入到后面的网络中,比如上图中:

  • 经过卷积层后得到了 ,然后 都会被当做 的输入*,以此类推,进行前向传播。

这样子做的好处有:

  • 缓解了梯度消失问题;
  • 增强了特征在网络中的传播能力;
  • 特征重用;
  • 减少参数的数量。

具体的实现,我们来看下Pytroch的源码:

DenseLayer

class _DenseLayer(nn.Module):
    def __init__(
        self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
    ) -> None:
        super().__init__()
        self.norm1: nn.BatchNorm2d
        self.add_module("norm1", nn.BatchNorm2d(num_input_features))
        self.relu1: nn.ReLU
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.conv1: nn.Conv2d
        self.add_module(
            "conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
        )
        self.norm2: nn.BatchNorm2d
        self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
        self.relu2: nn.ReLU
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.conv2: nn.Conv2d
        self.add_module(
            "conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
        )
        self.drop_rate = float(drop_rate)
        self.memory_efficient = memory_efficient

    def bn_function(self, inputs: List[Tensor]) -> Tensor:
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
    def any_requires_grad(self, input: List[Tensor]) -> bool:
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused  # noqa: T484
    def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
        def closure(*inputs):
            return self.bn_function(inputs)

        return cp.checkpoint(closure, *input)

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input: List[Tensor]) -> Tensor:  # noqa: F811
        pass

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
        if isinstance(input, Tensor):
            prev_features = [input]
        else:
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return new_features

DenseNet的主要构成是DenseBlock,而DenseBlock的基本构成就是DenseLayer(上面的源码),上面的代码中有一些是pytorch的高级用法,暂不展开讲(主要是比较菜),其主要的函数就是其中的bn_function。

DenseBlock

有了DenseLayer,我们看下DenseBlock:

class _DenseBlock(nn.ModuleDict):
    _version = 2

    def __init__(
        self,
        num_layers: int,
        num_input_features: int,
        bn_size: int,
        growth_rate: int,
        drop_rate: float,
        memory_efficient: bool = False,
    ) -> None:
        super().__init__()
        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
            self.add_module("denselayer%d" % (i + 1), layer)

    def forward(self, init_features: Tensor) -> Tensor:
        features = [init_features]
        for name, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)

代码其实也比较简单,就是初始化的时候就直接生成相应数量的DenseLayer,然后进行前向传播,整个DenseNet实现的关键点就是这个前向传播函数,注意几点:

  • 其features是一个list;
  • 再通过for循环,不断的对这个list添加元素(append函数);
  • 最后DenseBlock的返回值是将list中的所有元素在维度1(channel维度)进行拼接。

然后就可以用上面的DenseBlock进行组合,搭积木式的构建自己的DenseNet网络了,比如DenseNet121、DenseNet169等等。

Transition

为了减少参数量,DenseNet中还有Transition这个子模块,其代码如下:

class _Transition(nn.Sequential):
    def __init__(self, num_input_features: int, num_output_features: int) -> None:
        super().__init__()
        self.add_module("norm", nn.BatchNorm2d(num_input_features))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))

其作用一般是用来压缩通道数。

总结与思考

  1. DenseNet和ResNet的主要区别?
  2. 为什么DenseLayer中的bn_function的构成是norm》relu》conv的顺序?一般不都是conv》norm》relu?(有点迷,下次做实验)

进阶知识点

  1. add_module
  2. torch.jit

参考

【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】HUANG G, LIU Z, VAN DER MAATEN L, et al. Densely connected convolutional networks[C]//Proceedings of the IEEE conference on computer vision and pattern recognition.2017:4700-4708.
【3】Pytorch官方源码

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/59208.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java基础类型和运算符

文章目录变量与常量变量的命名规则常量final 关键字修饰的常量字面常量基本类型整型基本整型变量 int长整型 long短整型 short比特型 byte浮点数 float和double关于3*0.10.3三种特殊的double字符型 char布尔类型 boolean类型转换隐式类型提升强制类型转换运算符算数运算符基本四…

vue中打印插件vue-print-nb(二)-实例之两种方法——安包之设置一个id和绑定一个对象 下载print.js之ref设置锚点

vue中打印插件vue-print-nb(二)-实例之两种方法——安包之设置一个id和绑定一个对象 & 下载print.js之ref设置锚点 第一种方法 方式1、设置一个id ① 给要打印的部分设置一个 id ② 在打印按钮中添加 v-print"#id名" 1、安装vue-print-nb插件 npm install v…

Firefly RK3399 PC pro Android 10下载验证

一.Android 源码以及image 1.Android 10代码链接: 百度网盘 请输入提取码 密码:1234 下载后检查md5值,检查下载是否正确: fb41fcdc48b1cf90ecac4a5bb8fafc7a Firefly-RK3399_Android10.0_git_20211222.7z.001 82d665fb54fb412…

Flutter ー Authentication 认证

Flutter ー Authentication 认证 原文 https://medium.com/simbu/flutter-authentication-adb8df7cf673 前言 如果我相信我知道你是谁那我就能让你查看你的个人 应用 application 资料。 身份验证可能是应用程序必须处理的最大的交叉问题。 将它作为一个特性添加到 DigestableP…

HashMap JDK1.7与1.8的区别

结构 首先HashMap在1.7中是以数组链表的形式存在的, 而HashMap在1.8中则是以数组链表红黑树构成的, 当一个节点的链表长度超过8并且数组长度超过64时会将链表转换为红黑树, 初始化 初始容量大小介绍 说到数组就不得不提HashMap里面的成员变量DEFAULT_INITIAL_CAPACITY也就是…

Mysql进阶学习(八)DDL语言+数据类型和DTL语言

Mysql进阶学习(八)DDL语言与DTL语言DDL语言1、简介:1.1、库的管理1.1.1、库的创建1.1.2、库的修改1.1.3、库的删除1.2、表的管理1.2.1.表的创建 ★1.2.2.表的修改1.2.3.表的删除1.2.4.表的复制测试案例1. 创建表dept12. 将表departments中的数…

SpringBoot_整合Thymeleaff模板引擎

Thymeleaf模板引擎的主要目标是将优雅的自然模板带到开发工作流程中,并将HTML在浏览器中正确显示,并且可以作为静态原型,让开发团队能更容易地协作。Thymeleaf能够处理HTML,XML,JavaScript,CSS甚至纯文本。…

Qt扫盲-Qt Designer 设计师使用总结

Designer 设计师使用总结一、顶部菜单栏1. 常用的菜单内容2. 快捷工具栏说明二、左侧控件栏1. 组件分类2. 筛选三、中间绘图区1. 左侧控件区拖放控件到中间2. 中间区域布局3. 属性修改四、右侧属性栏1. 对象查看器2. 属性编辑器3.组织结构2. 属性设置五、美化专栏1.单个设置层叠…

微服务框架 SpringCloud微服务架构 12 DockerCompose 12.2 部署微服务集群

微服务框架 【SpringCloudRabbitMQDockerRedis搜索分布式,系统详解springcloud微服务技术栈课程|黑马程序员Java微服务】 SpringCloud微服务架构 文章目录微服务框架SpringCloud微服务架构12 DockerCompose12.2 部署微服务集群12.2.1 直接开干12 DockerCompose 1…

大数据:Hive简介及核心概念

一、简介 Hive 是一个构建在 Hadoop 之上的数据仓库,它可以将结构化的数据文件映射成表,并提供类 SQL 查询功能,用于查询的 SQL 语句会被转化为 MapReduce 作业,然后提交到 Hadoop 上运行。 特点: 简单、容易上手 (…

做短视频不知道靠什么变现,分享三个自我商业定位的方法,适用普通人

如果说你还停留在我也不知道我可以靠什么赚钱这样的一个状态当中。那我给你三个自我商业定位的方法。篇幅较长,点赞收藏慢慢看哦 首先第一个方法,从工作上或者专业的事情上找变现的方法。 那么你们需要了解一个概念叫做知识的诅咒。什么意思呢&#xf…

【论文整理1】On the Continuity of Rotation Representations in Neural Networks

1.前置知识 1.1 Gram-Schmidt正交化 【参考阅读】Gram-Schmidt过程 看完这篇应该基本能理解,但是他对于公式的讲解有一个地方讲解得不是很清楚! 即为什么分母是平方形式呢? 1.2 差集 定义:差集是一种集合运算,记A&#xff0…

Java并发编程—CompletableFuture的介绍和使用

在博主上一篇博客介绍中,Java并发编程—java异步Future的迭代过程_小魏快起床的博客-CSDN博客,这里面给大家分析了Future的使用过程和一些存在的问题,那么针对里面出现的阻塞问题,博主将在这一篇文章给大家介绍清楚 &#x1f34f…

MyBatis框架简介

MyBatis是一个开源的数据持久层框架,内部封装了通过JDBC访问数据库的操作,支持普通的SQL查询、存储过程和高级映射。作为持久层框架,主要思想是将程序中的大量的SQL语句分离出来,配置在相应的配置文件中,这样可以在不修…

Java—数据类型

文章目录数据类型八大基本数据类型Java中有了基本数据类型,为什么还要包装类型String字符串类型函数字符串类的length()方式是否能够得到字符串内有多少个字符?不可变字符串String为什么要设计成不可变的?boolean类型占多少位?为什…

【springboot进阶】使用aop + 注解方式,简单实现spring cache功能

目录 一、实现思路 二、定义缓存注解 三、aop 切面处理 四、使用方式 五、灵活的运用 六、总结 前几天有同学看了 SpringBoot整合RedisTemplate配置多个redis库 这篇文章,提问spring cache 能不能也动态配置多个redis库。介于笔者没怎么接触过,所以…

【Java开发】 Spring 08 :访问 Web 资源( 借助 RestTemplate or WebClient )

web 资源就是运行在服务器上的资源,比如放到 web 下的页面 js 文件、图片、css等,web资源分为静态web资源和动态web资源两类,接下来访问的就是动态资源(页面返回的数据是动态的,由后端程序产生)&#xff0…

Android 使用元数据

Android 使用元数据 前提介绍Metadata 有时候为安全起见,某个参数要给某个活动专用,并不希望其他活动也能获取该参数,也就是要使用第三方SDK时。Activity提供了元数据(Metadata)的概念,元数据是一种描述其…

C++类和对象(二)构造函数、析构函数、拷贝构造函数

目录 1.类的6个默认成员函数 2. 构造函数 2.1 概念 2.2 特性 3.析构函数 3.1 概念 3.2 特性 4. 拷贝构造函数 4.1 概念 4.2 特征 1.类的6个默认成员函数 如果一个类中什么成员都没有,简称为空类。 空类中真的什么都没有吗?并不是,…

【菜菜的sklearn课堂笔记】聚类算法Kmeans-聚类算法的模型评估指标

视频作者:菜菜TsaiTsai 链接:【技术干货】菜菜的机器学习sklearn【全85集】Python进阶_哔哩哔哩_bilibili 可以只看轮廓系数和卡林斯基-哈拉巴斯指数 不同于分类模型和回归,聚类算法的模型评估不是一件简单的事。在分类中,有直接结…