深度学习基础知识-tf.keras实例:衣物图像多分类分类器

news2024/11/16 7:50:56

参考书籍:《Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, 2nd Edition (Aurelien Geron [Géron, Aurélien])》


在这里插入图片描述
本次使用的数据集是tf.keras.datasets.fashion_mnist,里面包含6w张图,涵盖10个分类。

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import pickle

'''
fashion_mnist = keras.datasets.fashion_mnist.load_data()
with open('fashion_mnist.pkl', 'wb') as f:
    pickle.dump(fashion_mnist, f)
'''
def load_data():
    with open('fashion_mnist.pkl', 'rb') as f:
        mnist = pickle.load(f)
    (X_train_full, y_train_full), (X_test, y_test) = mnist
    X_valid, X_train = X_train_full[:5000] / 255.0, X_train_full[5000:] / 255.0
    y_valid, y_train = y_train_full[:5000], y_train_full[5000:]
    X_test = X_test / 255.0
    return X_train, X_valid, X_test, y_train, y_valid, y_test

class_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
 "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
#print(class_names[y_train[0]]) # coat
# 查看
some_image = X_train[0]
plt.imshow(some_image, cmap="binary")
plt.axis("off")
plt.show()

随便拿一张来看:
在这里插入图片描述
构建网络:

'''
model = keras.models.Sequential()
# 28x28 -> 1x784 也可以用InputLayer(input_shape=[28, 28])
model.add(keras.layers.Flatten(input_shape=[28, 28]))
# 其他激活函数:https://keras.io/api/layers/activations/
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
# output layer. 10个输出
model.add(keras.layers.Dense(10, activation="softmax"))
'''
# 也可以这么写:
model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[28, 28]),
    keras.layers.Dense(300, activation="relu"),
    keras.layers.Dense(100, activation="relu"),
    keras.layers.Dense(10, activation="softmax")
])
# 下面235500 = 784 x 300 + 300, 前面表示每个input都要跑向300个节点,所以要给权重w。然后每个节点要加一个偏置b
# 30100 = 300 x 100 + 100
print(model.summary())
'''
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 300)               235500    
                                                                 
 dense_1 (Dense)             (None, 100)               30100     
                                                                 
 dense_2 (Dense)             (None, 10)                1010      
                                                                 
=================================================================
Total params: 266,610
Trainable params: 266,610
Non-trainable params: 0
_________________________________________________________________
'''

import pydot
keras.utils.plot_model(model, 'model.png')

在这里插入图片描述


pydot安装与plot_model报错的解决:

参考:https://blog.csdn.net/shangxiaqiusuo1/article/details/85283432

先下载 https://graphviz.gitlab.io/_pages/Download/windows/graphviz-2.38.msi
然后双击,安装到D:\Program Files (x86)\Graphviz2.38\

  1. 建立变量名GRAPHVIZ_DOT,值为D:\Program Files (x86)\Graphviz2.38\bin\dot.exe
  2. 在用户环境变量添加一个新的变量:建立变量名 GRAPHVIZ_INSTALL_DIR, 值为D:\Program Files (x86)\Graphviz2.38
  3. 在系统环境变量的PATH中添加Graphviz的bin目录路径,如D:\Program Files (x86)\Graphviz2.38\bin
pip install graphviz
pip install pydot
pip install pydot-ng

在python文件中输入import pydot,然后按住Ctrl+鼠标左键点击pydot,会进入pydot的源文件,然后找到 self.prog = ‘dot’ ,改成 self.prog = ‘dot.exe’

这样改完如果还不行,在python文件里添加:

import os
os.environ["PATH"] += os.pathsep + 'D:/Program Files (x86)/Graphviz2.38/bin/'

基本上这样就ok了。


# 用index或名字都可以access层
hidden1 = model.layers[1]
print(model.get_layer('dense') is hidden1)
weights, biases = hidden1.get_weights()
print(weights)
# bias最开始初始化为0
print(biases)

设置+训练模型:

# 使用这个loss是因为数据有10种离散的、互斥的标签
# optimizer=keras.optimizers.SGD(lr=xx)
# 这样可以设置学习率。default lr=0.01
model.compile(loss="sparse_categorical_crossentropy",
              optimizer="sgd",
              metrics=["accuracy"])
X_train, X_valid, X_test, y_train, y_valid, y_test = load_data() # 获取数据的函数,略
history = model.fit(X_train, y_train, epochs=30, validation_data=(X_valid, y_valid))
'''
Epoch 1/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.6992 - accuracy: 0.7704 - val_loss: 0.4999 - val_accuracy: 0.8378
Epoch 2/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.4876 - accuracy: 0.8311 - val_loss: 0.4659 - val_accuracy: 0.8364
Epoch 3/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.4440 - accuracy: 0.8449 - val_loss: 0.4394 - val_accuracy: 0.8480
...
Epoch 29/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.2330 - accuracy: 0.9160 - val_loss: 0.3047 - val_accuracy: 0.8906
Epoch 30/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.2291 - accuracy: 0.9167 - val_loss: 0.2969 - val_accuracy: 0.8938
'''

