前言:
Tensorflow.js 官方提供了很多常用模型库,涵盖了平时开发中大部分场景的模型。例如,前面提到的图片识别,除此之外还有人体姿态识别,目标物体识别,语音文字等识别。其中一些可能是 Python 转换而来,但都是开发人员用海量数据或资源训练的,个人觉得准确度能满足大部分功能开发要求。这里要介绍的是目标物体识别模型 ——CooSSD。
目标检测在机器视觉中已经很常见了,就是模型可以对图片或者视频中的物体进行识别,并预测其最大概率的名称和展示概率值。以下就先以 Github 上 Coo-SSD 图片目标检测为例,最后再弄一个视频的目标实时识别。
demo 运行:
tensorflow.js 提供的例子是通过 yarn,由于我本地环境原因,就以 npm 和 parcel 运行其效果。先本地创建项目文件夹,然后再分别创建 index.html, script.js, package.json 和添加几张图片。
1. 依赖包安装
(1). package.json 配置,安装 tfjs-backend-cpu, tfjs-backend-webgl 和模型
{
"name": "tfjs-coco-ssd-demo",
"version": "1.0.0",
"description": "",
"main": "index.js",
"dependencies": {
"@tensorflow-models/coco-ssd": "^2.2.2",
"@tensorflow/tfjs-backend-cpu": "^3.3.0",
"@tensorflow/tfjs-backend-webgl": "^3.3.0",
"@tensorflow/tfjs-converter": "^3.3.0",
"@tensorflow/tfjs-core": "^3.3.0",
"stats.js": "^0.17.0"
},
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"author": "",
"license": "ISC",
"browserslist": [
"last 1 Chrome version"
]
}
(2). 命令切换到项目目录,运行 npm install
2. 代码:
(1). index.html
<h1>TensorFlow.js Object Detection</h1><selectid='base_model'>
<optionvalue="lite_mobilenet_v2">SSD Lite Mobilenet V2</option>
<optionvalue="mobilenet_v1">SSD Mobilenet v1</option>
<optionvalue="mobilenet_v2">SSD Mobilenet v2</option></select><buttontype="button"id="run">Run</button><buttontype="button"id="toggle">Toggle Image</button><div><imgid="image" /><canvasid="canvas"width="600"height="399"></canvas></div><scriptsrc="script.js"></script>
(2). script.js
import'@tensorflow/tfjs-backend-cpu';
import'@tensorflow/tfjs-backend-webgl';
import * as cocoSsd from'@tensorflow-models/coco-ssd';
import imageURL from'./image3.jpg';
import image2URL from'./image5.jpg';
let modelPromise;
window.onload = () => modelPromise = cocoSsd.load();
const button = document.getElementById('toggle');
button.onclick = () => {
image.src = image.src.endsWith(imageURL) ? image2URL : imageURL;
};
const select = document.getElementById('base_model');
select.onchange = async (event) => {
const model = await modelPromise;
model.dispose();
modelPromise = cocoSsd.load(
{base: event.srcElement.options[event.srcElement.selectedIndex].value});
};
const image = document.getElementById('image');
image.src = imageURL;
const runButton = document.getElementById('run');
runButton.onclick = async () => {
const model = await modelPromise;
console.log('model loaded');
console.time('predict1');
const result = await model.detect(image);
console.timeEnd('predict1');
const c = document.getElementById('canvas');
const context = c.getContext('2d');
context.drawImage(image, 0, 0);
context.font = '10px Arial';
console.log(result);
console.log('number of detections: ', result.length);
for (let i = 0; i < result.length; i++) {
context.beginPath();
context.rect(...result[i].bbox);
context.lineWidth = 1;
context.strokeStyle = 'green';
context.fillStyle = 'green';
context.stroke();
context.fillText(
result[i].score.toFixed(3) + ' ' + result[i].class, result[i].bbox[0],
result[i].bbox[1] > 10 ? result[i].bbox[1] - 5 : 10);
}
};
(3). 切换到项目目录,运行 parcel index.html
3. 运行效果
检测视频目标:
经过上面 demo 的图片检测发现,用于对某资源 (图片,视频) 进行检测的函数是 detect ()。查看该函数所处 Coco-SSD 文件发现,detect 函数接收三个参数,第一个参数可以是 tensorflow 张量,也可以分别是 DOM 里的图片,视频,画布等 HTML 元素,第二第三个参数分别用于过滤返回结果的最大识别目标数和最小概率目标,而返回自然就是一个 box, 按概率值降序排列。
1. 实现流程:
(1). 给视频标签添加播放监听
(2). 页面渲染完成加载 Coco-SSD 模型
(3). 模型加载成功轮询识别视频 (video 标签)
(4). 监听到视频播放停止关闭轮询检测
2. 编码:
(1). html 部分
<style>#big-box {
position: relative;
}
#img-box {
position: absolute;
top: 0px;
left: 0px;
}
#img-boxdiv {
position: absolute;
/*border: 2px solid #f00;*/pointer-events: none;
}
#img-boxdiv.className {
position: absolute;
top: 0;
/* background: #f00; */color: #fff;
}
#myPlayer {
max-width: 600px;
width: 100%;
}
</style><divid="showBox">等待模型加载...</div><br><divid="big-box"><videoid="myPlayer"muted="true"autoplaysrc="persons.mp4"controls=""playsinline=""webkit-playsinline=""></video><divid="img-box"></div></div><scriptsrc="persons.js"></script>
(2). js 部分
import'@tensorflow/tfjs-backend-cpu';
import'@tensorflow/tfjs-backend-webgl';
import * as cocoSsd from'@tensorflow-models/coco-ssd';
var myModel = null;
var V = null;
var requestAnimationFrameIndex = null;
var myPlayer = document.getElementById("myPlayer");
var videoHeight = 0;
var videoWidth = 0;
var clientHeight = 0;
var clientWidth = 0;
var modelLoad = false;
var videoLoad = false;
window.onload = function () {
myPlayer.addEventListener("canplay", function () {
videoHeight = myPlayer.videoHeight;
videoWidth = myPlayer.videoWidth;
clientHeight = myPlayer.clientHeight;
clientWidth = myPlayer.clientWidth;
V = this;
videoLoad = true;
})
loadModel();
}
functionloadModel() {
if (modelLoad) {
return;
}
cocoSsd.load().then(model => {
var showBox = document.getElementById("showBox");
showBox.innerHTML = "载入成功";
myModel = model;
detectImage();
modelLoad = true;
});
}
functiondetectImage() {
var showBox = document.getElementById("showBox");
// 分类名var classList = [];
// 分类颜色框var classColorMap = ["red", "green", "blue", "white"];
// 颜色角标var colorCursor = 0;
showBox.innerHTML = "检测中...";
if (videoLoad) {
myModel.detect(V).then(predictions => {
showBox.innerHTML = "检测结束";
const $imgbox = document.getElementById('img-box');
$imgbox.innerHTML = ""
predictions.forEach(box => {
if (classList.indexOf(box.class) != -1) {
classList.push(box.class);
}
console.log(box);
var borderColor = classColorMap[colorCursor%4];
// console.log(colorCursor);// console.log(borderColor);const $div = document.createElement('div')
//$div.className = 'rect';
$div.style.border = "2px solid "+borderColor;
var heightScale = (clientHeight / videoHeight);
var widthScale = (clientWidth / videoWidth)
var transformTop = box.bbox[1] * heightScale;
var transformLeft = box.bbox[0] * widthScale;
var transformWidth = box.bbox[2] * widthScale;
var transformHeight = box.bbox[3] * heightScale;
var score = box.score.toFixed(3);
$div.style.top = transformTop + 'px'
$div.style.left = transformLeft + 'px'
$div.style.width = transformWidth + 'px'
$div.style.height = transformHeight + 'px'
$div.innerHTML = `<span class='className'>${box.class}${score}</span>`
$imgbox.appendChild($div)
colorCursor++;
})
setTimeout(function () {
detectImage();
}, 10);
});
}
}
3. 演示效果