从代码学习深度学习 - 使用块的网络(VGG)PyTorch版

news2025/4/1 1:41:18

文章目录

  • 前言
  • 一、VGG网络简介
    • 1.1 VGG的核心特点
    • 1.2 VGG的典型结构
    • 1.3 优点与局限性
    • 1.4 本文的实现目标
  • 二、搭建VGG网络
    • 2.1 数据准备
    • 2.2 定义VGG块
    • 2.3 构建VGG网络
    • 2.4 辅助工具
      • 2.4.1 计时器和累加器
      • 2.4.2 准确率计算
      • 2.4.3 可视化工具
    • 2.5 训练模型
    • 2.6 运行实验
  • 总结


前言

深度学习是近年来人工智能领域的重要突破,而卷积神经网络(CNN)作为其核心技术之一,在图像分类、目标检测等领域展现了强大的能力。VGG(Visual Geometry Group)网络是CNN中的经典模型之一,以其模块化的“块”设计和深层结构而闻名。本篇博客将通过PyTorch实现一个简化的VGG网络,并结合代码逐步解析其构建、训练和可视化过程,帮助读者从代码层面理解深度学习的基本原理和实践方法。我们将使用Fashion-MNIST数据集进行实验,展示如何从零开始搭建并训练一个VGG模型。

本文的目标读者是对深度学习有基本了解、希望通过代码实践加深理解的初学者或中级开发者。以下是博客的完整内容,包括代码实现和详细说明。


一、VGG网络简介

VGG网络(Visual Geometry Group Network)是由牛津大学视觉几何组在2014年提出的深度卷积神经网络(CNN)模型,因其在ImageNet图像分类竞赛中的优异表现而广为人知。VGG的设计理念是通过堆叠多个小卷积核(通常为3×3)和池化层,构建一个深层网络,从而提取图像中的复杂特征。与之前的模型(如AlexNet)相比,VGG显著增加了网络深度(常见版本包括VGG-16和VGG-19,分别有16层和19层),并采用统一的模块化结构,使其易于理解和实现。

1.1 VGG的核心特点

  1. 小卷积核:VGG使用3×3的小卷积核替代传统的大卷积核(如5×5或7×7)。两个3×3卷积核的堆叠可以达到5×5的感受野,而参数量更少,计算效率更高,同时增加了非线性(通过更多ReLU激活)。
  2. 模块化设计:网络由多个“块”(block)组成,每个块包含若干卷积层和一个最大池化层。这种设计使得网络结构清晰,便于扩展或调整。
  3. 深度增加:VGG通过加深网络层数(从11层到19层不等)提升性能,证明了深度对特征提取的重要性。
  4. 全连接层:在卷积层之后,VGG使用多个全连接层(通常为4096、4096和1000神经元)进行分类,输出对应ImageNet的1000个类别。

1.2 VGG的典型结构

以下是VGG-16的结构示意图,展示了其卷积块和全连接层的组织方式:

在这里插入图片描述

上图中:

  • 绿色方框表示卷积层(3×3卷积核,步幅1,padding=1),对应图中的“convolution+ReLU”部分(以立方体表示)。这些卷积层负责提取图像特征,padding=1确保特征图尺寸在卷积后保持不变。
  • 红色方框表示最大池化层(2×2,步幅2),对应图中的“max pooling”部分(以红色立方体表示)。池化层将特征图尺寸减半(例如从224×224到112×112),同时保留重要特征。
  • 蓝色部分为全连接层,最终输出分类结果,对应图中的“fully connected+ReLU”和“softmax”部分(以蓝色线条表示)。全连接层将卷积特征展平后进行分类,输出对应ImageNet的1000个类别。

VGG-16包含13个卷积层和3个全连接层,总计16层(池化层不计入层数)。每个卷积块的通道数逐渐增加(从64到512),而池化层将特征图尺寸逐步减半(从224×224到7×7)。

1.3 优点与局限性

优点

  • 结构简单,易于实现和理解。
  • 小卷积核和深层设计提高了特征提取能力。
  • 在多种视觉任务中表现出色,可作为预训练模型迁移学习。

局限性

  • 参数量巨大(VGG-16约有1.38亿个参数),训练和推理耗时。
  • 深层网络可能导致梯度消失问题(尽管ReLU和适当初始化缓解了部分问题)。
  • 对内存和计算资源要求较高,不适合资源受限的设备。

1.4 本文的实现目标

在本文中,我们将基于PyTorch实现一个简化的VGG网络,针对Fashion-MNIST数据集(28×28灰度图像,10个类别)进行调整。我们保留VGG的模块化思想,但适当减少层数和参数量,以适应较小规模的数据和计算资源。通过代码实践,读者可以深入理解VGG的设计原理及其在实际任务中的应用。

下一节将进入具体的代码实现部分,逐步搭建VGG网络并完成训练。

二、搭建VGG网络

2.1 数据准备

在开始构建VGG网络之前,我们需要准备训练和测试数据。这里使用Fashion-MNIST数据集,这是一个包含10类服装图像的灰度图像数据集,每个图像大小为28×28像素。以下是数据加载的代码:

