AlexNet网络复现

news2024/12/24 18:50:27

1. 引言

在现代计算机视觉领域,深度学习已经成为了一个核心技术,其影响力远超过了传统的图像处理方法。但深度学习,特别是卷积神经网络(CNN)在计算机视觉的主导地位并不是从一开始就有的。在2012年之前,计算机视觉的许多任务都是由一系列手工设计的特征和浅层的机器学习模型完成的。

2012年,一个特殊的网络结构名为AlexNet在ImageNet Large Scale Visual Recognition Challenge(ILSVRC)上取得了出色的成果,这一结果震惊了整个计算机视觉和机器学习社区。AlexNet不仅在分类精度上大幅领先,更重要的是,它开启了一个全新的时代——深度学习的时代。

2. AlexNet背景与重要性

在深度学习成为主流之前,计算机视觉任务主要依赖于手工设计的特征,如SIFT、HOG等,与浅层机器学习模型相结合,如SVM。这些方法虽然在某些任务上有所成功,但总体上受限于其手工设计的特征提取和有限的模型容量。

为了推进计算机视觉的发展,ImageNet项目应运而生,这是一个包含数百万张标注图像的大型数据库。基于此,ImageNet Large Scale Visual Recognition Challenge(ILSVRC)被创建出来,旨在激励研究人员开发更好的图像分类方法。ILSVRC迅速成为了计算机视觉领域的标杆比赛。

2012年,由Alex Krizhevsky、Ilya Sutskever和Geoffrey Hinton合作设计的AlexNet在ILSVRC中大放异彩,它的错误率比第二名低了10%以上,这在当时是一个令人震惊的进步。它使用的深度卷积神经网络架构和其他创新技术,标志着深度学习在计算机视觉领域的崛起。

深度结构: 与之前的模型相比,AlexNet具有更深的网络结构,这使得它能够学习到更复杂的特征表示。
GPU计算: AlexNet的训练利用了GPU并行计算的优势,从而大大加速了深度网络的训练速度。
创新性技术: 如ReLU激活函数、Dropout等技术,都首次在这样的大规模图像任务中显示了其效果和价值。
启发后续研究: AlexNet的成功激励了更多的研究人员转向深度学习,导致了后续的VGG、GoogLeNet、ResNet等一系列网络的诞生。

3. 网络结构详解

在这里插入图片描述

3.1. 卷积层

卷积层是CNN中的核心部分,它通过卷积操作提取输入图像的特征。AlexNet包含多个卷积层,这些卷积层的过滤器数量和大小各异,以捕捉不同级别的特征。

滤波器 :AlexNet使用了大小为11x11、5x5和3x3的滤波器。
步长与填充:初始的卷积层使用了较大的步长(如步长为4的11x11滤波器),这有助于减少网络的计算复杂性。

3.2. 激活函数: ReLU

ReLU(Rectified Linear Unit)在AlexNet中首次在大规模网络中获得了广泛应用,因为它帮助网络更快地收敛并减轻了梯度消失的问题。

特性:ReLU的定义为f(x) = max(0, x),它是非线性的,但计算简单。
优势:相较于Sigmoid或Tanh激活函数,ReLU可以加速SGD的收敛速度。

3.3. 池化

池化层在CNN中用于降低特征的空间维度,从而减少计算量。同时,它还能增加特征的平移不变性。

类型:AlexNet主要使用最大池化。
池化窗口与步长:在AlexNet中,池化窗口为3x3,步长为2。

3.4 全连接层

AlexNet包含3个全连接层,它们用于将前面的特征图汇集到一起,为分类做最后的决策。

神经元数量:前两个全连接层包含4096个神经元,而最后一个全连接层(输出层)根据类别数量决定(在ImageNet挑战中为1000个类别)。

3.5 正则化:Dropout

Dropout是一种正则化技巧,它在训练期间随机“丢弃”神经元,从而防止网络过拟合。

位置:AlexNet在前两个全连接层之后应用了Dropout。
丢弃率:训练期间,每个神经元被丢弃的概率为0.5。

4. 主要特点与创新

4.1 深度结构

相较于其它前期的网络模型,AlexNet有着更深的层次结构,包括五个卷积层,接着是三个全连接层。这种深度结构允许网络学习更丰富和复杂的特征表示。

4.2 ReLU激活函数

之前的神经网络主要采用sigmoid或tanh作为激活函数。AlexNet采用ReLU作为其激活函数,这一简单的变动大大加速了网络的训练,并提高了模型的表现。

4.3 GPU并行计算

