04 卷积神经网络搭建

news2024/10/2 12:19:45

一、数据集

MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的[参考]。

MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。
 

 1.1 准备数据

将数据集分为训练集、验证集和测试集 

#训练集有60000张图片,前5000张图片作为验证集,后55000作为训练集

1. `x_train_all` 和 `y_train_all`:
   - `x_train_all` 包含了完整的训练数据集的图像数据,这些图像用于训练深度学习模型。
   - `y_train_all` 包含了完整的训练数据集的标签,即与 `x_train_all` 中的图像相对应的类别标签。

2. `x_test` 和 `y_test`:
   - `x_test` 包含了测试数据集的图像数据,这些图像用于评估深度学习模型的性能。
   - `y_test` 包含了测试数据集的标签,即与 `x_test` 中的图像相对应的类别标签。

from tensorflow import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]

print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

1.2 标准化

# 标准化
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(
    x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)
x_valid_scaled = scaler.transform(
    x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)

x_test_scaled = scaler.transform(
    x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)

1.3 数据集

def make_dataset(data, target, epochs, batch_size, shuffle=True):
    dataset = tf.data.Dataset.from_tensor_slices((data, target))
    if shuffle:
        dataset = dataset.shuffle(10000)
    dataset = dataset.repeat(epochs).batch(batch_size).prefetch(50)
    return dataset

batch_size = 64
epochs = 20
train_dataset = make_dataset(x_train_scaled, y_train, epochs, batch_size)

1.4 搭建模型

model = keras.models.Sequential()
# 卷积
model.add(keras.layers.Conv2D(filters = 32, 
                              kernel_size = 3, 
                              padding = 'same',
                              activation='relu',
                              # batch_size, height, width, channels
                              input_shape=(28, 28, 1))) # (28, 28, 32)

model.add(keras.layers.Conv2D(filters = 32, 
                              kernel_size = 3, 
                              padding = 'same',
                              activation='relu'))
# 池化
model.add(keras.layers.MaxPool2D()) # (14, 14, 32)

model.add(keras.layers.Conv2D(filters = 64, 
                              kernel_size = 3, 
                              padding = 'same',
                              activation='relu')) # (14, 14, 64)
model.add(keras.layers.Conv2D(filters = 64, 
                              kernel_size = 3, 
                              padding = 'same',
                              activation='relu'))
# 池化
model.add(keras.layers.MaxPool2D()) # (7, 7, 64)
model.add(keras.layers.Conv2D(filters = 128, 
                              kernel_size = 3, 
                              padding = 'same',
                              activation='relu')) # (7, 7, 128)
model.add(keras.layers.Conv2D(filters = 128, 
                              kernel_size = 3, 
                              padding = 'same',
                              activation='relu')) # (7, 7, 128)
# 池化, 向下取整
model.add(keras.layers.MaxPooling2D()) # (3, 3, 128)

model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(512, activation='relu'))
model.add(keras.layers.Dense(256, activation='relu'))
model.add(keras.layers.Dense(10, activation='softmax'))

model.compile(loss='sparse_categorical_crossentropy',
             optimizer='adam',
             metrics=['accuracy'])

print(model.summary())

