【TensorFlow2 之015】 在 TF 2.0 中实现 AlexNet

news2025/3/2 2:33:09

一、说明

       在这篇文章中,我们将展示如何在 TensorFlow 2.0 中实现基本的卷积神经网络 \(AlexNet\)。AlexNet 架构由 Alex Krizhevsky 设计,并与 Ilya Sutskever 和 Geoffrey Hinton 一起发布。并获得Image Net2012竞赛中冠军。

教程概述:

  1. 理论回顾
  2. 在 TensorFlow 2.0 中的实现

二 理论回顾

        现实生活中的计算机视觉问题需要大量高质量数据进行训练。过去,人们使用 CIFAR 和 NORB 数据集作为计算机视觉问题的基准数据集。然而,ImageNet竞赛改变了这一点。该数据集需要比以前更复杂的网络才能获得良好的结果。

        AlexNet 是 2012 年取得最佳结果的一种网络架构。它的 Top-5 错误率为 15.3%。第二好的成绩远远落后(26.2%)。

        该架构有大约 6000 万个参数,由以下层组成。

图层类型特征图尺寸内核大小跨步激活
图像1227×227
卷积9655×5511×114ReLU
最大池化9627×273×32
卷积25627×275×51ReLU
最大池化25613×133×32
卷积第384章13×133×31ReLU
卷积第384章13×133×31ReLU
卷积25613×133×31ReLU
最大池化2566×63×32
完全连接4096ReLU
完全连接4096ReLU
完全连接1000软最大

        在我们的例子中,我们将仅在 ImageNet 数据集中的两个类上训练模型,因此我们的最后一个全连接层将只有两个具有 Softmax 激活函数的神经元。

        有一些变化使得 AlexNet 与当时的其他网络不同。让我们看看是什么改变了历史!

2.1  重叠的池化层

        标准池化层汇总同一内核图中相邻神经元组的输出。传统上,相邻池单元总结的邻域不重叠。重叠池化层与标准池化层类似,只是计算 Max 的相邻窗口彼此重叠。

重叠池化与非重叠池化

2.2 ReLU 非线性

        评估神经元输出的传统方法是使用 sigmoid 或 tanh 激活函数。这两个函数固定在最小值和最大值之间,因此它们是饱和非线性的。然而,在 AlexNet 中,使用了修正线性单位函数,或者简称为 \(ReLU\)。该函数的阈值为\(0\)。这是一个非饱和激活函数。

        \(ReLU\) 函数需要更少的计算并允许更快的学习,这对在大型数据集上训练的大型模型的性能有很大影响。

2.3  局部响应标准化

        局部响应归一化 (LRN) 首次在 AlexNet 架构中引入,其中选择的激活函数是 \(ReLU\)。使用 LRN 的原因是为了鼓励 侧向抑制。 这是指神经元减少其邻居活动的能力。当我们使用 ReLU 激活函数处理神经元时,这非常有用。具有 \(ReLU\) 激活函数的神经元具有无界激活,我们需要 LRN 对其进行标准化。

三. TensorFlow 2.0中的实现

        交互式 Colab 笔记本可在以下链接找到

        让我们从导入所有必需的库开始

# Load the TensorBoard notebook extension
%load_ext tensorboard
import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout

        导入后,我们需要准备数据。在这里,我们将仅使用 ImageNet 数据集的一小部分。使用以下代码,您可以下载所有图像并将它们存储在文件夹中。

import cv2
import urllib
import requests
import PIL.Image
import numpy as np
from bs4 import BeautifulSoup

#ship synset
page = requests.get("http://www.image-net.org/api/text/imagenet.synset.geturls?wnid=n04194289")
soup = BeautifulSoup(page.content, 'html.parser')
#bicycle synset
bikes_page = requests.get("http://www.image-net.org/api/text/imagenet.synset.geturls?wnid=n02834778")
bikes_soup = BeautifulSoup(bikes_page.content, 'html.parser')

str_soup=str(soup)
split_urls=str_soup.split('\r\n')

bikes_str_soup=str(bikes_soup)
bikes_split_urls=bikes_str_soup.split('\r\n')

!mkdir /content/train
!mkdir /content/train/ships
!mkdir /content/train/bikes
!mkdir /content/validation
!mkdir /content/validation/ships
!mkdir /content/validation/bikes

img_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)

def url_to_image(url):
    resp = urllib.request.urlopen(url)
    image = np.asarray(bytearray(resp.read()), dtype="uint8")
    image = cv2.imdecode(image, cv2.IMREAD_COLOR)
    return image