由于其深度结构,AlexNet的计算需求远超过当时的CPU能力。为了解决这个问题,设计者利用了两个GPU进行并行计算。这不仅大大加速了训练速度,而且开启了后续深度学习模型利用GPU进行训练的趋势。

4.4 局部响应归一化 (LRN)

虽然后续的研究表明LRN可能不是必要的,但在AlexNet中,作者介绍了局部响应归一化作为一种规范化技术,它在某种程度上模拟了生物神经元的侧抑制机制,有助于增强模型的泛化能力。

4.5 Dropout

为了防止这样一个大型网络过拟合,AlexNet引入了Dropout技术。通过随机关闭一部分神经元,Dropout可以在训练过程中有效地模拟集成学习,从而增强模型的泛化性。

4.6 大数据和数据增强

AlexNet在ImageNet上训练,该数据集包含超过1500万的高分辨率图像和1000个类别。此外,为了进一步扩充数据并提高模型的鲁棒性,设计者还采用了多种数据增强技术,如图像旋转、裁剪和翻转。

4.7 叠加的卷积层

与之前的网络设计不同,AlexNet在没有池化的情况下叠加了多个卷积层,这允许模型捕捉更为复杂的特征组合。

5. 实践:搭建AlexNet

5.1 model

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()

        # 特征提取层
        self.features = nn.Sequential(
            # 第一卷积层
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 第二卷积层
            nn.Conv2d(48, 128, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 第三、四、五卷积层
            nn.Conv2d(128, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )

        # 分类层
        self.classifier = nn.Sequential(
            # Dropout层可以减少过拟合
            nn.Dropout(p=0.5),
            # 全连接层
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )

        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # 通过特征提取层
        x = self.features(x)
        # 展平特征图
        x = torch.flatten(x, start_dim=1)
        # 通过分类层
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        # 初始化权重
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

5.2 train

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet

# 配置参数
BATCH_SIZE = 32
EPOCHS = 10
LR = 0.0002
SAVE_PATH = './AlexNet.pth'


def load_data(data_root):
    """
    加载数据集
    """
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    train_dataset = datasets.ImageFolder(root=os.path.join(data_root, "train"), transform=data_transform["train"])
    validate_dataset = datasets.ImageFolder(root=os.path.join(data_root, "val"), transform=data_transform["val"])

    nw = min([os.cpu_count(), BATCH_SIZE if BATCH_SIZE > 1 else 0, 8])
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=nw)
    validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4, shuffle=False, num_workers=nw)

    return train_loader, validate_loader


def save_class_indices(dataset, save_path='class_indices.json'):
    """
    保存类别和对应的编码到json文件中
    """
    flower_list = dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    with open(save_path, 'w') as f:
        json.dump(cla_dict, f, indent=4)


def train_one_epoch(net, data_loader, optimizer, loss_function, device):
    """
    训练一个epoch
    """
    net.train()
    running_loss = 0.0
    for images, labels in tqdm(data_loader, file=sys.stdout):
        optimizer.zero_grad()
        outputs = net(images.to(device))
        loss = loss_function(outputs, labels.to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(data_loader)


def validate(net, data_loader, device):
    """
    验证模型
    """
    net.eval()
    acc = 0.0
    with torch.no_grad():
        for images, labels in tqdm(data_loader, file=sys.stdout):
            outputs = net(images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, labels.to(device)).sum().item()
    return acc / len(data_loader.dataset)


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using {device} device.")

    data_root = os.path.abspath(os.path.join(os.getcwd(), "./.."))
    image_path = os.path.join(data_root, "data_set", "flower_data")
    assert os.path.exists(image_path), f"{image_path} path does not exist."

    train_loader, validate_loader = load_data(image_path)
    save_class_indices(train_loader.dataset)

    print(
        f"Using {len(train_loader.dataset)} images for training, {len(validate_loader.dataset)} images for validation.")

    net = AlexNet(num_classes=5, init_weights=True).to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=LR)

    best_acc = 0.0
    for epoch in range(EPOCHS):
        train_loss = train_one_epoch(net, train_loader, optimizer, loss_function, device)
        val_acc = validate(net, validate_loader, device)

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), SAVE_PATH)

        print(f"Epoch {epoch + 1}/{EPOCHS} - Train loss: {train_loss:.4f} - Val Accuracy: {val_acc:.4f}")

    print('Finished Training')


if __name__ == '__main__':
    main()

5.3 predict

import os
import json
import argparse

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import AlexNet

# 定义命令行参数解析函数
def parse_args():
    parser = argparse.ArgumentParser(description="预测输入图片的分类")
    parser.add_argument("img_path", help="待预测图片的路径")
    parser.add_argument("--model_path", default="./AlexNet.pth", help="已训练的AlexNet模型的路径")
    parser.add_argument("--class_indices", default="./class_indices.json", help="类别索引的json文件路径")
    return parser.parse_args()

