经典卷积神经网络-AlexNet

news2024/12/27 11:40:53

AlexNet

学习目标

  • 知道AlexNet网络结构
  • 能够利用AlexNet完成图像分类

 

2012年,AlexNet横空出世,该模型的名字源于论文第一作者的姓名Alex Krizhevsky 。AlexNet使用了8层卷积神经网络,以很大的优势赢得了ImageNet 2012图像识别挑战赛。它首次证明了学习到的特征可以超越手工设计的特征,从而一举打破计算机视觉研究的方向。

1.AlexNet的网络架构

AlexNet与LeNet的设计理念非常相似,但也有显著的区别,其网络架构如下图所示:

 

该网络的特点是:

  • AlexNet包含8层变换,有5层卷积和2层全连接隐藏层,以及1个全连接输出层

  • AlexNet第一层中的卷积核形状是11×1111×11。第二层中的卷积核形状减小到5×55×5,之后全采用3×33×3。所有的池化层窗口大小为3×33×3、步幅为2的最大池化。

  • AlexNet将sigmoid激活函数改成了ReLU激活函数,使计算更简单,网络更容易训练

  • AlexNet通过dropOut来控制全连接层的模型复杂度。

  • AlexNet引入了大量的图像增强,如翻转、裁剪和颜色变化,从而进一步扩大数据集来缓解过拟合。

在tf.keras中实现AlexNet模型:

# 构建AlexNet模型
net = tf.keras.models.Sequential([
    # 卷积层:96个卷积核,卷积核为11*11,步幅为4,激活函数relu
    tf.keras.layers.Conv2D(filters=96,kernel_size=11,strides=4,activation='relu'),
    # 池化:窗口大小为3*3、步幅为2
    tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
    # 卷积层:256个卷积核,卷积核为5*5,步幅为1,padding为same,激活函数relu
    tf.keras.layers.Conv2D(filters=256,kernel_size=5,padding='same',activation='relu'),
    # 池化:窗口大小为3*3、步幅为2
    tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
    # 卷积层:384个卷积核,卷积核为3*3,步幅为1,padding为same,激活函数relu
    tf.keras.layers.Conv2D(filters=384,kernel_size=3,padding='same',activation='relu'),
    # 卷积层:384个卷积核,卷积核为3*3,步幅为1,padding为same,激活函数relu
    tf.keras.layers.Conv2D(filters=384,kernel_size=3,padding='same',activation='relu'),
    # 卷积层:256个卷积核,卷积核为3*3,步幅为1,padding为same,激活函数relu
    tf.keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'),
    # 池化:窗口大小为3*3、步幅为2
    tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
    # 伸展为1维向量
    tf.keras.layers.Flatten(),
    # 全连接层:4096个神经元,激活函数relu
    tf.keras.layers.Dense(4096,activation='relu'),
    # 随机失活
    tf.keras.layers.Dropout(0.5),
    # 全链接层:4096个神经元,激活函数relu
    tf.keras.layers.Dense(4096,activation='relu'),
    # 随机失活
    tf.keras.layers.Dropout(0.5),
    # 输出层:10个神经元,激活函数softmax
    tf.keras.layers.Dense(10,activation='softmax')
])

我们构造一个高和宽均为227的单通道数据样本来看一下模型的架构:

