目录
- 残差块
- ResNet模型
- 手写数字识别
残差块
左图残差块实现如下
import tensorflow as tf
from tensorflow.keras import layers,activations
#残差块
class Residul(tf.keras.Model):
def __init__(self,num_channels,use_1x1conv=False,strides=1):
super(Residul,self).__init__() #调用基类init方法
self.conv1=layers.Conv2D(num_channels,kernel_size=3,padding='same',strides=strides)
self.conv2=layers.Conv2D(num_channels,kernel_size=3,strides=strides)
if use_1x1conv:
self.conv3=layers.Conv2D(num_channels,kernel_size=1,strides=strides)
else:
self.conv3=None
self.bn1=layers.BatchNormalization()
self.bn2=layers.BatchNormalization()
def call(self,x):
y=activations.relu(self.bn1(self.conv1(x)))
y=self.bn2(self.conv2(y))
if self.conv3:
x=self.conv3(x)
outputs=activations.relu(y+x)
return outputs
ResNet模型
#残差模块
class ResnetBlock(tf.keras.layers.Layer):
def __init__(self,num_channels,num_res,first_block=False):
super(ResnetBlock,self).__init__()
#存储残差块
self.listLayers=[]
#遍历残差数目生成模块
for i in range(num_res):
#如果是第一个残差块而且不是第一个模块时
if i ==0 and not first_block:
self.listLayers.append(Residul(num_channels,use_1x1conv=True,strides=2))
else:
self.listLayers.append(Residul(num_channels))
def call(self,x):
for layers in self.listLayers:
x=layers(x)
return x
#构建resNet网络
class ResNet(tf.keras.Model):
def __init__(self,num_blocks):
super(ResNet,self).__init__()
self.conv=layers.Conv2D(64,kernel_size=7,strides=2,padding='same')
self.bn=layers.BatchNormalization()
self.relu=layers.Activation('relu')
self.mp=layers.MaxPool2D(pool_size=3,strides=2,padding='same')
#残差模块
self.res_block1=ResnetBlock(64,num_blocks[0],first_block=True)
self.res_block2=ResnetBlock(128,num_blocks[1])
self.res_block3=ResnetBlock(256,num_blocks[2])
self.res_block4=ResnetBlock(512,num_blocks[3])
#GAP
self.gap=layers.GlobalAveragePooling2D() #全局平均池化
#全连接层
self.fc=layers.Dense(units=10,activation=tf.keras.activations.softmax)
def call(self,x):
#输入部分的传输过程
x=self.conv(x)
x=self.bn(x)
x=self.relu(x)
x=self.mp(x)
#block
x=self.res_block1(x)
x=self.res_block2(x)
x=self.res_block3(x)
x=self.res_block4(x)
#输出部分的传输
x=self.gap(x)
x=self.fc(x)
return x
mynet=ResNet([2,2,2,2])
X=tf.random.uniform((1,224,224,1))
y=mynet(X)
mynet.summary()
手写数字识别
和Alex Net一样