使用的数据集卫星图像有两类,airplane和lake,每个类别样本量各700张,大小为256*256,RGB三通道彩色卫星影像。搭建深度卷积神经网络,实现卫星影像二分类。
数据链接百度网盘地址,提取码: cq47
1、查看tensorflow版本
import tensorflow as tf
print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())
2、加载并显示训练数据
从文件夹中获取所有数据路径
import glob
import random
all_image_path = glob.glob('./data/air_lake_dataset/*/*.jpg') # glob相比于pathlib更简洁
random.shuffle(all_image_path)
读取并处理图像
def load_and_preprocess_image(path):
img_raw = tf.io.read_file(path)
img_tensor = tf.image.decode_jpeg(img_raw,channels=3)
img_tensor = tf.image.resize(img_tensor,[256,256])
img_tensor = tf.cast(img_tensor,tf.float32)
img_tensor = img_tensor/255
return img_tensor
处理标签
label_to_index = {'airplane':0,'lake':1}
index_to_label = dict((v,k) for k,v in label_to_index.items())
labels = [label_to_index.get(img.split('/')[3]) for img in all_image_path]
显示卫星影像
import matplotlib.pyplot as plt
def plot_images_lables(all_image_path,labels,start_idx,num=5):
fig = plt.gcf()
fig.set_size_inches(12,14)
images = [load_and_preprocess_image(img_path) for img_path in all_image_path[start_idx:start_idx+5]]
for i in range(num):
ax = plt.subplot(1,num,1+i)
ax.imshow(images[i])
title = 'label=' + index_to_label.get(labels[start_idx+i])
ax.set_title(title,fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
plot_images_lables(all_image_path,labels,0,5)
4、使用tf.data.Dataset制作训练/测试数据
制作 Dataset
img_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
img_ds = img_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
img_label_ds = tf.data.Dataset.zip((img_ds,label_ds))
训练集、测试集划分
test_count = int(len(labels)*0.2)
train_count = len(labels) - test_count
train_ds = img_label_ds.skip(test_count)
test_ds = img_label_ds.take(test_count)
分批次加载数据
BATCH_SIZE = 16
train_ds = train_ds.repeat().shuffle(100).batch(BATCH_SIZE)
test_ds = test_ds.repeat().batch(BATCH_SIZE)
5、CNN模型构建
from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2D
model = tf.keras.Sequential([
Input(shape=(256,256,3)),
Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'), # 增加filter个数,增加模型拟合能力
Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),
MaxPool2D(), # 默认2*2. 池化层扩大视野
Dropout(0.2), # 防止过拟合
Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
MaxPool2D(),
Dropout(0.2),
Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
MaxPool2D(),
Dropout(0.2),
Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
GlobalAvgPool2D(), # 全局平均池化
Dense(1024,activation='relu'),
Dense(256,activation='relu'),
Dense(1,activation='sigmoid')
])
model.summary()
6、模型编译与训练
model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), # 已经使用sigmoid激活过了
metrics=['acc'])
steps_per_epoch = train_count//BATCH_SIZE
val_step = test_count//BATCH_SIZE
H = model.fit(train_ds,
epochs=10,
steps_per_epoch=steps_per_epoch,
validation_data=test_ds,
validation_steps=val_step,
verbose=1)
7、模型评估
import matplotlib.pyplot as plt
fig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')
plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
plt.show()
8、模型预测
def pred_img(img_path):
img = load_and_preprocess_image(img_path)
img = tf.expand_dims(img, axis=0)
pred = model.predict(img)
pred = index_to_label.get((pred>0.5).astype('int')[0][0])
return pred
img_path = './data/air_lake_dataset/airplane/airplane_240.jpg'
pred = pred_img(img_path)
img_tensor = load_and_preprocess_image(img_path)
plt.imshow(img_tensor)
title = 'label=' + img_path.split('/')[3].strip() + ', pred=' + pred
plt.title(title)
plt.show()