【TensorFlow2 之014】在 TF 2.0 中实现 LeNet-5

news2024/12/26 0:07:06

一、说明

         在这篇文章中,我们将展示如何在 TensorFlow 中实现像 \(LeNet-5\) 这样的基础卷积神经网络。LeNet-5 架构由 Yann LeCun 于 1998 年发明,是第一个卷积神经网络。

 数据黑客变种rs    深度学习 机器学习 TensorFlow    2020 年 2 月 29 日  |  0

1.1 教程概述:

  1. 理论重述
  2. 在 TensorFlow 中的实现

1. 理论重述

\(LeNet-5 \) 的目标是识别手写数字。因此,它作为输入 \(32\times32\times1 \) 图像。它是灰度图像,因此通道数为 \(1 \)。下面我们可以看到该网络的架构。

LeNet5架构

LeNet-5架构

        I. 最初的 MNIST 现在被认为太简单了。在过去的二十年里,许多研究人员针对原始 MNIST 提出了成功的解决方案。您可以在此处查看直接比较结果。

        由于该网络主要是为 MNIST 数据集设计的,因此它的性能明显更好。通过微小的改变,它就可以在 Fashion MNIST 数据集上达到这种准确性。然而,在这篇文章中,我们将坚持网络的原始架构。

关于 LeNet-5 架构,您可以在此处阅读详细的理论文章。

        让我们总结一下 LeNet-5 架构的各层。

图层类型特征图尺寸内核大小跨步激活
图像132×32
卷积628×285×51正值
平均池化614×142×22
卷积1610×105×51正值
平均池化165×52×22
完全连接120正值
完全连接84正值
完全连接10软最大

二、TensorFlow中的实现

        交互式 Colab 笔记本可在以下链接找到  
        在 Google Colab 中运行

为了练习,您可以尝试通过将Fashion_mnist替换为mnist、cifar10或其他来更改数据集。


        让我们从导入所有必需的库开始。导入后,我们可以使用导入的模块来加载数据。load_data  () 函数将自动下载数据并将其拆分为训练集和测试集。

import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.layers import Dense, Flatten, Conv2D, AveragePooling2D

from tensorflow.keras import datasets
from tensorflow.keras.utils import to_categorical

#from __future__ import absolute_import, division, print_function, unicode_literals
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 2us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 7s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 1s 0us/step

        我们可以检查新数据的形状,发现我们的图像是 28×28 像素,因此我们需要添加一个新轴,它将代表多个通道。此外,对标签进行 one-hot 编码和对输入图像进行归一化也很重要。

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_train[0].shape, 'image shape')

x_train shape: (60000, 28, 28)
60000 train samples
10000 test samples
(28, 28) image shape
 

# Add a new axis
x_train = x_train[:, :, :, np.newaxis]
x_test = x_test[:, :, :, np.newaxis]

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_train[0].shape, 'image shape')
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
(28, 28, 1) image shape
# Convert class vectors to binary class matrices.

num_classes = 10
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)
# Data normalization
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255


        现在,是时候开始使用 TensorFlow 2.0 来构建我们的卷积神经网络了。最简单的方法是使用 Sequential API。我们将其包装在一个名为LeNet的类中。输入是图像,输出是类概率向量。

在 tf.keras 中,顺序模型表示层的线性堆栈,在本例中它遵循网络架构。

# LeNet-5 model
class LeNet(Sequential):
    def __init__(self, input_shape, nb_classes):
        super().__init__()

        self.add(Conv2D(6, kernel_size=(5, 5), strides=(1, 1), activation='tanh', input_shape=input_shape, padding="same"))
        self.add(AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))
        self.add(Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))
        self.add(AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))
        self.add(Flatten())
        self.add(Dense(120, activation='tanh'))
        self.add(Dense(84, activation='tanh'))
        self.add(Dense(nb_classes, activation='softmax'))

        self.compile(optimizer='adam',
                    loss=categorical_crossentropy,
                    metrics=['accuracy'])
