专栏介绍
- ✨本文收录于【深度学习】:《PyTorch入门到项目实战》专栏,此专栏主要记录如何使用
PyTorch
实现深度学习算法及其项目实战,目前pytorch基础计算已经更新完,正在更新CNN,接下来会陆续更新RNN、CV、NLP、搜推广项目实战,尽量坚持每周持续更新,欢迎大家订阅!- 🌸个人主页:JOJO数据科学
- 📝个人介绍:某985统计硕士在读
- 💌如果文章对你有帮助,欢迎✌
关注
、👍点赞
、✌收藏
、👍订阅
专栏- 参考资料:《动手学深度学习》
文章目录
- 一、引言
- 二、1×1卷积网络
- 1️⃣介绍
- 2️⃣计算逻辑
- 3️⃣主要作用
- 二、NiN架构
- 1️⃣NiN块
- 2️⃣全局平均池化层
- 三、 Pytorch代码实现
- 1️⃣定义NiN块
- 2️⃣定义NiN网络
- 3️⃣加载数据集
- 4️⃣初始化模型
- 5️⃣模型训练与评估
- 四、总结
一、引言
我们之前介绍了LeNet,AlexNet,VGG。在我们用卷积层提取特征后,全连接层的参数如下:
可以看出,全连接层的参数很大,很占内存。因此,如果可以不使用全连接层,或者说减少全连接层的个数,可以减少参数,减少过拟合。下面我们来讨论这一章要介绍的内容NiN
二、1×1卷积网络
1️⃣介绍
在架构内容设计方面,其中一个比较有帮助的想法是使用1×1卷积。如下所示
也许你会好奇,1×1的卷积能做什么呢?不就是乘以数字么?似乎没有什么用,我们来具体看看它如何工作的。假设一个1×1卷积,这里是数字2,输入一张6×6×1的图片,然后对它做卷积,卷积层大小为1×1×1,结果相当于把这个图片乘以2,所以前三个单元格分别是2、4、6等等。
用1×1的过滤器进行卷积,似乎用处不大,只是对输入矩阵乘以某个数字。但这仅仅是对于6×6×1的一个通道的图片来说,1×1卷积效果不佳
如果是一张6×6×32的图片,那么使用1×1过滤器进行卷积效果更好。具体来说,1×1卷积所实现的功能是遍历这36个单元格,计算左图中32个数字和过滤器中32个数字的元素积之和,然后应用ReLU非线性函数。
2️⃣计算逻辑
上述1×1×32过滤器中的32可以这样理解,一个神经元的输入是32个数字(输入图片中32个通道中的数字),即相同高度和宽度上某一切片上的32个数字,这32个数字具有不同通道,乘以32个权重
(将过滤器中的32个数理解为权重)。所以1×1卷积可以从根本上理解为对这32个不同的位置都应用一个全连接层。和传统的CNN在卷积层之后接全连接层相比,全连接层会将特征图展平为一个向量,并进行线性变换。而用1×1卷积核替代全连接层,将空间上的每个像素点作为一组特征进行卷积操作,从而保留了空间结构信息,避免展平为向量,提高了网络的表达能力。
此外,当有多个卷积层时,我们可以更改输出通道数,如下图所示。
输入图片为4×4×3
- 第一个1×1卷积是增加通道数(通道从3→6)
原始图像 (4×4×3) → Conv 1 (6个1×1 ×3 kernel) → Conv1 输出图像 (4×4×6) - 第二个1×1卷积是减少通道数(通道从6→2)。
Conv1输出图片(4×4×6)→Conv 2(2个1×1×6 kernel)→Conv2输出图片(4×4×2)
3️⃣主要作用
- 通道数调整:通过1×1卷积,将卷积核的通道数设置为所需的输出通道数,就可以实现通道数的调整。这样就能够控制特征图的维度,使其适应后续层的输入要求。
- 特征融合:1×1卷积通过调整卷积核的通道数,将不同通道的特征图相加,从而实现特征的融合。
- 非线性映射:尽管1×1卷积没有类似3×3或5×5卷积核的局部感知视野,但它仍然引入了非线性映射。由于卷积操作中存在激活函数,1×1卷积能够对特征图进行非线性变换,并增强网络的表达能力。
下面我们来看一下NiN架构
二、NiN架构
NiN(Network in Network)
由Min Lin等人在2013年提出。它的设计目标是通过引入多层感知机结构(MLPConv)来提高卷积神经网络(CNN)的表达能力。
NiN框架的核心思想是在卷积层内嵌套一个小型MLP网络,用于增强特征表达能力。与传统的CNN不同,NiN框架在每个卷积层中使用1×1的卷积核,这样可以引入非线性变换和参数共享,从而提高特征的非线性表示能力。具体而言,NiN框架包含了以下几个关键组件:
1️⃣NiN块
一个NiN块由1个卷积层和2个1×1卷积层构成。其中,第一个卷积层负责提取空间特征,第2个1×1卷积层将通道数降低,第3个1×1卷积层则将通道数增加。这样的设计可以增加网络的非线性表示能力,并且通过1×1卷积层调整通道数可以灵活控制特征图的维度。
2️⃣全局平均池化层
在NiN网络的最后,通过全局平均池化层将特征图的空间维度降为1×1,得到一个通道数等于类别数的特征图。然后,通过Softmax
函数进行分类。
主要结构如下
下面我们用Pytorch来实现基于NiN架构对Fashion-MNIST数据集识别
三、 Pytorch代码实现
1️⃣定义NiN块
这里和原始的nin块有两个1×1卷积不同,我这里只使用了1个1×1卷积,因为数据集比较小,所以使用1个1×1卷积层效果更好,并且也大大节省了训练时间。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
def nin_block(in_channels, out_channels, kernel_size, strides, padding):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())
2️⃣定义NiN网络
net = nn.Sequential(
nin_block(1, 96, kernel_size=11, strides=4, padding=0),
nn.MaxPool2d(3, stride=2),
nin_block(96, 256, kernel_size=5, strides=1, padding=2),
nn.MaxPool2d(3, stride=2),
nin_block(256, 384, kernel_size=3, strides=1, padding=1),
nn.MaxPool2d(3, stride=2),
nn.Dropout(0.5),
# 标签类别数是10
nin_block(384, 10, kernel_size=3, strides=1, padding=1),
nn.AdaptiveAvgPool2d((1, 1)),
# 将四维的输出转成二维的输出,其形状为(批量大小,10)
nn.Flatten())
3️⃣加载数据集
# 加载Fashion-MNIST数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224,224)),
transforms.Normalize((0.5,), (0.5,))
])
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
4️⃣初始化模型
# Xavier初始化:
def init_weights(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d: #对全连接层和卷积层初始化
nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
# 检查是否有可用的GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = net.to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
5️⃣模型训练与评估
# 训练模型
num_epochs = 10
train_losses = []
test_losses = []
for epoch in range(num_epochs):
train_loss = 0.0
test_loss = 0.0
correct = 0
total = 0
# 训练模型
model.train()
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
# 测试模型
model.eval()
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
avg_train_loss = train_loss / len(trainloader)
avg_test_loss = test_loss / len(testloader)
train_losses.append(avg_train_loss)
test_losses.append(avg_test_loss)
print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}, Acc: {correct/total*100:.2f}%")
# 绘制测试误差和训练误差曲线
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Testing Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
Epoch [1/10], Train Loss: 2.2827, Test Loss: 2.0189, Acc: 33.98%
Epoch [2/10], Train Loss: 1.7984, Test Loss: 1.2083, Acc: 59.43%
Epoch [3/10], Train Loss: 1.0804, Test Loss: 0.9443, Acc: 65.87%
Epoch [4/10], Train Loss: 1.0075, Test Loss: 0.8990, Acc: 67.74%
Epoch [5/10], Train Loss: 0.8120, Test Loss: 0.8054, Acc: 69.70%
Epoch [6/10], Train Loss: 0.7379, Test Loss: 0.7040, Acc: 73.27%
Epoch [7/10], Train Loss: 0.4918, Test Loss: 0.5636, Acc: 79.59%
Epoch [8/10], Train Loss: 0.4344, Test Loss: 0.4079, Acc: 84.71%
Epoch [9/10], Train Loss: 0.4012, Test Loss: 0.3962, Acc: 85.51%
Epoch [10/10], Train Loss: 0.3833, Test Loss: 0.3757, Acc: 85.74%
从结果来看,和AlexNet相比,精确度还要低一些,可能是我们的数据集太小,把batch_size调大一点可能效果会好一些。
四、总结
NiN
框架的主要优点是:
- 提高了表达能力:引入了MLP结构,增强了网络的非线性表示能力,有助于更好地捕捉复杂的特征。
- 减少参数:使用1×1卷积核和全局平均池化层,减少了网络中的参数数量,降低了过拟合的风险。
- 提高计算效率:由于减少了参数数量,NiN框架相对于传统的CNN具有更高的计算效率。
🔎总的来说,NiN框架在许多计算机视觉任务中取得了很好的性能,成为CNN架构设计中的重要思路之一,后续我们要介绍的GoogleNet借用了这种思想。
本章的介绍到此介绍,如果文章对你有帮助,请多多点赞、收藏、评论、订阅支持!!【深度学习】:《PyTorch入门到项目实战》