keras.models.Sequential()`是使用 Keras 构建神经网络模型的开始。`keras.models.Sequential()` 创建了一个 Sequential 模型对象,这是 Keras 中的一种常见模型类型。Sequential 模型是一个线性的、层叠的神经网络模型,适用于顺序层的堆叠,其中每一层都是依次添加到模型中

一旦创建了 Sequential 模型,你可以使用 `.add()` 方法来逐层添加神经网络层,从输入层到输出层。每个层都可以通过实例化 Keras 中的层类来创建,例如

`keras.layers.Dense` 用于全连接层,

`keras.layers.Conv2D` 用于卷积层等。

以下是一个简单的例子,演示如何使用 Sequential 模型创建一个简单的前馈神经网络

from tensorflow import keras

# 创建一个 Sequential 模型
model = keras.models.Sequential()

# 添加输入层和第一个隐藏层
model.add(keras.layers.Input(shape=(input_shape,)))  # 输入层,input_shape 根据你的数据维度定义
model.add(keras.layers.Dense(units=128, activation='relu'))  # 隐藏层1

# 添加第二个隐藏层
model.add(keras.layers.Dense(units=64, activation='relu'))  # 隐藏层2

# 添加输出层
model.add(keras.layers.Dense(units=num_classes, activation='softmax'))  # 输出层,num_classes 是输出类别的数量

# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

上面的代码创建了一个具有两个隐藏层和一个输出层的前馈神经网络模型,并使用了 ReLU 激活函数和 softmax 激活函数。这只是一个简单的示例,你可以根据你的任务和数据来构建更复杂的模型。一旦模型构建完成,你可以使用 `.fit()` 方法来训练模型。

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

是 Keras 中用于编译深度学习模型的代码,它设置了模型的损失函数、优化器和评估指标。让我为你解释每个参数的含义:

- `loss='sparse_categorical_crossentropy'`: 这里设置了模型的损失函数。`sparse_categorical_crossentropy` 是一种用于多类别分类问题的损失函数。它适用于目标变量是整数形式(类别标签)的情况,而不需要将目标变量进行独热编码(one-hot encoding)。模型的目标是最小化这个损失函数,从而使预测结果尽可能接近真实标签。

- `optimizer='adam'`: 这里设置了优化器,用于模型的参数更新。`'adam'` 是一种常用的优化算法,它基于梯度下降的方法,具有自适应学习率和动量的特性,通常在深度学习中表现良好。优化器的作用是最小化损失函数,从而调整模型的权重和参数以使模型更好地拟合数据。

- `metrics=['accuracy']`: 这里设置了评估指标,用于在模型训练期间监测模型性能。

`['accuracy']` 表示模型在训练期间将计算并输出准确度(accuracy),即正确分类的样本数与总样本数的比率。准确度通常用于分类问题的性能评估。

一旦模型编译完成,你可以使用 `model.fit()` 方法来训练模型,该方法会使用上述设置的损失函数、优化器和评估指标来进行训练。例如:

```python
model.fit(x_train, y_train, epochs=10, validation_data=(x_valid, y_valid))
```

1.5 train

eval_dataset = make_dataset(x_valid_scaled, y_valid, epochs=1, batch_size=32, shuffle=False)



history = model.fit(train_dataset, 
         steps_per_epoch=x_train_scaled.shape[0] // batch_size,
         epochs=10,
         validation_data=eval_dataset)

 

1.6 模型评估

model.evaluate(eval_dataset)

 

1.7 可视化

def plot_learning_curves(history):
    pd.DataFrame(history.history).plot(figsize=(8, 5))
    plt.grid(True)
    plt.gca().set_ylim(0, 1)
    plt.show()
    
plot_learning_curves(history)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/983297.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue表格不显示列号123456

我在网上找了半天,都是如何添加列号123456的,没有找到不显示列号的参考,现在把这个解决了,特此记录一下。 没有加右边的就会显示,加上右边的就隐藏了

python+django协同过滤算法的音乐推荐系统研究vue

本系统提供给管理员对用户、音乐分类、歌手、热门歌曲等诸多功能进行管理。本系统对于用户输入的任何信息都进行了一定的验证,为管理员操作提高了效率,也使其数据安全性得到了保障。本音乐推荐研究以Django作为框架,B/S模式以及MySql作为后台…

QtCreator CMakeLists.txt添加模块(Modules)

find_package(QT NAMES Qt6 Qt5 REQUIRED COMPONENTS Widgets Sql) find_package(Qt${QT_VERSION_MAJOR} REQUIRED COMPONENTS Widgets Sql) target_link_libraries(HookeViscometer PRIVATE Qt${QT_VERSION_MAJOR}::Widgets Qt${QT_VERSION_MAJOR}::Sql) 蓝色部分为添加的Q…

13分钟聊聊并发包中常用同步组件并手写一个自定义同步组件

本篇文章通过AQS自己来实现一个同步组件,并从源码级别聊聊JUC并发包中的常用同步组件 本篇文章需要的前置知识就是AQS,阅读本篇文章大概需要13分钟 自定义同步组件 为了更容易理解其他同步组件,我们先来使用AQS自己来实现一个常用的可重入…

(数字图像处理MATLAB+Python)第十二章图像编码-第三、四节:有损编码和JPEG

文章目录 一:有损编码(1)预测编码A:概述B:DM编码C:最优预测器 (2)变换编码A:概述B:实现变换编码的主要问题 二:JPEG 一:有损编码 &am…

Kafka3.0.0版本——消费者(消费者总体工作流程图解)

一、消费者总体工作流程图解 角色划分:生产者、zookeeper、kafka集群、消费者、消费者组。如下图所示: 生产者发送消息给leader,followerr主动从leader同步数据,一个消费者可以消费某一个分区数据或者一个消费者可以消费多个分区数据。如下图…

9月6日上课内容 redis高可用

RDB 持久化 RDB持久化是指在指定的时间间隔内将内存中当前进程中的数据生成快照保存到硬盘(因此也称作快照持久化),用二进制压缩存储,保存的文件后缀是rdb;当Redis重新启动时,可以读取快照文件恢复数据。 1. 触发条件 RDB持久化…

解锁前端Vue3宝藏级资料 Vue3全面解析 第二章 Vue3 基础语法指令

本章主要介绍vue3中的基础指令使用方式和一些开发技巧。分为基础指令,逻辑指令,列表指令,事件,MVVM数据绑定与监听。本章中所有代码例子都是在使用Vite 创建的 vue项目中来完成的。 基础语法指令 2.1 基础指令2.1.1 设置变量2.1.2…

记一次生产环境服务卡死排查记录

接现场运维报告某java服务CPU狂飙,服务处于卡死无响应状态 询问现场运维什么场景造成的,答复是偶发现象,没有规律,和请求高峰期并没有关系。 因为服务是负载均衡的(A、B两台),临时处理让运维重…

【AIGC系列】Stable Diffusion 小白快速入门课程大纲

一、前言 本文是《Stable Diffusion 从入门到企业级应用实战》系列课程的前置学习引导部分,《Stable Diffusion新手完整学习地图课程》的课程大纲。该课程主要的培训对象是: 没有人工智能背景,想快速上手Stable Diffusion的初学者;想掌握St…

这些国外客户真直接

最近在某平台上遇到的客户,很大一部分都是非英语国家的客户,然而他们也有很多共性的习惯。 第一种:直接表达自己对这个产品感兴趣,然后接下来就没有下文了,而之所以可以看得懂,则是借助平台本身的翻译系统&…

三维数字沙盘电子沙盘虚拟现实模拟推演大数据人工智能开发教程第15课

三维数字沙盘电子沙盘虚拟现实模拟推演大数据人工智能开发教程第15课 现在不管什么GIS平台首先要解决的就是数据来源问题,因为没有数据的GIS就是一个空壳,下面我就目前一些主流的数据获取 方式了解做如下之我见(主要针对互联网上的一些卫星…

解锁前端Vue3宝藏级资料 第三章 Vue Router路由器的使用

Vue Router 是 Vue.js 的官方路由器。通过使用 Vue Router,你可以构建一个包含多个页面的应用程序。它可以样多个页面之间流畅地跳转,而无需每次移动到另一个页面时都要重新加载整个页面。Vue Router 路由是使用 Vue.js 构建单页应用项目的必备库。官网地…

IMAU鸿蒙北向开发-2023年9月6日学习日志

1. TextArea 基本使用 //TextArea 基本使用 Entry Component struct Index {State message: string Hello Worldbuild() {Column() {TextArea({placeholder: "请输入个人介绍",text: "个人介绍控制在200字以内。"}).margin({top: 100}).caretColor(Color…

【正版软件】Air Explorer - 一个程序访问您的所有云服务

前言:Air Explorer支持最好的云服务。 功能特点: 直接管理云中的文件 设置同一服务上的多个帐户 您可以在任何云服务或计算机之间同步文件夹 云文件浏览器易于使用 通过加入您的所有云服务来增加存储空间 应用程序适用于Windows/Mac Air Explorer…

Druid LogFilter输出可执行的SQL

配置 测试代码: DruidDataSource dataSource new DruidDataSource(); dataSource.setUrl("xxx"); dataSource.setUsername("xxx"); dataSource.setPassword("xxx"); dataSource.setFilters("slf4j"); dataSource.setVal…

2023年大数据平台数据安全厂商汇总

大数据时代,大数据平台数据安全至关重要,这关系着大家的切身安全。所以企业一定要慎重选择大数据平台数据安全厂商。这里给大家简单汇总一下,同时给大家推荐一下,仅供参考哦! 2023年大数据平台数据安全厂商汇总 1、…

Kafka3.0.0版本——消费者(消费方式)

目录 一、Kafka 消费方式1.1、pull(拉) 模式1.2、push (推)模式1.3、Kafka采用pull(拉) 模式缺点 一、Kafka 消费方式 1.1、pull(拉) 模式 consumer采用从broker中主动拉取数据。K…

【Python小项目之Tkinter应用】随机点名/抽奖工具大优化:实现背景图与其他组件自适应窗口大小变化并保持相对位置和比例不变

文章目录 前言一、需求分析与实现思路明确需求实现思路二、关键代码2.1 实现背景图随着窗口大小变化而变化2.2 更换place的参数三、完整代码四、总结4.1 意外收获前言 话不多说,直接看优化后的效果: 优化前: 是不是非常的哇塞,相比于之前只能固定窗口大小来运行,优化后…

基于SSM的营业厅宽带系统

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…