Keras-3-实例1-二分类问题

news2024/11/22 18:03:57

1. 二分类问题

1.1 IMDB 数据集加载

IMDB 包含5w条严重两极分化的评论,数据集被分为 2.5w 训练数据 和 2.5w 测试数据,训练集和测试集中的正面和负面评论占比都是50%

from keras.datasets import imdb

(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)
## num_words=10000, 保留训练数据中前10000个最常出现的单词,舍弃低频词;
## train_data 和 test_data 都是评论构成的列表,而评论又由单词索引组成;
## train_labels 和 test_labels 都是 0 和 1 构成的列表;
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
17464789/17464789 [==============================] - 6s 0us/step
print(train_data[0]) ## 将一句话转化为由单词表索引构成的一句话,实现句子转化为向量;
print(test_labels[0])
[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]
0

1.2 数据集预处理

将整数序列转化为张量

## 将 整数序列 转化为 张量,输入网络
## 用的方法是 one-hot 编码

import numpy as np

def vectorize_sequences(sequences, dimension=10000):
    ## 单词表中有 10000 个高频词,所以 dimension=10000;
    results = np.zeros((len(sequences), dimension)) ## 初始化一个 len(sequences) * 10000 的零矩阵;
    for i, word in enumerate(sequences):
        results[i, word] = 1.
    return results

x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

## 将标签向量化
y_train = np.asarray(train_labels).astype("float32")
y_test = np.asarray(test_labels).astype("float32")
## 整数序列 转化为 张量后的结果
print(x_train.shape)
print(x_train[0])
print(y_test[0])
(25000, 10000)
[0. 1. 1. ... 0. 0. 0.]
0.0

1.3 构建网络

带有 relu 激活函数的 全连接层 (Dense)的简单堆叠。

Dense堆叠有几个关键问题:

  1. 网络有多少层?

  2. 每层有多少个隐藏单元

一个隐藏单元 (hidden unit)是该层表示空间的一个维度,比如隐藏单元为16,意思就是将输入数据投影到16维的表示空间中。
隐藏单元越多(即 更高维的表示空间),网络越能够学到更加复杂的表示,但相应的计算代价也变大。

## 定义模型

from keras import models
from keras import layers

model = models.Sequential() ## 构建线性堆叠的网络
model.add(layers.Dense(16, activation="relu", input_shape=(10000, )))
model.add(layers.Dense(16, activation="relu")) ## 16: 该层隐藏单元的个数(也是表示空间的维度);relu: 激活函数(将所有的负值归零)
model.add(layers.Dense(1, activation="sigmoid")) ## sigmoid: 激活函数(将任意值压缩到 [0,1])


## 编译模型
model.compile(optimizer="rmsprop",
              loss="binary_crossentropy",
              metrics=["accuracy"]) ## 损失函数为 二元交叉熵,优化器为 rmsprop,模型的评估指标用的是 accuracy

Metal device set to: Apple M1

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB

1.4 训练模型

## 留出验证集 (原始训练数据集中留出 10000 个样本作为验证集,剩下的作为训练集)
x_val = x_train[:10000] ## 验证集数据
partial_x_train = x_train[10000:] ## 训练集数据

y_val = y_train[:10000] ## 验证集标签
partial_y_train = y_train[10000:] ## 训练集标签


## 训练模型
history = model.fit(partial_x_train,
                    partial_y_train,
                    epochs=20, ## 迭代次数
                    batch_size=512, ## 批量大小
                    validation_data=(x_val, y_val)) ## 模型在验证集上的损失和精度

## model.fit() 返回一个 History 对象,该对象有一个 history 成员,它是一个 字典,包含训练过程中的所有数据
Epoch 1/20


2023-06-06 21:55:20.911277: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


