1 卷积神经网络(CNN)简介
在使用PyTorch构建GAN生成对抗网络一文中,我们使用GAN构建了一个可以生成人脸图像的模型。但尽管是较为简单的模型,仍占用了1G左右的GPU内存,因此需要探索更加节约资源的方式。
卷积神经网络(Convolutional Neural Network,简称CNN)是一种深度学习模型,主要应用于图像处理、语音识别等领域。它的主要思想是通过卷积操作对输入图像的特征进行提取,再通过多层网络对特征进行分类和判断。
CNN的网络结构通常由卷积层、池化层和全连接层组成。卷积层的作用是对输入图像的特征进行提取,池化层的作用是减少数据的维度,以提高计算效率;全连接层则用于对特征进行分类和判断。
CNN可以通过训练学习到输入图像的特征表示,从而可以在未知图像上进行分类、识别等任务。它已经成为计算机视觉领域的重要技术,在诸多应用中取得了良好的效果。
2 从普通BP到CNN的网路结构转变
以前面建立好的手写数字分类器为例,(使用PyTorch构建神经网络构建手写数字分类器)在模型结构定义中,需要对神经网络层做出相应的修改:
self.model = nn.Sequential(
# expand 1 to 10 filters
nn.Conv2d(1, 10, kernel_size=5, stride=2),
nn.LeakyReLU(0.02),
nn.BatchNorm2d(10),
# 10 filters to 10 filters
nn.Conv2d(10, 10, kernel_size=3, stride=2),
nn.LeakyReLU(0.02),
nn.BatchNorm2d(10),
View(250),
nn.Linear(250, 10),
nn.Sigmoid()
)
更新后的神经网络架构如下:
- 第一个卷积层:把1个通道的输入图像扩展为10个通道,使用5x5的卷积核,步长为2。
- 第二个卷积层:10个通道的输入图像不变,使用3x3的卷积核,步长为2。
- 第一个全连接层:把250个节点的一维向量映射到10个节点。
其中用到的函数的含义:
4. Conv2d:对由一个或多个输入平面组成的输入信号进行二维卷积。第1个参数是输入参数,对于黑白图像,输入的通道数即为1。第2个参数是输出通道的数量。在上面的代码中,我们创建了10个卷积核,从而生成10个特征图。kernel_size
函数代表了卷积核的大小,使用的是5×5的卷积核。stride
是卷积核移动时的大小。该数值小于卷积核大小时,说明卷积核所覆盖的区域有重叠。
5. LeakyReLU:非线性激活函数,常用于生成对抗网络。
6. BatchNorm2d:批量归一化,用于提高网络的稳定性和收敛速度。
7. View:将多维张量展平为一维向量。(自定义函数,详见完整代码)
8. Sigmoid:S形函数,用于二分类问题的输出。
对于一个28*28像素的图片,第一步卷积之后将会生成一个12*12像素的图片(计算方式:共走了 28 − 5 2 \frac{28-5}{2} 228−5 步)。第二步卷积之后将会生成一个5*5像素的图片(计算方式:共走了 12 − 3 2 \frac{12-3}{2} 212−3 步)。
3 从普通BP到CNN的辅助修改
在网络结构中用到了View函数,在上面的参考博文中并未涉及这部分代码,因此把这给你功能进行补充。(与人脸识别篇代码中的View完全相同)
class View(nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape,
def forward(self, x):
return x.view(*self.shape)
此外,修改后的CNN网络结构,其传入的图片应将其修改为4D数据。因此在模型训练时,将传入的数据进行变形。
start_time = time.perf_counter() # 计时开始
C = Classifier()
epochs = 3
for i in range(epochs):
print('training epoch', i+1, 'of', epochs)
for label, image_data_tensor, target_tensor in mnist_dataset:
C.train(image_data_tensor.view(1, 1, 28, 28), target_tensor)
注:上面两个VIEW并不相同,一个是我们自行定义用于分类器类使用的函数,一个是torch的自带功能。
除此之外代码均可保持不变,这部分的原始代码可在此找到到或文末留言申请。
4 模型评估
在训练初期,可以看到模型的损失呈现迅速下降。下面使用测试集对模型准确率进行评价:
使用一张图片来查看模型的生成。此处我们分别选择了一张数字0和数字6,可以发现与BP模型相比,CNN模型对结果变得更有信心了。