深度学习_AlexNet_2

news2024/11/18 1:48:27

目标

  • 知道AlexNet网络结构
  • 能够利用AlexNet完成图像分类

2012年,AlexNet横空出世,该模型的名字源于论文第一作者的姓名Alex Krizhevsky 。AlexNet使用了8层卷积神经网络,以很大的优势赢得了ImageNet 2012图像识别挑战赛。它首次证明了学习到的特征可以超越手工设计的特征,从而一举打破计算机视觉研究的方向。

1.AlexNet的网络架构

AlexNet与LeNet的设计理念非常相似,但也有显著的区别,其网络架构如下图所示:

该网络的特点是:

  • AlexNet包含8层变换,有5层卷积和2层全连接隐藏层,以及1个全连接输出层

  • AlexNet第一层中的卷积核形状是11×1111×11。第二层中的卷积核形状减小到5×55×5,之后全采用3×33×3。所有的池化层窗口大小为3×33×3、步幅为2的最大池化。

  • AlexNet将sigmoid激活函数改成了ReLU激活函数,使计算更简单,网络更容易训练

  • AlexNet通过dropOut来控制全连接层的模型复杂度。

  • AlexNet引入了大量的图像增强,如翻转、裁剪和颜色变化,从而进一步扩大数据集来缓解过拟合。

在tf.keras中实现AlexNet模型:

# 构建AlexNet模型
net = tf.keras.models.Sequential([
    # 卷积层:96个卷积核,卷积核为11*11,步幅为4,激活函数relu
    tf.keras.layers.Conv2D(filters=96,kernel_size=11,strides=4,activation='relu'),
    # 池化:窗口大小为3*3、步幅为2
    tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
    # 卷积层:256个卷积核,卷积核为5*5,步幅为1,padding为same,激活函数relu
    tf.keras.layers.Conv2D(filters=256,kernel_size=5,padding='same',activation='relu'),
    # 池化:窗口大小为3*3、步幅为2
    tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
    # 卷积层:384个卷积核,卷积核为3*3,步幅为1,padding为same,激活函数relu
    tf.keras.layers.Conv2D(filters=384,kernel_size=3,padding='same',activation='relu'),
    # 卷积层:384个卷积核,卷积核为3*3,步幅为1,padding为same,激活函数relu
    tf.keras.layers.Conv2D(filters=384,kernel_size=3,padding='same',activation='relu'),
    # 卷积层:256个卷积核,卷积核为3*3,步幅为1,padding为same,激活函数relu
    tf.keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'),
    # 池化:窗口大小为3*3、步幅为2
    tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
    # 伸展为1维向量
    tf.keras.layers.Flatten(),
    # 全连接层:4096个神经元,激活函数relu
    tf.keras.layers.Dense(4096,activation='relu'),
    # 随机失活
    tf.keras.layers.Dropout(0.5),
    # 全链接层:4096个神经元,激活函数relu
    tf.keras.layers.Dense(4096,activation='relu'),
    # 随机失活
    tf.keras.layers.Dropout(0.5),
    # 输出层:10个神经元,激活函数softmax
    tf.keras.layers.Dense(10,activation='softmax')
])

我们构造一个高和宽均为227的单通道数据样本来看一下模型的架构:

