重要提示:由于树莓派相对孱弱的性能,直接在其上训练模型可能花(lang4)费非常长的时间。本文仅作为示例性的可行性参考,请酌情考虑实验平台。
著名的Tensorflow框架也可以运行在树莓派上。理论还没吃透,但使用Sequential模型体验图片分类的代码已经可以跑通。以下就是一个不求甚解版的笔记:
首先安装Tensorflow:
sudo apt install libatlas-base-dev
python3 -m pip install tensorflow -i https://pypi.tuna.tsinghua.edu.cn/simple
python3 -m pip install tensorflow-io -i https://pypi.tuna.tsinghua.edu.cn/simple
然后,预先准备好Demo中用到的花卉分类图片包(当然也可以在代码运行过程中现下,只是我个人不太推荐):
cd
wget https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
然后就可以写python代码了。首先是制作训练并制作模型的代码(出于快速体验以及纯ssh可操作的目的已略去了所有训练过程的可视化评估的代码。具体可参见文末的参考资料):
# import all libs
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
# here is the pre-downloaded demo package
dataset_url = "file:///home/ki/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.jpg')))
print("Total Files: " + str(image_count))
# Training split
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(180, 180),
batch_size=32)
# Testing or Validation split
val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(180,180),
batch_size=32)
# Print class names
class_names = train_ds.class_names
print(class_names)
# Create model
num_classes = len(class_names)
model = Sequential([
layers.Rescaling(1./255, input_shape=(180,180, 3)),
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes)
])
# Compiling the model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True),
metrics=['accuracy'])
model.summary()
# learn patterns by providing training and test/validation dataset
epochs=9
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
# save the model
model.save(r'/home/ki/flowers.hd5')
因为我的树莓派账号是ki,所以凡是涉及/home/ki的路径都要改为你自己的home目录。另外model.fit过程我这边的代码迭代了9次,这将非常非常耗时,但后面使用自定义图片做预测时正确率可以非常高。
以下是使用自定义图片做分类预测的代码:
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
"""flower_photo/
daisy/
dandelion/
roses/
sunflowers/
tulips/"""
import pathlib
batch_size = 32
img_height = 180
img_width = 180
"""
dataset_url = "file:///home/ki/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)
"""
class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
# Load Saved Model
pre_model = tf.keras.models.load_model("/home/ki/flowers.hd5")
sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)
img = keras.preprocessing.image.load_img(
sunflower_path, target_size=(img_height, img_width)
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # Create a batch
predictions = pre_model.predict(img_array)
score = tf.nn.softmax(predictions[0])
print(score)
print(
"This image most likely belongs to {} with a {:.2f} percent confidence."
.format(class_names[np.argmax(score)], 100 * np.max(score))
)
同样要注意根据你的实际情况修改模型保存路径。class_names的内容及次序要严格按照之前训练时获得的。如果不确定,可以将上面那段注释的语句释放看看打印的内容。
我这边的识别结果:model.fit迭代数为3的时候识别出玫瑰(错误),而迭代数为9时,识别为向日葵/太阳花(正确)。
参考资料:
Image Recognition using TensorFlow - GeeksforGeeks
Tensorflow---Tensorflow的五种保存模型的方式介绍_tensorflow保存模型_水哥很水的博客-CSDN博客
【Tensorflow】使用Tensorflow自定义模型和训练_tensorflow导入自定义模型_沐兮Krystal的博客-CSDN博客