# 构造输入X,并将其送入到net网络中
X = tf.random.uniform((1,227,227,1)
y = net(X)
# 通过net.summay()查看网络的形状
net.summay()

网络架构如下:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (1, 55, 55, 96)           11712     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (1, 27, 27, 96)           0         
_________________________________________________________________
conv2d_1 (Conv2D)            (1, 27, 27, 256)          614656    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (1, 13, 13, 256)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (1, 13, 13, 384)          885120    
_________________________________________________________________
conv2d_3 (Conv2D)            (1, 13, 13, 384)          1327488   
_________________________________________________________________
conv2d_4 (Conv2D)            (1, 13, 13, 256)          884992    
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (1, 6, 6, 256)            0         
_________________________________________________________________
flatten (Flatten)            (1, 9216)                 0         
_________________________________________________________________
dense (Dense)                (1, 4096)                 37752832  
_________________________________________________________________
dropout (Dropout)            (1, 4096)                 0         
_________________________________________________________________
dense_1 (Dense)              (1, 4096)                 16781312  
_________________________________________________________________
dropout_1 (Dropout)          (1, 4096)                 0         
_________________________________________________________________
dense_2 (Dense)              (1, 10)                   40970     
=================================================================
Total params: 58,299,082
Trainable params: 58,299,082
Non-trainable params: 0
_________________________________________________________________

2.手写数字势识别

AlexNet使用ImageNet数据集进行训练,但因为ImageNet数据集较大训练时间较长,我们仍用前面的MNIST数据集来演示AlexNet。读取数据的时将图像高和宽扩大到AlexNet使用的图像高和宽227。这个通过tf.image.resize_with_pad来实现。

2.1 数据读取

首先获取数据,并进行维度调整:

import numpy as np
# 获取手写数字数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 训练集数据维度的调整:N H W C
train_images = np.reshape(train_images,(train_images.shape[0],train_images.shape[1],train_images.shape[2],1))
# 测试集数据维度的调整:N H W C
test_images = np.reshape(test_images,(test_images.shape[0],test_images.shape[1],test_images.shape[2],1))

由于使用全部数据训练时间较长,我们定义两个方法获取部分数据,并将图像调整为227*227大小,进行模型训练:

# 定义两个方法随机抽取部分样本演示
# 获取训练集数据
def get_train(size):
    # 随机生成要抽样的样本的索引
    index = np.random.randint(0, np.shape(train_images)[0], size)
    # 将这些数据resize成227*227大小
    resized_images = tf.image.resize_with_pad(train_images[index],227,227,)
    # 返回抽取的
    return resized_images.numpy(), train_labels[index]
# 获取测试集数据 
def get_test(size):
    # 随机生成要抽样的样本的索引
    index = np.random.randint(0, np.shape(test_images)[0], size)
    # 将这些数据resize成227*227大小
    resized_images = tf.image.resize_with_pad(test_images[index],227,227,)
    # 返回抽样的测试样本
    return resized_images.numpy(), test_labels[index]

调用上述两个方法,获取参与模型训练和测试的数据集:

# 获取训练样本和测试样本
train_images,train_labels = get_train(256)
test_images,test_labels = get_test(128)

为了让大家更好的理解,我们将数据展示出来:

# 数据展示:将数据集的前九个数据集进行展示
for i in range(9):
    plt.subplot(3,3,i+1)
    # 以灰度图显示,不进行插值
    plt.imshow(train_images[i].astype(np.int8).squeeze(), cmap='gray', interpolation='none')
    # 设置图片的标题:对应的类别
    plt.title("数字{}".format(train_labels[i]))

结果为:

 

我们就使用上述创建的模型进行训练和评估。

2.2 模型编译

# 指定优化器,损失函数和评价指标
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0, nesterov=False)

net.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

2.3 模型训练

# 模型训练:指定训练数据,batchsize,epoch,验证集
net.fit(train_images,train_labels,batch_size=128,epochs=3,verbose=1,validation_split=0.1)

训练输出为:

Epoch 1/3
2/2 [==============================] - 3s 2s/step - loss: 2.3003 - accuracy: 0.0913 - val_loss: 2.3026 - val_accuracy: 0.0000e+00
Epoch 2/3
2/2 [==============================] - 3s 2s/step - loss: 2.3069 - accuracy: 0.0957 - val_loss: 2.3026 - val_accuracy: 0.0000e+00
Epoch 3/3
2/2 [==============================] - 4s 2s/step - loss: 2.3117 - accuracy: 0.0826 - val_loss: 2.3026 - val_accuracy: 0.0000e+00

2.4 模型评估

# 指定测试数据
net.evaluate(test_images,test_labels,verbose=1)

输出为:

4/4 [==============================] - 1s 168ms/step - loss: 2.3026 - accuracy: 0.0781
[2.3025851249694824, 0.078125]

如果我们使用整个数据集训练网络,并进行评估的结果:

[0.4866700246334076, 0.8395]

总结

  • 知道AlexNet的网络架构
  • 动手实现手写数字的识别

