1、查看tensorflow版本
import tensorflow as tf
print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())
2、加载fashion_mnist数据与预处理
import numpy as np
(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.fashion_mnist.load_data()
# print(train_images.shape) # (60000, 28, 28)
# print(train_labels.shape) # (60000,)
# print(test_images.shape) # (10000, 28, 28)
# print(test_labels.shape) # (10000,)
train_images = np.expand_dims(train_images, -1)
# print(train_images.shape) # (个数, hight, width,channels)=(60000, 28, 28, 1)
3、CNN模型构建
from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2D
model = tf.keras.Sequential()
model.add(Input(shape=(28,28,1))) # train_images.shape[1:]
model.add(Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same')) # 增加filter个数,增加模型拟合能力
model.add(Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'))
model.add(MaxPool2D()) # 默认2*2. 池化层扩大视野
model.add(Dropout(0.2)) # 防止过拟合
model.add(Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'))
model.add(Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'))
model.add(MaxPool2D()) # 默认2*2
model.add(Dropout(0.2)) # 防止过拟合
model.add(Conv2D(filters=256,kernel_size=(3,3),activation='relu'))
model.add(GlobalAvgPool2D()) # 全局平均池化
model.add(Dense(10,activation='softmax'))
model.summary()
4、模型配置与训练
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['acc'])
H = model.fit(x=train_images,
y=train_labels,
validation_split=0.2,
# validation_data=(X_test,y_test),
epochs=10,
batch_size=64,
verbose=1)
5、损失函数和准确率分析
根据损失函数和准确率,判断模型是否过拟合或者欠拟合,不断调整网络结构,使得模型最优。
import matplotlib.pyplot as plt
fig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')
plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')