目录
1、摘要
2、实现过程
2.1、依赖
2.2、imread
2.3、contiguous函数
2.3.1、转化示例
2.3.3、核心代码
2.4、Flatten拉直
2.5、最终结果
3、完整代码
1、摘要
在上一篇文章中 使用java加载、调用onnx模型_onnx java-CSDN博客
发现使用Java加载调用模型的分类结果与使用Python加载调用,返回的多分类结果不一致,并且存在较大的误差。
经过与AI工程师通过后,发现。在进行图片加载调用模型的过程中,存在对图片的预处理操作。
主要由 transforms 依次经过 :
- Resize 缩放
- CenterCrop 中心剪裁
- HWC-->CHW,BGR-->RGB 通道、维度变换
- continuousMat 维度变换[2, 1, 0]
- ToTensor + Normalize (0-1)区间,正则化
continuousMat函数
NumPy提供了 ascontiguousarray方法。该方法的作用是将输入数组转换为一个连续的内存块中的数组。如果输入数组已经是连续的,则该方法会返回输入数组的视图(即不复制数据);如果输入数组不是连续的,则该方法会返回输入数组的副本,并确保副本是连续的。这样,我们就可以确保在进行后续计算或调用其他库函数时,使用的是连续数组,从而避免潜在的性能问题或错误。
在python的 ToTensor 调用中发现
2、实现过程
2.1、依赖
其中opencv-320.jar 由网上手动下载,并手动导入
<!-- ONNX Runtime -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.10.0</version>
</dependency>
<!-- 加载lib目录下的opencv包 -->
<dependency>
<groupId>org.opencv</groupId>
<artifactId>opencv</artifactId>
<version>4.8.0</version>
<scope>system</scope>
<!--通过路径加载OpenCV480的jar包-->
<systemPath>${basedir}/libs/opencv-320.jar</systemPath>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.5.8</version>
</dependency>
2.2、imread
借由 Imgcodecs.imread 函数读取原始图片,并形成Mat对象
注:在基础的color255值的比对中,发现,同一个点位的像素值,在java 与python中存在细微的差距(1-3个值),但大部分都能一致。
2.3、contiguous函数
2.3.1、转化示例
上述转化过程表示
维度选择 [2, 1, 0]
1)在原始的【64,64,3】(x, y, z)的数组中
2)先将【0,0】、【0,1】....【0,64】每个中的z, y, x 分别放在 【0,0,0】、【1,0,0】、【2,0,0】变形存储
3)再按2的过程将【1,0】、【1,1】......【1,64】
4)转换结果表示为 【3,64,64】
HWC 转 CHW ,BGR转RGB
2.3.3、核心代码
由于是三通道 ,所以下方的convert固定3,
public static double[][][] asContiguousArray(Mat mat, int[] dims) {
int row = mat.rows();
int col = mat.cols();
double[][][] convert = new double[3][row][col];
for (int r = 0; r < row; r++) {
double[] d0 = new double[col];
double[] d1 = new double[col];
double[] d2 = new double[col];
for (int c = 0; c < col; c++) {
d0[c] = mat.get(r,c)[dims[0]];
d1[c] = mat.get(r,c)[dims[1]];
d2[c] = mat.get(r,c)[dims[2]];
}
convert[0][r] = d0;
convert[1][r] = d1;
convert[2][r] = d2;
}
return convert;
}
2.4、Flatten拉直
在python中,可以采用多维的Tensor对象,输入到模型中处理,
但实际的Java过程,OnnxTensor需要 Buffer对象,最终导致,只接受一维的Float数组。
2.5、最终结果
1)python:
原始输出:tensor([[-1.19724, 1.18705,-0.73468, 1.9731日, 0.06351, 0.02216,-1.07441]])
SoftMax: tensor([[0.02208,0.23955,0.03506,0.52574,0.07788,0.07473,0.02496]])
2)java
原始输出:[[-1.1788176,1.2926369,-0.72909486,1.9525943,0.07915904,-0.0071295537,-1.0969812]]
SoftMax:[0.02215092526,0.26225931,0.0347299,0.507395,0.077933,0.071490,0.0240]
两组数据在对比后,可以看出存在细微的结果差异,但总体保持一致。
3、完整代码
1)session.run(Collections.singletonMap("images", inputTensor));
其中的imagesv传入模型的入参名,依据实际生成模型的入参决定
2)width =64
由于模型需要【3,64,64】尺寸的图片,所以并未对图片高度做定义,个人需要依据自身模型的图片要求做修改。
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import cn.hutool.json.JSONUtil;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.Rect;
import org.opencv.core.Size;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.math.BigDecimal;
import java.nio.FloatBuffer;
import java.util.Collections;
/***
* @author xuancg
* @date 2024/8/15
*/
public class OnnxTest2 {
static { System.loadLibrary(Core.NATIVE_LIBRARY_NAME); }
public static void main(String[] args) throws Exception {
// 1. 加载 ONNX 模型
String modelPath = "best.onnx";
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
OrtSession session = env.createSession(modelPath, opts);
// 图片尺寸
int width = 64;
// 2. 加载并预处理图像
Mat img = Imgcodecs.imread("5.png");
System.out.println("--------------------------init--------------------------");
System.out.println(JSONUtil.toJsonStr(img.get(0,0)));
// 2.1 Resize 图像
Size size = new Size(width, width);
Imgproc.resize(img, img, size);
System.out.println("--------------------------after resize--------------------------");
System.out.println(JSONUtil.toJsonStr(img.get(0,0)));
System.out.println(JSONUtil.toJsonStr(img.get(0,1)));
// 2.2 中心裁剪图像
int cropSize = width;
int startX = (img.cols() - cropSize) / 2;
int startY = (img.rows() - cropSize) / 2;
Rect cropRect = new Rect(startX, startY, cropSize, cropSize);
Mat croppedImg = new Mat(img, cropRect);
System.out.println("--------------------------after centerCrop--------------------------");
System.out.println(JSONUtil.toJsonStr(croppedImg.get(0,0)));
double[][][] contigArr = asContiguousArray(img, new int[]{2, 1, 0});
System.out.println("--------------------------after continuousMat--------------------------");
System.out.println(JSONUtil.toJsonStr(contigArr[0][0]));
// 2.3 Normalize 图像
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};
// 2.4 将图像转换为 Tensor (ToTensor) + Normalize
for (int i = 0; i < contigArr.length; i++) {
for (int j = 0; j < contigArr[0].length; j++) {
for (int k = 0; k < contigArr[0][0].length; k++) {
double val = contigArr[i][j][k];
contigArr[i][j][k] = Double.valueOf((val / 255f - mean[i]) / std[i]).floatValue();
}
}
}
System.out.println("--------------------------after ToTensor + Normalize--------------------------");
System.out.println(JSONUtil.toJsonStr(contigArr[0][0]));
float[] imgData = new float[width * width * 3];
int idx = 0;
// 2.5 数据拉直
for (int i = 0; i < contigArr.length; i++) {
for (int j = 0; j < contigArr[0].length; j++) {
for (int k = 0; k < contigArr[0][0].length; k++) {
imgData[idx] = Double.valueOf(contigArr[i][j][k]).floatValue();
idx++;
}
}
}
// 3. 创建 ONNX Tensor 并推理
FloatBuffer tensorBuffer = FloatBuffer.wrap(imgData);
long[] shape = {1, 3, width, width}; // CHW format
OnnxTensor inputTensor = OnnxTensor.createTensor(env, tensorBuffer, shape);
OrtSession.Result result = session.run(Collections.singletonMap("images", inputTensor));
// 4. 处理结果
System.out.println(JSONUtil.toJsonStr(result.get(0).getValue())); // 二维数组 [[0.1,0,222....]]
System.out.println(JSONUtil.toJsonStr(softmax(((float[][]) result.get(0).getValue())[0])));
}
/**
* 转概率分布
*/
static double[] softmax(float[] input) {
double[] exps = new double[input.length];
double sum = 0.0;
for (int i = 0; i < input.length; i++) {
exps[i] = Math.exp(input[i]);
sum += exps[i];
}
double[] probabilities = new double[input.length];
for (int i = 0; i < input.length; i++) {
probabilities[i] = exps[i] / sum;
}
return probabilities;
}
public static double[][][] asContiguousArray(Mat mat, int[] dims) {
int row = mat.rows();
int col = mat.cols();
double[][][] convert = new double[3][row][col];
for (int r = 0; r < row; r++) {
double[] d0 = new double[col];
double[] d1 = new double[col];
double[] d2 = new double[col];
for (int c = 0; c < col; c++) {
d0[c] = mat.get(r,c)[dims[0]];
d1[c] = mat.get(r,c)[dims[1]];
d2[c] = mat.get(r,c)[dims[2]];
}
convert[0][r] = d0;
convert[1][r] = d1;
convert[2][r] = d2;
}
return convert;
}
}