人工智能学习路线,记得收藏哦


零基础入门:

Python小白基础入门教程 Python入门到精通教程
零基础必备:全套Python教程_Python基础入门视频教程,零基础小白自学Python入门教程

python基础进阶:Python深入浅出进阶教程【敢信?】收藏=点赞十倍
Python实战Djongo项目:python企业级开发项目-手把手从0到1开发《美多商城》
mysql数据库:MySQL全套教程,MySQL从基础到黑马订单案例实战
机器学习算法:3天快速入门python机器学习
聚类算法:360°解读机器学习经典算法——聚类算法
数据挖掘:Python教程,4天快速入门Python数据挖掘,系统精讲+实战案例
Web服务器:Python高级语法进阶教程_python多任务及网络编程,从零搭建网站全套教程
180分钟爬虫入门:180分钟轻松获取疫情数据,Python爬虫入门课
Scrapy框架:Python爬虫基础,快速入门Scrapy爬虫框架
多线程:python多线程编程

人工智能入门:智能机器人软件开发教程基础,从helloworld到神经网络
人工智能深度学习:智能机器人软件开发教程基础,从helloworld到神经网络
图像与视觉处理:人工智能教程|零基础学习计算机视觉快速入门


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/144316.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

为何瑞达利欧的《原则一》这么难读懂?

开始搞不懂,为何一个桥水基金创始人,一位投资人,却写了一本这样的书,书中的内容初看时觉得与他从事的投资事业几乎毫无关系? 《原则》其副标题为《生活和工作的原则》 乍看,此书黑色的封皮,让我…

【自学C++】C++ std命名空间

C std命名空间 C std命名空间教程 在 C 中 std 命名空间 是 C 中标准库类型对象的命名空间。我们常用的输入和输出 函数 都是定义在 std 命名空间中的,因此,我们需要使用输入和输出,必须要引入 std 命名空间。 要引用一个命名空间中的内容…

电脑自动删除文件怎么恢复?分享4种方法

电脑出现文件丢失的情况常有发生,但是出现电脑自动删除文件的情况是怎么回事呢?电脑自动删除的文件怎么恢复呢?本文将详细阐述电脑自动删除文件原因和文件恢复方法。一、电脑自动删除文件是什么原因1.可能不是删除而是电脑开机用户名更改后导…

Java真的不难(五十三)Docker的快速入门及使用

Docker的入门及使用 这篇文章将不全面介绍理论,Docker对于我们后端开发来说会用就行,能使用Docker去安装一些镜像运行,为简化配置节省时间和错误率,所以这篇文章实用性很高,可以直接上手! 一、什么是Docke…

生产制造业ERP管理系统财务管理解决方案

对于生产制造型企业来说,良好的资金运营管理机制是企业长期、稳定、健康发展的保证。因此,企业急需借助生产制造业ERP管理系统,不断加强企业财务管理,从而有效提升企业的经营效率,降低财务风险,缓解资金成本…

云渲染答疑:动画渲染价格一般多少?

云渲染是什么?云渲染就是通过互联网将用户本地需要渲染的文件上传到云端服务器中,再通过云端庞大的计算机集群资源进行运算操作,帮助用户在云端完成渲染工作后,用户再下载到本地的过程,整个过程操作十分简便。云渲染动…

【云原生进阶之容器】第二章Controller Manager原理2.5节--DeltaFIFO剖析

5 DeltaFIFO DeltaFIFO是K8s中用来存储处理数据的Queue,相较于传统的FIFO,它不仅仅存储了数据保证了先进先出,而且存储有K8s 资源对象的类型,它的作用是保证Reflector和Indexer之间对象同步。其是连接Reflector(生产者)和indexer(消费者)的重要通道。其核心处理流程如下: …

android 换肤框架搭建及使用 (3 完结篇)

本系列计划3篇: Android 换肤之资源(Resources)加载(一)setContentView() / LayoutInflater源码分析(二)换肤框架搭建(三) — 本篇 tips: 本篇只说实现思路,以及使用,具体细节请下载代码查看! 本篇实现效果: fragment换肤recyclerView换肤自定义view属性换肤打开打开打开动…

