LeNet网络简介

news2024/9/28 12:40:45

1.  背景

主要介绍LeNet网络预测在CIFAR-10图像数据集上的训练及预测。

2. CIFAR-10图像数据集简介

        CIFAR-10是一个包含了6W张32*32像素的三通道彩色图像数据集,图像划分为10大类,每个类别包含了6K张图像。其中训练集5W张,测试集1W张。

数据加载及预处理:

def load_and_proc_data():
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()
    print('X_train shape', X_train.shape)
    # X_train shape (50000, 32, 32, 3)
    print(X_train.shape[0], 'train samples')
    print(X_test.shape[0], 'test samples')

    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255
    X_test /= 255

    # 将类向量转换成二值类别矩阵
    y_train = np_utils.to_categorical(y_train, NB_CLASSES)
    y_test = np_utils.to_categorical(y_test, NB_CLASSES)
    return X_train, X_test, y_train, y_test

3. LeNet网络模型定义

3.1 单层卷积网络

from keras.models import Sequential
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Activation, Flatten, Dense, Dropout
from keras.datasets import cifar10
from keras.utils import np_utils
from keras.optimizers import RMSprop

class LeNet:
    @staticmethod
    def build(input_shape, classes):
        model = Sequential()
        model.add(Conv2D(32, kernel_size=3, padding='same', input_shape=input_shape))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Flatten())
        model.add(Dense(512))
        model.add(Activation('relu'))
        model.add(Dropout(0.5))

        model.add(Dense(classes))
        model.add(Activation('softmax'))
        model.summary()  # 概要汇总网络
        return model

3.2 模型结构及相关参数

  3.3 增加模型深度(多层卷积)

class LeNet:
    @staticmethod
    def build(input_shape, classes):
        model = Sequential()
        model.add(Conv2D(32, kernel_size=3, padding='same', input_shape=input_shape))
        model.add(Activation('relu'))
        model.add(Conv2D(32, kernel_size=3, padding='same'))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Conv2D(64, kernel_size=3, padding='same'))
        model.add(Activation('relu'))
        model.add(Conv2D(64, kernel_size=3, padding='same'))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
        model.add(Dropout(0.25))
        
        model.add(Flatten())
        model.add(Dense(512))
        model.add(Activation('relu'))
        model.add(Dropout(0.5))

        model.add(Dense(classes))
        model.add(Activation('softmax'))
        model.summary()  # 概要汇总网络
        return model

4. 模型训练及预测

def model_train(X_train, y_train):
    OPTIMIZER = RMSprop()
    model = LeNet.build(input_shape=INPUT_SHAPE, classes=NB_CLASSES)
    model.compile(loss='categorical_crossentropy', optimizer=OPTIMIZER, metrics=['accuracy'])
    history = model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=NB_EPOCH, verbose=1, validation_split=VALIDATION_SPLIT)
    # plot_picture(history)
    return model

def model_evaluate(model, X_test, y_test):
    score = model.evaluate(X_test, y_test, batch_size=BATCH_SIZE, verbose=1)
    print('Test score: ', score[0])
    print('Test acc: ', score[1])

5. 打印准确率和损失函数

import matplotlib.pyplot as plt

def plot_picture(history):
    print(history.history.keys())
    # -----------acc---------------
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('model acc')
    plt.ylabel('acc')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()

    # -----------loss---------------
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()

 6. 模型保存

def model_save(model):
    # 保存网络结构
    model_json = model.to_json()
    with open('cifar10_architecture.json', 'w') as f:
        f.write(model_json)
    # 保存网络权重
    model.save_weights('cifar10_weights.h5', overwrite=True)

7. 主函数

NB_EPOCH = 50
BATCH_SIZE = 128
VALIDATION_SPLIT = 0.2
IMG_ROWS, IMG_COLS = 32, 32
IMG_CHANNELS = 3
INPUT_SHAPE = (IMG_ROWS, IMG_COLS, IMG_CHANNELS)  # 注意顺序
NB_CLASSES = 10

if __name__ == '__main__':
    X_train, X_test, y_train, y_test = load_and_proc_data()
    model = model_train(X_train, y_train)
    # model_save(model)
    model_evaluate(model, X_test, y_test)

模型输出