这里设置了30次循环,其实未必达到最优,也基本不会过拟合。

如果训练集是有偏的,比如某些类overrepresented,某些类underrepresented,那么在fit()前应该设置class_weight,给underrepresented类以更大的权重,overrepresented类以更小的权重。如果有的case需要格外注意,比如某些cases是专家标注,另外一些是普通标注的,那么可以用per-instance weights,即设置sample_weight。如果class_weight和sample_weight都设置了,keras会把它们相乘。另外,也可以为验证集单独设置sample weights。

另外,history.history是个字典,里面包含loss,accuracy,val_loss, val_accuracy(每个都是epochs个数据),所以可以画图:
(就是把上面打印的信息以图的方式反映出来)

import pandas as pd
import matplotlib.pyplot as plt

print(history.history)
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
# 'gca'代表Get Current Axis
plt.gca().set_ylim(0, 1) 
plt.show()

在这里插入图片描述
为什么在前几个epoch上,validation的结果看起来比train好?
因为validation error是在每个epoch结束时计算的,而training error是一个running mean,即在每个epoch运行时计算的,所以training的图像应该移半个epoch。此时前几个epoch的图像应该是比较近似的,甚至overlap。

等调完超参数(如learning rate, layer num, batch_size…),评估一下模型:

print(model.evaluate(X_test, y_test))
'''
loss, accuracy
[0.34293538331985474, 0.8896999955177307]
'''

注意:如果loss很大可能是X_test没有归一化。

保存模型:其他保存模型的方法:https://blog.csdn.net/qq_22841387/article/details/130194553

import joblib
joblib.dump(model, "my_model.joblib")

# 导入使用load,导入后可以使用新数据继续训练 model.fit()
model = joblib.load("my_model.joblib")

预测

X_new = X_test[:3]
y_proba = model.predict(X_new)
print(y_proba.round(2))
'''
[[0.   0.   0.   0.   0.   0.   0.   0.   0.   1.  ]
 [0.   0.   0.99 0.   0.01 0.   0.   0.   0.   0.  ]
 [0.   1.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
'''

import numpy as np
y_pred = model.predict(X_new)
# tf 2.6前可以用model.predict_classes(X_new), 2.6开始删除了该函数
labels = np.argmax(y_pred, axis=1)
print(labels) # [9 2 1]
# 显示分类名称
print(np.array(class_names)[labels]) # ['Ankle boot' 'Pullover' 'Trouser']

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/612239.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux中的ACL以及加固

ACL访问控制 // ACL:Access Control List 访问控制列表 // -p :以原始格式显示 ACL [rootzbx ~]# getfacl -p /root/ // 查看ACL策略 # file: /root/ # owner: root # group: root user::r-x group::r-x other::--- 设置ACL策略 // -m : 修改文件的ACL // -b : 表示删除所有的…

RedisGraph的图存储模型

