对于一些源码可以参考我上一篇博客:学习yolo+Java+opencv简单案例(一)-CSDN博客
这篇文章主要演示的是使用面向对象优雅的实现图像识别:
也有接口演示,包括将Onnx对象放入Bean中程序跑起来就初始化一次(重点)
在文章的最后附上我的代码地址。
目录
一、整体架构
二、pom.xml
三、Java代码
1、model包
2、output包
3、utils包
4、实现类
(1)初始化模型
(2)读取图像
(3)执行模型推理
(4)处理并保存图像
四、运行测试
五、接口改造
1、pom.xml:
2、编写config
3、编写controller
4、编写service和impl
5、接口测试
测试yolov8模型:
测试yolov7模型:
一、整体架构
在src路径里面的model包里面几个是各个模型的实现,那几个继承Onnx这个抽象类,在抽象类对模型进行初始化,在output包中,实现Output接口负责对输出,返回拿到模型推理的结果,utils是工具包。
二、pom.xml
yolo-study:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.bluefoxyu</groupId>
<artifactId>yolo-study</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>pom</packaging>
<modules>
<module>predict-test</module>
<module>CameraDetection</module>
<module>yolo-common</module>
<module>CameraDetectionWarn</module>
<module>PlateDetection</module>
<module>dp</module>
</modules>
<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.16.1</version>
</dependency>
<dependency>
<groupId>org.openpnp</groupId>
<artifactId>opencv</artifactId>
<version>4.7.0-0</version>
</dependency>
</dependencies>
</project>
三、Java代码
1、model包
这里实际只用了yolov7模型
Onnx:
package com.bluefoxyu.model.domain;
import ai.onnxruntime.*;
import com.bluefoxyu.output.Output;
import com.bluefoxyu.utils.ImageUtil;
import org.opencv.core.*;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.io.File;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.text.SimpleDateFormat;
import java.util.*;
/**
* onnx抽象类,没写get 和 set 自己增加
*/
public abstract class Onnx {
protected OrtEnvironment environment;
protected OrtSession session;
protected String[] labels;
protected double[][] colors;
boolean gpu = false;
long[] input_shape = {1, 3, 640, 640};
int stride = 32;
public float confThreshold = 0.45F;
public OnnxJavaType inputType;
OnnxTensor inputTensor;
public float nmsThreshold = 0.45F;
public double ratio;
public double dw;
public double dh;
/**
* 初始化
* @param labels 模型分类标签
* @param model_path 模型路径
* @param gpu 是否开启gou
* @throws OrtException
*/
/*Onnx类通过OrtEnvironment和OrtSession初始化模型,
并从模型的输入信息中获取张量(Tensor)的类型(如UINT8或FLOAT)。
同时,为每个分类标签随机生成一个颜色,用于绘制检测框。*/
public Onnx(String[] labels,String model_path,boolean gpu) throws OrtException {
nu.pattern.OpenCV.loadLocally();
this.labels = labels;
this.gpu = gpu;
environment = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
if(gpu){
sessionOptions.addCUDA(0);
}
session = environment.createSession(model_path,sessionOptions);
Map<String, NodeInfo> inputMetaMap = session.getInputInfo();
NodeInfo inputMeta = inputMetaMap.get(session.getInputNames().iterator().next());
this.inputType = ((TensorInfo) inputMeta.getInfo()).type;
System.out.println(inputMeta.toString());
colors = new double[labels.length][3];
for (int i = 0; i < colors.length; i++) {
Random random = new Random();
double[] color = {random.nextDouble()*256, random.nextDouble()*256, random.nextDouble()*256};
colors[i] = color;
}
}
public List<Output> run(Mat img) throws OrtException {
Map<String, OnnxTensor> inputContainer = this.preprocess(img);
return this.postprocess(this.session.run(inputContainer),img);
}
/**
* 后处理
* @param result
* @return
* @throws OrtException
*/
public abstract List<Output> postprocess(OrtSession.Result result, Mat img) throws OrtException;
/**
* 画框标注,可以继承后复写此方法
* @param outputs
*/
public Mat drawprocess(List<Output> outputs, Mat img){
for (Output output : outputs) {
System.err.println( output.toString());
Point topLeft = new Point(output.getLocation().get(0).get("x"), output.getLocation().get(0).get("y"));
Point bottomRight = new Point(output.getLocation().get(2).get("x"), output.getLocation().get(2).get("y"));
Scalar color = new Scalar(colors[output.getClsId()]);
Imgproc.rectangle(img, topLeft, bottomRight, color, 2);
Point boxNameLoc = new Point(output.getLocation().get(0).get("x"), output.getLocation().get(0).get("y"));
// 也可以二次往视频画面上叠加其他文字或者数据,比如物联网设备数据等等
Imgproc.putText(img, labels[output.getClsId()], boxNameLoc, Imgproc.FONT_HERSHEY_SIMPLEX, 0.7, color, 2);
}
/*System.err.println("----------------------推理成功的图像保存在项目的video目录下:output.png,可以打开查看效果!---------------------------------");
Imgcodecs.imwrite("video/output.png", img);
return img;*/
// 设置保存图像的目录路径
String outputDir = "./dp/video";
// 检查目录是否存在,如果不存在则创建
File directory = new File(outputDir);
if (!directory.exists()) {
directory.mkdirs(); // 创建目录及其所有必需的父目录
System.out.println("没有video目录,创建目录成功");
}
// 获取当前日期和时间
String timeStamp = new SimpleDateFormat("yyyy-MM-dd-HH_mm_ss").format(new Date());
// 设置保存文件的完整路径
String outputPath = outputDir + "/output-" + timeStamp + ".png";
System.err.println("----------------------推理成功的图像保存在项目的video目录下:" + outputPath + ",可以打开查看效果!---------------------------------");
// 保存图像
Imgcodecs.imwrite(outputPath, img);
return img;
};
/**
* 默认预处理方法,如果输入shape不一样可以继承后覆盖重写该方法
* @param img 图像
* @return
* @throws OrtException
*/
public Map<String, OnnxTensor> preprocess(Mat img) throws OrtException {
img = this.letterbox(img);
Imgproc.cvtColor(img, img, Imgproc.COLOR_BGR2RGB);
Map<String, OnnxTensor> container = new HashMap<>();
if (this.inputType.equals(OnnxJavaType.UINT8)) {
byte[] whc = new byte[(int) (input_shape[0]*input_shape[1]*input_shape[2]*input_shape[3])];
img.get(0, 0, whc);
byte[] chw = ImageUtil.whc2cwh(whc);
ByteBuffer inputBuffer = ByteBuffer.wrap(chw);
inputTensor = OnnxTensor.createTensor(this.environment, inputBuffer, input_shape, this.inputType);
} else {
img.convertTo(img, CvType.CV_32FC1, 1. / 255);
float[] whc = new float[(int) (input_shape[0]*input_shape[1]*input_shape[2]*input_shape[3])];
img.get(0, 0, whc);
float[] chw = ImageUtil.whc2cwh(whc);
FloatBuffer inputBuffer = FloatBuffer.wrap(chw);
inputTensor = OnnxTensor.createTensor(this.environment, inputBuffer, input_shape);
}
container.put(this.session.getInputInfo().keySet().iterator().next(), inputTensor);
return container;
}
/**
* 图像缩放
* @param im
* @return
*/
public Mat letterbox(Mat im) {
int[] shape = {im.rows(), im.cols()};
double r = Math.min((double) input_shape[2] / shape[0],(double) input_shape[3] / shape[1]);
Size newUnpad = new Size(Math.round(shape[1] * r), Math.round(shape[0] * r));
double dw = (double)input_shape[2] - newUnpad.width, dh = (double)input_shape[3] - newUnpad.height;
dw /= 2;
dh /= 2;
if (shape[1] != newUnpad.width || shape[0] != newUnpad.height) {
Imgproc.resize(im, im, newUnpad, 0, 0, Imgproc.INTER_LINEAR);
}
int top = (int) Math.round(dh - 0.1), bottom = (int) Math.round(dh + 0.1);
int left = (int) Math.round(dw - 0.1), right = (int) Math.round(dw + 0.1);
Core.copyMakeBorder(im, im, top, bottom, left, right, Core.BORDER_CONSTANT, new Scalar(new double[]{114,114,114}));
this.ratio = r;
this.dh = dh;
this.dw = dw;
return im;
}
}
PaddleDetection:
package com.bluefoxyu.model;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.bluefoxyu.model.domain.Onnx;
import com.bluefoxyu.output.Output;
import org.opencv.core.Mat;
import java.util.List;
/**
* paddlepaddle 目标检测模型
*/
public class PaddleDetection extends Onnx {
/**
* 初始化
*
* @param labels 模型分类标签
* @param model_path 模型路径
* @param gpu 是否开启gou
* @throws OrtException
*/
public PaddleDetection(String[] labels, String model_path, boolean gpu) throws OrtException {
super(labels, model_path, gpu);
}
@Override
public List<Output> postprocess(OrtSession.Result result, Mat img) throws OrtException {
return null;
}
}
YoloV5:
package com.bluefoxyu.model;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.bluefoxyu.output.DetectionOutput;
import com.bluefoxyu.utils.ImageUtil;
import com.bluefoxyu.model.domain.Onnx;
import com.bluefoxyu.output.Output;
import org.opencv.core.Mat;
import java.util.*;
public class YoloV5 extends Onnx {
/**
* 初始化
*
* @param labels 模型分类标签
* @param model_path 模型路径
* @param gpu 是否开启gou
* @throws OrtException
*/
public YoloV5(String[] labels, String model_path, boolean gpu) throws OrtException {
super(labels, model_path, gpu);
}
@Override
public List<Output> postprocess(OrtSession.Result result, Mat img) throws OrtException {
float[][] outputData = ((float[][][])result.get(0).getValue())[0];
Map<Integer, List<float[]>> class2Bbox = new HashMap<>();
for (float[] bbox : outputData) {
float score = bbox[4];
if (score < confThreshold) continue;
float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 5, bbox.length);
int label = ImageUtil.argmax(conditionalProbabilities);
ImageUtil.xywh2xyxy(bbox);
if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) continue;
class2Bbox.putIfAbsent(label, new ArrayList<>());
class2Bbox.get(label).add(bbox);
}
List<Output> outputList = new ArrayList<>();
for (Map.Entry<Integer, List<float[]>> entry : class2Bbox.entrySet()) {
List<float[]> bboxes = entry.getValue();
bboxes = ImageUtil.nonMaxSuppression(bboxes, this.nmsThreshold);
for (float[] x : bboxes) { //预处理进行了缩放,后处理要放大回来
double x0 = (x[0] - this.dw) / this.ratio;
double y0 = (x[1] - this.dh) / this.ratio;
double x1 = (x[2] - this.dw) / this.ratio;
double y1 = (x[3] - this.dh) / this.ratio;
Output output = new DetectionOutput(1,(int)x0,(int)y0,(int)x1,(int)y1,entry.getKey(),x[4], labels[entry.getKey()]);
outputList.add(output);
}
}
return outputList;
}
}
YoloV7:
package com.bluefoxyu.model;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.bluefoxyu.output.DetectionOutput;
import com.bluefoxyu.model.domain.Onnx;
import com.bluefoxyu.output.Output;
import org.opencv.core.Mat;
import java.util.ArrayList;
import java.util.List;
public class YoloV7 extends Onnx {
/**
* 初始化
*
* @param labels 模型分类标签
* @param model_path 模型路径
* @param gpu 是否开启gou
* @throws OrtException
*/
public YoloV7(String[] labels, String model_path, boolean gpu) throws OrtException {
super(labels, model_path, gpu);
}
@Override
public List<Output> postprocess(OrtSession.Result result, Mat img) throws OrtException {
float[][] outputData = (float[][]) result.get(0).getValue();
List<Output> outputList = new ArrayList<>();
for (float[] x : outputData) { //预处理进行了缩放,后处理要放大回来
double x0 = (x[1] - this.dw) / this.ratio;
double y0 = (x[2] - this.dh) / this.ratio;
double x1 = (x[3] - this.dw) / this.ratio;
double y1 = (x[4] - this.dh) / this.ratio;
Output output = new DetectionOutput((int)x[0],(int)x0,(int)y0,(int)x1,(int)y1,(int) x[5], x[6], labels[(int) x[5]]);
outputList.add(output);
}
return outputList;
}
}
YoloV8:
package com.bluefoxyu.model;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.bluefoxyu.output.DetectionOutput;
import com.bluefoxyu.utils.ImageUtil;
import com.bluefoxyu.model.domain.Onnx;
import com.bluefoxyu.output.Output;
import org.opencv.core.Mat;
import java.util.*;
public class YoloV8 extends Onnx {
/**
* 初始化
*
* @param labels 模型分类标签
* @param model_path 模型路径
* @param gpu 是否开启gou
* @throws OrtException
*/
public YoloV8(String[] labels, String model_path, boolean gpu) throws OrtException {
super(labels, model_path, gpu);
}
@Override
public List<Output> postprocess(OrtSession.Result result, Mat img) throws OrtException {
float[][] outputData = ((float[][][])result.get(0).getValue())[0];
outputData = ImageUtil.transposeMatrix(outputData);
Map<Integer, List<float[]>> class2Bbox = new HashMap<>();
for (float[] bbox : outputData) {
float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 4, outputData.length);
int label = ImageUtil.argmax(conditionalProbabilities);
float conf = conditionalProbabilities[label];
if (conf < confThreshold) continue;
bbox[4] = conf;
ImageUtil.xywh2xyxy(bbox);
if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) continue;
class2Bbox.putIfAbsent(label, new ArrayList<>());
class2Bbox.get(label).add(bbox);
}
List<Output> outputList = new ArrayList<>();
for (Map.Entry<Integer, List<float[]>> entry : class2Bbox.entrySet()) {
List<float[]> bboxes = entry.getValue();
bboxes = ImageUtil.nonMaxSuppression(bboxes, this.nmsThreshold);
for (float[] x : bboxes) { //预处理进行了缩放,后处理要放大回来
double x0 = (x[0] - this.dw) / this.ratio;
double y0 = (x[1] - this.dh) / this.ratio;
double x1 = (x[2] - this.dw) / this.ratio;
double y1 = (x[3] - this.dh) / this.ratio;
Output output = new DetectionOutput(1,(int)x0,(int)y0,(int)x1,(int)y1,entry.getKey(),x[4], labels[entry.getKey()]);
outputList.add(output);
}
}
return outputList;
}
}
2、output包
DetectionOutput:
package com.bluefoxyu.output;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 目标检测,分类输出对象
*/
public class DetectionOutput implements Output {
Float score ;
String name ;
Integer batchId ;
private Integer clsId;
private List<Map<String,Integer>> location;
public DetectionOutput(Integer batchId, Integer x0, Integer y0, Integer x1, Integer y1, Integer clsId, Float score, String name ){
this.batchId = batchId;
this.score = score;
this.name = name;
this.clsId = clsId;
this.location = new ArrayList<>();
Map<String,Integer> xy1 = new HashMap<>();
Map<String,Integer> xy2 = new HashMap<>();
Map<String,Integer> xy3 = new HashMap<>();
Map<String,Integer> xy4 = new HashMap<>();
xy1.put("x",x0);
xy1.put("y",y0);
xy2.put("x",x1);
xy2.put("y",y0);
xy3.put("x",x1);
xy3.put("y",y1);
xy4.put("x",x0);
xy4.put("y",y1);
location.add(xy1);
location.add(xy2);
location.add(xy3);
location.add(xy4);
}
public Float getScore() {
return score;
}
public void setScore(Float score) {
this.score = score;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public Integer getBatchId() {
return batchId;
}
public void setBatchId(Integer batchId) {
this.batchId = batchId;
}
public Integer getClsId() {
return clsId;
}
public void setClsId(Integer clsId) {
this.clsId = clsId;
}
public List<Map<String, Integer>> getLocation() {
return location;
}
public void setLocation(List<Map<String, Integer>> location) {
this.location = location;
}
@Override
public String toString() {
return "DetectionOutput {" +
" name: " + getName() +
", location : [ { x:" + location.get(0).get("x") +" , y:" + location.get(0).get("y") +"}" +
", { x:" + location.get(1).get("x") + " , y:" + location.get(1).get("y") +"}" +
", { x:" + location.get(2).get("x") + " , y:" + location.get(2).get("y") +"}" +
", { x:" + location.get(3).get("x") + " , y:" +location.get(3).get("y") +"}" +
"] }";
}
}
LicenseOutput:
package com.bluefoxyu.output;
/**
* 车牌识别输出
*/
public class LicenseOutput extends DetectionOutput{
/**
* 车牌颜色
*/
private String color;
public LicenseOutput(Integer batchId, Integer x0, Integer y0, Integer x1, Integer y1, Integer clsId, Float score, String name) {
super(batchId, x0, y0, x1, y1, clsId, score, name);
}
}
Output:
package com.bluefoxyu.output;
import java.util.List;
import java.util.Map;
/**
* 模型统一输出接口
*/
public interface Output {
public List<Map<String, Integer>> getLocation();
public String getName();
public Integer getClsId();
}
3、utils包
ImageUtil:
package com.bluefoxyu.utils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
public class ImageUtil {
public static int argmax(float[] a) {
float re = -Float.MAX_VALUE;
int arg = -1;
for (int i = 0; i < a.length; i++) {
if (a[i] >= re) {
re = a[i];
arg = i;
}
}
return arg;
}
public static void whc2cwh(float[] src, float[] dst, int start) {
int j = start;
for (int ch = 0; ch < 3; ++ch) {
for (int i = ch; i < src.length; i += 3) {
dst[j] = src[i];
j++;
}
}
}
public static void xywh2xyxy(float[] bbox) {
float x = bbox[0];
float y = bbox[1];
float w = bbox[2];
float h = bbox[3];
bbox[0] = x - w * 0.5f;
bbox[1] = y - h * 0.5f;
bbox[2] = x + w * 0.5f;
bbox[3] = y + h * 0.5f;
}
public static List<float[]> nonMaxSuppression(List<float[]> bboxes, float iouThreshold) {
List<float[]> bestBboxes = new ArrayList<>();
bboxes.sort(Comparator.comparing(a -> a[4]));
while (!bboxes.isEmpty()) {
float[] bestBbox = bboxes.remove(bboxes.size() - 1);
bestBboxes.add(bestBbox);
bboxes = bboxes.stream().filter(a -> computeIOU(a, bestBbox) < iouThreshold).collect(Collectors.toList());
}
return bestBboxes;
}
public static float computeIOU(float[] box1, float[] box2) {
float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);
float left = Math.max(box1[0], box2[0]);
float top = Math.max(box1[1], box2[1]);
float right = Math.min(box1[2], box2[2]);
float bottom = Math.min(box1[3], box2[3]);
float interArea = Math.max(right - left, 0) * Math.max(bottom - top, 0);
float unionArea = area1 + area2 - interArea;
return Math.max(interArea / unionArea, 1e-8f);
}
public static float[][] transposeMatrix(float [][] m){
float[][] temp = new float[m[0].length][m.length];
for (int i = 0; i < m.length; i++)
for (int j = 0; j < m[0].length; j++)
temp[j][i] = m[i][j];
return temp;
}
public void scaleCoords(float[] bbox, float orgW, float orgH, float padW, float padH, float gain) {
// xmin, ymin, xmax, ymax -> (xmin_org, ymin_org, xmax_org, ymax_org)
bbox[0] = Math.max(0, Math.min(orgW - 1, (bbox[0] - padW) / gain));
bbox[1] = Math.max(0, Math.min(orgH - 1, (bbox[1] - padH) / gain));
bbox[2] = Math.max(0, Math.min(orgW - 1, (bbox[2] - padW) / gain));
bbox[3] = Math.max(0, Math.min(orgH - 1, (bbox[3] - padH) / gain));
}
public static float[] whc2cwh(float[] src) {
float[] chw = new float[src.length];
int j = 0;
for (int ch = 0; ch < 3; ++ch) {
for (int i = ch; i < src.length; i += 3) {
chw[j] = src[i];
j++;
}
}
return chw;
}
public static byte[] whc2cwh(byte[] src) {
byte[] chw = new byte[src.length];
int j = 0;
for (int ch = 0; ch < 3; ++ch) {
for (int i = ch; i < src.length; i += 3) {
chw[j] = src[i];
j++;
}
}
return chw;
}
}
4、实现类
package com.bluefoxyu;
import ai.onnxruntime.OrtException;
import com.bluefoxyu.model.domain.Onnx;
import com.bluefoxyu.output.Output;
import com.bluefoxyu.model.YoloV7;
import org.opencv.core.Mat;
import org.opencv.imgcodecs.Imgcodecs;
import java.util.List;
public class Main {
static String model_path = "./dp/src/main/resources/model/yolov7-tiny.onnx";
static String test_img = "./dp/images/some_people.png";
static String[] names = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter",
"bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear",
"zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase",
"frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
"tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
"oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "toothbrush"};
public static void main(String[] args) throws OrtException {
// 1. 初始化模型
// 全局new一次即可,千万不要每次使用都new。可以使用@Bean,或者在spring项目启动时初始化一次即可
Onnx onnx = new YoloV7(names,model_path,false);
//Onnx onnx = new YoloV5(labels,model_path,false);
// 2. 读取图像
// 也可以使用接口收到的base64图像Imgcodecs.imdecode()
Mat img = Imgcodecs.imread(test_img);
// 3. 执行模型推理
// 这一步已经结束,可以通过接口返回给前端结果,或者自己循环打印看结果输出
List<Output> outputs = onnx.run(img.clone());
// 4. 处理并保存图像
// 可以调用此方法本地查看图片效果,也可以不调用
onnx.drawprocess(outputs,img);
}
}
实现的思路:
(1)初始化模型
// 1. 初始化模型
// 全局new一次即可,千万不要每次使用都new。可以使用@Bean,或者在spring项目启动时初始化一次即可
Onnx onnx = new YoloV7(names,model_path,false);
//Onnx onnx = new YoloV5(labels,model_path,false);
这行代码实例化了一个
YoloV7
类对象,YoloV7
是Onnx
类的子类。这个对象负责加载YOLOv7模型(yolov7-tiny.onnx
),并进行目标检测:
1、model_path
: 模型的路径。
2、names
: 包含模型可识别的分类标签的数组。
3、false
: 表示是否使用GPU加速,false
表示不使用。
在这里只要修改一些模型的路径和使用的模型对象,就能切换到别的模型
(2)读取图像
使用OpenCV的
Imgcodecs.imread
方法从指定路径读取图像,并将其存储在Mat
对象中。Mat
是OpenCV中用于表示图像的基本数据结构。
test_img
: 图像文件路径。
(3)执行模型推理
// 3. 执行模型推理
// 这一步已经结束,可以通过接口返回给前端结果,或者自己循环打印看结果输出
List<Output> outputs = onnx.run(img.clone());
预处理方法:
1、letterbox
: 对图像进行缩放和填充,使其符合模型输入的尺寸。
2、cvtColor
: 将图像从BGR转换为RGB格式。
3、inputTensor
: 创建ONNX的输入张量(根据张量类型创建相应的ByteBuffer或FloatBuffer)。该方法将图像转换为模型所需的输入格式,并将其存储在
OnnxTensor
对象中,准备进行推理。
推理和后处理方法:
1、run
: 执行推理,返回模型输出。
2、postprocess
: 子类实现的方法,负责将模型输出转换为可理解的对象列表。
(4)处理并保存图像
// 4. 处理并保存图像
// 可以调用此方法本地查看图片效果,也可以不调用
onnx.drawprocess(outputs,img);
1、drawprocess
: 在图像上绘制检测框,并将处理后的图像保存到指定路径。2、自动目录创建: 如果输出目录不存在,代码会自动创建。
3、动态生成文件名: 使用当前时间戳生成唯一的文件名,避免覆盖。
四、运行测试
测试前图片:
处理后:
五、接口改造
改造后大致模样:
1、pom.xml:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.bluefoxyu</groupId>
<artifactId>yolo-study</artifactId>
<version>1.0-SNAPSHOT</version>
</parent>
<artifactId>dp</artifactId>
<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<!-- https://mvnrepository.com/artifact/org.springframework.boot/spring-boot-starter-web -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
<version>3.2.4</version>
</dependency>
</dependencies>
</project>
2、编写config
@Configuration
public class OnnxConfig {
static String yolov8_model_path = "./dp/src/main/resources/model/yolov8s.onnx";
static String yolov7_model_path = "./dp/src/main/resources/model/yolov7-tiny.onnx";
static String[] names = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter",
"bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear",
"zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase",
"frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
"tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
"oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "toothbrush"};
@Bean(name = "YoloV8Onnx")
public Onnx YoloV8Onnx() throws OrtException {
// 加载模型(按需求修改需要加载的模型)
String modelPath = yolov8_model_path;
return new YoloV8(names, modelPath, false);
}
@Bean(name = "YoloV7Onnx")
public Onnx YoloV7Onnx() throws OrtException {
// 加载模型(按需求修改需要加载的模型)
String modelPath = yolov7_model_path;
return new YoloV7(names, modelPath, false);
}
}
3、编写controller
@Slf4j
@RestController
@RequestMapping("/api")
public class DetectController {
@Resource
private DetectService detectService;
//这里到时候可以按需求从前端传过来
static String test_img = "./dp/images/some_people.png";
@PostMapping("/yoloV8/detect")
public List<Output> yoloV8Detection() throws OrtException {
log.info("yoloV8检测开始");
return detectService.yoloV8Detection(test_img);
}
@PostMapping("/yoloV7/detect")
public List<Output> yoloV7Detection() throws OrtException {
log.info("yoloV7检测开始");
return detectService.yoloV7Detection(test_img);
}
}
4、编写service和impl
public interface DetectService {
List<Output> yoloV8Detection(String test_img) throws OrtException;
List<Output> yoloV7Detection(String test_img) throws OrtException;
}
@Service
public class DetectServiceImpl implements DetectService {
@Resource
@Qualifier("YoloV8Onnx") //指定注入的 Bean 是 OnnxConfig 类中由 @Bean 注解生成的、名称为 "YoloV8Onnx" 的 Bean。
private Onnx yoloV8Onnx;
@Resource
@Qualifier("YoloV7Onnx") //指定注入的 Bean 是 OnnxConfig 类中由 @Bean 注解生成的、名称为 "YoloV8Onnx" 的 Bean。
private Onnx yoloV7Onnx;
@Override
public List<Output> yoloV8Detection(String test_img) throws OrtException {
// 1. 初始化模型
// 全局new一次即可,千万不要每次使用都new。可以使用@Bean,或者在spring项目启动时初始化一次即可
/*Onnx onnx = new YoloV8(names,model_path,false);*/
// 2. 读取图像
// 也可以使用接口收到的base64图像Imgcodecs.imdecode()
Mat img = Imgcodecs.imread(test_img);
// 3. 执行模型推理
// 这一步已经结束,可以通过接口返回给前端结果,或者自己循环打印看结果输出
List<Output> outputs = yoloV8Onnx.run(img.clone());
// 4. 处理并保存图像
// 可以调用此方法本地查看图片效果,也可以不调用
yoloV8Onnx.drawprocess(outputs,img);
return outputs;
}
@Override
public List<Output> yoloV7Detection(String test_img) throws OrtException {
// 1. 初始化模型
// 全局new一次即可,千万不要每次使用都new。可以使用@Bean,或者在spring项目启动时初始化一次即可
/*Onnx onnx = new YoloV8(names,model_path,false);*/
// 2. 读取图像
// 也可以使用接口收到的base64图像Imgcodecs.imdecode()
Mat img = Imgcodecs.imread(test_img);
// 3. 执行模型推理
// 这一步已经结束,可以通过接口返回给前端结果,或者自己循环打印看结果输出
List<Output> outputs = yoloV7Onnx.run(img.clone());
// 4. 处理并保存图像
// 可以调用此方法本地查看图片效果,也可以不调用
yoloV7Onnx.drawprocess(outputs,img);
return outputs;
}
}
5、接口测试
测试yolov8模型:
postman:
idea:
yolov8,模型测试成功!
测试yolov7模型:
postman:
idea:
Github地址:GitHub - bluefoxyu/yolo-study: 学习yolo+java案例第一次提交
至此,演示结束。