前言: 使用在teachable machine训练的h5格式模型
tensorflow使用篇
1. 使用teachable machine训练模型
地址: 传送门, 需要梯子翻一下
训练后, 导出的时候可以选择三种类型
导出模型文件 converted_keras.zip (py版)
解压后得到
2. py项目中使用模型
根据你当时使用teachable machine的时间, 选择py项目中TensorFlow的版本
我现在使用的是必须是2.3.0版本及以上才行, 然后我直接升级到了2.10.0
如果版本不匹配会报错如下
ValueError: (‘Unrecognized keyword arguments:’, dict_keys([‘ragged’]))
解决的方法就是升级TensorFlow版本
pip install tensorflow==2.10.0 --upgrade
目录结构如下
app.py
# -*- coding: utf-8 -*-
import flask as fk
from flask import jsonify, request
import tensorflow as tf
from PIL import Image
import numpy as np
app = fk.Flask(__name__)
# 加载标签映射
class_label_map = {}
with open('labels.txt', 'r', encoding='utf-8') as f:
for line in f.readlines():
index, label = line.strip().split()
class_label_map[int(index)] = label
print(class_label_map)
# 加载模型
global model
model = tf.keras.models.load_model('keras_model.h5')
print('模型加载成功')
# 图片预处理方法
def preprocess_image(image_path):
img = Image.open(image_path)
# 调整大小、归一化等操作,具体取决于模型要求
img_resized = img.resize((224, 224))
img_array = np.array(img_resized) / 255.0 # 将像素值归一化到[0, 1]区间
img_array = np.expand_dims(img_array, axis=0) # 添加批量维度(batch size = 1)
return img_array
# 预测方法
def load_model():
# 准备输入数据
input_data = preprocess_image("danka.jpg")
# input_data = preprocess_image("duolianka.jpg")
# 预测
predictions = model.predict(input_data)
# 获取预测结果
predicted_class_index = np.argmax(predictions[0])
# 获取预测的类名
predicted_class_name = class_label_map[predicted_class_index]
print(f"Predicted class: {predicted_class_name}")
return predicted_class_name
# 测试预测
@app.route('/api/hello', methods=['GET'])
def get_data():
return load_model()
# 假设我们要提供一个获取用户信息的API
@app.route('/api/user/<int:user_id>', methods=['GET'])
def get_user_info(user_id):
# 这里模拟从数据库或其他服务获取用户信息
user_data = {'id': user_id, 'name': 'John Doe', 'email': 'john.doe@example.com'}
# 假设用户不存在,返回404
# 返回JSON格式的用户信息
return jsonify(user_data)
# 定义一个接收POST请求的路由,假设该接口用于创建新用户
@app.route('/api/users', methods=['POST'])
def create_user():
# 从请求体中获取JSON格式的数据
data = request.get_json()
# 检查必要的字段是否存在
if not all(key in data for key in ('username', 'email', 'password')):
return jsonify({"error": "Missing required fields"}), 400
# 这里仅做示例,实际开发中应将数据保存至数据库等
new_user = {
'username': data['username'],
'email': data['email'],
'password': data['password']
}
# 模拟用户创建成功
resultMap = {"message": "User created successfully", "user": new_user}
# 返回201状态码表示已创建资源
return jsonify(resultMap), 201
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
可以直接使用postman请求无参get, 可以得到卡类型