Test score:  1.3542113304138184
Test acc:  0.6733999848365784

8. 模型加载及在线推理

模型训练好以后,从模型文件加载模型,并进行预测。

import numpy as np
from keras.models import model_from_json
from keras.optimizers import SGD
from skimage.transform import resize
import imageio

def input_data_proc():
    img_names = ['cat.png', 'dog.png']
    img_list = []
    for img_name in img_names:
        img = imageio.imread(img_name)
        img = resize(img, output_shape=(32, 32, 3)).astype('float32')
        print('size: ', img.shape)
        img_list.append(img)
    img_list = np.array(img_list) / 255
    return img_list

def model_predict(model, optim, img_list):
    model.compile(loss='categorical_crossentropy', optimizer=optim, metrics=['accuracy'])
    preds = model.predict(img_list)
    preds = np.argmax(preds, axis=1)
    print(preds)

if __name__ == '__main__':
        model_json = 'cifar10_architecture.json'
        model_weight = 'cifar10_weights.h5'
        model = model_from_json(open(model_json).read())
        model.load_weights(model_weight)

        optim = SGD()
        img_list = input_data_proc()
        model_predict(model, optim, img_list)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/438617.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

window环境rabbitMq安装

RabbitMQ是一个开源的遵循 AMQP协议实现的基于 Erlang语言编写,即需要先安装部署Erlang环境再安装RabbitMQ环境*需加注意的是,可根据两者版本号的对应表,安装相应版本的Erlang和RabbitMQ。 一、安装准备工具 版本查看地址:Rabbi…

mysql的启动关闭原理和实战、及常见的错误排查

