使用 TensorFlow 构建一个简单的卷积神经网络(CNN)模型,完成对 CIFAR-10 数据集的图像分类任务。
使用自动编码器作为特征提取器,先通过自动编码器对图像数据进行降维,将图像从高维映射到低维特征空间,然后将提取的特征传入到 CNN 进行分类。
对比在不使用自动编码器特征提取的情况下,直接使用 CNN 进行分类的模型性能。
# 导入必要的库
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, Flatten
# 加载CIFAR - 10数据集
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# 预处理数据
x_train = x_train / 255.0
x_test = x_test / 255.0
# 构建自动编码器
input_img = Input(shape=(32, 32, 3))
# 编码器
encoded = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
encoded = MaxPooling2D((2, 2), padding='same')(encoded)
# 解码器
decoded = Conv2D(16, (3, 3), activation='relu', padding='same')(encoded)
decoded = tf.keras.layers.UpSampling2D((2, 2))(decoded)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(decoded)
autoencoder = Model(input_img, decoded)
# 编译自动编码器
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
# 训练自动编码器
autoencoder.fit(x_train, x_train,
epochs=10,
batch_size=128,
validation_data=(x_test, x_test))
# 获取编码器部分
encoder = Model(input_img, encoded)
# 使用编码器提取特征
x_train_encoded = encoder.predict(x_train)
x_test_encoded = encoder.predict(x_test)
# 构建CNN分类器
input_features = Input(shape=x_train_encoded.shape[1:])
flatten = Flatten()(input_features)
dense1 = Dense(128, activation='relu')(flatten)
output = Dense(10, activation='softmax')(dense1)
classifier = Model(input_features, output)
# 编译CNN分类器
classifier.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 训练CNN分类器
classifier.fit(x_train_encoded, y_train,
epochs=10,
batch_size=128,
validation_data=(x_test_encoded, y_test))
# 预测第1000个数据的类别(假设x_test是测试数据)
prediction = classifier.predict(x_test_encoded[999:1000])
predicted_class = tf.argmax(prediction, axis=1).numpy()[0]
Epoch 1/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 16ms/step - loss: 0.6075 - val_loss: 0.5594
Epoch 2/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - loss: 0.5575 - val_loss: 0.5567
Epoch 3/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5555 - val_loss: 0.5559
Epoch 4/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5547 - val_loss: 0.5551
Epoch 5/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5544 - val_loss: 0.5547
Epoch 6/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5538 - val_loss: 0.5534
Epoch 7/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - loss: 0.5520 - val_loss: 0.5527
Epoch 8/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5519 - val_loss: 0.5525
Epoch 9/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5517 - val_loss: 0.5523
Epoch 10/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5506 - val_loss: 0.5522
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 821us/step
Epoch 1/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.2917 - loss: 1.9827 - val_accuracy: 0.4199 - val_loss: 1.6372
Epoch 2/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4255 - loss: 1.6262 - val_accuracy: 0.4405 - val_loss: 1.5675
Epoch 3/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4576 - loss: 1.5361 - val_accuracy: 0.4749 - val_loss: 1.5088
Epoch 4/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4828 - loss: 1.4765 - val_accuracy: 0.5027 - val_loss: 1.4356
Epoch 5/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4963 - loss: 1.4287 - val_accuracy: 0.5055 - val_loss: 1.4277
Epoch 6/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5055 - loss: 1.3981 - val_accuracy: 0.5067 - val_loss: 1.4073
Epoch 7/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5203 - loss: 1.3623 - val_accuracy: 0.5194 - val_loss: 1.3617
Epoch 8/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5338 - loss: 1.3217 - val_accuracy: 0.5246 - val_loss: 1.3555
Epoch 9/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5352 - loss: 1.3199 - val_accuracy: 0.5352 - val_loss: 1.3252
Epoch 10/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5379 - loss: 1.3143 - val_accuracy: 0.5144 - val_loss: 1.3900
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step
predicted_class
8
以下是使用 Python 语言结合 TensorFlow 库构建卷积神经网络(CNN)对 CIFAR-10 数据集进行图像分类,并获取第 1000 个数据预测类别的示例代码:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np
# 加载CIFAR-10数据集
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
# 归一化像素值到0-1范围
train_images, test_images = train_images / 255.0, test_images / 255.0
# 构建简单的CNN模型
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10)) # 输出层,对应10个类别(CIFAR-10有10类)
# 编译模型
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 训练模型
model.fit(train_images, train_labels, epochs=10,
validation_data=(test_images, test_labels))
# 对测试集进行预测
predictions = model.predict(test_images)
# 获取预测的类别(取概率最大的类别索引作为预测类别)
predicted_classes = np.argmax(predictions, axis=1)
# 获取第1000个数据的预测类别
print(predicted_classes[999])
Epoch 1/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 6ms/step - accuracy: 0.3424 - loss: 1.7848 - val_accuracy: 0.5325 - val_loss: 1.2917
Epoch 2/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.5668 - loss: 1.2166 - val_accuracy: 0.6098 - val_loss: 1.1129
Epoch 3/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 6ms/step - accuracy: 0.6391 - loss: 1.0275 - val_accuracy: 0.6399 - val_loss: 1.0059
Epoch 4/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 6ms/step - accuracy: 0.6676 - loss: 0.9352 - val_accuracy: 0.6690 - val_loss: 0.9239
Epoch 5/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7065 - loss: 0.8364 - val_accuracy: 0.6935 - val_loss: 0.8983
Epoch 6/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7272 - loss: 0.7781 - val_accuracy: 0.6914 - val_loss: 0.8845
Epoch 7/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7507 - loss: 0.7195 - val_accuracy: 0.6843 - val_loss: 0.9197
Epoch 8/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7609 - loss: 0.6772 - val_accuracy: 0.7024 - val_loss: 0.8741
Epoch 9/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7773 - loss: 0.6337 - val_accuracy: 0.7055 - val_loss: 0.8704
Epoch 10/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7888 - loss: 0.5993 - val_accuracy: 0.7074 - val_loss: 0.8714
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step
8