30/30 [==============================] - 3s 47ms/step - loss: 0.5334 - accuracy: 0.7871 - val_loss: 0.4422 - val_accuracy: 0.8205
Epoch 2/20
30/30 [==============================] - 1s 19ms/step - loss: 0.3317 - accuracy: 0.8979 - val_loss: 0.3197 - val_accuracy: 0.8862
Epoch 3/20
30/30 [==============================] - 1s 17ms/step - loss: 0.2383 - accuracy: 0.9252 - val_loss: 0.2873 - val_accuracy: 0.8876
Epoch 4/20
30/30 [==============================] - 0s 17ms/step - loss: 0.1843 - accuracy: 0.9424 - val_loss: 0.2761 - val_accuracy: 0.8904
Epoch 5/20
30/30 [==============================] - 0s 16ms/step - loss: 0.1475 - accuracy: 0.9540 - val_loss: 0.2824 - val_accuracy: 0.8876
Epoch 6/20
30/30 [==============================] - 0s 16ms/step - loss: 0.1253 - accuracy: 0.9621 - val_loss: 0.2898 - val_accuracy: 0.8868
Epoch 7/20
30/30 [==============================] - 0s 17ms/step - loss: 0.1007 - accuracy: 0.9719 - val_loss: 0.3063 - val_accuracy: 0.8840
Epoch 8/20
30/30 [==============================] - 0s 17ms/step - loss: 0.0826 - accuracy: 0.9777 - val_loss: 0.3309 - val_accuracy: 0.8778
Epoch 9/20
30/30 [==============================] - 0s 16ms/step - loss: 0.0692 - accuracy: 0.9819 - val_loss: 0.3497 - val_accuracy: 0.8786
Epoch 10/20
30/30 [==============================] - 0s 16ms/step - loss: 0.0544 - accuracy: 0.9868 - val_loss: 0.3707 - val_accuracy: 0.8780
Epoch 11/20
30/30 [==============================] - 0s 16ms/step - loss: 0.0446 - accuracy: 0.9899 - val_loss: 0.4029 - val_accuracy: 0.8761
Epoch 12/20
30/30 [==============================] - 1s 17ms/step - loss: 0.0338 - accuracy: 0.9935 - val_loss: 0.4364 - val_accuracy: 0.8742
Epoch 13/20
30/30 [==============================] - 1s 18ms/step - loss: 0.0315 - accuracy: 0.9932 - val_loss: 0.4550 - val_accuracy: 0.8749
Epoch 14/20
30/30 [==============================] - 1s 19ms/step - loss: 0.0181 - accuracy: 0.9978 - val_loss: 0.4940 - val_accuracy: 0.8726
Epoch 15/20
30/30 [==============================] - 1s 18ms/step - loss: 0.0173 - accuracy: 0.9979 - val_loss: 0.5231 - val_accuracy: 0.8727
Epoch 16/20
30/30 [==============================] - 1s 18ms/step - loss: 0.0137 - accuracy: 0.9977 - val_loss: 0.5800 - val_accuracy: 0.8648
Epoch 17/20
30/30 [==============================] - 1s 18ms/step - loss: 0.0076 - accuracy: 0.9997 - val_loss: 0.6507 - val_accuracy: 0.8583
Epoch 18/20
30/30 [==============================] - 0s 16ms/step - loss: 0.0092 - accuracy: 0.9991 - val_loss: 0.6157 - val_accuracy: 0.8694
Epoch 19/20
30/30 [==============================] - 0s 16ms/step - loss: 0.0041 - accuracy: 0.9999 - val_loss: 0.6636 - val_accuracy: 0.8658
Epoch 20/20
30/30 [==============================] - 0s 16ms/step - loss: 0.0047 - accuracy: 0.9996 - val_loss: 0.6849 - val_accuracy: 0.8677
## history 中包含 训练过程和验证过程 中监控的指标(损失和精度)。
history_dict = history.history
history_dict.keys()
dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])

1.5 可视化 监控指标

## 训练损失和验证损失
import matplotlib.pyplot as plt

loss_values = history_dict["loss"]
val_loss_values = history_dict["val_loss"]

epochs = range(1, len(loss_values)+1)