# 构造输入X,并将其送入到net网络中
X = tf.random.uniform((1,227,227,1)
y = net(X)
# 通过net.summay()查看网络的形状
net.summay()

网络架构如下:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (1, 55, 55, 96)           11712     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (1, 27, 27, 96)           0         
_________________________________________________________________
conv2d_1 (Conv2D)            (1, 27, 27, 256)          614656    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (1, 13, 13, 256)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (1, 13, 13, 384)          885120    
_________________________________________________________________
conv2d_3 (Conv2D)            (1, 13, 13, 384)          1327488   
_________________________________________________________________
conv2d_4 (Conv2D)            (1, 13, 13, 256)          884992    
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (1, 6, 6, 256)            0         
_________________________________________________________________
flatten (Flatten)            (1, 9216)                 0         
_________________________________________________________________
dense (Dense)                (1, 4096)                 37752832  
_________________________________________________________________
dropout (Dropout)            (1, 4096)                 0         
_________________________________________________________________
dense_1 (Dense)              (1, 4096)                 16781312  
_________________________________________________________________
dropout_1 (Dropout)          (1, 4096)                 0         
_________________________________________________________________
dense_2 (Dense)              (1, 10)                   40970     
=================================================================
Total params: 58,299,082
Trainable params: 58,299,082
Non-trainable params: 0
_________________________________________________________________

2.手写数字势识别

AlexNet使用ImageNet数据集进行训练,但因为ImageNet数据集较大训练时间较长,我们仍用前面的MNIST数据集来演示AlexNet。读取数据的时将图像高和宽扩大到AlexNet使用的图像高和宽227。这个通过tf.image.resize_with_pad来实现。

2.1 数据读取

首先获取数据,并进行维度调整:

import numpy as np
# 获取手写数字数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 训练集数据维度的调整:N H W C
train_images = np.reshape(train_images,(train_images.shape[0],train_images.shape[1],train_images.shape[2],1))
# 测试集数据维度的调整:N H W C
test_images = np.reshape(test_images,(test_images.shape[0],test_images.shape[1],test_images.shape[2],1))

由于使用全部数据训练时间较长,我们定义两个方法获取部分数据,并将图像调整为227*227大小,进行模型训练:

# 定义两个方法随机抽取部分样本演示 # 获取训练集数据 def get_train(size): # 随机生成要抽样的样本的索引 index = np.random.randint(0, np.shape(train_images)[0], size) # 将这些数据resize成227*227大小 resized_images = tf.image.resize_with_pad(train_images[index],227,227,) # 返回抽取的 return resized_images.numpy(), train_labels[index] # 获取测试集数据 def get_test(size): # 随机生成要抽样的样本的索引 index = np.random.randint(0, np.shape(test_images)[0], size) # 将这些数据resize成227*227大小 resized_images = tf.image.resize_with_pad(test_images[index],227,227,) # 返回抽样的测试样本 return resized_images.numpy(), test_labels[index]

调用上述两个方法,获取参与模型训练和测试的数据集:

# 获取训练样本和测试样本
train_images,train_labels = get_train(256)
test_images,test_labels = get_test(128)

为了让大家更好的理解,我们将数据展示出来:

# 数据展示:将数据集的前九个数据集进行展示
for i in range(9):
    plt.subplot(3,3,i+1)
    # 以灰度图显示,不进行插值
    plt.imshow(train_images[i].astype(np.int8).squeeze(), cmap='gray', interpolation='none')
    # 设置图片的标题:对应的类别
    plt.title("数字{}".format(train_labels[i]))

结果为:

我们就使用上述创建的模型进行训练和评估。

2.2 模型编译

# 指定优化器,损失函数和评价指标
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0, nesterov=False)

net.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

2.3 模型训练

# 模型训练:指定训练数据,batchsize,epoch,验证集
net.fit(train_images,train_labels,batch_size=128,epochs=3,verbose=1,validation_split=0.1)

训练输出为:

Epoch 1/3 2/2 [==============================] - 3s 2s/step - loss: 2.3003 - accuracy: 0.0913 - val_loss: 2.3026 - val_accuracy: 0.0000e+00 Epoch 2/3 2/2 [==============================] - 3s 2s/step - loss: 2.3069 - accuracy: 0.0957 - val_loss: 2.3026 - val_accuracy: 0.0000e+00 Epoch 3/3 2/2 [==============================] - 4s 2s/step - loss: 2.3117 - accuracy: 0.0826 - val_loss: 2.3026 - val_accuracy: 0.0000e+00

2.4 模型评估

# 指定测试数据
net.evaluate(test_images,test_labels,verbose=1)

输出为:

4/4 [==============================] - 1s 168ms/step - loss: 2.3026 - accuracy: 0.0781 [2.3025851249694824, 0.078125]

如果我们使用整个数据集训练网络,并进行评估的结果:

[0.4866700246334076, 0.8395]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1509457.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

安卓studio安装(从安装到配置到helloworld)

