深度学习之经典网络-AlexNet详解

news2024/11/25 22:33:20

        AlexNet 是一种经典的卷积神经网络(CNN)架构,在 2012 年的 ImageNet 大规模视觉识别挑战赛(ILSVRC)中表现优异,将 CNN 引入深度学习的新时代。AlexNet 的设计在多方面改进了卷积神经网络的架构,使其能够在大型数据集上有效训练。以下是 AlexNet 的详解:

1. AlexNet 架构概述

        AlexNet 有 8 层权重层,包括 5 层卷积层和 3 层全连接层(FC 层),并引入了一些重要的创新,包括激活函数、Dropout 正则化和重叠池化。它通过增加网络的深度和宽度,结合 GPU 加速,极大提升了 CNN 的能力。

2. AlexNet 架构细节

        (1)输入层

  • 输入图像的尺寸为 227x227x3(RGB 3 通道图像)。
  • AlexNet 采用的是 ImageNet 数据集,其图像分辨率较高,因此需要更大的卷积核和池化核。 

        (2)卷积层(Conv Layers)

  • 第一层卷积层(Conv1):卷积核大小为 11x11,步长为 4,使用 96 个滤波器。输出的特征图尺寸为 55x55x96。经过 ReLU 激活函数处理。
  • 第二层卷积层(Conv2):卷积核大小为 5x5,步长为 1,使用 256 个滤波器。由于输入图像较大,为减小计算量,每次滑动 1 像素,并采用了最大池化。输出的特征图尺寸为 27x27x256。
  • 第三、四、五层卷积层(Conv3、Conv4、Conv5):分别采用 3x3 的卷积核,步长为 1,滤波器数分别为 384、384 和 256。

        (3)激活函数(ReLU)

  • AlexNet 是第一个在每一层卷积层之后使用 ReLU(Rectified Linear Unit)激活函数的网络。与 sigmoid 激活函数不同,ReLU 不会出现梯度消失问题,且能加快训练速度。

        (4)池化层(Pooling Layers)

  • 使用最大池化(Max Pooling),窗口大小为 3x3,步长为 2。
  • AlexNet 引入了“重叠池化”,即池化窗口的步长小于窗口的大小(3x3 池化窗口和 2 步长),使得池化层能够更好地提取空间信息。

        (5)全连接层(Fully Connected Layers)

  • AlexNet 的最后 3 层是全连接层。
  • FC6 层:输入是前一层展平后的特征图,输出为 4096 个节点。
  • FC7 层:与 FC6 类似,输出也为 4096 个节点。
  • FC8 层:为最终的输出层,节点数等于类别数(在 ImageNet 数据集中为 1000),通过 softmax 得到每个类别的概率。

        (6)Dropout 正则化

  • 在全连接层中,AlexNet 引入了 Dropout 正则化,将随机的神经元设为 0,以减少过拟合。Dropout 率为 0.5,即每个神经元有 50% 的概率不参与计算。

        (7)局部响应归一化(Local Response Normalization, LRN)

  • LRN 是一种正则化技术,通过对某一层激活值进行归一化操作,增加了模型的泛化能力。虽然 LRN 不再是现代 CNN 的标准,但在 AlexNet 中有效防止了某些神经元的权值变得过大。

3. AlexNet 的创新点

AlexNet 的创新之处主要体现在以下几点:

  • ReLU 激活函数的应用

    通过使用 ReLU,AlexNet 成功避免了 sigmoid 和 tanh 激活函数可能导致的梯度消失问题,从而加速了训练过程。
  • 重叠池化

    重叠池化减小了过拟合风险,使得网络能更好地进行特征提取和层次化表示。
  • Dropout 正则化

    Dropout 的引入在当时是一个非常重要的创新,它通过让神经元随机失活来防止过拟合。
  • 多 GPU 训练

    AlexNet 在 GPU 上进行了分布式训练,将不同的卷积层分配到两个 GPU 上,从而加速了计算。
  • 数据增强

    AlexNet 使用数据增强(如随机剪裁、镜像翻转和颜色扰动),进一步增加了训练数据的多样性,减少了过拟合风险。

