✨ 博客主页:小小马车夫的主页
✨ 所属专栏:Tensorflow
文章目录
- 前言
- 一、环境
- 二、fashion_mnist数据集介绍
- 三、fashion_mnist数据集下载和展示
- 四、数据预处理
- 五、构建模型和训练模型
- 六、模型预测
- 总结
前言
前面介绍mnist手写数字集训练,本文对fashion_mnist
数据集训练和预测进行简要介绍。
一、环境
MacOS: 13.0
Python: 3.9.13
Tensorflow: 2.11.0
二、fashion_mnist数据集介绍
fashion_mnist数据集和mnist数据集类似,都是28x28的灰度图片,区分是fashion_mnist数据集是服装图片,具体分类如下图:
分类 | 英文描述 | 中文描述 |
---|---|---|
0 | t-shirt | T恤 |
1 | trouser | 牛仔裤 |
2 | pullover | 套衫 |
3 | dress | 裙子 |
4 | coat | 外套 |
5 | sandal | 凉鞋 |
6 | shirt | 衬衫 |
7 | sneaker | 运动鞋 |
8 | bag | 包 |
9 | ankle boot | 短靴 |
三、fashion_mnist数据集下载和展示
运用tensorflow下载fashion_mnist数据集与mnist类似,代码如下:
import tensorflow as tf
from tensorflow import keras
import numpy as np
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print(train_images.shape, train_labels.shape)
print(test_images.shape, test_labels.shape)
输出:
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
可以看到训练集是60000张28x28的灰度图,测试集是10000张28x28的灰度图。
一些样例展示如下:
四、数据预处理
数据预处理主要是对图片归一化处理,如下:
train_images=train_images / 255.
test_images = test_images / 255.
五、构建模型和训练模型
模型构建
model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))
model.add(keras.layers.Dense(128, activation=tf.nn.relu))
model.add(keras.layers.Dense(10, activation=tf.nn.softmax))
model.summary()
模型结构如下:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
dense (Dense) (None, 128) 100480
dense_1 (Dense) (None, 10) 1290
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
模型训练
class MyCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
#loss小于0.25就停止训练
if logs.get('loss') < 0.25:
self.model.stop_training = True
callbacks = MyCallback()
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.sparse_categorical_crossentropy, metrics=['acc'])
h = model.fit(train_images, train_labels, batch_size=32, epochs=15, validation_data=(test_images_scaled, test_labels), callbacks=[callbacks])
查看结果
Epoch 1/15
1875/1875 [==============================] - 11s 5ms/step - loss: 0.5031 - acc: 0.8239 - val_loss: 0.4201 - val_acc: 0.8499
Epoch 2/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3774 - acc: 0.8648 - val_loss: 0.4333 - val_acc: 0.8482
Epoch 3/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3371 - acc: 0.8773 - val_loss: 0.3662 - val_acc: 0.8667
Epoch 4/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.3145 - acc: 0.8845 - val_loss: 0.3697 - val_acc: 0.8667
Epoch 5/15
1875/1875 [==============================] - 10s 5ms/step - loss: 0.2929 - acc: 0.8921 - val_loss: 0.3404 - val_acc: 0.8794
Epoch 6/15
1875/1875 [==============================] - 10s 5ms/step - loss: 0.2805 - acc: 0.8958 - val_loss: 0.3453 - val_acc: 0.8793
Epoch 7/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2683 - acc: 0.9009 - val_loss: 0.3452 - val_acc: 0.8778
Epoch 8/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2566 - acc: 0.9032 - val_loss: 0.3370 - val_acc: 0.8820
Epoch 9/15
1875/1875 [==============================] - 9s 5ms/step - loss: 0.2480 - acc: 0.9065 - val_loss: 0.3482 - val_acc: 0.8789
用图标显示损失曲线和准确率曲线
loss_list = h.history['loss']
acc_list = h.history['acc']
test_loss_list = h.history['val_loss']
test_acc_list = h.history['val_acc']
plt.rcParams['font.sans-serif'] = ['Songti SC']
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(20, 10))
plt.subplot(221)
plt.ylabel('loss')
plt.plot(loss_list, color='blue', marker='.', label='train_loss')
plt.plot(test_loss_list, color='red', marker='.', label='val_loss')
plt.legend(loc='upper left')
plt.title('损失曲线', fontsize=16)
plt.subplot(222)
plt.ylabel('acc')
plt.plot(acc_list, color='blue', marker='.', label='train_acc')
plt.plot(test_acc_list, color='red', marker='.', label='val_acc')
plt.legend(loc='upper left')
plt.title('准确率曲线', fontsize=16)
plt.show()
输出:
六、模型预测
选一个图像进行预测:
image = tf.cast(test_images[1], tf.float32)
image = tf.reshape(image, [1, 28, 28])
np.argmax(model.predict(image))
print(test_labels[1])
plt.imshow(test_images[1])
plt.show()
输出:
1/1 [==============================] - 0s 408ms/step
2
总结
本文主要介绍了tensorflow fashion_mnist的下载、训练、预测,模型用的全连接网络。
如果觉得有些帮助或觉得文章还不错,请关注一下博主,你的关注是我持续写作的动力。另外,如果有什么问题,可以在评论区留言,或者私信博主,博主看到后会第一时间进行回复。
【间歇性的努力和蒙混过日子,都是对之前努力的清零】
欢迎转载,转载请注明出处:https://blog.csdn.net/xxm524/article/details/128160073