plt.plot(epochs, loss_values, "bo", label="Training loss") ## "bo" 表示蓝色圆点
plt.plot(epochs, val_loss_values, "b", label="Validation loss") ## "bo" 表示蓝色实线
plt.title("Training and validation loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

plt.show()

在这里插入图片描述

## 练精度和验证精度
acc_values = history_dict["accuracy"]
val_acc_values = history_dict["val_accuracy"]

plt.plot(epochs, acc_values, "bo", label="Training accuracy") ## "bo" 表示蓝色圆点
plt.plot(epochs, val_acc_values, "b", label="Validation accuracy") ## "bo" 表示蓝色实线
plt.title("Training and validation accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()

plt.show()

在这里插入图片描述

根据可视化结果可以看出:训练损失每轮都在降低,训练精度每轮都在升高。但验证损失和验证精度似乎在第4轮达到最佳值,所以训练次数增加,模型可能出现过拟合的问题。

为了防止过拟合,可以在第3轮之后停止训练。

1.6 从头开始重新训练模型

## 构建模型
model = models.Sequential()
model.add(layers.Dense(16, activation="relu", input_shape=(10000, )))
model.add(layers.Dense(16, activation="relu"))
model.add(layers.Dense(1, activation="sigmoid"))

## 编译模型
model.compile(optimizer="rmsprop",
              loss="binary_crossentropy",
              metrics=["accuracy"])

## 训练模型
model.fit(x_train, y_train, epochs=4, batch_size=512) ## 只训练4次
results = model.evaluate(x_test, y_test) ## 模型在测试集上进行评估

print(results)
Epoch 1/4
49/49 [==============================] - 1s 15ms/step - loss: 0.4761 - accuracy: 0.8302
Epoch 2/4
49/49 [==============================] - 1s 12ms/step - loss: 0.2805 - accuracy: 0.9044
Epoch 3/4
49/49 [==============================] - 1s 11ms/step - loss: 0.2104 - accuracy: 0.9257
Epoch 4/4
49/49 [==============================] - 1s 11ms/step - loss: 0.1748 - accuracy: 0.9380
782/782 [==============================] - 4s 5ms/step - loss: 0.2906 - accuracy: 0.8833
[0.29057276248931885, 0.8833200335502625]

1.7 使用训练好的模型在新数据集上生成预测结果

model.predict(x_test)
782/782 [==============================] - 3s 3ms/step





array([[0.2139433],
       [0.9987571],
       [0.7920793],
       ...,
       [0.0869531],
       [0.0685806],
       [0.5763908]], dtype=float32)

1.8 小结

  1. 需要对原始数据进行预处理,将单词序列转化为张量 (word embedding);
  2. 中间层每层都要用 激活函数;
  3. 对于二分类问题,最后一层应该只有一个 隐藏单元,并只用 Sigmoid 激活最后一层,使得输出值是 0~1 之间的标量,表示概率值;
  4. 对于二分类问题的 sigmoid 输出,应该用 二元交叉熵 (binary_corssentropy) 作为 损失函数;
  5. 模型训练次数增多,可能会出现过拟合的问题,所以要根据相关指标确定最佳训练次数;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/617399.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UE5 Chaos破碎系统学习1

在UE5中,Chaos破碎系统被直接进行了整合,本篇文章就来讲讲chaos的基础使用。 1.基础破碎 1.首先选中需要进行破碎的模型,例如这里选择一个Box,然后切换至Fracture Mode(破碎模式): 2.点击右侧…

JAVA实现打字练习软件

转眼已经学了一学期的java了,老师让我们根据所学知识点写一个打字练习软件的综合练习。一开始我也不是很有思路,我找了一下发现csdn上关于这个小项目的代码也不算很多,所以我最后自己在csdn查了一些资料,写了这么一个简略版本的打…

【C++】——list的介绍及模拟实现

文章目录 1. 前言2. list的介绍3. list的常用接口3.1 list的构造函数3.2 iterator的使用3.3 list的空间管理3.4 list的结点访问3.5 list的增删查改 4. list迭代器失效的问题5. list模拟实现6. list与vector的对比7. 结尾 1. 前言 我们之前已经学习了string和vector&#xff0c…

Remix IDE已支持Sui Move在线开发

网页版Remix IDE与WELLDONE Code插件结合,让您无需本地设置或安装即可开始构建Sui应用程序。 不熟悉Sui的构建者可能想在正式配置开发环境之前,浅尝一下构建Sui应用程序。Remix IDE与WELLDONE Code插件组合,即可帮助构建者实现从浏览器窗口开…

JavaScript函数的增强知识

函数属性和arguments以及剩余参数 函数属性name与length ◼ 我们知道JavaScript中函数也是一个对象,那么对象中就可以有属性和方法。 ◼ 属性name:一个函数的名词我们可以通过name来访问; // 自定义属性foo.message "Hello Foo"…

Nginx 之 Tomcat 负载均衡、动静分离

一.详细安装及操作实例(Nginx 七层代理) 首先至少准备三台服务器 Nginx 服务器:192.168.247.131:80 Tomcat服务器1:192.168.247.133:80 Tomcat服务器2:192.168.247.134:8080 192.168.247.134:80811.部署Nginx 负载均…

微信自动回复怎么设置呢?

友友们 你们是否有以下这些烦恼 1、每天要手动点击“添加”按钮多次以通过大量好友? 2、你是否经常需要在多个微信帐号之间来回切换? 3、你的回复速度慢,导致客户流失率高? 4、为了及时回复,你总是需要带着多部手机出门&…

二十一、C++11(中)

文章目录 一、左值&右值(一)基本概念1.左值是什么2.右值是什么 (二)左值引用和右值引用1.左值引用2.右值引用 二、右值引用使用场景和意义(一)引入(二)左值引用的使用场景&#…

Linux编译器(gcc/g++)调试器gdb项目自动化构建工具(make/Makefile)版本管理git

Linux编译器-gcc/g&&调试器gdb&&项目自动化构建工具-make/Makefile&&版本管理git 🔆gcc/g的使用可执行文件的"生产"过程gcc如何完成预处理编译汇编链接 函数库函数库一般分为静态库和动态库两种静态C/C库的安装 gcc选项gcc选项记…

WPF 学习:如何照着MaterialDesign的Demo学习

文章目录 往期回顾对应视频资源如何照着wpf项目学习找到你想要抄的页面查找对应源码演示示例如何认清页面元素抄袭实战 项目地址总结 往期回顾 WPF Debug运行是 实时可视化树无效,无法查看代码 WPF MaterialDesign 初学项目实战(0):github …

【Java】线程池的概念及使用、ThreadPoolExecutor的构造方法

什么是线程池为什么用线程池JDK提供的线程池工厂模式如何使用 自定义线程池ThreadPoolExecutor类的构造方法工作原理拒绝策略 线程池的使用 什么是线程池 在之前JDBC编程中,通过DataSource获取Connection的时候就已经用到了池的概念。这里的池指的是数据库连接池。…

Vue电商项目--uuid游客身份获取购物车数据

uuid游客身份获取购物车数据 获取购物车列表 请求地址 /api/cart/cartList 请求方式 GET 参数类型 参数名称 类型 是否必选 描述 无 无 无 无 返回示例 成功: { "code": 200, "message": "成功", "…

马尔萨斯 ( Malthus)人口指数增长模型Logistic 模型

3.要求与任务 从 1790 — 1990 年间美国每隔 10 年的人口记录如下表所示: 用以上数据检验马尔萨斯 ( Malthus)人口指数增长模型,根据检验结果进一步讨论马尔萨斯 人口模型的改进,并利用至少两种模型来预测美国2010 年的人口数量。 提示 1 &…

自学黑客(网络安全),一般人我还是劝你算了吧

作为从16年接触网络安全的小白,谈谈零基础如何入门网络安全,有不对的地方,请多多指教。 这些年最后悔的事情莫过于没有把自己学习的东西积累下来形成一个知识体系。 后续我也会陆续的整理网络安全的相关学习资料及文章,与大家一…

数据结构与算法练习(三)二叉树

文章目录 1、树2、二叉树3、满二叉树4、完全二叉树5、二叉树的遍历(前序、中序、后序)二叉树删除节点或树 6、顺序存储二叉树顺序存储二叉树遍历(前序、中序、后序) 7、线索化二叉树中序线索二叉树前序线索二叉树后序线索二叉树 1…

Matlab 之 Curve Fitting APP 使用笔记

文章目录 Part.I IntroductionPart.II 使用笔记Chap.I 拟合函数Chap.II 注意事项 Part.I Introduction 曲线或曲面拟合获取拟合参数。本篇博文主要记录一下 Matlab 拟合 APP Curve Fitting 的使用方法。 Part.II 使用笔记 这个APP用来做拟合的,包括二维数据的线拟…

常见的样本统计量及其数字特征

常见的样本统计量及其数字特征 下图来自《统计学图鉴》 样本统计量有什么作用? 因为总体特征包含有总体均值、总体方差等特征,我们在用样本推断总体时,其实就是用样本特征去估计总体特征,例如:样本均值这个统计量的期…

案例33:基于Springboot名城小区物业管理系统开题报告设计

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

Spark RDD统计每日新增用户

文章目录 一,提出任务二,实现思路三,准备工作1、在本地创建用户文件2、将用户文件上传到HDFS指定位置 四,完成任务1、在Spark Shell里完成任务(1)读取文件,得到RDD(2)倒排…

为什么要对实体类进行序列化并且要生成序列化ID?

一、为什么要对实体类进行序列化且要生成序列化ID 在Java开发中,实体类将会被用来与其他对象进行交互。Java语言是面向对象的,所以实体类包含了很多信息和方法。序列化是Java中一种将对象转换为字节流的机制,使得对象可以在网络上传输和存储。…