模型特性

  • 所有卷积层都使用ReLU作为非线性映射函数,使模型收敛速度更快
  • 在多个GPU上进行模型的训练,不但可以提高模型的训练速度,还能提升数据的使用规模
  • 使用LRN对局部的特征进行归一化,结果作为ReLU激活函数的输入能有效降低错误率
  • 重叠最大池化(overlapping max pooling),即池化范围z与步长s存在关系z>s,避免平均池化(average pooling)的平均效应
  • 使用随机丢弃技术(dropout)选择性地忽略训练中的单个神经元,避免模型的过拟合

4. AlexNet 的优势和局限性

  • 优势
    • AlexNet 通过加深网络层数和增加神经元数量,提高了模型的表现力。
    • 使用 GPU 进行加速计算,使得大规模数据集上的训练成为可能。
    • Dropout、重叠池化和数据增强等技术有效地降低了过拟合风险。
  • 局限性
    • AlexNet 参数数量较多,导致计算资源需求较大。
    • 在深度增加的同时,过大的全连接层会导致大量参数和计算。
    • LRN 归一化的效果有限,现代模型往往使用批归一化(Batch Normalization)来取代。

5. AlexNet 的影响

         VGGNet、GoogLeNet 和 ResNet 等网络都在 AlexNet 的基础上进行了改进和扩展。

6.代码示例

        PyTorch 中的 AlexNet 实现:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义 AlexNet 模型结构
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):  # 默认输出1000类,可根据任务调整
        super(AlexNet, self).__init__()

        # 特征提取部分,包括卷积和池化层
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),  # 第1个卷积层,输出96个特征图
            nn.ReLU(inplace=True),  # 激活函数,ReLU
            nn.MaxPool2d(kernel_size=3, stride=2),  # 第1个最大池化层

            nn.Conv2d(96, 256, kernel_size=5, padding=2),  # 第2个卷积层
            nn.ReLU(inplace=True),  # 激活函数
            nn.MaxPool2d(kernel_size=3, stride=2),  # 第2个最大池化层

            nn.Conv2d(256, 384, kernel_size=3, padding=1),  # 第3个卷积层
            nn.ReLU(inplace=True),  # 激活函数

            nn.Conv2d(384, 384, kernel_size=3, padding=1),  # 第4个卷积层
            nn.ReLU(inplace=True),  # 激活函数

            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # 第5个卷积层
            nn.ReLU(inplace=True),  # 激活函数
            nn.MaxPool2d(kernel_size=3, stride=2)  # 第3个最大池化层
        )

        # 分类部分,包含全连接层和 Dropout 层
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),  # Dropout层,防止过拟合
            nn.Linear(256 * 6 * 6, 4096),  # 全连接层,输入尺寸为 256*6*6,输出4096
            nn.ReLU(inplace=True),  # 激活函数

            nn.Dropout(0.5),  # 第二个Dropout层
            nn.Linear(4096, 4096),  # 第二个全连接层
            nn.ReLU(inplace=True),  # 激活函数

            nn.Linear(4096, num_classes)  # 最后一个全连接层,输出类别数
        )

    # 定义前向传播过程
    def forward(self, x):
        x = self.features(x)  # 经过特征提取层
        x = x.view(x.size(0), 256 * 6 * 6)  # 展平特征图用于输入全连接层
        x = self.classifier(x)  # 经过分类层
        return x

# 数据预处理,定义图像转换操作
transform = transforms.Compose([
    transforms.Resize(256),  # 调整图像大小到256
    transforms.CenterCrop(227),  # 中心裁剪为227x227大小,符合AlexNet输入要求
    transforms.ToTensor(),  # 转换为Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
])

# 加载 CIFAR10 数据集,训练集和测试集分别创建 DataLoader
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  # 设置批次大小和打乱数据

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 初始化模型、损失函数和优化器
model = AlexNet(num_classes=10)  # CIFAR10 任务设置10个输出类别
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 检查是否有 GPU 加速
model.to(device)  # 将模型移动到设备上

criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,用于分类任务
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # 随机梯度下降优化器

# 训练模型
def train(model, device, train_loader, criterion, optimizer, epoch):
    model.train()  # 设置模型为训练模式
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)  # 将数据移动到设备上
        optimizer.zero_grad()  # 清空梯度
        output = model(data)  # 前向传播
        loss = criterion(output, target)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新模型参数
        if batch_idx % 100 == 0:  # 每100个批次打印一次训练状态
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}]\tLoss: {loss.item():.6f}')