import torchvision
import torchvision.transforms as transforms
from torch.utils import data
import multiprocessing

def get_dataloader_workers():
    """使用电脑支持的最大进程数来读取数据"""
    return multiprocessing.cpu_count()

def load_data_fashion_mnist(batch_size, resize=None):
    """
    下载Fashion-MNIST数据集,然后将其加载到内存中。
    
    参数:
        batch_size (int): 每个数据批次的大小。
        resize (int, 可选): 图像的目标尺寸。如果为 None,则不调整大小。
    
    返回:
        tuple: 包含训练 DataLoader 和测试 DataLoader 的元组。
    """
    # 定义变换管道
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    
    # 加载 Fashion-MNIST 训练和测试数据集
    mnist_train = torchvision.datasets.FashionMNIST(
        root="./data",
        train=True,
        transform=trans,
        download=True
    )
    mnist_test = torchvision.datasets.FashionMNIST(
        root="./data",
        train=False,
        transform=trans,
        download=True
    )
    
    # 返回 DataLoader 对象
    return (
        data.DataLoader(
            mnist_train,
            batch_size,
            shuffle=True,
            num_workers=get_dataloader_workers()
        ),
        data.DataLoader(
            mnist_test,
            batch_size,
            shuffle=False,
            num_workers=get_dataloader_workers()
        )
    )

这段代码定义了load_data_fashion_mnist函数,用于加载Fashion-MNIST数据集并将其封装成PyTorch的DataLoader对象。transforms.ToTensor()将图像转换为张量格式,batch_size控制每个批次的数据量,shuffle=True确保训练数据随机打乱以提高模型泛化能力。num_workers通过多进程加速数据加载。

2.2 定义VGG块

VGG网络的核心思想是将网络分解为多个“块”(block),每个块包含若干卷积层和一个池化层。以下是VGG块的实现:

import torch
from torch import nn

def vgg_block(num_convs, in_channels, out_channels):
    layers = []                          # 初始化一个空列表,用于存储网络层
    for _ in range(num_convs):           # 循环 num_convs 次,构建卷积层
        layers.append(nn.Conv2d(         # 添加一个二维卷积层
            in_channels,                 # 输入通道数
            out_channels,                # 输出通道数
            kernel_size=3,               # 卷积核大小为 3x3
            padding=1))                  # 填充大小为 1,保持特征图尺寸
        layers.append(nn.ReLU())         # 添加 ReLU 激活函数
        in_channels = out_channels       # 更新输入通道数为输出通道数,用于下一次卷积
    layers.append(nn.MaxPool2d(          # 添加一个最大池化层
        kernel_size=2,                   # 池化核大小为 2x2
        stride=2))                       # 步幅为 2,缩小特征图尺寸
    return nn.Sequential(*layers)        

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2324875.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows 安装多用户和其它一些问题 VMware Onedrive打不开

以下以win10家庭版为例,win11、专业版类似。 Onedrive相关问题参看我的其他文章: Windows如何同时登录两个OneDrive个人版账号_onedrive登录两个账号-CSDN博客 win10 win11 设置文件权限以解决Onedrive不能同步问题_onedrive没有同步权限-CSDN博客 O…

java基础自用笔记:异常、泛型、集合框架(List、Set、Map)、Stream流

异常 异常体系 编译时异常代表程序觉得你可能会出错。 运行时异常代表已经出错 异常基本处理 异常的作用 可以在可能出现的异常的地方用返回异常来代替return,这样提醒程序出现异常简洁清晰 自定义异常 最好用运行时异常,不会像编译时异常那样烦人&a…

第六届 蓝桥杯 嵌入式 省赛

参考 第六届蓝桥杯嵌入式省赛程序设计题解析(基于HAL库)_蓝桥杯嵌入式第六届真题-CSDN博客 一、分析功能 RTC 定时 1)时间初始化 2)定时上报电压时间 ADC测量 采集电位器的输出电压信号。 串行功能 1)传送要设置…

爱普生FC-135晶振5G手机的极端温度性能守护者

在5G时代,智能手机不仅需要高速率与低延迟,更需在严寒、酷暑、振动等复杂环境中保持稳定运行。作为 5G 手机的核心时钟源,爱普生32.768kHz晶振FC-135凭借其宽温适应性、高精度稳定性与微型化设计,成为5G手机核心时钟源的理想选择&…

如何备份你的 Postman 所有 Collection?

团队合作需要、备份,还是迁移到其他平台,我们都需要在 Postman 中将这些珍贵的集合数据导出。 如何从 Postman 中导出所有集合(Collection)教程

MinGW下编译ffmpeg源码时生成compile_commands.json

在前面的博文MinGW下编译nginx源码中,有介绍到使用compiledb工具在MinGW环境中生成compile_commands.json,以为compiledb是捕获的make时的输出,而nginx生成时控制台是有输出编译时的命令行信息的,笔者之前编译过ffmpeg的源码&…

【数据结构】树与森林