n_of_training_images=100
for progress in range(n_of_training_images):
    if not split_urls[progress] == None:
        try:
            I = url_to_image(split_urls[progress])
            if (len(I.shape))==3:
                save_path = '/content/train/ships/img'+str(progress)+'.jpg'
                cv2.imwrite(save_path,I)
        except:
            None

for progress in range(n_of_training_images):
    if not bikes_split_urls[progress] == None:
        try:
            I = url_to_image(bikes_split_urls[progress])
            if (len(I.shape))==3:
                save_path = '/content/train/bikes/img'+str(progress)+'.jpg'
                cv2.imwrite(save_path,I)
        except:
            None


for progress in range(50):
    if not split_urls[progress] == None:
        try:
            I = url_to_image(split_urls[n_of_training_images+progress])
            if (len(I.shape))==3:
                save_path = '/content/validation/ships/img'+str(progress)+'.jpg'
                cv2.imwrite(save_path,I)
        except:
            None


for progress in range(50):
    if not bikes_split_urls[progress] == None:
        try:
            I = url_to_image(bikes_split_urls[n_of_training_images+progress])
            if (len(I.shape))==3:
                save_path = '/content/validation/bikes/img'+str(progress)+'.jpg'
                cv2.imwrite(save_path,I)
        except:
            None

        现在我们可以创建一个网络。原始 AlexNet 的最后一层有 1000 个神经元,但这里我们只使用一个。这是因为我们只将图像用于两个类。为了构建我们的卷积神经网络,我们将使用 Sequential API。

num_classes = 2

# AlexNet model
class AlexNet(Sequential):
    def __init__(self, input_shape, num_classes):
        super().__init__()

        self.add(Conv2D(96, kernel_size=(11,11), strides= 4,
                        padding= 'valid', activation= 'relu',
                        input_shape= input_shape,
                        kernel_initializer= 'he_normal'))
        self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),
                              padding= 'valid', data_format= None))

        self.add(Conv2D(256, kernel_size=(5,5), strides= 1,
                        padding= 'same', activation= 'relu',
                        kernel_initializer= 'he_normal'))
        self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),
                              padding= 'valid', data_format= None)) 

        self.add(Conv2D(384, kernel_size=(3,3), strides= 1,
                        padding= 'same', activation= 'relu',
                        kernel_initializer= 'he_normal'))

        self.add(Conv2D(384, kernel_size=(3,3), strides= 1,
                        padding= 'same', activation= 'relu',
                        kernel_initializer= 'he_normal'))

        self.add(Conv2D(256, kernel_size=(3,3), strides= 1,
                        padding= 'same', activation= 'relu',
                        kernel_initializer= 'he_normal'))

        self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),
                              padding= 'valid', data_format= None))

        self.add(Flatten())
        self.add(Dense(4096, activation= 'relu'))
        self.add(Dense(4096, activation= 'relu'))
        self.add(Dense(1000, activation= 'relu'))
        self.add(Dense(num_classes, activation= 'softmax'))

        self.compile(optimizer= tf.keras.optimizers.Adam(0.001),
                    loss='categorical_crossentropy',
                    metrics=['accuracy'])

model = AlexNet((227, 227, 3), num_classes)

        创建模型后,我们定义一些重要的参数以供以后使用。此外,让我们创建图像数据生成器。\(AlexNet\)的参数非常多,有6000万个,这是一个巨大的数字。如果没有足够的数据,这将很可能导致过度拟合。因此,在这里,我们将利用数据增强技术,您可以在此处找到更多相关信息。

        出于同样的原因,AlexNet 中使用了 dropout 层。该技术包括以预定概率“关闭”神经元。这迫使每个神经元具有更强大的特征,可以与其他神经元一起使用。我们不会在这里使用 dropout 层,因为我们不会使用整个数据集。

# some training parameters
EPOCHS = 100
BATCH_SIZE = 32
image_height = 227
image_width = 227
train_dir = "train"
valid_dir = "validation"
model_dir = "my_model.h5"

train_datagen = ImageDataGenerator(
                  rescale=1./255,
                  rotation_range=10,
                  width_shift_range=0.1,
                  height_shift_range=0.1,
                  shear_range=0.1,
                  zoom_range=0.1)

train_generator = train_datagen.flow_from_directory(train_dir,
                                                    target_size=(image_height, image_width),
                                                    color_mode="rgb",
                                                    batch_size=BATCH_SIZE,
                                                    seed=1,
                                                    shuffle=True,
                                                    class_mode="categorical")