# 加载和预处理图片
def load_image(img_path, transform):
    img = Image.open(img_path)
    img = transform(img)
    return torch.unsqueeze(img, dim=0)

# 加载模型
def load_model(model_path, device, num_classes=5):
    model = AlexNet(num_classes=num_classes).to(device)
    model.load_state_dict(torch.load(model_path))
    return model

# 使用模型进行预测
def predict_image(img, model, class_indict, device):
    model.eval()
    with torch.no_grad():
        output = torch.squeeze(model(img.to(device))).cpu()
        probabilities = torch.softmax(output, dim=0)
        predicted_class = torch.argmax(probabilities).numpy()
    return predicted_class, probabilities

def main():
    args = parse_args()  # 解析命令行参数

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # 定义图片预处理操作
    transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    img = load_image(args.img_path, transform)  # 加载图片

    # 从json文件中读取类别索引
    with open(args.class_indices, "r") as f:
        class_indict = json.load(f)

    model = load_model(args.model_path, device)  # 加载模型

    # 使用模型预测图片
    predicted_class, probabilities = predict_image(img, model, class_indict, device)

    print("预测类别: {}   概率: {:.3}".format(class_indict[str(predicted_class)],
                                         probabilities[predicted_class].numpy()))

    # 打印所有类别的预测概率
    for i in range(len(probabilities)):
        print("类别: {:10}   概率: {:.3}".format(class_indict[str(i)],
                                             probabilities[i].numpy()))
    plt.imshow(Image.open(args.img_path))  # 显示图片
    plt.title("预测结果: {}".format(class_indict[str(predicted_class)]))
    plt.show()

if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1060871.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

二、互联网技术——网络协议

文章目录 一、OSI与TCP/IP参考模型二、TCP/IP参考模型各层功能三、TCP/IP参考模型与对应协议四、常用协议与功能五、常用协议端口 一、OSI与TCP/IP参考模型 二、TCP/IP参考模型各层功能 三、TCP/IP参考模型与对应协议 例题:TCP/IP模型包含四个层次,由上至…

正点原子嵌入式linux驱动开发——U-boot使用

在学会U-boot的移植以及其启动过程之前,先体验一下U-boot会更有助于学习的认知。STM32MP157开发板光盘资料里面已经提供了一个正点原子团队已经移植好的U-Boot,本章我们就直接编译这个移植好的U-Boot,然后烧写到EMMC里面启动,启动…

华为云云耀云服务器L实例评测|部署在线影音媒体系统 Jellyfin

华为云云耀云服务器L实例评测|部署在线影音媒体系统 Jellyfin 一、云耀云服务器L实例介绍1.1 云服务器介绍1.2 产品规格1.3 应用场景1.4 支持镜像 二、云耀云服务器L实例配置2.1 重置密码2.2 服务器连接2.3 安全组配置 三、部署 Jellyfin3.1 Jellyfin 介绍3.2 Docke…

VD6283TX环境光传感器驱动开发(4)----移植闪烁频率代码

VD6283TX环境光传感器驱动开发----4.移植闪烁频率代码 闪烁定义视频教学样品申请源码下载开发板设置开发板选择IIC配置串口配置开启X-CUBE-ALS软件包时钟树配置ADC使用定时器触发采样KEIL配置FFT代码配置app_x-cube-als.c需要添加函数 闪烁定义 光学闪烁被定义为人造光源的脉动…

全志ARM926 Melis2.0系统的开发指引③

全志ARM926 Melis2.0系统的开发指引③ 编写目的6. 存储系统简介6.1.概要描述6.2.文件系统接口6.2.1. 文件系统支持6.2.2. 文件系统接口函数 6.3. Flash 分区6.3.1.如何配置可配分区的大小 6.4.存储介质开发6.4.1. NOR Flash6.4.1.1.添加新 Nor Flash6.4.1.2.Nor Flash 保存用户…

Llama2-Chinese项目:6-模型评测

测试问题筛选自AtomBulb[1],共95个测试问题,包含:通用知识、语言理解、创作能力、逻辑推理、代码编程、工作技能、使用工具、人格特征八个大的类别。 1.测试中的Prompt   例如对于问题"列出5种可以改善睡眠质量的方法"&#xff…

DP读书:《openEuler操作系统》(四)鲲鹏处理器