# 测试模型
def test(model, device, test_loader, criterion):
    model.eval()  # 设置模型为评估模式
    test_loss = 0  # 初始化测试损失
    correct = 0  # 初始化正确分类的数量
    with torch.no_grad():  # 禁用梯度计算
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)  # 前向传播
            test_loss += criterion(output, target).item()  # 累积测试损失
            pred = output.argmax(dim=1, keepdim=True)  # 获取预测的最大概率类别
            correct += pred.eq(target.view_as(pred)).sum().item()  # 统计正确分类的数量

    test_loss /= len(test_loader.dataset)  # 计算平均损失
    accuracy = 100. * correct / len(test_loader.dataset)  # 计算准确率
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')

# 主循环:训练和测试
num_epochs = 10  # 定义训练轮数
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, criterion, optimizer, epoch)  # 调用训练函数
    test(model, device, test_loader, criterion)  # 调用测试函数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2232268.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android亮屏Job的功耗优化方案

摘要: Job运行时会带来持锁的现象,目前灭屏放电Job的锁托管已经有doze和绿盟标准监管,但是亮屏时仍旧存在过长的持锁现象,故为了优化功耗和不影响用户体验下,新增亮屏放电下如果满足冻结和已运行过一次Job,则进行job限制,当非冻结时恢复的策略 1.现象: (gms_schedu…

Linux版更新流程

一.下载更新包 下载地址:https://www.nvisual.com/%e4%b8%8b%e8%bd%bd/ 二.更新包组成 更新包由三部分组成: 前端更新包:压缩的ZIP文件,例如:dist-2.2.26-20231227.zip (2.2.26是版本号 20231227是发布日期)后端更…

Java环境下配置环境(jar包)并连接mysql数据库

目录 jar包下载 配置 简单连接数据库 一、注册驱动(jdk6以后会自动注册) 二、连接对应的数据库 以前学习数据库就只是操作数据库,根本不知道该怎么和软件交互,将存储的数据读到软件中去,最近学习了Java连接数据库…

鸿蒙网络编程系列42-仓颉版域名解析示例

1. 域名解析简介 域名解析是网络开发中经常使用的功能之一,特别是对于当前版本的鸿蒙API,使用TCP或者UDP等网络协议通讯时,只能使用确定的IP地址进行绑定或者发送消息,还不支持直接使用域名,所以,通过域名…

第15课 算法(下)

掌握冒泡排序、选择排序、插入排序、顺序查找、对分查找的的基本原理,并能使用这些算法编写简单的Python程序。 一、冒泡排序 1、冒泡排序的概念 冒泡排序是最简单的排序算法,是在一列数据中把较大(或较小)的数据逐次向右推移的…

Netty 强大的 ByteBuf

Netty 强大的 ByteBuf Netty ByteBuf功能可以类比NIO 中 ByteBuffer,那为什么不直接使用NIO 中ByteBuffer? 主要是易用性和扩展性一些方面,有点可以肯定,Netty 基于NIO实现的,底层肯定用了ByteBuffer 。 jdk Buffer API 复杂性…

从安装到实战:Spring Boot与kafka终极整合指南

docker环境下部署kafka 前置条件 Apache Kafka 自 2.8.0 版本开始引入了不依赖 Zookeeper 的“Kafka Raft Metadata Mode”,本文章依然使用Zookeeper 作为集群管理的插件。 #拉去zookeeper镜像docker pull wurstmeister/zookeeper#运行zookeeper容器docker run -…

【Kettle的安装与使用】使用Kettle实现mysql和hive的数据传输(使用Kettle将mysql数据导入hive、将hive数据导入mysql)

文章目录 一、安装1、解压2、修改字符集3、启动 二、实战1、将hive数据导入mysql2、将mysql数据导入到hive 一、安装 Kettle的安装包在文章结尾 1、解压 在windows中解压到一个非中文路径下 2、修改字符集 修改 spoon.bat 文件 "-Dfile.encodingUTF-8"3、启动…

如何看待AI技术的应用前景?

文章目录 如何看待AI技术的应用前景引言AI技术的现状1. AI的定义与分类2. 当前AI技术的应用领域 AI技术的应用前景1. 经济效益2. 社会影响3. 技术进步 AI技术应用面临的挑战1. 数据隐私与安全2. 可解释性与信任3. 技能短缺与就业影响 AI技术的未来发展方向1. 人工智能的伦理与法…

PyQt5实战——UTF-8编码器UI页面设计以及按钮连接(五)

个人博客:苏三有春的博客 系类往期文章: PyQt5实战——多脚本集合包,前言与环境配置(一) PyQt5实战——多脚本集合包,UI以及工程布局(二) PyQt5实战——多脚本集合包,程序…

快速入门CSS

欢迎关注个人主页:逸狼 创造不易,可以点点赞吗 如有错误,欢迎指出~ 目录 CSS css的三种引入方式 css书写规范 选择器分类 标签选择器 class选择器 id选择器 复合选择器 通配符选择器 color颜色设置 border边框设置 width/heigth 内/外边距 C…

【基础】os模块

前言 1、os是operation system(操作系统)的缩写;os模块就是python对操作系统操作接口的封装。os模块提供了多数操作系统的功能接口函数。(OS模块提供了与操作系统进行交互的函数) 2、操作系统属于Python的标准实用程…

Linux---cp命令

Linux cp 命令 | 菜鸟教程 (runoob.com) 命令作用: cp命令主要用于复制文件或目录 语法: cp [options] source dest cp [选项] 源文件 目标文件 source:要复制的文件或目录的名称 dest:复制后的文件或目录的名称 注意:用户使用该指令复制目录时&…

MyBatis-Plus快速入门:从安装到第一个Demo

一、前言 在现代 Java 应用程序中,数据访问层的效率与简洁性至关重要。MyBatis-Plus 作为 MyBatis 的增强工具,旨在简化常见的数据操作,提升开发效率。它提供了丰富的功能,如自动生成 SQL、条件构造器和简单易用的 CRUD 操作&…

【android12】【AHandler】【3.AHandler原理篇AHandler类方法全解】

AHandler系列 【android12】【AHandler】【1.AHandler异步无回复消息原理篇】-CSDN博客 【android12】【AHandler】【2.AHandler异步回复消息原理篇】-CSDN博客 其他系列 本人系列文章-CSDN博客 1.简介 前面两篇我们主要介绍了有回复和无回复的消息的使用方法和源码解析&a…

美发系统——职员绩效和提成——调试过程

一、学会通过现象看本质 首先,通过现象看本质能够让技术研究者更深入地理解问题。在面对技术故障或挑战时,表面的现象往往只是冰山一角,如果只关注表象,可能会采取治标不治本的解决方法。而洞察本质则可以找到问题的根源&#xf…

记一次:Clickhouse同步mysql数据库

ClickHouse可以通过使用MaterializeMySQL引擎来实现与MySQL的数据同步。 前言:因为数据量比较大,既然要分库,为何不让clickhouse同步一下mysql数据库呢? 零、前期准备--mysql的查询和配置 1 查询mysql的配置状态 查询以下语句…

教程:使用 InterBase Express 访问数据库(二)

1. 添加数据模块(IBX 通用教程) 本节将创建一个数据模块(TDataModule),这是一种包含应用程序使用的非可视组件的表单。 以下是完全配置好的 TDataModule 的视图: 创建 TDataModule 后,您可以在其他表单中使用这个数据模块。 2. 添加 TDataModule 要将数据模块添加到…

Matlab实现海马优化算法(SHO)求解路径规划问题

目录 1.内容介绍 2.部分代码 3.实验结果 4.内容获取 1内容介绍 海马优化算法(SHO)是一种受自然界海马行为启发的优化算法,它通过模拟海马在寻找食物和配偶时的探索、跟踪和聚集行为来搜索最优解。SHO因其高效的全局搜索能力和局部搜索能力而…

002-Kotlin界面开发之Kotlin旋风之旅

Kotlin旋风之旅 Compose Desktop中哪些Kotlin知识是必须的? 在学习Compose Desktop中,以下Kotlin知识是必须的: 基础语法:包括变量声明、数据类型、条件语句、循环等。面向对象编程:类与对象、继承、接口、抽象类等。…