经典卷积神经网络 - AlexNet

news2024/12/27 3:46:14

在这里插入图片描述在这里插入图片描述
AlexNet是由Alex Krizhevsky、Ilya Sutskever和Geoffrey Hinton在2012年ImageNet图像分类竞赛中提出的一种经典的卷积神经网络。当时,AlexNet在 ImageNet 大规模视觉识别竞赛中取得了优异的成绩,把深度学习模型在比赛中的正确率提升到一个前所未有的高度。因此,它的出现对深度学习发展具有里程碑式的意义。

基本结构

AlexNet输入为RGB三通道的224 × 224 × 3大小的图像(也可填充为227 × 227 × 3 )。AlexNet 共包含5 个卷积层(包含3个池化)和 3 个全连接层。其中,每个卷积层都包含卷积核、偏置项、ReLU激活函数和局部响应归一化(LRN)模块。第1、2、5个卷积层后面都跟着一个最大池化层,后三个层为全连接层。最终输出层为softmax,将网络输出转化为概率值,用于预测图像的类别。

由于ImageNet数据集太大,本文以MNIST数据集进行代替,修改网络参数,输入通道为1,输出结果为10个。

代码实现

model.py

import torch
from torch import nn

class AlexNet(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.model = nn.Sequential(
            nn.Conv2d(1,96,kernel_size=11,stride=4,padding=1),nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Conv2d(96,256,kernel_size=5,padding=2),nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Conv2d(256,384,kernel_size=3,padding=1),nn.ReLU(),
            nn.Conv2d(384,384,kernel_size=3,padding=1),nn.ReLU(),
            nn.Conv2d(384,256,kernel_size=3,padding=1),nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Flatten(),
            nn.Linear(6400,4096),nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096,4096),nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096,10)
        )

    def forward(self,x):
        return self.model(x)

# 验证网络正确性
if __name__ == '__main__':
    net = AlexNet()
    my_input = torch.ones((64,1,28,28))
    my_output = net(my_input)
    print(my_output.shape)

train.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms
from model import AlexNet

# 扫描数据次数
epochs = 10
# 分组大小
batch = 64
# 学习率
learning_rate = 0.01
# 训练次数
train_step = 0
# 测试次数
test_step = 0


# 定义图像转换
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.MNIST(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.MNIST(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))

# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
# 创建网络
net = AlexNet()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)

writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):
    print("-------------------第 {} 轮训练开始-------------------".format(epoch))
    net.train()
    for data in train_dataloader:
        train_step = train_step + 1
        images,targets = data
        images = images.to(device)
        targets = targets.to(device)
        outputs = net(images)
        loss_out = loss(outputs,targets)
        optimizer.zero_grad()
        loss_out.backward()
        optimizer.step()

        if train_step%100==0:
            writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)
            print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))

    # 测试
    net.eval()
    total_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            test_step = test_step + 1
            images, targets = data
            images = images.to(device)
            targets = targets.to(device)
            outputs = net(images)
            loss_out = loss(outputs, targets)
            total_loss = total_loss + loss_out
            accuracy = (targets == torch.argmax(outputs,dim=1)).sum()
            total_accuracy = total_accuracy + accuracy
        # 计算精确率
        print(total_accuracy)
        accuracy_rate = total_accuracy / test_size

        print("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))
        print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))
        writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)
        writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)
    torch.save(net,"./model/net_{}.pth".format(epoch+1))
    print("模型net_{}.pth已保存".format(epoch+1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1126944.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于springboot实现广场舞团平台系统项目【项目源码+论文说明】计算机毕业设计

基于SPRINGBOOT实现广场舞团平台系统演示 摘要 随着信息技术和网络技术的飞速发展,人类已进入全新信息化时代,传统管理技术已无法高效,便捷地管理信息。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生&am…

38.红黑树(王道第7章查找补充知识)

目录 一. 红黑树的定义 二. 红黑树的性质 三. 红黑树的插入 四. 红黑树的删除(略) 一. 红黑树的定义 红黑树是二叉排序树-左子树结点值≤根结点值≤右子树结点值。 与普通BST相比,有以下要求: ①每个结点或是红色,或是黑色的②根节点是…

探索C++赋值运算符重载的内部机制:手把手教你精通

W...Y的主页 😊 代码仓库分享💕 🍔前言: 前一篇博客中我们已经了解并学习了初始化和清理模块中的构造函数与析构函数,还有拷贝复制中的拷贝复制函数,它们都是类与对象中重要的成员,今天我们要…

构建实时视频聊天应用:使用WebRTC和Netty的完整指南

构建实时视频聊天应用:使用WebRTC和Netty的完整指南 使用WebRTC和Netty构建实时视频聊天应用准备工作步骤1:创建Netty服务器步骤2:创建WebRTC前端应用步骤3:处理WebRTC连接步骤4:处理Netty服务器端步骤5:运…

光流法动目标检测

目录 前言 一、效果展示 二、光流法介绍 三、代码展示 总结 前言 动目标检测是计算机视觉领域的一个热门研究方向。传统的方法主要基于背景建模,但这些方法对于光照变化、遮挡和噪声敏感。因此,研究人员一直在寻找更加鲁棒和有效的技术来解决这一问题。…

如何性能测试中进行业务验证?

在性能测试过程中,验证HTTP code和响应业务code码是比较基础的,但是在一些业务中,这些参数并不能保证接口正常响应了,很可能返回了错误信息,所以这个时候对接口进行业务验证就尤其重要。下面分享一个对某个资源进行业务…

CentOS 7设置固定IP地址

当我们安装了一个虚拟机或者装了一个系统的时候,经常会遇到需要设置固定ip的情况,本文就以Centos 7为例,讲述如何修改固定IP地址。 1、用ifconfig命令查看使用的网卡 如上图所示,我们就会看到我们目前使用的网卡名称 2、编辑网卡…

nginx创建站点“nginx: [emerg] host not found in upstream”错误

nginx配置语法上没有错误的,只是系统无法解析这个域名,所以报错. 解决办法就是添加dns到/etc/resolv.conf 或者是/etc/hosts,让其能够解析到IP。具体步骤如下: vim /etc/hosts 修改hosts文件,在hosts文件里面加上一句 127.0.0.1 localhost.localdomain x…

TiDB x 北京银行丨新一代分布式数据库的探索与实践

北京银行作为中国最大的城商行,坚持以数字化转型统领发展模式、业务结构、客户结构、营运能力、管理方式的五大转型,分布式数据库建设是北京银行数字化转型的重要组成部分。 在新时代、新监管、新业态、新模式的数字化转型背景下,监管要求的…

刚刚腾讯云发布了2023年双11优惠活动!终于等到你

终于等到你,想买台腾讯云服务器,等啊等,终于等来了2023年腾讯云双十一优惠活动,还好没让我失望,2核4G5M带宽的轻量应用服务器三年566,省钱了: txybk.com/go/1111 哈哈哈哈哈。 2023腾讯云双11优…

《低代码指南》——如何用维格表搭建CRM

信息 手机上就能随时随地记录客户信息更智能地进行部门协作、沟通让每一项客户沟通都有迹可循一个表格实现客户全生命周期管理企业如何在激烈的市场竞争中崭露头角,拥有自己的立足之地,CRM 系统必然是一大助力。但传统 CRM 系统功能太多太复杂,不够灵活,内部推广、维护又很…

Linux常用命令——clear命令

在线Linux命令查询工具 clear 清除当前屏幕终端上的任何信息 补充说明 clear命令用于清除当前屏幕终端上的任何信息。 语法 clear实例 直接输入clear命令当前终端上的任何信息就可被清除。 在线Linux命令查询工具

Python下载安装

本文以Windows下安装python3.6为例 一、进入Python的官网,链接: python官网 二、选择下载,选择Windows 三、选择自己需要版本的python进行下载 四、选择所下载的exe文件,选择Upgrade Now 五、等待下载 六、安装成功

基于蝗虫算法的无人机航迹规划-附代码

基于蝗虫算法的无人机航迹规划 文章目录 基于蝗虫算法的无人机航迹规划1.蝗虫搜索算法2.无人机飞行环境建模3.无人机航迹规划建模4.实验结果4.1地图创建4.2 航迹规划 5.参考文献6.Matlab代码 摘要:本文主要介绍利用蝗虫算法来优化无人机航迹规划。 1.蝗虫搜索算法 …

Java实现添加文字水印、图片水印

目录 前言 一、获取原图片对象信息 1、读取本地图片 2、读取网络图片 二、处理水印 三、添加水印 四、获取目标图片 五、完整工具类 六、结果展示 前言 现在很多人都喜欢在各种平台上分享自己的照片吧,不管是一些制作出来的媒体图片还是精致的人像图片&…

人大金仓三大兼容:MySQL迁移无忧

近日,MySQL 5.7停服事件引发广泛关注。MySQL目前已经成为中国用户使用非常广泛的数据库,其中5.7版本的用户比重又是最高的。随着信息技术应用创新深入各行各业,国产数据库对MySQL的平滑替换成为大势所趋。 作为数据库领域国家队,人…

Jmeter并发压测数据库的TPC值

Apache JMeter 视频讲解演示:https://www.bilibili.com/video/BV1Dh4y1J7NW/ Apache组织开发的基于Java的压力测试工具,常常用来模拟高并发压测场景 下载网址:https://jmeter.apache.org/download_jmeter.cgi 下载二进制包,解…

【深度学习 | Transformer】释放注意力的力量:探索深度学习中的 变形金刚,一文带你读通各个模块 —— 总结篇(三)

🤵‍♂️ 个人主页: AI_magician 📡主页地址: 作者简介:CSDN内容合伙人,全栈领域优质创作者。 👨‍💻景愿:旨在于能和更多的热爱计算机的伙伴一起成长!!&…

【linux系统】服务器安装Pycharm

文章目录 安装pycharm步骤1. 进入pycharm官网2. 上传到服务器3. 安装过程 摘要:pycharm是Python语言的图形化开发工具。因为如果在Linux环境下的Python shell 中直接进行编程,其无法保存与修改,在大型项目当中这是很不方便的,而py…

【废话文学】各种概念混搭

我认为他一定是在主体意识中出现了一种异常的反馈 这种反馈打破了既定的习惯性模式 于是思维意识出现了层阶梯式的神话 我认为通过XXX同志这个主体意识上的问题 要看出他自身的轨迹而带有意念性 这个悲剧带有鲜明的主观色彩和思辨色彩 而不要只听着在对他人生哲学上的虚无上的研…