- 🍨 本文為🔗365天深度學習訓練營 中的學習紀錄博客
- 🍖 原作者:K同学啊 | 接輔導、項目定制
一、前期准备
1. 导入数据
# Import the required libraries
import pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Model
from matplotlib.ticker import MultipleLocator
from datetime import datetime
# Load the data
data_dir = './data/48-data/'
data_dir = pathlib.Path(data_dir)
data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[2] for path in data_paths]
image_count = len(list(data_dir.glob('*/*')))
print("Total number of images:", image_count)

二、数据预处理
1. 加载数据
# Data loading and preprocessing
batch_size = 16
img_height = 336
img_width = 336
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=12,
image_size=(img_height, img_width),
batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=12,
image_size=(img_height, img_width),
batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)
Nicole Kidman', 'Robert Downey Jr', 'Sandra Bullock', 'Scarlett Johansson', 'Tom Cruise', 'Tom Hanks', 'Will Smith']
2. 检查数据
# Check the shape of the data
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break

3. 配置数据集
AUTOTUNE = tf.data.AUTOTUNE
def train_preprocessing(image,label):
return (image/255.0,label)
train_ds = (
train_ds.cache()
.shuffle(1000)
.map(train_preprocessing) # 这里可以设置预处理函数
# .batch(batch_size) # 在image_dataset_from_directory处已经设置了batch_size
.prefetch(buffer_size=AUTOTUNE)
)
val_ds = (
val_ds.cache()
.shuffle(1000)
.map(train_preprocessing) # 这里可以设置预处理函数
# .batch(batch_size) # 在image_dataset_from_directory处已经设置了batch_size
.prefetch(buffer_size=AUTOTUNE)
)
4. 数据可视化
plt.rcParams['font.family'] = 'SimHei' # 设置字体为黑体(支持中文)
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
plt.figure(figsize=(10, 8)) # 图形的宽为10高为5
plt.suptitle("数据展示")
for images, labels in train_ds.take(1):
for i in range(15):
plt.subplot(4, 5, i + 1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
# 显示图片
plt.imshow(images[i])
# 显示标签
plt.xlabel(class_names[labels[i]-1])
plt.show()

三、训练模型
1. 构建模型
def create_model(optimizer='adam'):
# 加载预训练模型
vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',
include_top=False,
input_shape=(img_width, img_height, 3),
pooling='avg')
for layer in vgg16_base_model.layers:
layer.trainable = False
X = vgg16_base_model.output
X = Dense(170, activation='relu')(X)
X = BatchNormalization()(X)
X = Dropout(0.5)(X)
output = Dense(len(class_names), activation='softmax')(X)
vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)
vgg16_model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return vgg16_model
model1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())
model2.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
58889256/58889256 [==============================] - 5s 0us/step
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 336, 336, 3)] 0
block1_conv1 (Conv2D) (None, 336, 336, 64) 1792
block1_conv2 (Conv2D) (None, 336, 336, 64) 36928
block1_pool (MaxPooling2D) (None, 168, 168, 64) 0
block2_conv1 (Conv2D) (None, 168, 168, 128) 73856
block2_conv2 (Conv2D) (None, 168, 168, 128) 147584
block2_pool (MaxPooling2D) (None, 84, 84, 128) 0
block3_conv1 (Conv2D) (None, 84, 84, 256) 295168
block3_conv2 (Conv2D) (None, 84, 84, 256) 590080
block3_conv3 (Conv2D) (None, 84, 84, 256) 590080
block3_pool (MaxPooling2D) (None, 42, 42, 256) 0
block4_conv1 (Conv2D) (None, 42, 42, 512) 1180160
block4_conv2 (Conv2D) (None, 42, 42, 512) 2359808
block4_conv3 (Conv2D) (None, 42, 42, 512) 2359808
block4_pool (MaxPooling2D) (None, 21, 21, 512) 0
block5_conv1 (Conv2D) (None, 21, 21, 512) 2359808
block5_conv2 (Conv2D) (None, 21, 21, 512) 2359808
block5_conv3 (Conv2D) (None, 21, 21, 512) 2359808
block5_pool (MaxPooling2D) (None, 10, 10, 512) 0
global_average_pooling2d_1 (None, 512) 0
(GlobalAveragePooling2D)
dense_2 (Dense) (None, 170) 87210
batch_normalization_1 (Bat (None, 170) 680
chNormalization)
dropout_1 (Dropout) (None, 170) 0
dense_3 (Dense) (None, 17) 2907
=================================================================
Total params: 14805485 (56.48 MB)
Trainable params: 90457 (353.35 KB)
Non-trainable params: 14715028 (56.13 MB)
_________________________________________________________________
3. 训练模型
# Train the model
NO_EPOCHS = 50
history_model1 = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2 = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
Epoch 1/50
90/90 [==============================] - 1202s 13s/step - loss: 0.3176 - accuracy: 0.8965 - val_loss: 7.7180 - val_accuracy: 0.0583
Epoch 2/50
90/90 [==============================] - 1090s 12s/step - loss: 0.2925 - accuracy: 0.9167 - val_loss: 7.4216 - val_accuracy: 0.0472
Epoch 3/50
90/90 [==============================] - 1296s 14s/step - loss: 0.3077 - accuracy: 0.9125 - val_loss: 8.2351 - val_accuracy: 0.0583
Epoch 4/50
90/90 [==============================] - 1302s 14s/step - loss: 0.2624 - accuracy: 0.9326 - val_loss: 8.9317 - val_accuracy: 0.0583
Epoch 5/50
90/90 [==============================] - 1040s 12s/step - loss: 0.2837 - accuracy: 0.9174 - val_loss: 9.0407 - val_accuracy: 0.0583
Epoch 6/50
90/90 [==============================] - 961s 11s/step - loss: 0.2769 - accuracy: 0.9139 - val_loss: 8.2484 - val_accuracy: 0.0583
Epoch 7/50
90/90 [==============================] - 950s 11s/step - loss: 0.2749 - accuracy: 0.9160 - val_loss: 8.8199 - val_accuracy: 0.0444
Epoch 8/50
90/90 [==============================] - 934s 10s/step - loss: 0.2525 - accuracy: 0.9292 - val_loss: 8.1721 - val_accuracy: 0.0722
Epoch 9/50
90/90 [==============================] - 1260s 14s/step - loss: 0.2306 - accuracy: 0.9361 - val_loss: 8.6387 - val_accuracy: 0.0583
Epoch 10/50
90/90 [==============================] - 1429s 16s/step - loss: 0.2448 - accuracy: 0.9208 - val_loss: 9.7182 - val_accuracy: 0.0583
Epoch 11/50
90/90 [==============================] - 1044s 12s/step - loss: 0.2269 - accuracy: 0.9299 - val_loss: 10.4608 - val_accuracy: 0.0583
Epoch 12/50
90/90 [==============================] - 1352s 15s/step - loss: 0.2121 - accuracy: 0.9333 - val_loss: 9.2537 - val_accuracy: 0.0472
Epoch 13/50
90/90 [==============================] - 1969s 22s/step - loss: 0.2014 - accuracy: 0.9368 - val_loss: 9.2780 - val_accuracy: 0.0722
Epoch 14/50
90/90 [==============================] - 1372s 15s/step - loss: 0.1803 - accuracy: 0.9486 - val_loss: 9.4223 - val_accuracy: 0.0583
Epoch 15/50
90/90 [==============================] - 1460s 16s/step - loss: 0.1795 - accuracy: 0.9535 - val_loss: 8.9366 - val_accuracy: 0.0583
Epoch 16/50
90/90 [==============================] - 1409s 16s/step - loss: 0.2325 - accuracy: 0.9215 - val_loss: 10.3105 - val_accuracy: 0.0472
Epoch 17/50
90/90 [==============================] - 1353s 15s/step - loss: 0.2212 - accuracy: 0.9271 - val_loss: 9.2342 - val_accuracy: 0.0583
Epoch 18/50
90/90 [==============================] - 1201s 13s/step - loss: 0.1793 - accuracy: 0.9500 - val_loss: 9.9170 - val_accuracy: 0.0472
Epoch 19/50
90/90 [==============================] - 929s 10s/step - loss: 0.1930 - accuracy: 0.9354 - val_loss: 9.9911 - val_accuracy: 0.0583
Epoch 20/50
90/90 [==============================] - 13115s 147s/step - loss: 0.2122 - accuracy: 0.9333 - val_loss: 9.5141 - val_accuracy: 0.0750
Epoch 21/50
90/90 [==============================] - 849s 9s/step - loss: 0.2142 - accuracy: 0.9319 - val_loss: 9.9998 - val_accuracy: 0.0472
Epoch 22/50
90/90 [==============================] - 806s 9s/step - loss: 0.1790 - accuracy: 0.9417 - val_loss: 9.0953 - val_accuracy: 0.0583
Epoch 23/50
90/90 [==============================] - 953s 11s/step - loss: 0.1722 - accuracy: 0.9486 - val_loss: 10.1111 - val_accuracy: 0.0583
Epoch 24/50
90/90 [==============================] - 1117s 12s/step - loss: 0.1824 - accuracy: 0.9368 - val_loss: 11.0077 - val_accuracy: 0.0472
Epoch 25/50
90/90 [==============================] - 1111s 12s/step - loss: 0.1613 - accuracy: 0.9514 - val_loss: 11.9721 - val_accuracy: 0.0472
Epoch 26/50
90/90 [==============================] - 1148s 13s/step - loss: 0.1641 - accuracy: 0.9556 - val_loss: 12.8058 - val_accuracy: 0.0472
Epoch 27/50
90/90 [==============================] - 1227s 14s/step - loss: 0.1286 - accuracy: 0.9590 - val_loss: 10.5750 - val_accuracy: 0.0472
Epoch 28/50
90/90 [==============================] - 1191s 13s/step - loss: 0.1791 - accuracy: 0.9493 - val_loss: 12.0891 - val_accuracy: 0.0472
Epoch 29/50
90/90 [==============================] - 1191s 13s/step - loss: 0.1629 - accuracy: 0.9493 - val_loss: 11.8981 - val_accuracy: 0.0472
Epoch 30/50
90/90 [==============================] - 1234s 14s/step - loss: 0.1545 - accuracy: 0.9479 - val_loss: 10.4402 - val_accuracy: 0.0472
Epoch 31/50
90/90 [==============================] - 956s 11s/step - loss: 0.1687 - accuracy: 0.9507 - val_loss: 8.6383 - val_accuracy: 0.0472
Epoch 32/50
90/90 [==============================] - 896s 10s/step - loss: 0.1470 - accuracy: 0.9528 - val_loss: 12.8927 - val_accuracy: 0.0472
Epoch 33/50
90/90 [==============================] - 901s 10s/step - loss: 0.1373 - accuracy: 0.9556 - val_loss: 10.4122 - val_accuracy: 0.0472
Epoch 34/50
90/90 [==============================] - 899s 10s/step - loss: 0.1428 - accuracy: 0.9521 - val_loss: 11.1399 - val_accuracy: 0.0750
Epoch 35/50
90/90 [==============================] - 878s 10s/step - loss: 0.1343 - accuracy: 0.9583 - val_loss: 12.0714 - val_accuracy: 0.0583
Epoch 36/50
90/90 [==============================] - 886s 10s/step - loss: 0.1432 - accuracy: 0.9535 - val_loss: 12.5365 - val_accuracy: 0.0583
Epoch 37/50
90/90 [==============================] - 863s 10s/step - loss: 0.1337 - accuracy: 0.9569 - val_loss: 10.0840 - val_accuracy: 0.0583
Epoch 38/50
90/90 [==============================] - 889s 10s/step - loss: 0.1632 - accuracy: 0.9514 - val_loss: 9.1576 - val_accuracy: 0.0722
Epoch 39/50
90/90 [==============================] - 881s 10s/step - loss: 0.1418 - accuracy: 0.9549 - val_loss: 14.8210 - val_accuracy: 0.0583
Epoch 40/50
90/90 [==============================] - 890s 10s/step - loss: 0.1690 - accuracy: 0.9514 - val_loss: 11.0727 - val_accuracy: 0.0472
Epoch 41/50
90/90 [==============================] - 870s 10s/step - loss: 0.1260 - accuracy: 0.9701 - val_loss: 10.9087 - val_accuracy: 0.0583
Epoch 42/50
90/90 [==============================] - 868s 10s/step - loss: 0.1620 - accuracy: 0.9417 - val_loss: 18.5777 - val_accuracy: 0.0583
Epoch 43/50
90/90 [==============================] - 885s 10s/step - loss: 0.1554 - accuracy: 0.9444 - val_loss: 16.1502 - val_accuracy: 0.0583
Epoch 44/50
90/90 [==============================] - 861s 10s/step - loss: 0.1444 - accuracy: 0.9472 - val_loss: 11.4246 - val_accuracy: 0.0583
Epoch 45/50
90/90 [==============================] - 891s 10s/step - loss: 0.1707 - accuracy: 0.9479 - val_loss: 9.7772 - val_accuracy: 0.0472
Epoch 46/50
90/90 [==============================] - 871s 10s/step - loss: 0.1733 - accuracy: 0.9368 - val_loss: 11.6579 - val_accuracy: 0.0472
Epoch 47/50
90/90 [==============================] - 867s 10s/step - loss: 0.1455 - accuracy: 0.9521 - val_loss: 10.5239 - val_accuracy: 0.0722
Epoch 48/50
90/90 [==============================] - 886s 10s/step - loss: 0.1527 - accuracy: 0.9472 - val_loss: 12.6337 - val_accuracy: 0.0583
Epoch 49/50
90/90 [==============================] - 894s 10s/step - loss: 0.1689 - accuracy: 0.9451 - val_loss: 13.6906 - val_accuracy: 0.0583
Epoch 50/50
90/90 [==============================] - 882s 10s/step - loss: 0.1434 - accuracy: 0.9458 - val_loss: 11.2179 - val_accuracy: 0.0583
Epoch 1/50
90/90 [==============================] - 914s 10s/step - loss: 3.0652 - accuracy: 0.1132 - val_loss: 2.8820 - val_accuracy: 0.0417
Epoch 2/50
90/90 [==============================] - 855s 10s/step - loss: 2.4852 - accuracy: 0.2215 - val_loss: 2.9252 - val_accuracy: 0.0444
Epoch 3/50
90/90 [==============================] - 856s 10s/step - loss: 2.2494 - accuracy: 0.2639 - val_loss: 3.0725 - val_accuracy: 0.0417
Epoch 4/50
90/90 [==============================] - 865s 10s/step - loss: 2.0995 - accuracy: 0.3368 - val_loss: 3.3332 - val_accuracy: 0.0417
Epoch 5/50
90/90 [==============================] - 859s 10s/step - loss: 1.9039 - accuracy: 0.3833 - val_loss: 3.5608 - val_accuracy: 0.0444
Epoch 6/50
90/90 [==============================] - 871s 10s/step - loss: 1.7996 - accuracy: 0.4236 - val_loss: 4.3610 - val_accuracy: 0.0417
Epoch 7/50
90/90 [==============================] - 868s 10s/step - loss: 1.6905 - accuracy: 0.4313 - val_loss: 4.8573 - val_accuracy: 0.0417
Epoch 8/50
90/90 [==============================] - 875s 10s/step - loss: 1.6161 - accuracy: 0.4750 - val_loss: 5.4109 - val_accuracy: 0.0417
Epoch 9/50
90/90 [==============================] - 855s 10s/step - loss: 1.5523 - accuracy: 0.4889 - val_loss: 5.2799 - val_accuracy: 0.0417
Epoch 10/50
90/90 [==============================] - 855s 10s/step - loss: 1.4717 - accuracy: 0.5312 - val_loss: 5.2821 - val_accuracy: 0.0417
Epoch 11/50
90/90 [==============================] - 888s 10s/step - loss: 1.4668 - accuracy: 0.5257 - val_loss: 5.5069 - val_accuracy: 0.0417
Epoch 12/50
90/90 [==============================] - 890s 10s/step - loss: 1.3670 - accuracy: 0.5639 - val_loss: 5.6636 - val_accuracy: 0.0417
Epoch 13/50
90/90 [==============================] - 861s 10s/step - loss: 1.3412 - accuracy: 0.5618 - val_loss: 5.5362 - val_accuracy: 0.0417
Epoch 14/50
90/90 [==============================] - 885s 10s/step - loss: 1.2694 - accuracy: 0.5931 - val_loss: 5.9473 - val_accuracy: 0.0417
Epoch 15/50
90/90 [==============================] - 882s 10s/step - loss: 1.2464 - accuracy: 0.6062 - val_loss: 6.1568 - val_accuracy: 0.0417
Epoch 16/50
90/90 [==============================] - 890s 10s/step - loss: 1.1958 - accuracy: 0.6306 - val_loss: 5.9811 - val_accuracy: 0.0417
Epoch 17/50
90/90 [==============================] - 881s 10s/step - loss: 1.1817 - accuracy: 0.6257 - val_loss: 5.8977 - val_accuracy: 0.0417
Epoch 18/50
90/90 [==============================] - 885s 10s/step - loss: 1.1527 - accuracy: 0.6354 - val_loss: 5.9559 - val_accuracy: 0.0472
Epoch 19/50
90/90 [==============================] - 870s 10s/step - loss: 1.0981 - accuracy: 0.6507 - val_loss: 6.1796 - val_accuracy: 0.0417
Epoch 20/50
90/90 [==============================] - 873s 10s/step - loss: 1.0697 - accuracy: 0.6667 - val_loss: 5.8840 - val_accuracy: 0.0417
Epoch 21/50
90/90 [==============================] - 901s 10s/step - loss: 1.0661 - accuracy: 0.6646 - val_loss: 6.1797 - val_accuracy: 0.0472
Epoch 22/50
90/90 [==============================] - 879s 10s/step - loss: 0.9922 - accuracy: 0.6903 - val_loss: 6.2074 - val_accuracy: 0.0417
Epoch 23/50
90/90 [==============================] - 876s 10s/step - loss: 0.9992 - accuracy: 0.6806 - val_loss: 5.4473 - val_accuracy: 0.0417
Epoch 24/50
90/90 [==============================] - 905s 10s/step - loss: 0.9279 - accuracy: 0.7069 - val_loss: 5.5743 - val_accuracy: 0.0417
Epoch 25/50
90/90 [==============================] - 894s 10s/step - loss: 0.9319 - accuracy: 0.7118 - val_loss: 6.1316 - val_accuracy: 0.0472
Epoch 26/50
90/90 [==============================] - 927s 10s/step - loss: 0.8869 - accuracy: 0.7222 - val_loss: 6.0186 - val_accuracy: 0.0472
Epoch 27/50
90/90 [==============================] - 893s 10s/step - loss: 0.9086 - accuracy: 0.7118 - val_loss: 6.8811 - val_accuracy: 0.0417
Epoch 28/50
90/90 [==============================] - 877s 10s/step - loss: 0.8965 - accuracy: 0.7118 - val_loss: 6.9371 - val_accuracy: 0.0472
Epoch 29/50
90/90 [==============================] - 912s 10s/step - loss: 0.9026 - accuracy: 0.7194 - val_loss: 6.2633 - val_accuracy: 0.0417
Epoch 30/50
90/90 [==============================] - 906s 10s/step - loss: 0.8067 - accuracy: 0.7535 - val_loss: 6.3067 - val_accuracy: 0.0472
Epoch 31/50
90/90 [==============================] - 900s 10s/step - loss: 0.7955 - accuracy: 0.7556 - val_loss: 6.1450 - val_accuracy: 0.0472
Epoch 32/50
90/90 [==============================] - 918s 10s/step - loss: 0.7941 - accuracy: 0.7486 - val_loss: 6.2223 - val_accuracy: 0.0472
Epoch 33/50
90/90 [==============================] - 1473s 16s/step - loss: 0.7692 - accuracy: 0.7667 - val_loss: 6.2006 - val_accuracy: 0.0528
Epoch 34/50
90/90 [==============================] - 1436s 16s/step - loss: 0.7648 - accuracy: 0.7514 - val_loss: 6.1662 - val_accuracy: 0.0472
Epoch 35/50
90/90 [==============================] - 1386s 15s/step - loss: 0.7358 - accuracy: 0.7722 - val_loss: 6.1199 - val_accuracy: 0.0417
Epoch 36/50
90/90 [==============================] - 1033s 11s/step - loss: 0.7337 - accuracy: 0.7604 - val_loss: 6.4092 - val_accuracy: 0.0472
Epoch 37/50
90/90 [==============================] - 897s 10s/step - loss: 0.7166 - accuracy: 0.7743 - val_loss: 7.1209 - val_accuracy: 0.0472
Epoch 38/50
90/90 [==============================] - 897s 10s/step - loss: 0.6971 - accuracy: 0.7910 - val_loss: 6.5154 - val_accuracy: 0.0417
Epoch 39/50
90/90 [==============================] - 874s 10s/step - loss: 0.6958 - accuracy: 0.7833 - val_loss: 6.9477 - val_accuracy: 0.0472
Epoch 40/50
90/90 [==============================] - 1045s 12s/step - loss: 0.6516 - accuracy: 0.8049 - val_loss: 6.6442 - val_accuracy: 0.0472
Epoch 41/50
90/90 [==============================] - 1187s 13s/step - loss: 0.6481 - accuracy: 0.7903 - val_loss: 6.5062 - val_accuracy: 0.0472
Epoch 42/50
90/90 [==============================] - 975s 11s/step - loss: 0.6312 - accuracy: 0.8021 - val_loss: 6.6628 - val_accuracy: 0.0583
Epoch 43/50
90/90 [==============================] - 887s 10s/step - loss: 0.6247 - accuracy: 0.8042 - val_loss: 6.5811 - val_accuracy: 0.0417
Epoch 44/50
90/90 [==============================] - 898s 10s/step - loss: 0.6188 - accuracy: 0.7951 - val_loss: 6.3517 - val_accuracy: 0.0583
Epoch 45/50
90/90 [==============================] - 894s 10s/step - loss: 0.6151 - accuracy: 0.8139 - val_loss: 7.5465 - val_accuracy: 0.0583
Epoch 46/50
90/90 [==============================] - 911s 10s/step - loss: 0.5698 - accuracy: 0.8271 - val_loss: 7.7967 - val_accuracy: 0.0583
Epoch 47/50
90/90 [==============================] - 904s 10s/step - loss: 0.5727 - accuracy: 0.8188 - val_loss: 7.2678 - val_accuracy: 0.0417
Epoch 48/50
90/90 [==============================] - 887s 10s/step - loss: 0.5595 - accuracy: 0.8167 - val_loss: 7.5204 - val_accuracy: 0.0583
Epoch 49/50
90/90 [==============================] - 874s 10s/step - loss: 0.5318 - accuracy: 0.8299 - val_loss: 7.6148 - val_accuracy: 0.0583
Epoch 50/50
90/90 [==============================] - 1299s 15s/step - loss: 0.5296 - accuracy: 0.8313 - val_loss: 6.7918 - val_accuracy: 0.0417
四、模型评估
1. Loss与Accuracy图
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi'] = 300 #分辨率
current_time = datetime.now() # 获取当前时间
acc1 = history_model1.history['accuracy']
acc2 = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']
loss1 = history_model1.history['loss']
loss2 = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']
epochs_range = range(len(acc1))
plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))
plt.show()

2. 评估模型
def test_accuracy_report(model):
score = model.evaluate(val_ds, verbose=0)
print('Loss function: %s, accuracy:' % score[0], score[1])
test_accuracy_report(model2)
Loss function: 6.791763782501221, accuracy: 0.0416666679084301