目录
- 背景
- 原理
- 技术选型
- 技术栈
- 构造封装
- 优化
- 模型选择
- 让模型加载更快
- 张量释放
- 让indexddb更快
- 将图片拆出单独建表
- 特征向量降维
- 后续规划
- 模型的下发更新
- 模型的增强学习
- 识别数据的上传和下发
背景
先定性,带AI识别的生鲜收银机早就上市了,目前学习的只能说是别人玩剩的,但是依然收获满满,算是第一个ai识别的应用吧,关键是技术栈全是前端的,还是有一定参考价值
原理
自己的果蔬的ai识别系统我们参考了市面上的一套成熟方案:
识别过程:
学习过程:
技术选型
技术栈
- python
其实一开始是准备使用python的,因为性能方面更快,也更好写,开源的模型也多,但是不得不面对一个问题,就是我们的硬件是一套安卓系统,如果用python的话根据厂家的事例,要么将编写好的python程序转成c++的,要么用接口部署python服务后以restful接口形式提供。第一个方案太不熟了,毕竟还有业务压力,pass。第二个方案中途试了,但考虑到网络损耗,觉得还是部署本地的模型是最优解
可以附上pyhton部分代码:
模型就是用的mobilenet_fv.h5
识别部分:没有使用数据库功能,本地提供了一个缓存机制,半成品
import tensorflow as tf
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.mobilenet import preprocess_input, decode_predictions
from PIL import Image
import numpy as np
import cv2
import time
import io
import collections
import hashlib
from heapq import nlargest
# 初始化缓存,使用OrderedDict保持插入顺序,方便LRU策略
cache = collections.OrderedDict()
# 缓存容量
CACHE_CAPACITY = 1000
model = MobileNet(weights='imagenet', include_top=False, pooling='avg')
def cosine_similarity(features1, features2):
dot_product = np.dot(features1.flatten(), features2.flatten())
norm_features1 = np.linalg.norm(features1)
norm_features2 = np.linalg.norm(features2)
return dot_product / (norm_features1 * norm_features2)
def query_cache(image_features):
if not cache:
return None
max_similarity = -1
best_label = None
for image_id, (cached_features, label) in cache.items():
similarity = cosine_similarity(image_features, cached_features)
if similarity > max_similarity:
max_similarity = similarity
best_label = label
if max_similarity >= 0.5:
return best_label
else:
return None
def query_cache_top5(image_features):
if not cache:
return []
top_similarities = []
for image_id, (cached_features, label) in cache.items():
similarity = cosine_similarity(image_features, cached_features)
top_similarities.append((f"{similarity:.4f}", label))
# 使用 nlargest 获取前5个最大的相似度
top_5_similarities = nlargest(5, top_similarities)
return top_5_similarities
def get_image_id_from_hash(img):
buffer = img.tobytes()
return hashlib.md5(buffer).hexdigest()
class MainDetect:
# 初始化
def __init__(self):
super().__init__()
# 模型初始化
self.image_id = None
self.image_features = None
self.model = tf.keras.models.load_model("models/mobilenet_fv.h5") # todo 修改模型名称
self.class_names = ['哈密瓜', '柠檬', '桂圆', '梨', '榴莲', '火龙果', '猕猴桃', '胡萝卜', '芒果', '苦瓜',
'草莓', '荔枝', '菠萝', '车厘子', '黄瓜'] # todo 修改类名,这个数组在模型训练的开始会输出
# 预测图片
def predict_img(self, image_data):
img = Image.open(io.BytesIO(image_data))
self.image_id = get_image_id_from_hash(img)
img = np.asarray(img) # 将图片转化为numpy的数组
img = cv2.resize(img, (224, 224))
img_cropped = img[:, :, :3]
target = img_cropped.reshape(1, 224, 224, 3)
start_time = time.time() # 记录开始时间
outputs = self.model.predict(target, batch_size=1, ) # 将图片输入模型得到结果
end_time = time.time() # 记录结束时间
elapsed_time = end_time - start_time # 计算时间差
self.image_features = outputs
result = query_cache_top5(outputs)
# if len(result) == 0:
# result_index = int(np.argmax(outputs))
# result = [["1.0000", self.class_names[result_index]]] # 获得对应的水果名称
return {"result": result, "outputs": outputs, "time": f"{elapsed_time * 1000:.2f}ms"}
def classify_image(self, image_data):
img = Image.open(io.BytesIO(image_data))
self.image_id = get_image_id_from_hash(img)
img = np.asarray(img) # 将图片转化为numpy的数组
img = cv2.resize(img, (224, 224))
img_cropped = img[:, :, :3]
target = img_cropped.reshape(1, 224, 224, 3)
# 进行预测
start_time = time.time() # 记录开始时间
outputs = model.predict(target)
end_time = time.time() # 记录结束时间
elapsed_time = end_time - start_time # 计算时间差
self.image_features = outputs
result = query_cache_top5(outputs)
# if len(result) == 0:
# result_index = int(np.argmax(outputs))
# result = [["1.0000", self.class_names[result_index]]] # 获得对应的水果名称
return {"result": result, "outputs": outputs, "time": f"{elapsed_time * 1000:.2f}ms"}
def update_cache(self, label):
# 如果缓存已满,移除最久未使用的条目
if len(cache) >= CACHE_CAPACITY:
cache.popitem(last=False)
# 添加新条目
cache[self.image_id] = (self.image_features, label)
self.class_names.append(label)
return True
def clear_cache(slef):
cache.clear()
return True
服务部分:
from flask import Flask, request, jsonify
from flask_cors import CORS
from detect import MainDetect
import numpy as np
app = Flask(__name__)
CORS(app) # 允许所有路由上的跨域请求
detector = MainDetect()
@app.route('/')
def home():
return "Welcome to the Vegetable Recognize App!"
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return 'No file part', 400
file = request.files['file']
if file.filename == '':
return 'No selected file', 400
try:
image_data = file.read()
data = detector.classify_image(image_data)
result = data["result"]
outputs = data["outputs"]
time = data["time"]
return jsonify({"top5": result, "time": time, "features": outputs.tolist()})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/study', methods=['GET'])
def study():
try:
label = request.args.get('label')
data = detector.update_cache(label)
if data is True:
return jsonify("success")
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/clear', methods=['GET'])
def recognize():
try:
detector.clear_cache()
return jsonify("success")
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000, debug=True)
- js
既然要本地化,那就老老实实用tensorflowjs吧
但是又面临一个问题,模型选用,这里其实走了一段弯路,一开始并没打算使用mobilenet或已经训练好的模型,想自己训练,后来发现数据集也很难搞,所以算了,使用别人的的吧;然后网上下了一个性能太差了,而且60mb太大了,加载半天;最后还是老老实实用mobilenet,然后把模型下到本地,香得很呀,这里使用的是v1版本,因为v2实在下不下来T.T
构造封装
直接展示成品吧:
识别的class:
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
import Storage, {StorageData} from './storage';
import {Tensor} from "@tensorflow/tfjs-core/dist/tensor";
interface Options {
num?: number;
count?: number;
type?: string;
modelUrl?: string;
customModelUrl?: string;
maxSize?: number;
initCallback?: () => void;
dbName?: string;
tableName?: string;
tableImgName?: string;
version?: number;
}
class XsyxRecognize {
private model: mobilenet.MobileNet | tf.LayersModel | null = null;
private customModel: tf.LayersModel | null = null;
private readonly num: number;
private type: string;
private modelUrl: string;
private readonly customModelUrl: string;
private readonly maxSize: number;
private readonly initCallback?: () => void;
private storage: Storage;
private modelLoadReady: boolean = false;
constructor(options: Options = {}) {
this.num = options.num ?? 5;
this.type = options.type ?? 'mobilenet';
this.modelUrl = this.getModelUrl(options);
this.customModelUrl = this.getModelUrl(options, true);
this.maxSize = options.maxSize ?? (this.type === 'mobilenet' ? 224 : 96);
this.initCallback = options.initCallback;
this.storage = new Storage({
dbName: options.dbName ?? 'featureDatabase',
tableName: options.tableName ?? 'feature',
tableImgName: options.tableImgName ?? 'img',
version: options.version ?? 1,
count: options.count ?? 10
});
(async () => {
await this.init()
})()
}
private getModelUrl(options: Options, isCustom: boolean = false): string {
const env = process.env.NODE_ENV;
return env === 'development'
? `http://127.0.0.1:8090/model/${isCustom ? 'custom' : 'mobilenet'}/model.json`
: (isCustom ? (options.customModelUrl ?? `https://front-xps-cdn.xsyx.xyz/custom/hotstore/model/custom/model.json`) : (options.modelUrl ?? `https://front-xps-cdn.xsyx.xyz/custom/hotstore/model/${this.type === 'mobilenet' ? 'mobilenet' : 'init'}/model.json`));
}
async init(): Promise<void> {
await Promise.all([this.storage.openDB(), this.load()]);
this.initCallback?.();
}
async load(): Promise<void> {
if (this.model === null) {
if (this.type === 'mobilenet') {
try {
const res = await Promise.all([mobilenet.load({
version: 1,
alpha: 1.0,
modelUrl: this.modelUrl
}), tf.loadLayersModel(this.customModelUrl)])
this.model = res[0];
this.customModel = res[1]
} catch (e) {
console.log('预制模型加载失败,开始加载备用本地模型,原因:' + e);
this.type = 'custom';
this.modelUrl = 'https://front-xps-cdn.xsyx.xyz/custom/hotstore/model/init/model.json';
await this.load();
}
} else {
this.model = await tf.loadLayersModel(this.modelUrl);
}
this.modelLoadReady = true;
console.log('模型已加载');
} else {
console.log('模型加载中,请勿重复加载');
}
}
async predict(obj: File | Blob | HTMLCanvasElement): Promise<any> {
let type: string;
if (obj instanceof File) {
type = 'File';
} else if (obj instanceof Blob) {
type = 'Blob';
} else if (obj instanceof HTMLCanvasElement) {
type = 'Canvas';
} else {
throw new Error('Unsupported object type');
}
// @ts-ignore
return await this[`predict${type}`]?.(obj);
}
async predictFile(file: File | Blob): Promise<any> {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = async (event) => {
const img = new Image();
img.onload = async () => {
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const scale = Math.min(this.maxSize / img.width, this.maxSize / img.height);
canvas.width = img.width * scale;
canvas.height = img.height * scale;
ctx!.drawImage(img, 0, 0, canvas.width, canvas.height);
try {
const result = await this.predictCanvas(canvas);
resolve(result);
} catch (e) {
reject(e);
}
};
img.src = (event.target as FileReader).result as string;
img.onerror = () => {
reject('图片加载失败');
};
};
reader.readAsDataURL(file);
});
}
async predictBlob(blob: Blob): Promise<any> {
return await this.predictFile(blob);
}
async predictCanvas(canvas: HTMLCanvasElement): Promise<any> {
try {
if (!this.model || !this.modelLoadReady) {
throw new Error('模型尚未加载完成');
}
const prediction = tf.tidy(() => {
const imgTensor = tf.browser.fromPixels(canvas);
const mImageTensor = imgTensor.resizeBilinear([this.maxSize, this.maxSize])
.toFloat()
.div(tf.scalar(255.0))
.expandDims();
console.time('predict');
let prediction;
if (this.type === 'mobilenet') {
prediction = (this.model as mobilenet.MobileNet).infer(imgTensor, true);
//降维
prediction = (this.customModel as tf.Sequential).predict(prediction) as Tensor;
// 归一化
prediction = tf.div(prediction, tf.norm(prediction));
} else {
prediction = (this.model as tf.LayersModel).predict(mImageTensor);
}
console.timeEnd('predict');
return prediction
})
//@ts-ignore
const features = await prediction.flatten().array();
const result = await this.findTopNSimilar(features);
return {
features,
top5: result
};
} catch (e) {
throw new Error('识别错误:' + e);
}
}
cosineSimilarity(vecA: number[], vecB: number[]): number {
const dotProduct = vecA.reduce((sum, value, index) => sum + value * vecB[index], 0);
const magnitudeA = Math.sqrt(vecA.reduce((sum, value) => sum + value * value, 0));
const magnitudeB = Math.sqrt(vecB.reduce((sum, value) => sum + value * value, 0));
return dotProduct / (magnitudeA * magnitudeB);
}
async findTopNSimilar(inputFeatures: number[]): Promise<{ label: string; similarity: number }[]> {
console.time('read')
const featureDatabase = await this.storage.getAll();
console.timeEnd('read')
console.time('calc')
const similarities = featureDatabase.map(({features, label}) => {
let similarity = 0
if (features) {
similarity = this.cosineSimilarity(inputFeatures, features);
}
return {label: label as string, similarity: similarity as number};
});
similarities.sort((a, b) => b.similarity - a.similarity);
const uniqueLabels = new Set<string>();
const topNUnique: { label: string; similarity: number }[] = [];
for (const item of similarities) {
if (!uniqueLabels.has(item.label as string)) {
uniqueLabels.add(item.label);
topNUnique.push(item);
if (topNUnique.length === this.num) break;
}
}
console.timeEnd('calc')
return topNUnique;
}
async study(data: StorageData): Promise<string> {
await this.storage.addData(data);
return '学习完成';
}
async update(newData: any): Promise<string> {
await this.storage.update(newData);
return '更新完成';
}
/**
* 训练模型
*/
async trainModel(options: {
units?: number,
activation?: undefined,
inputShape?: number[]
optimizer?: string
loss?: string
modelName?: string
}) {
const {
units = 256,
activation = 'relu',
inputShape = [1024],
optimizer = 'adam',
loss = 'categoricalCrossentropy',
modelName = 'my-model'
} = options
const model = tf.sequential();
model.add(tf.layers.dense({
units,
activation,
inputShape
})); // 输出 256 维的特征向量
model.add(tf.layers.dense({units})); // 这里 numOutputFeatures 是你希望的输出特征维度
// 编译模型
model.compile({optimizer, loss});
await model.save(`downloads://${modelName}`)
}
async getByName(indexName: string, val: any): Promise<any> {
return await this.storage.getByName(indexName, val);
}
closeDB(): void {
this.storage.closeDB();
}
async remove(val: any, tableName?: string): Promise<string> {
await this.storage.remove(val, tableName);
return '清除完成';
}
async createTable(tableName: string, keyPath: string, autoIncrement: boolean, indexName?: string): Promise<void> {
await this.storage.createTable(tableName, keyPath, autoIncrement, indexName);
}
async delTable(): Promise<string> {
await this.storage.delTable();
return '清除完成';
}
async clearTable(tableName: string): Promise<string> {
await this.storage.clearTable(tableName);
return '清除完成';
}
async deleteDB(name: string): Promise<string> {
await this.storage.deleteDB(name);
return '清除完成';
}
}
export default XsyxRecognize;
数据的增删改查:
interface StorageOptions {
dbName: string;
tableName: string;
tableImgName: string;
version: number;
keyPath?: string;
count: number;
indexArr?: Index[];
autoIncrement?: boolean;
onupgradeneeded?: (db: IDBDatabase, table: IDBObjectStore) => void;
}
interface Index {
name: string;
keyPath: string;
unique?: boolean;
multiEntry?: boolean;
}
export interface StorageData {
features?: number[],
label?: string,
id?: IDBValidKey,
imgId?: IDBValidKey,
img?: Blob
}
class Storage {
private readonly dbName: string;
private readonly tableName: string;
private readonly tableImgName: string;
private readonly version: number;
private readonly keyPath: string;
private readonly count: number;
private readonly autoIncrement: boolean;
private indexArr: Index[];
private db: IDBDatabase = {} as IDBDatabase;
private readonly onupgradeneeded?: (db: IDBDatabase, table: IDBObjectStore) => void;
constructor(options: StorageOptions) {
const {
dbName,
tableName,
tableImgName,
version,
keyPath,
indexArr,
autoIncrement,
onupgradeneeded,
count
} = options;
this.dbName = dbName;
this.tableName = tableName;
this.tableImgName = tableImgName;
this.version = version;
this.count = count;
this.keyPath = keyPath ?? 'id';
this.autoIncrement = autoIncrement ?? true;
this.indexArr = indexArr ?? [];
this.onupgradeneeded = onupgradeneeded;
}
openDB(): Promise<IDBDatabase> {
return new Promise((resolve, reject) => {
//@ts-ignore
const indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB;
const request = indexedDB.open(this.dbName, this.version);
request.onsuccess = (event: Event) => {
const target = event.target as IDBOpenDBRequest;
this.db = target.result as IDBDatabase;
console.log("数据库连接成功");
resolve(this.db);
};
request.onerror = (event) => {
const target = event.target as IDBOpenDBRequest;
console.log("数据库打开报错");
reject(target.error);
};
request.onupgradeneeded = (event) => {
console.log("数据库更新");
this.db = (event.target as IDBOpenDBRequest).result;
let newTable: IDBObjectStore;
if (!this.db.objectStoreNames.contains(this.tableName)) {
const keyCfg = {keyPath: this.keyPath, autoIncrement: this.autoIncrement};
newTable = this.db.createObjectStore(this.tableName, keyCfg);
if (!newTable.indexNames.contains('label')) {
newTable.createIndex('label', 'label', {unique: false});
}
this.indexArr.forEach((v) => {
newTable.createIndex(v.name, v.keyPath, {unique: v.unique, multiEntry: v.multiEntry});
});
} else {
const transaction = this.db.transaction(this.tableName, 'readwrite');
newTable = transaction.objectStore(this.tableName);
}
let imgTable: IDBObjectStore;
if (!this.db.objectStoreNames.contains(this.tableImgName)) {
imgTable = this.db.createObjectStore(this.tableImgName, {keyPath: 'imgId'});
if (!imgTable.indexNames.contains('imgId')) {
imgTable.createIndex('imgId', 'imgId', {unique: true});
}
} else {
const transaction = this.db.transaction(this.tableImgName, 'readwrite');
imgTable = transaction.objectStore(this.tableImgName);
}
this.onupgradeneeded?.(this.db, imgTable);
};
});
}
/**
* 创建表
* @param tableName
* @param keyPath
* @param autoIncrement
* @param indexName
*/
createTable(tableName: string, keyPath: string, autoIncrement: boolean, indexName?: string): Promise<IDBObjectStore> {
return new Promise((resolve) => {
const keyCfg = {keyPath, autoIncrement};
const newTable = this.db.createObjectStore(tableName, keyCfg);
if (indexName) {
newTable.createIndex(indexName, keyPath, {unique: false});
}
resolve(newTable);
});
}
/**
* 读取所有数据
* @param tableName
*/
readAll(tableName: string = this.tableName): Promise<StorageData[]> {
return new Promise((resolve, reject) => {
const objectStore = this.db.transaction(tableName, 'readonly').objectStore(tableName);
const index = objectStore.index('label');
const results: StorageData[] = [];
const uniqueLabels = new Map<string, StorageData[]>();
// Open a cursor to iterate over the index
const cursorRequest = index.openCursor();
cursorRequest.onsuccess = (event) => {
const cursor = (event.target as IDBRequest).result;
if (cursor) {
const label = cursor.value.label;
if (!uniqueLabels.has(label)) {
uniqueLabels.set(label, []);
}
// @ts-ignore
if (this.count ? uniqueLabels.get(label)?.length < this.count : true) {
uniqueLabels.get(label)?.push(cursor.value);
}
cursor.continue();
} else {
// When cursor has finished iterating, aggregate results
uniqueLabels.forEach((data) => {
results.push(...data);
});
resolve(results);
}
};
cursorRequest.onerror = (event) => {
console.error('Cursor request error:', (event.target as IDBRequest).error);
reject((event.target as IDBRequest).error);
};
});
}
getAll(tableName: string = this.tableName): Promise<StorageData[]> {
return new Promise((resolve, reject) => {
const objectStore = this.db.transaction(tableName, 'readonly').objectStore(tableName);
const request = objectStore.getAll();
request.onsuccess = (event) => {
resolve((event.target as IDBRequest).result);
};
request.onerror = (event) => {
reject(event);
};
});
}
/**
* 通过光标读取表所有数据
* @param tableName
*/
readAllByCursor(tableName: string = this.tableName): Promise<StorageData[]> {
return new Promise((resolve, reject) => {
const objectStore = this.db.transaction(tableName, 'readonly').objectStore(tableName);
const request = objectStore.openCursor();
const allData: StorageData[] = [];
request.onsuccess = (event) => {
const cursor = (event.target as IDBRequest).result;
if (cursor) {
allData.push(cursor.value);
cursor.continue();
} else {
resolve(allData);
}
};
request.onerror = (event) => {
reject(event);
};
});
}
/**
* 添加数据
* @param data
*/
addData(data: StorageData): Promise<IDBValidKey> {
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([this.tableName, this.tableImgName], 'readwrite');
const store = transaction.objectStore(this.tableName);
// 存储 label 数据
if ("label" in data) {
const request = store.add({label: data.label, features: data.features});
request.onsuccess = (event) => {
// 存储 img 数据
const imgId = request.result; // 使用 label 的 ID 作为 img 的 ID
const imgStore = transaction.objectStore(this.tableImgName);
const res = imgStore.add({imgId, img: data.img})
res.onsuccess = (event) => {
console.log("数据写入成功");
resolve((event.target as IDBRequest).result);
}
};
request.onerror = (event) => {
console.log("数据写入失败");
reject(event);
};
}
});
}
/**
* 更新数据
* @param newData
*/
update(newData: StorageData): Promise<void> {
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([this.tableName, this.tableImgName], 'readwrite');
const objectStore = transaction.objectStore(this.tableName);
const {id, label, features, img, imgId} = newData
const request = objectStore.put(id ? {
id,
label,
features,
} : {
label,
features
});
request.onsuccess = async (event) => {
const objectImgStore = transaction.objectStore(this.tableImgName);
const res = objectImgStore.put(imgId ? {
imgId,
img
} : {
imgId: (event.target as IDBRequest).result,
img
})
res.onsuccess = (e) => {
console.log("数据更新成功");
resolve((event.target as IDBRequest).result);
}
};
request.onerror = (event) => {
console.log("数据更新失败");
reject(event);
};
});
}
/**
* 根据索引读取数据
* @param indexName
* @param val
*/
getByName(indexName: string, val: IDBValidKey): Promise<StorageData[]> {
return new Promise((resolve, reject) => {
if (!this.db.objectStoreNames.contains(this.tableName)) {
reject('表不存在');
}
if (!this.db.transaction(this.tableName).objectStore(this.tableName).indexNames.contains(indexName)) {
reject('索引不存在');
}
const transaction = this.db.transaction([this.tableName], 'readonly');
const store = transaction.objectStore(this.tableName);
const index = store.index(indexName);
const request = index.getAll(val);
request.onsuccess = async (event) => {
const result = (event.target as IDBRequest).result;
if (result) {
for (const item of result) {
const imgId = item.id as IDBValidKey
const imgObj = await this.readData(imgId, this.tableImgName) as StorageData
item.img = imgObj.img
item.imgId = imgObj.imgId
}
resolve(result);
} else {
console.log('未查到结果');
}
};
});
}
/**
* 根据主键读取某一条数据
* @param val
* @param tableName
*/
readData(val: IDBValidKey, tableName: string = this.tableName): Promise<StorageData | undefined> {
return new Promise((resolve, reject) => {
const objectStore = this.db.transaction(tableName, 'readonly').objectStore(tableName);
const request = objectStore.get(val);
request.onerror = (e) => {
console.log('事务失败');
reject(e);
};
request.onsuccess = (e) => {
if (request.result) {
resolve(request.result);
} else {
console.log('未获得数据记录');
reject('未获得数据记录');
}
};
});
}
/**
* 导出数据
*/
async exportData(): Promise<string> {
return new Promise(async (resolve) => {
const data = await this.readAllByCursor();
for (const item of data) {
const imgId = item.id as IDBValidKey
const imgObj = await this.readData(imgId, this.tableImgName) as StorageData
item.img = imgObj.img
item.imgId = imgObj.imgId
}
resolve(JSON.stringify(data))
})
}
/**
* 导入数据
* @param data 导入的数据
* @param isClear 是否先清空表
* @returns {Promise<string>} 导入完成的通知
*/
async importData(data: StorageData[], isClear: boolean = false): Promise<'导入完成'> {
return new Promise(async (resolve, reject) => {
if (isClear) {
await this.clearTable(this.tableName);
await this.clearTable(this.tableImgName);
}
for (const item of data) {
await this.update(item)
}
resolve('导入完成')
});
}
// 关闭数据库,相对于open
closeDB(): void {
this.db?.close();
}
// 删除数据库中指定主键值的某条记录
remove(val: IDBValidKey, tableName: string = this.tableName): Promise<void> {
return new Promise((resolve) => {
const request = this.db.transaction(tableName, 'readwrite').objectStore(tableName).delete(val); // 删除指定主键值的某条记录
request.onsuccess = (event) => {
resolve();
console.log('数据删除成功');
};
});
}
// 删除某张指定表名的表
delTable(): Promise<unknown> {
return new Promise(async (resolve, reject) => {
if (this.db.objectStoreNames.contains(this.tableName)) {
try {
this.db.deleteObjectStore(this.tableName);
this.db.deleteObjectStore(this.tableImgName);
console.log(`删除表成功`);
resolve(`Object store ${this.tableName} deleted successfully`);
} catch (error) {
reject(`Failed to delete object store ${this.tableName}: ${error}`);
}
} else {
resolve(`Object store ${this.tableName} does not exist`);
}
});
}
// 对某张表清空但不删除
clearTable(tableName: string = this.tableName): Promise<'清空完成'> {
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([this.tableName, this.tableImgName], 'readwrite') as IDBTransaction;
const store = transaction.objectStore(this.tableName) as IDBObjectStore;
const imgStore = transaction.objectStore(this.tableImgName) as IDBObjectStore;
store.clear();
imgStore.clear();
transaction.oncomplete = () => {
resolve('清空完成');
};
transaction.onerror = (event) => {
reject(event)
}
});
}
// 删除数据库,name为 open时用的名字
deleteDB(name: string = this.dbName): Promise<void> {
return new Promise((resolve, reject) => {
this.closeDB();
if (window.indexedDB) {
const DBDelRequest = window.indexedDB.deleteDatabase(name); // 删除指定数据库
DBDelRequest.onerror = (event) => {
console.log("删除失败");
reject((event.target as IDBRequest).error);
};
DBDelRequest.onsuccess = (event) => {
console.log("删除成功");
resolve();
};
} else {
reject(new Error('IndexedDB not supported'));
}
});
}
}
export default Storage;
优化
模型选择
正如前面所说,直接mobilenet起手,特征值提取又快又准
让模型加载更快
因为是安卓机,性能上比windows差很多,所以每次下载+加载模型这两步真的很久,所以我们能做的就是把模型作为静态文件一起放在项目里,省去下载那部分时间:
vue项目就是public文件下,你懂的
张量释放
一定要记得把识别过程中的张量都释放,不然会越用越卡
让indexddb更快
原本在计算相似度时查特征值是使用的光标进行的所有特征值查询,但后来发现其实使用getAll
更快,所以推荐大家使用getAll方法,因为光标还有一个一个读的过程
将图片拆出单独建表
整个识别过程其实主要是三步,识别+读取+计算,
识别和计算的速度其实很快了,主要耗时在读取数据,为了优化indexddb的读取速度,减小数据量势在必行
所以把一开始存在一起的图片blob文件进行了拆分,通过特征值表的id进行关联
特征向量降维
同样是为了减小数据体积,所以准备将原本mobilenet给出的1024维度的特征向量进行降维到256,数据量骤减,读取速度更快,代价是识别速度会慢一点,但相比读取速度的提升,这些代价是十分小的
- 先基于mobilenet训练自己的降维模型
一个注意点,由于运行在浏览器内,mode.save时如果没有对应接口建议直接下载,不然会总是报找不到路径的错
-
基于mobilenet的特征向量使用自己的模型进行降维
-
一定要归一化处理,不然算出来的近似值都很低
-
自己训练的模型要保存,然后下次识别前要先加载,不然每次识别的结果都不一样,这个坑替你们踩了
后续规划
开发加调试差不多四天吧,上线后效果还不错,比原厂家提供的识别速度快了3倍多,基本达到了windows系统下的识别效果
模型的下发更新
基于此次封装的插件,后续为了应对各种业务场景,需要有一个模型的更新回滚机制,来保证远程就能处理实际问题
模型的增强学习
后续也可以在服务器上将搜集上来的已识别并打好标了的图像进行增强学习,提升准确度的同时也为之后的实时检测打下基础
识别数据的上传和下发
目前数据是存在本机,但后续一定是存服务器上的,这样不仅可以实现数据在各收银机的共用,还能为模型的增强学习提供数据集