valid_datagen = ImageDataGenerator(rescale=1.0/255.0)
valid_generator = valid_datagen.flow_from_directory(valid_dir,
                                                    target_size=(image_height, image_width),
                                                    color_mode="rgb",
                                                    batch_size=BATCH_SIZE,
                                                    seed=7,
                                                    shuffle=True,
                                                    class_mode="categorical"
                                                    )
train_num = train_generator.samples
valid_num = valid_generator.samples

        现在我们可以设置TensorBoard并开始训练我们的模型。这样我们就可以实时跟踪模型性能。

log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
callback_list = [tensorboard_callback]

# start training
model.fit(train_generator,
                    epochs=EPOCHS,
                    steps_per_epoch=train_num // BATCH_SIZE,
                    validation_data=valid_generator,
                    validation_steps=valid_num // BATCH_SIZE,
                    callbacks=callback_list,
                    verbose=0)

# save the whole model
model.save(model_dir)

%tensorboard --logdir logs/fit

        让我们使用我们的模型进行一些预测并将其可视化。

class_names = ['bike', 'ship']

x_valid, label_batch  = next(iter(valid_generator))

prediction_values = model.predict_classes(x_valid)

# set up the figure
fig = plt.figure(figsize=(10, 6))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)

# plot the images: each image is 227x227 pixels
for i in range(8):
    ax = fig.add_subplot(2, 4, i + 1, xticks=[], yticks=[])
    ax.imshow(x_valid[i,:],cmap=plt.cm.gray_r, interpolation='nearest')
  
    if prediction_values[i] == np.argmax(label_batch[i]):
        # label the image with the blue text
        ax.text(3, 17, class_names[prediction_values[i]], color='blue', fontsize=14)
    else:
        # label the image with the red text
        ax.text(3, 17, class_names[prediction_values[i]], color='red', fontsize=14)

 

四、概括

        在这篇文章中,我们展示了如何在 TensorFlow 2.0 中实现 \(AlexNet\)。我们只使用了 ImageNet 数据集的一部分,这就是为什么我们没有得到最好的结果。为了获得更高的准确性,需要更多的数据和更长的训练时间。

参考资料:

 数据黑客变种rs    深度学习 机器学习 TensorFlow    2020 年 2 月 29 日  |  0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1086842.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring framework Day13:注解结合Java配置类

前言 前面我们管理 bean 都是在 xml 文件中去管理,本次我们将介绍如何在 Java 配置类中去管理 bean。 注解结合 Java 配置类是一种常见的 Spring 注入 Bean 的方式。通常情况下,开发人员会使用 Java Config 来定义应用程序的配置信息,而在 …

三维地下管线建模软件MagicPipe3D V3.1.3发布

经纬管网建模系统MagicPipe3D V3.1.3持续更新,内容如下: (1)新增管线流向配置,建模生成带流向箭头管道模型; (2)新增建模完成后可以直接载入3DTiles或obj模型功能; &a…

GoLang开发使用gin框架搭建web程序

目录 1.SDK安装 ​2.编辑器下载 3.编辑器准备 4.使用 4.1常见请求方式 1.SDK安装 保证装了Golang的sdk(官网下载windows.zip->解压,安装,配置bin的环境变量) 2.编辑器下载 Download GoLand: A Go IDE with extended support for JavaScript, Ty…

postman 密码rsa加密登录-1获取公钥

fiddler抓包看到:请求系统地址会自动跳转到sso接口,查看200状态的接口返回的html里存在一个encrypt的信息,咨询开发这个就是返回的公钥。 在postman的tests里对该返回进行处理,获取公钥并设为环境变量 //获取公钥 var pubKey re…

Rancher 使用指南

Rancher 使用指南 Rancher 是什么?Rancher 与 OpenShift / Kubesphere 主要区别对比RancherOpenShiftKubesphere 对比 Rancher 和 OpenShift Rancher 安装 Rancher 是什么? 企业级Kubernetes管理平台 Rancher 是供采用容器的团队使用的完整软件堆栈。它解决了管理多个Kuber…

RT-Thread 内核移植(学习)

