DenseNet算法实战
文章目录
DenseNet算法实战
前言 一、设计理念 二、网络结构 1.DenseNet网络结构 2. DenseBlock + Transition结构 3. DenseBlock 非线性结构
三、代码实现 1. 导入相关的包 2. DenseBlock 内部结构 3. DenseBlock 模块 4. Transition 层 5. 最后实现DenseNet网络
前言
主要介绍DenseNet模型,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接,通过特征图在channel上的连接来实现特征重用 使用pytorch框架进行代码编写, 对应的tensorflow代码正在写中…
一、设计理念
相比ResNet, DenseNet提出了一个更为激进的密集连接机制: 即互相连接所有的层,具体来说就是每个层都会接受其前面所有层作为其额外的输入。 下图为ResNet网络的短路连接机制 下图为DenseNet网络的短路连接机制 而对于DenseNet,则是通过跨通道concat的形式来连接,会连接前面的所有层作为输入,这里要注意所有的层的输入都来源于前面所有层在channel维度concat,
二、网络结构
1.DenseNet网络结构
2. DenseBlock + Transition结构
DenseNet网络中使用DenseBlock + Transition的结构,其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。而Transition层是两个相邻的DenseBlock,并且通过pooling使特征图大小降低。 下图为DenseBlock + Transition结构
3. DenseBlock 非线性结构
在DenseBlock中, 各个层的特征图大小一致, 可以在channel维度上连接,DenseBlock基本结构是BN + ReLU +(33)Conv的结构,如下图所示 由于后面层的输入会非常大, DenseBlock内部可以采用bottleneck层来减少计算量, 主要是原有的结构增加1 1的Conv, 即BN + ReLU + 11Conv + BN + ReLU +3 3Conv, 称为DenseBlock 结构
三、代码实现
1. 导入相关的包
import torch
import torch. nn as nn
import torchvision. transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings
from torchsummary import summary
import torch. nn. functional as F
from collections import OrderedDict
2. DenseBlock 内部结构
class _DenseLayer ( nn. Sequential) :
def __init__ ( self, num_input_features, growth_rate, bn_size, drop_rate) :
super ( _DenseLayer, self) . __init__( )
self. add_module( "norm1" , nn. BatchNorm2d( num_input_features) )
self. add_module( "relu1" , nn. ReLU( inplace= True ) )
self. add_module( "conv1" , nn. Conv2d( num_input_features, bn_size* growth_rate,
kernel_size= 1 , stride= 1 , bias= False ) )
self. add_module( "norm2" , nn. BatchNorm2d( bn_size* growth_rate) )
self. add_module( "relu2" , nn. ReLU( inplace= True ) )
self. add_module( "conv2" , nn. Conv2d( bn_size* growth_rate, growth_rate,
kernel_size= 3 , stride= 1 , padding= 1 , bias= False ) )
self. drop_rate = drop_rate
def forward ( self, x) :
new_feartures = super ( _DenseLayer, self) . forward( x)
if self. drop_rate > 0 :
new_feartures = F. dropout( new_feartures, p= self. drop_rate, training= self. training)
return torch. cat( [ x, new_feartures] , 1 )
3. DenseBlock 模块
class _DenseBlock ( nn. Sequential) :
def __init__ ( self, num_layers, num_input_features, bn_size, growth_rate, drop_rate) :
super ( _DenseBlock, self) . __init__( )
for i in range ( num_layers) :
layer = _DenseLayer( num_input_features + i* growth_rate, growth_rate,
bn_size, drop_rate)
self. add_module( "denselayer%d" % ( i+ 1 ) , layer)
4. Transition 层
class _Transition ( nn. Sequential) :
def __init__ ( self, num_input_feature, num_output_features) :
super ( _Transition, self) . __init__( )
self. add_module( "norm" , nn. BatchNorm2d( num_input_feature) )
self. add_module( "relu" , nn. ReLU( inplace= True ) )
self. add_module( "conv" , nn. Conv2d( num_input_feature, num_output_features,
kernel_size= 1 , stride= 1 , bias= False ) )
self. add_module( "pool" , nn. AvgPool2d( 2 , stride= 2 ) )
5. 最后实现DenseNet网络
class DenseNet ( nn. Module) :
def __init__ ( self, growth_rate= 32 , block_config= ( 6 , 12 , 24 , 16 ) , num_init_features= 64 ,
bn_size= 4 , compression_rate= 0.5 , drop_rate= 0 , num_classes= 1000 ) :
super ( DenseNet, self) . __init__( )
self. features = nn. Sequential( OrderedDict( [
( "conv0" , nn. Conv2d( 3 , num_init_features, 7 , 2 , 3 , bias= False ) ) ,
( "norm0" , nn. BatchNorm2d( num_init_features) ) ,
( "relu0" , nn. ReLU( inplace= True ) ) ,
( "pool" , nn. MaxPool2d( 3 , stride= 2 , padding= 1 ) )
] ) )
num_features = num_init_features
for i, num_layers in enumerate ( block_config) :
block = _DenseBlock( num_layers, num_features, bn_size, growth_rate, drop_rate)
self. features. add_module( "denseblock%d" % ( i+ 1 ) , block)
num_features += num_layers* growth_rate
if i != len ( block_config) - 1 :
transition = _Transition( num_features, int ( num_features* compression_rate) )
self. features. add_module( ( "transition%d" % ( i+ 1 ) , transition) )
num_features = int ( num_features * compression_rate)
self. features. add_module( "norm5" , nn. BatchNorm2d( num_features) )
self. features. add_module( "relu5" , nn. ReLU( inplace= True ) )
self. classifier = nn. Linear( num_features, num_classes)
for m in self. modules( ) :
if isinstance ( m, nn. Conv2d) :
nn. init. kaiming_normal( m. weight)
elif isinstance ( m, nn. BatchNorm2d) :
nn. init. constant_( m. bias, 0 )
nn. init. constant_( m. weight, 1 )
elif isinstance ( m, nn. Linear) :
nn. init. constant_( m. bias, 0 )
def forward ( self, x) :
features = self. features( x)
out = F. avg_pool2d( features, 7 , stride= 1 ) . view( features. size( 0 ) , - 1 )
out = self. classifier( out)
return out