文章目录
- 前言
- 一、VGG网络简介
-
- 1.1 VGG的核心特点
- 1.2 VGG的典型结构
- 1.3 优点与局限性
- 1.4 本文的实现目标
- 二、搭建VGG网络
-
- 2.1 数据准备
- 2.2 定义VGG块
- 2.3 构建VGG网络
- 2.4 辅助工具
-
- 2.4.1 计时器和累加器
- 2.4.2 准确率计算
- 2.4.3 可视化工具
- 2.5 训练模型
- 2.6 运行实验
- 总结
前言
深度学习是近年来人工智能领域的重要突破,而卷积神经网络(CNN)作为其核心技术之一,在图像分类、目标检测等领域展现了强大的能力。VGG(Visual Geometry Group)网络是CNN中的经典模型之一,以其模块化的“块”设计和深层结构而闻名。本篇博客将通过PyTorch实现一个简化的VGG网络,并结合代码逐步解析其构建、训练和可视化过程,帮助读者从代码层面理解深度学习的基本原理和实践方法。我们将使用Fashion-MNIST数据集进行实验,展示如何从零开始搭建并训练一个VGG模型。
本文的目标读者是对深度学习有基本了解、希望通过代码实践加深理解的初学者或中级开发者。以下是博客的完整内容,包括代码实现和详细说明。
一、VGG网络简介
VGG网络(Visual Geometry Group Network)是由牛津大学视觉几何组在2014年提出的深度卷积神经网络(CNN)模型,因其在ImageNet图像分类竞赛中的优异表现而广为人知。VGG的设计理念是通过堆叠多个小卷积核(通常为3×3)和池化层,构建一个深层网络,从而提取图像中的复杂特征。与之前的模型(如AlexNet)相比,VGG显著增加了网络深度(常见版本包括VGG-16和VGG-19,分别有16层和19层),并采用统一的模块化结构,使其易于理解和实现。
1.1 VGG的核心特点
- 小卷积核:VGG使用3×3的小卷积核替代传统的大卷积核(如5×5或7×7)。两个3×3卷积核的堆叠可以达到5×5的感受野,而参数量更少,计算效率更高,同时增加了非线性(通过更多ReLU激活)。
- 模块化设计:网络由多个“块”(block)组成,每个块包含若干卷积层和一个最大池化层。这种设计使得网络结构清晰,便于扩展或调整。
- 深度增加:VGG通过加深网络层数(从11层到19层不等)提升性能,证明了深度对特征提取的重要性。
- 全连接层:在卷积层之后,VGG使用多个全连接层(通常为4096、4096和1000神经元)进行分类,输出对应ImageNet的1000个类别。
1.2 VGG的典型结构
以下是VGG-16的结构示意图,展示了其卷积块和全连接层的组织方式:
上图中:
- 绿色方框表示卷积层(3×3卷积核,步幅1,padding=1),对应图中的“convolution+ReLU”部分(以立方体表示)。这些卷积层负责提取图像特征,padding=1确保特征图尺寸在卷积后保持不变。
- 红色方框表示最大池化层(2×2,步幅2),对应图中的“max pooling”部分(以红色立方体表示)。池化层将特征图尺寸减半(例如从224×224到112×112),同时保留重要特征。
- 蓝色部分为全连接层,最终输出分类结果,对应图中的“fully connected+ReLU”和“softmax”部分(以蓝色线条表示)。全连接层将卷积特征展平后进行分类,输出对应ImageNet的1000个类别。
VGG-16包含13个卷积层和3个全连接层,总计16层(池化层不计入层数)。每个卷积块的通道数逐渐增加(从64到512),而池化层将特征图尺寸逐步减半(从224×224到7×7)。
1.3 优点与局限性
优点:
- 结构简单,易于实现和理解。
- 小卷积核和深层设计提高了特征提取能力。
- 在多种视觉任务中表现出色,可作为预训练模型迁移学习。
局限性:
- 参数量巨大(VGG-16约有1.38亿个参数),训练和推理耗时。
- 深层网络可能导致梯度消失问题(尽管ReLU和适当初始化缓解了部分问题)。
- 对内存和计算资源要求较高,不适合资源受限的设备。
1.4 本文的实现目标
在本文中,我们将基于PyTorch实现一个简化的VGG网络,针对Fashion-MNIST数据集(28×28灰度图像,10个类别)进行调整。我们保留VGG的模块化思想,但适当减少层数和参数量,以适应较小规模的数据和计算资源。通过代码实践,读者可以深入理解VGG的设计原理及其在实际任务中的应用。
下一节将进入具体的代码实现部分,逐步搭建VGG网络并完成训练。
二、搭建VGG网络
2.1 数据准备
在开始构建VGG网络之前,我们需要准备训练和测试数据。这里使用Fashion-MNIST数据集,这是一个包含10类服装图像的灰度图像数据集,每个图像大小为28×28像素。以下是数据加载的代码:
import torchvision
import torchvision.transforms as transforms
from torch.utils import data
import multiprocessing
def get_dataloader_workers():
"""使用电脑支持的最大进程数来读取数据"""
return multiprocessing.cpu_count()
def load_data_fashion_mnist(batch_size, resize=None):
"""
下载Fashion-MNIST数据集,然后将其加载到内存中。
参数:
batch_size (int): 每个数据批次的大小。
resize (int, 可选): 图像的目标尺寸。如果为 None,则不调整大小。
返回:
tuple: 包含训练 DataLoader 和测试 DataLoader 的元组。
"""
# 定义变换管道
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
# 加载 Fashion-MNIST 训练和测试数据集
mnist_train = torchvision.datasets.FashionMNIST(
root="./data",
train=True,
transform=trans,
download=True
)
mnist_test = torchvision.datasets.FashionMNIST(
root="./data",
train=False,
transform=trans,
download=True
)
# 返回 DataLoader 对象
return (
data.DataLoader(
mnist_train,
batch_size,
shuffle=True,
num_workers=get_dataloader_workers()
),
data.DataLoader(
mnist_test,
batch_size,
shuffle=False,
num_workers=get_dataloader_workers()
)
)
这段代码定义了load_data_fashion_mnist
函数,用于加载Fashion-MNIST数据集并将其封装成PyTorch的DataLoader
对象。transforms.ToTensor()
将图像转换为张量格式,batch_size
控制每个批次的数据量,shuffle=True
确保训练数据随机打乱以提高模型泛化能力。num_workers
通过多进程加速数据加载。
2.2 定义VGG块
VGG网络的核心思想是将网络分解为多个“块”(block),每个块包含若干卷积层和一个池化层。以下是VGG块的实现:
import torch
from torch import nn
def vgg_block(num_convs, in_channels, out_channels):
layers = [] # 初始化一个空列表,用于存储网络层
for _ in range(num_convs): # 循环 num_convs 次,构建卷积层
layers.append(nn.Conv2d( # 添加一个二维卷积层
in_channels, # 输入通道数
out_channels, # 输出通道数
kernel_size=3, # 卷积核大小为 3x3
padding=1)) # 填充大小为 1,保持特征图尺寸
layers.append(nn.ReLU()) # 添加 ReLU 激活函数
in_channels = out_channels # 更新输入通道数为输出通道数,用于下一次卷积
layers.append(nn.MaxPool2d( # 添加一个最大池化层
kernel_size=2, # 池化核大小为 2x2
stride=2)) # 步幅为 2,缩小特征图尺寸
return nn.Sequential(*layers)