1 overview 在RedisGraph的整体架构中,非常简略的概括了RedisGraph的图存储模型: RedisGraph使用DataBlock来存储node和edge的属性。RedisGraph使用稀疏矩阵来表示图,稀疏矩阵的存储格式为按行压缩的稀疏矩阵(Compressed Sparse…

同浏览器下多窗口进行跨源通信、同源通信

同浏览器下多窗口进行跨源通信、同源通信 多页面通信运用到了“发布订阅”的设计模式,一个页面发布指令,其他页面进行订阅并进行相应的行为操作! 一、跨源通信 window.postMessage() window.postMessage() 方法可以安全地实现跨源通信。通常…

Qt6之调用Windows下vc生成的动态链接库dll

Qt是跨平台工具,显然能和windows的动态库一起使用。 在Windows操作系统上,库以文件的形式存在,并且可以分为动态链接库(DLL) 和静态链接库两种。动态链接库文控以.dll为后缀名,静态链接库文控以.lib为后缀名。不管是动态链接库还是…

独立站卖家如何应对PayPal风险?3大策略教你安全收款!

PayPal是全球风险控制做得最好的第三方在线支付平台,PayPal付款是钱直接到卖家PayPal账户。但随着外贸交易的日益发展,恶意买家的问题也越来越多。如何防范风险,保证收款安全,成为独立站卖家们所关注的问题。下面为大家分享三种策…

背包DP-入门篇

目录 01背包: 完全背包: 多重背包: 分组背包: 01背包: [NOIP2005 普及组] 采药 - 洛谷https://www.luogu.com.cn/problem/P1048 01背包背景 在一个小山上,有个n个黄金和一个容量为w的背包,…

【Python】深度理解Class类、Object类、Type元类的概念和关系

深度理解Class类、Object类、Type元类 1.Class类、Object类、Type元类的表面关系2.Class、Object、Type解释3.关系详解4.那么如何看待object、type在Python面对对象概念中的一席之地呢?5.那么object、type扮演了什么样的角色呢?他们对class又分别做了什么…

【计算机组成】Cache与CPU的直接映射、全相联映射与组相联映射

一.Cache与CPU需要映射的原因 CPU准备访问内存时,会先问问cache存储器有没有已经提前准备好了数据,如果没有则再找内存要: 如果Cache刚好命中,则直接从Cache中读取数据: 如果Cache没有命中(Cache失效&#…

时序数据库InfluxDB快速入门使用

推荐博客: Influxdb中文文档 linux安装influxdb Influxdb安装、启动influxdb控制台、常用命令、Influx命令使用、Influx-sql使用举例、Influxdb的数据格式、Influxdb客户端工具 1.安装 1、influxdb数据库官网的下载链接: https://portal.influxdata.c…

如何利用MES系统进行生产防呆防错?

一、认识MES系统的防呆防错功能 首先,我们要清楚了解,什么是MES系统的防呆防错。MES系统防呆防错是指利用MES系统来避免生产过程中的错误和缺陷,保障生产排程和生产过程顺利进行的过程。MES系统防呆防错包括以下方面: 1. 自动识别…

relation-graph关系图谱组件2.0版本遇到的问题

前提:之前已经写过一篇1.1版本的问题,这里就不过多讲了(如果想要解决火狐低版本兼容,看那个就行) 这次主要讲的是和1.X版本的区别和一些其它问题 区别 参数名不同:以前的links>lines (虽然现在links也…

遇见未来,降低职场焦虑——中国人民大学与加拿大女王大学金融硕士来助力

身在职场的你有感到一丝丝的焦虑吗?偶尔的小焦虑可以作为我们工作中的动力,时刻提醒我们保持奋进。预见未来才能遇见未来,随着社会经济不断发展,没有什么是一成不变的。处于职场上升期的我们更要懂得未雨绸缪,增加自身…

ClickHouse集群搭建总结

简介 ClickHouse是俄罗斯最大的搜素引擎Yandex于2016年开源的列式数据库管理系统,使用C 语言编写, 主要应用于OLAP场景。 使用理由 在大数据量的情况下,能以很低的延迟返回查询结果。 笔者注: 在单机亿级数据量的场景下可以达到毫秒级的查询…

SpringCloudAlibaba 微服务生态

一 微服务架构 1.1 微服务 微服务其实是一种架构风格,我们在开发一个应用的时候这个应用应该是由一组小型服务组成,每个小型服务都运行在自己的进程内;小服务之间通过HTTP的方式进行互联互通。 1.2 微服务架构的常见问题 一旦采用微服务系…

ChatGPT 之后,B 端产品设计会迎来颠覆式革命吗?| Liga妙谈

近日,脑机接口公司 Neuralink 宣布,其植入式脑机接口设备首次人体临床研究已被准许启动。遥想当年,我们还嘲讽罗老师「动嘴做 PPT」,谁曾想不久后我们可能连嘴都不用动🙊。 脑机接口何时会引爆人机交互革命尚未可知&a…

简述三观;

文章目录 三观世界观人生观价值观三观不合怎么看三观不正: 教养育儿教育心智不成熟的表现 三观 指人生观,世界观和价值观; https://wenku.baidu.com/view/102a655fd4bbfd0a79563c1ec5da50e2534dd1d8.html?fraladdin664466&ind1&_wkts_1685949448098&…

深入理解API网关Kong:动态负载均衡配置

深入理解API网关Kong:动态负载均衡配置 背景 在 NGINX 中,负载均衡的配置主要在 upstream 指令中进行。upstream 指令用于定义一个服务器群组和负载均衡方法。客户端请求在这个服务器群组中进行分发。 NGINX 提供了以下几种负载均衡方法: …

python接口自动化 —— 什么是接口、接口优势、类型(详解)

简介 经常听别人说接口测试,接口测试自动化,但是你对接口,有多少了解和认识,知道什么是接口吗?它是用来做什么的,测试时候要注意什么?坦白的说,笔者之前也不是很清楚。接下来先看一下…

从简历被拒到收割 8 个高薪 offer,我用了 3 个月...

半年前我一个小老弟从外包离职了,本以为有两年经验进个一般的公司没有问题的,结果人家一看是外包出来的,面试问的问题也不是很懂,简历被拒了好几次。还好这个小老弟没有气馁,在论坛博客和里面的大佬虚心学习&#xff0…

地震勘探基础(八)之地震动校正

地震动校正 在地震资料数字处理过程中,速度分析,动校正和水平叠加三个处理内容是相互关联的。水平叠加是为了提高地震资料的信噪比,要想得到好的叠加效果,必须做好动校正。而做好动校正,需要进行准确的速度分析。只有…