安卓studio安装 2024.3.11官网的版本(有些翻墙步骤下载东西也解决了) 这次写的略有草率,后面会更新布局的,因为截图量太大了,有需要的小伙伴可以试着接受一下哈哈哈哈 !(https://gitee.com/jiuzheyangbawjf/img/raw/ma…

【webrtc】m122:BitrateProber 源码阅读与分析

pacing controller 需要 bitrate prober Pacing模块中存在一个BitrateProber prober_的成员变量,专门用来处理带宽探测 大神的分析也是基于最新版本webrtc的:ProbeController每次可能会生成多个探测源数据ProbeClusterConfig,其中每个源数据ProbeClusterConfig对应一个探测簇…

软件杯 垃圾邮件(短信)分类算法实现 机器学习 深度学习

文章目录 0 前言2 垃圾短信/邮件 分类算法 原理2.1 常用的分类器 - 贝叶斯分类器 3 数据集介绍4 数据预处理5 特征提取6 训练分类器7 综合测试结果8 其他模型方法9 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 垃圾邮件(短信)分类算…

C++Qt学习——Qt信号槽

信号和槽是Qt编程的基础,他们的存在使得在Qt中处理界面各个组件的交互操作变得更加直观简单。信号(SUGNAL):也就是发送者发送的函数信号,例如PushButtun最常见的信号就是鼠标单击的时候发射的click()信号槽&#xff08…

华为配置ISP选路实现报文按运营商转发

CLI举例:配置ISP选路实现报文按运营商转发 介绍通过配置ISP选路实现报文按运营商转发的配置举例。 组网需求 如图1所示,FW作为安全网关部署在网络出口,企业分别从ISP1和ISP2租用一条链路。 企业希望访问Server 1的报文从ISP1链路转发&#…

Python机器学习预测+回归全家桶,新增TCN,BiTCN,TCN-GRU,BiTCN-BiGRU等组合模型预测...

截止到本期,一共发了4篇关于机器学习预测全家桶Python代码的文章。参考往期文章如下: 1.机器学习预测全家桶-Python,一次性搞定多/单特征输入,多/单步预测!最强模板! 2.机器学习预测全家桶-Python&#xff…

el-table中 el-popover 性能优化

场景:在 el-table 中使用 el-popover ,出现了 loading 加载卡顿的问题,接口返回的数据的时间大概是 140ms ,所以不是接口慢的原因;通过对表中结构的逐步排查,发现是表中的 某一行 所影响的;并且 其中含有 e…

qt 汉字输出 中文输出 显示乱码 qDebug() 乱码 解决

要正确显示汉字,必须要先了解计算机文字编码相关知识,参考:unicode ucs2 utf16 utf8 ansi GBK GB2312 互转 及 渲染_ucs2编码转换-CSDN博客 1、汉字输出到 应用程序输出面板 qt 自定义的输出类qDebug() 、QDebug对象、QMessageLogger默认输…

单例模式及线程安全的实践

🌟 欢迎来到 我的博客! 🌈 💡 探索未知, 分享知识 !💫 本文目录 引言基本的单例模式长啥样?怎样才能线程安全?**懒汉模式** ( 双 重 检 查 ) 🎉总结🎉 引言 单例模式是个…

WebPack自动吐出脚本

window.c c; window.res ""; window.flag false;c function (r) {if (flag) {window.res window.res "${r.toString()}" ":" (e[r] "") ",";}return window.c(r); }代码改进了一下,可以过滤掉重复的方…

酷开科技发力研发酷开系统,让家庭娱乐生活更加丰富多彩

在这个快节奏的社会,家庭娱乐已成为我们日常生活中不可或缺的一部分,为了给家庭带来更多欢笑与感动,酷开科技发力研发出拥有丰富内容和技术的智能电视操作系统——酷开系统,它集合了电影、电视剧、综艺、游戏、音乐等海量内容&…

腾讯云和阿里云4核8G云服务器多少钱一年和1个月费用对比

4核8G云服务器多少钱一年?阿里云ECS服务器u1价格955.58元一年,腾讯云轻量4核8G12M带宽价格是646元15个月,阿腾云atengyun.com整理4核8G云服务器价格表,包括一年费用和1个月收费明细: 云服务器4核8G配置收费价格 阿里…

6.S081的Lab学习——Lab1: Xv6 and Unix utilities

文章目录 前言一、启动xv6(难度:Easy)解析: 二、sleep(难度:Easy)解析: 三、pingpong(难度:Easy)解析: 四、Primes(素数,难度:Moderate/Hard)解析&#xff1a…

pymysql连不上mysql的原因

我试了两种解决办法。可以参考一下 第一种:查看有没有打开mysql服务 第二种:刷新 MySQL 用户权限 password改成自己的密码 GRANT ALL PRIVILEGES ON *.* TO root% IDENTIFIED BY password WITH GRANT OPTION;FLUSH PRIVILEGES; 第三种:检…

CSS3的一些常用语句以及解释

margin和padding position static 该关键字指定元素使用正常的布局行为,即元素在文档常规流中当前的布局位置。此时 top, right, bottom, left 和 z-index 属性无效。 relative 该关键字下,元素先放置在未添加定位时的位置,再在不改变页面…

C# 入门

教程: .NET | 构建。测试。部署。 (microsoft.com) C# 文档 - 入门、教程、参考。 | Microsoft Learn C# 数据类型 | 菜鸟教程 (runoob.com) IDE: Visual Studio: 面向软件开发人员和 Teams 的 IDE 和代码编辑器 (microsoft.com) Rider&#xff1a…

Net8 ABP VNext集成FreeSql、SqlSugar

ABP可以快速搭建开发架构,但是内置的是EFCore,国内中小企业使用FreeSql与SqlSugar还是较多,为新手提供使用提供参考 ABP、FreeSql、SqlSugar参考地址: ABP Framework | Open source web application framework for ASP.NET Core…

Buildroot 之二 详解构建架构、流程、external tree、示例

构建系统 Buildroot 中的构建系统使用的是从 Linux Kernel(4.17-rc2) 中移植的 Kconfig(配置) + Makefile & Kbuild(编译)这套构建系统,移植后的源码位于 support/kconfig/ 目录下。Buildroot 本身是一个构建系统,与直接编译源码不同,因此,它对这套系统进行了比较…

【恒源智享云】conda虚拟环境的操作指令

conda虚拟环境的操作指令 由于虚拟环境经常会用到,但是我总忘记,所以写个博客,留作自用。 在恒源智享云上,可以直接在终端界面输入指令,例如: 查看已经存在的虚拟环境列表 conda env list查看当前虚拟…

SpringDataRedis笔记

spring:application:name: springdataredisredis:host: 120.0.0.1port: 6379password: 123456lettuce:pool:#最大连接数 默认就是8max-active: 8#最大空闲连接 默认就是8max-idle: 8#最小空闲连接 默认是0min-idle: 0#连接等待时间 默认-1无限等待max-wait: 100RedisTemplate默…