前言 MySQL是一个关系型数据库管理系统,由瑞典MySQL AB 公司开发,属于 Oracle 旗下产品。MySQL是最流行的关系型数据库管理系统之一,在 WEB 应用方面,MySQL是最好的 RDBMS (Relational Database Management System,关系…

日本政府官宣:投资42亿日元,量子计算要上“云”

引《日经新闻》报道,日本政府宣布将投资4.2亿日元(约合2.18亿人民币)来支持量子计算领域的发展。这笔资金将被用于扩大云计算平台上的共享量子计算能力,为企业提供更加高效的量子计算服务。该计划将由东京大学领导,支持…

【LeetCode: 1187. 使数组严格递增 | 暴力递归=>记忆化搜索=>动态规划 】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

STL——list、stack与queue

📖作者介绍:22级树莓人(计算机专业),热爱编程<目前在c++阶段>——目标Windows,MySQL,Qt,数据结构与算法,Linux,多线程&…

Springboot 整合 Mybatis

创建SpringBoot项目 首先在IDEA中创建一个SpringBoot项目,注意Java Version 然后Packaging为Jar包形式,Type改为Maven形式。 在上图的下一步中可以选择相关依赖,也可以在项目里面的pom文件中自己添加相关依赖,然后进行import也可…

在外包搞了7年,废了.....

我以自身的经验告诫大家,不要去外包,原因: 无法深入理解项目:由于外包公司通常只负责项目的某一个部分或某一个阶段,软件测试人员无法对整个项目进行深入了解,可能会影响到测试的全面性和准确性。 对测试要…

RB-PEG-NHS;NHS-PEG-Rhodamine罗丹明聚乙二醇琥珀酰亚胺 红色荧光染料罗丹明B功能化聚乙二醇

RB-PEG-NHS,罗丹明-聚乙二醇-活性脂 中文名称:罗丹明-聚乙二醇-活性脂 英文名称:RB-PEG-NHS 性状:固体或者粘稠液体,取决于分子量大小。 溶剂:溶于大部分有机溶剂,溶于水。 分子量:400、60…

【深度学习】RNN、LSTM、GRU

【深度学习】RNN、LSTM、GRU RNNLSTMGRU结语 RNN 和普通神经网络一样,RNN有输入层、输出层和隐含层,不一样的是RNN在不同的时间 t t t会有不同的状态,其中 t − 1 t-1 t−1时刻隐含层的输出会作用到 t t t时刻的隐含层。 RNN因为加入了时间…

强大的图像查看器:EdgeView mac中文

EdgeView mac中文版是mac上一款强大的图像查看软件,可以处理一些最流行的图像文件格式,同时还提供对导航杂志或漫画书的支持。EdgeView能够打开著名的图像文件格式主要包括JPG,GIF,PSD在内的多种格式文件,支持Retina显…

深度解析JavaScript自动化测试工具Cypress的工作运行原理

目录 引言 什么是Cypress? Cypress的工作原理 Cypress运行原理 Cypress和其他自动化测试工具有什么不一样? Cypress的缺点 【自动化测试工程师学习路线】 引言 在当今的软件开发中,自动化测试工具已成为不可或缺的一部分,…

负载均衡式在线OJ

目录 项目介绍所用技术与开发环境所用技术开发环境 项目各种安装升级 gcc安装 jsoncpp安装 cpp-httplib安装boost库安装与测试 ctemplate 项目宏观结构总体文件目录comm : 公共模块compile_run_server:编译和运行compiler.hpp编译runner.hpp 运行compiler_runner.hp…

ChatGPT 速通手册——开源社区的进展

开源社区的进展 在 ChatGPT 以外,谷歌、脸书等互联网巨头,也都发布过千亿级参数的大语言模型,但在交谈问答方面表现相对 ChatGPT 来说都显得一般。根据科学人员推测,很重要的一部分原因是缺失了RLHF(Reinforcement Learning with…

Banana Pi CM4 计算机模组评测(VS 树莓派计算模块 CM4)

如果您正在寻找一款可靠的单板计算机来提升您的下一个项目,但找不到满足您需求的 Raspberry Pi,让我们看看我是否可以提供帮助。在这篇详细的评论中,我将向您介绍 Banana Pi CM4,这是一款适用于各种任务的多功能且功能强大的解决方…

【OpenCV 例程 300篇】257.OpenCV 生成随机矩阵

『youcans 的 OpenCV 例程300篇 - 总目录』 【youcans 的 OpenCV 例程 300篇】257. OpenCV 生成随机矩阵 3.2 OpenCV 创建随机图像 OpenCV 中提供了 cv.randn 和 cv.randu 函数生成随机数矩阵,也可以用于创建随机图像。 函数 cv.randn 生成的矩阵服从正态分布&…

【caddy】 caddy反向代理api服务 聚合go-zero微服务 放过nginx让caddy来快速实现吧

帮助go-zero开发者聚合api 相关视频一、go-zero 微服务整体架构1、微服务的基本架构2、go-zero 微服务的 apiauthrpc.api 文件routes.go 文件 二、本地开发的痛点1、本地多个端口开启的服务2、apifox、postman 三、caddy1、mac下caddy安装2、配置我们自己的caddyfile1&#xff…

SpringBootWeb入门-HTTP协议

一、SpringBootWeb-快速入门 建好springboot工程之后,只留下这几个文件。 这个是springboot的父工程,其实就是继承 二、HTTP协议-概述 •HTTP-概述 三、HTPP协议-请求协议 四、HTTP协议-响应协议 一、状态码大类 状态码分类说明1xx响应中——临时状态码…

buuctf -2

目录 你竟然赶我走 大白 N种方法解决 [ACTF2020 新生赛]Include 1 php://filter的一些学习 [ACTF2020 新生赛]Exec [强网杯 2019]随便注 你竟然赶我走 1.下载文件,得到一张图片 2.放进010分析,在文件尾得到flag 大白 1.根据题目提示&#xff0…

Python安装模块总失败?一次教你学会镜像安装

人生苦短,我用python 安装模块总是不成功? 这次一次性讲清楚~ 还是安装报错指路:点击此处跳转文末名片获取 为什么会出现安装模块失败? 首先我们要知道 其实大部分我们在用的模块, 都是歪果仁开发的, 然而我们在输入 “pip install 模块名” 的时候,…

「 JVM 」常见的垃圾收集器Garbage collector(GC)

「 JVM 」常见的垃圾收集器Garbage collector(GC) 参考&鸣谢 【JVM系统学习之路】常见垃圾回收器 山间木匠 Java 的七种垃圾收集器 | Linux 中国 Jayashree Huttanagoudar 带你走近Java虚拟机到底有哪些经典的垃圾收集器 码上遇见你 文章目录 「 JV…