内核移植 内核移植就是指将RT-Thread内核在不同的芯片架构、不同的板卡上运行起来,能够具备线程管理和调度,内存管理,线程间同步和通信、定时器管理等功能。 移植可分为CPU架构移植和BSP(Board support package,板级…

催交费通知单套打单纸设置说明

2.0系统打印催交费通知单设置尺寸操作展示如下,仅供参考。具体如下: 一、Win7系统 1.找到设备和打印机,选中对应打印机后点击上方打印服务器属性; 2.创建一个宽14cm,高14cm的表单; 二、win10系统 1.找到打印机,点管理,选择打印首选项;

Unity关键词语音识别

一、背景 最近使用unity开发语音交互内容的时候,遇到了这样的需求,就是需要使用语音关键字来唤醒应用程序,然后再和程序做交互,有点像智能音箱的意思。具体的技术方案方面,也找了一些第三方的服务,比如百度…

当涉及到API接口数据分析时,主要可以从以下几个方面展开

当涉及到API接口数据分析时,主要可以从以下几个方面展开: 请求分析:可以统计每个API接口的请求次数、请求成功率、失败率等基础指标。这些指标可以帮助你了解API接口的使用情况,比如哪个API接口被调用的次数最多,哪个…

2023年09月 C/C++(四级)真题解析#中国电子学会#全国青少年软件编程等级考试

C/C编程(1~8级)全部真题・点这里 Python编程(1~6级)全部真题・点这里 第1题:酒鬼 Santo刚刚与房东打赌赢得了一间在New Clondike 的大客厅。今天,他来到这个大客厅欣赏他的奖品。房东摆出了一行瓶子在酒吧上…

《向量数据库指南》——向量数据库与 ANN 算法库的区别

向量数据库与 ANN 算法库的区别 我们经常听到一个这样的错误观念——向量数据库只是在 ANN(approximate nearest neighbor,近似最近邻)算法上封装了一层。但这种说法大错特错。 向量数据库可以处理大规模数据,而 ANN 算法库只能处理小型的数据集 从本质上来看,以 Milvus 为…

Adobe Premiere Pro 和 After Effects 安装出错的解决路径

在有点年头的电脑上安装Premiere Pro 和 After Effects 遇到了前所未有的的麻烦,请了某宝上的小哥进行远程安装,两个软件倒是可以用了,但Win11系统无法正常关机,用了几天系统除了关机时会蓝屏几十秒,其他没有发现毛病&…

centos 7 lamp owncloud

OwnCloud是一款开源的云存储软件,基于PHP的自建网盘。基本上是私人使用,没有用户注册功能,但是有用户添加功能,你可以无限制地添加用户,OwnCloud支持多个平台(windows,MAC,Android&a…

计算机网络 | 物理层

计算机网络 | 物理层 计算机网络 | 物理层基本概念数据通信基本知识(一)一个数据通信流程的例子数据通信相关术语三种通信方式数据传输方式串行传输和并行传输同步传输和异步传输 小结 数据通信基本知识(二)码元(Symbo…

【Java 进阶篇】JavaScript 一元运算符详解

在JavaScript中,一元运算符是一类操作符,它们作用于单一操作数(一个值)。这些运算符执行各种操作,包括递增、递减、类型转换等。本文将详细介绍JavaScript中的一元运算符,解释它们的用途,提供示…

MySQL MVCC详细介绍

MVCC概念 MVCC(Multi-Version Concurrency Control) 多版本并发控制,是一种并发控制机制,用于处理数据库中的并发读写操作,它通过在每个事务中创建数据的快照,实现了读写操作的隔离性,从而避免了读写冲突和数据不一致的问题。 M…

VUE echarts 柱状图、折线图 双Y轴 显示

weekData: [“1周”,“2周”,“3周”,“4周”,“5周”,“6周”,“7周”,“8周”,“9周”,“10周”], //柱状图横轴 jdslData: [150, 220, 430, 360, 450, 680, 100, 450, 680, 200], // 折线图的数据 cyslData: [100, 200, 400, 300, 500, 500, 500, 450, 480, 400], // 柱状图…

基于VScode 使用plantUML 插件设计状态机

本文主要记录本人初次在VScode上使用PlantUML设计 本文只讲述操作的实际方法,假设java已安装成功 。 1. 在VScode下安装如下插件 2. 验证环境是否正常 新建一个文件夹并在目录下面新建文件test.plantuml 其内容如下所示: startuml hello world skinparam Style …

ubuntu|23 安装Gnome主题

ubuntu23 安装主题 进入网站选择需要的主题 https://www.opendesktop.org/s/Gnome/p/1357889 1 资源下载 经常加载不出来, 这里直接进入github下载源码 下载zip 2 安装主题 根据文档提示, 执行install.sh就能安装 3 切换主题 安装 tweak工具 sudo …

Win10玩游戏老是弹回桌面的解决方法

在Win10电脑中,用户不仅可以办公,也可以畅玩各种各样的游戏。但是,有时候用户在玩游戏的时候,遇到了游戏老是自己弹回桌面的问题,这样是非常影响游戏体验的,却不清楚具体的解决方法。下面小编给大家带来了简…