model = LeNet(x_train[0].shape, num_classes)
model.summary()
Model: "le_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 28, 28, 6)         156       
_________________________________________________________________
average_pooling2d (AveragePo (None, 14, 14, 6)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 10, 10, 16)        2416      
_________________________________________________________________
average_pooling2d_1 (Average (None, 5, 5, 16)          0         
_________________________________________________________________
flatten (Flatten)            (None, 400)               0         
_________________________________________________________________
dense (Dense)                (None, 120)               48120     
_________________________________________________________________
dense_1 (Dense)              (None, 84)                10164     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                850       
=================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0

        创建模型后,我们需要训练它的参数以使其强大。让我们训练给定数量的 epoch 模型。我们可以在TensorBoard中看到训练的进度。

# Place the logs in a timestamped subdirectory
# This allows to easy select different training runs
# In order not to overwrite some data, it is useful to have a name with a timestamp
log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# Specify the callback object
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

# tf.keras.callback.TensorBoard ensures that logs are created and stored
# We need to pass callback object to the fit method
# The way to do this is by passing the list of callback objects, which is in our case just one
model.fit(x_train, y=y_train, 
          epochs=20, 
          validation_data=(x_test, y_test), 
          callbacks=[tensorboard_callback],
          verbose=0)

%tensorboard --logdir logs/fit

        仅 20 个 epoch 就已经不错了。

        之后,我们可以做出一些预测并将其可视化。使用predict_classes()函数,我们可以从网络中获取准确的类别而不是概率值。它与使用numpy.argmax()相同。我们将用红色显示错误的预测,用蓝色表示正确的预测。

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

prediction_values = model.predict_classes(x_test)

# set up the figure
fig = plt.figure(figsize=(15, 7))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)

# plot the images: each image is 28x28 pixels
for i in range(50):
    ax = fig.add_subplot(5, 10, i + 1, xticks=[], yticks=[])
    ax.imshow(x_test[i,:].reshape((28,28)),cmap=plt.cm.gray_r, interpolation='nearest')
  
    if prediction_values[i] == np.argmax(y_test[i]):
        # label the image with the blue text
        ax.text(0, 7, class_names[prediction_values[i]], color='blue')
    else:
        # label the image with the red text
        ax.text(0, 7, class_names[prediction_values[i]], color='red')
时尚迷斯特

三、总结

        所以,在这里我们学习了如何在 Tensorflow 2.0 中开发和训练 LeNet-5。在下 一篇文章中 ,我们将继续实现流行的卷积神经网络,并学习如何在 TensorFlow 2.0 中实现AlexNet 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1084475.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaScript(上)

1.JavaScript概述 JavaScript 是一种客户端脚本语言。运行在客户端浏览器中,每一个浏览器都具备解析 JavaScript 的引擎 脚本语言:不需要编译,就可以被浏览器直接解析执行了 核心功能就是增强用户和 HTML 页面的交互过程,让页面…

400电话怎么办理

400电话,又称为全国统一客服热线电话,是一种企业或机构为了提供更便捷的客户服务而开通的电话号码。通过拨打400电话,客户可以免费或按照本地市话费率与企业或机构进行沟通,解决问题或获得相关服务。下面将介绍400电话的办理流程和…

企业集中式日志管理解决方案

集中式日志记录解决方案收集日志并统一来自各种网络设备(如服务器、防火墙、路由器、工作站)、应用程序(如IIS、Apache、DHCP)、入侵检测系统等的数据。该解决方案在中央控制台中显示日志,使其易于访问,日志…

【若依】定时任务问题:关闭了定时任务,但是依然在跑,且同一时刻跑了多条记录,为什么?

文章目录 问题1描述:原因:办法: 问题2描述:原因:办法: 问题1 描述: 定时任务关闭了, 但是服务器定时任务依然在跑 原因: 若依自带定时任务有缓存,且缓存是服务器内存&#xff0c…

布隆过滤器的优缺点及哈希切割问题

文章目录 1.布隆过滤器优点2.布隆过滤器缺陷3.哈希切割 1.布隆过滤器优点 增加和查询元素的时间复杂度为:O(K)(K为哈希函数的个数,一般较小),与数据量大小无关哈希函数相互之间没有关系,方便硬件并行运算布隆过滤器不需要存储元素本身&#…

Stable Diffusion XL搭建

本文参考:Stable Diffusion XL1.0正式发布了,赶紧来尝鲜吧-云海天教程 Stable Diffision最新模型SDXL 1.0使用全教程 - 知乎 1、SDXL与SD的区别 (1)分辨率得到了提升 原先使用SD生成图片,一般都是生成512*512&…

软件测试工程师简历项目经验怎么写?--1000个已成功入职的软件测试工程师简历范文模板(含真实简历)

一、前言:浅谈面试 ​ 面试是我们进入一个公司的门槛,通过了面试才能进入公司,你的面试结果和你的薪资是息息相关的。那如何才能顺利的通过面试,得到公司的认可呢?面试软件测试要注意哪些问题呢?下面和笔者一起来看看吧。这里…

【牛客面试必刷TOP101】Day11.BM63 跳台阶和 BM67 不同路径的数目(一)