目录 树的存储方法 双亲表示法 孩子表示法 孩子兄弟表示法 树、森林与二叉树的转换 树转换成二叉树 森林转换成二叉树 二叉树转换成森林 树与森林的遍历 树的遍历 森林的遍历 树的存储方法 双亲表示法 这种存储结构采用一组连续空间来存储每个结点,同时…

跟着StatQuest学知识08-RNN与LSTM

一、RNN (一)简介 整个过程权重和偏置共享。 (二)梯度爆炸问题 在这个例子中w2大于1,会出现梯度爆炸问题。 当我们循环的次数越来越多的时候,这个巨大的数字会进入某些梯度,步长就会大幅增加&…

【SpringCloud】Eureka的使用

3. Eureka 3.1 Eureka 介绍 Eureka主要分为两个部分: EurekaServer: 作为注册中心Server端,向微服务应用程序提供服务注册,发现,健康检查等能力。 EurekaClient: 服务提供者,服务启动时,会向 EurekaS…

初识MySQL · 数据类型

目录 前言: 数值类型 文本、二进制数据类型 时间类型 String类型 前言: 对于MySQL来说,是一门编程语言,可能定义不是那么的严格,但是对于MySQL来说也是拥有自己的数据类型的,比如tinyint,…

QT图片轮播器(QT实操学习2)

1.项目架构 1.UI界面 2.widget.h​ #ifndef WIDGET_H #define WIDGET_H#include <QWidget>#define TIMEOUT 1 * 1000 QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : public QWidget {Q_OBJECTpublic:Widget(QWidget *parent n…

深度解析衡石科技HENGSHI SENSE嵌入式分析能力:如何实现3天快速集成

嵌入式分析成为现代SaaS的核心竞争力 在当今SaaS市场竞争中&#xff0c;数据分析能力已成为产品差异化的关键因素。根据Bessemer Venture Partners的最新调研&#xff0c;拥有深度嵌入式分析功能的SaaS产品&#xff0c;其客户留存率比行业平均水平高出23%&#xff0c;ARR增长速…

杂草YOLO系列数据集4000张

一份开源数据集——杂草YOLO数据集&#xff0c;该数据集适用于农业智能化、植物识别等计算机视觉应用场景。 数据集详情 ​训练集&#xff1a;3,664张高清标注图像​测试集&#xff1a;180张多样性场景样本​验证集&#xff1a;359张严格筛选数据 下载链接 杂草YOLO数据集分…

Vue 2 探秘:visible 和 append-to-body 是谁的小秘密?

&#x1f680; Vue 2 探秘&#xff1a;visible 和 append-to-body 是谁的小秘密&#xff1f;&#x1f914; 父组件&#xff1a;identify-list.vue子组件&#xff1a;fake-clue-list.vue 嘿&#xff0c;各位前端探险家&#xff01;&#x1f44b; 今天我们要在 Vue 2 的代码丛林…

机器学习的一百个概念(1)单位归一化

前言 本文隶属于专栏《机器学习的一百个概念》&#xff0c;该专栏为笔者原创&#xff0c;引用请注明来源&#xff0c;不足和错误之处请在评论区帮忙指出&#xff0c;谢谢&#xff01; 本专栏目录结构和参考文献请见[《机器学习的一百个概念》 ima 知识库 知识库广场搜索&…

SpringCould微服务架构之Docker(5)

Docker的基本操作&#xff1a; 镜像相关命令&#xff1a; 1.镜像名称一般分两部分组成&#xff1a;[repository]:[tag]。 2. 在没有指定tag时&#xff0c;默认是latest&#xff0c;代表着最新版本的镜像。 镜像命令的案例&#xff1a; 镜像操作常用的命令&#xff1a; dock…

SpringAI与JBoltAI深度对比:从工具集到企业级AI开发范式的跃迁

一、Java生态下大模型开发的困境与需求 技术公司的能力断层 多数企业缺乏将Java与大模型结合的标准开发范式&#xff0c;停留在碎片化工具使用阶段。 大模型应用需要全生命周期管理能力&#xff0c;而不仅仅是API调用。 工具集的局限性 SpringAI作为工具集的定位&#xff1…

Python中multiprocessing的使用详解

1.实现多进程 代码实现&#xff1a; from multiprocessing import Process import datetime import timedef task01(name):current_timedatetime.datetime.now()start_timecurrent_time.strftime(%Y-%m-%d %H:%M:%S). "{:03d}".format(current_time.microsecond //…

强化学习与神经网络结合(以 DQN 展开)

目录 基于 PyTorch 实现简单 DQN double DQN dueling DQN Noisy DQN&#xff1a;通过噪声层实现探索&#xff0c;替代 ε- 贪心策略 Rainbow_DQN如何计算连续型的Actions 强化学习中&#xff0c;智能体&#xff08;Agent&#xff09;通过与环境交互学习最优策略。当状态空间或动…

飞书电子表格自建应用

背景 coze官方的插件不支持更多的飞书电子表格操作&#xff0c;因为需要自建应用 飞书创建文件夹 创建应用 开发者后台 - 飞书开放平台 添加机器人 添加权限 创建群 添加刚刚创建的机器人到群里 文件夹邀请群 创建好后&#xff0c;就可以拿到id和key 参考教程&#xff1a; 创…