解决第三方图片403问题

第三方平台怎么处理图片资源保护的? 服务端一般使用 Referer 请求头识别访问来源,然后处理资源访问。 Referer 是什么东西? 扩展参考: http://www.ruanyifeng.com/blog/2019/06/http-referer.html Referer是 HTTP 请求头的一部分,当浏览器向 Web 服务…

HTML实现舔狗日记

演示 css html, body {background: radial-gradient(#181818, #000000);margin: 0;padding: 0;border: 0;-ms-overflow-style: none;}::-webkit-scrollbar {width: 0.5em;height: 0.5em;background-color: #c7c7c7;}/*定义滚动条轨道 内阴影圆角*/::-webkit-scrollbar-track {…

不会写代码?也不懂技术?3分钟搭建电商cps系统搞副业

大家好,我是小悟 唠唠家常 以前见面聊天,大家都习惯性会问“你吃饭了吗”,现在大家一出口就是“你阳了吗”。2023年元旦过去了,你还阳着么?不出意外的话就会出意外,小悟也已经中招过了,在家躺…

【Linux】tcpdump命令详解

1、列出本机所有的网卡接口 tcpdump -D2、捕获特定网口的数据包 tcpdump -i bond0.1083、捕获具体数量的数据包 tcpdump -c 5 -i eth04、捕获的数据包保存到指定的文件 tcpdump -w 0001.pcap -i eth05、捕获的数据包显示IP而不

E4402B频谱分析仪

18320918653 E4402B E4402B|Agilent|3G|频谱分析仪|安捷伦|9kHz至3GHz 品牌:安捷伦 Agilent 惠普 HP 测量速度:28次更新/秒 测量精度:1dB 可选用的10Hz分辨事宽滤波器 机箱可容纳6插槽选件卡 97dB三阶动态范围 能在现场使用的坚固&a…

(1分钟速览)SLAM问题中一般方程和超定方程的求解

今天在学习的过程中偶然看到了一个博客,总结Axb的,那么我也写一篇。首先就是判断A的秩和(A|b)的秩之间的关系,然后通过这个关系来进行进一步地判断。编辑切换为居中添加图片注释,不超过 140 字(可选)求解方…

RabbitMQ通配符模式

🍁博客主页:👉不会压弯的小飞侠 ✨欢迎关注:👉点赞👍收藏⭐留言✒ ✨系列专栏:👉Linux专栏 🔥欢迎大佬指正,一起学习!一起加油! 目录&…

Jenkins安装方式之war包及相关环境配置

持续创作,加速成长!这是我参与「掘金日新计划 10 月更文挑战」的第4天,点击查看活动详情 最近总有小伙伴发私信问我jenkins如何以war形式运行?以及运行后如何添加相关的环境配置,这里我就给大家贴出我的解决方案&…

Bandit算法学习[网站优化]04——UCB(Upper Confidence Bound) 算法

Bandit算法学习[网站优化]04——UCB(Upper Confidence Bound) 算法 参考资料 White J. Bandit algorithms for website optimization[M]. " O’Reilly Media, Inc.", 2013.https://github.com/johnmyleswhite/BanditsBookUCB算法原理及其在星际争霸比赛中的应用Aue…

Springboot 接口为null的值不返回对应的key

偶然听到两个应届生一段对话,一个后端,一个前端 。 前端: 大哥,你没有值就不要返回那个key行不行? 后端: 什么我看看。 后端: 这是本来返回值实体有的,不是必填,所以n…

Lua 元表及常见元方法

一、什么是元表 Lua 中的 table 使用起来有点像c中的 map 或者 unordered_map ,都是通过对应的key 获取对应的value。如果访问了表中不存在的key时,就会触发Lua的一种机制,Lua也正是凭借这个机制可以用来模拟类似“继承”的行为,…

低代码能够为企业带来什么?

目录 1、为企业快速开发应用赋能 2、低成本使用数字化工具 3、满足企业定制化需求 大数据时代的快速发展下,传统的应用开发技术手段渐渐地无法满足企业的高需求。并且,企业想在应用开发的基础上同时实现个性化定制,而传统的技术条件所需要…