鲲鹏处理器 一、处理器概述1.Soc2.Chip3.DIE4.Cluster5.Core 二、体系架构1.计算子系统2.存储子系统3.其他子系统 三、CPU编程模型1.中断与异常2.异常级别a.基本概念b.异常级别切换 下面为整理的内容:鲲鹏处理器 架构与编程(一)处理器与服务器…

全志ARM926 Melis2.0系统的开发指引④

全志ARM926 Melis2.0系统的开发指引④ 编写目的7. 固件打包脚本7.1.概要描述7.2.术语定义7.2.1. makefile7.2.2. image.bat 7.3.工具介绍7.4.打包步骤7.4.1. makefile 部分7.4.2. image.bat 部分 7.5.问题与解决方案7.5.1. 固件由那些文件构成7.5.2. melis100.fex 文件包含什么…

(二)正点原子STM32MP135移植——TF-A移植

目录 一、TF-A概述 二、编译官方代码 2.1 解压源码 2.2 打补丁 2.3 编译准备 (1)修改Makfile.sdk (2)设置环境变量 (3)编译 三、移植 3.1 复制官方文件 3.2 修改电源 3.3 修改TF卡和emmc 3.4 添…

Monkey基本使用及介绍

1 简介.. 1 1.1 Monkey是干什么的.. 1 1.2 我们为什么要用monkey. 1 1.3 试行monkey的计划.. 2 2 monkey使用.. 4 2.1 基本常识.. 4 2.2 基本使用.. 6 2.2.1 通过adb 来启动monkey. 6 2.2.2 一些命令选项.. 7 2.2.3 一些测试例子.. 7 2.2.4 执行注意事项.. 9 2.2.5侦…

pandas read_json时ValueError: Expected object or value的解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

数据结构 1.2 算法

算法的基本概念 算法的定义 算法是对特定问题求解步骤的一种描述,它是指定的有限序列,其中的每条指令表示一个或多个操作。 例、 算法的特性 (5个) 1.有穷性 一个算法总在执行有穷步之后结束,且每一步都可以在有穷…

Redis作为缓存,mysql的数据如何与redis进行同步?

Redis作为缓存,mysql的数据如何与redis进行同步? 一定要设置前提,先介绍业务背景 延时双删 双写一致性:当修改了数据库的数据也要同时更新缓存的数据,缓存和数据库的数据要保持一致 读操作:缓存命中,直接返回;缓存未…

位移贴图和法线贴图的区别

位移贴图和法线贴图都是用于增强模型表面细节和真实感的纹理贴图技术,但是它们之间也存在着差异。 1、什么是位移贴图 位移贴图:位移贴图通过在模型顶点上定义位移值来改变模型表面的形状。该贴图包含了每个像素的高度值信息,使得模型的细节…

Nginx与Spring Boot的错误模拟实践:探索502和504错误的原因

文章目录 前言502和504区别---都是Nginx返回的access.log和error.log介绍SpringBoot结合Nginx实战502 and 504准备工作Nginx配置host配置SpringBoot 502模拟access.logerror.log 504模拟access.logerror.log 500模拟access.logerror.log 总结 前言 刚工作那会,最常…

基于Java Web 的购物网站

本系统采用基于JAVA语言实现、架构模式选择B/S架构,Tomcat7.0及以上作为运行服务器支持,基于JAVA等主要技术和框架设计,idea作为开发环境,数据库采用MYSQL5.7以上。 开发环境: JDK版本:JDK1.8 服务器&…

【前后缀技巧】2022牛客多校3 A

登录—专业IT笔试面试备考平台_牛客网 题意: 思路: 这种是典中典中典,对于gcd,背包问题都是一样的处理方式 预处理出前缀lca和后缀lca,枚举哪个消失即可,可以统计方案数 Code: #include &l…

karmada v1.7.0安装指导

前言 安装心得 经过多种方式操作,发现二进制方法安装太复杂,证书生成及其手工操作太多了,没有安装成功;helm方式的安装,v1.7.0的chart包执行安装会报错,手工修复了报错并修改了镜像地址,还是各…

在Ubuntu 20.04搭建最小实验环境

sudo apt-get -y install --no-install-recommends wget gnupg ca-certificates安装导入GPG公钥所需的依赖包。 sudo wget -O - https://openresty.org/package/pubkey.gpg | sudo apt-key add -导入GPG密钥。 sudo apt-get -y install --no-install-recommends software-p…

【APUE】文件系统 — 类 du 命令功能实现

一、du命令解析 Summarize disk usage of the set of FILEs, recursively for directories. du 命令用于输出文件所占用的磁盘空间 默认情况下,它会输出当前目录下(包括该目录的所有子目录下)的所有文件的大小总和,以 1024B 为单…