作者简介:大家好,我是未央; 博客首页:未央.303 系列专栏:牛客面试必刷TOP101 每日一句:人的一生,可以有所作为的时机只有一次,那就是现在!!!&…

SpringSecurity + jwt + vue2 实现权限管理 , 前端Cookie.set() 设置jwt token无效问题(已解决)

问题描述 今天也是日常写程序的一天 , 还是那个熟悉的IDEA , 还是那个熟悉的Chrome浏览器 , 还是那个熟悉的网站 , 当我准备登录系统进行登录的时候 , 发现会直接重定向到登录页 , 后端也没有报错 , 前端也没有报错 , 于是我得脸上又多了一张痛苦面具 , 紧接着在前端疯狂debug…

WPF中prism模块化

1、参照(wpf中prism框架切换页面-CSDN博客)文中配置MainView和MainViewModel 2、模块其实就是引用类库,新建两个类库ModuleA ModuleB,修改输出类型为类库,并配置以下文件: ModuleA ModuleAProfile ModuleB Module…

用位运算实现加减乘除法

我们知道计算机只认识0和1,而计算机在计算加减乘除的是也不是我们理解的直接预算,而是通过逻辑运算来实现的,也就是与、非、或、异或,下面就通过这些逻辑运算符来实现加减乘除法 加法:比如11用二进制表示就是00000001…

什么是可持续发展的葡萄酒?

在过去的几年里,消费者越来越意识到他们的日常生活选择对我们的星球和周围环境的潜在影响。我们可以看到使用更少塑料、浪费更少水、食物里程更短的产品越来越受欢迎。这些需求如何转化为葡萄酒世界?这种产品通常要走1000英里才能到达你的杯子。 来自云…

写进简历的软件测试项目实战经验(包含电商、银行、app等)

前言: 今天给大家带来几个软件测试项目的实战总结及经验,适合想自学、转行或者面试的朋友,可以写进简历里的那种哦。 1、项目名称: 家电购 项目描述: “家电购”商城系统是基于 web 浏览器的电子商务系统,通过互联…

3、在docker 容器中安装tomcat

1、在服务器上查找tomcat镜像,查看前5条 docker search tomcat --limit 5​​​​​​​ 2、拉取镜像到本地 拉取官方的tomcat到本地 docker pull tomcat:9.0.34-jdk8 3、查看本地镜像 docker images |grep tomcat 4、启动tomcat 服务 使用默认配置 docker ru…

你不知道的测试小技巧——postman接口测试导入导出操作详解

postman中的集合脚本,环境变量、全局变量全部都可以导出,然后分享给团队成员,导出后的脚本可以通过newman生成测试报告。另外还可以将浏览器,抓包工具,接口文档(swagger)中的数据包导入到postman中,并且会自…

知识付费H5页面+后端+全功能制作源码系统

罗峰今天给大家要分享的是知识付费H5页面制作的源码系统,H5也是一种响应式界面,能自动兼容所有的打开设备屏幕,使得页面在不同尺寸的手机、平板等设备上打开时,界面也会自动兼容适应。这也是大部分用户选择H5页面的原因&#xff0…

centos7下 编译coreboot生成真机可用的bios固件, 并在真机上演示 (下篇)

本文章应该是全网最详细的了, 真机版的coreboot bios固件演示了, 希望对你有帮助 centos7下 编译coreboot生成真机可用的bios固件, 并在真机上演示 (上篇)-CSDN博客 文章为上下两大篇 上篇: 文章主要是 一些东西和资料以及步骤 并 编译出可以用于真机的 bios固件 coreboot.r…

漏电断路器

漏电断路器又名漏保。 一、漏电断路器的作用 1、具有空气开关的功能,三相空气开关对任意一相出现过载或短路,均会跳闸。 2、漏电时,L1和L3进线端子之间有220V的电压差,分励脱钩器就可以工作,引起跳闸。 注意&#…

idea 打包 java 项目 报错类似 No valid Maven installation found - 在maven打包前,修改打包名(jar包)

目录 一、idea 打包 java 项目 报错类似 No valid Maven installation found二、在maven打包前,修改打包名参考链接 一、idea 打包 java 项目 报错类似 No valid Maven installation found 解决措施:一定要能看到maven的版本才行,配置到盖层…

剑指offer(C++)-JZ66:构建乘积数组(算法-其他)

作者:翟天保Steven 版权声明:著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处 题目描述: 给定一个数组 A[0,1,...,n-1] ,请构建一个数组 B[0,1,...,n-1] ,其中 B 的元素 B[i]A[0]*A[1]*...*A[i-1]…