使用java加载、调用onnx模型(二)

news2024/12/27 8:49:49

目录

1、摘要

2、实现过程

2.1、依赖

2.2、imread

2.3、contiguous函数

2.3.1、转化示例

2.3.3、核心代码 

2.4、Flatten拉直

2.5、最终结果

3、完整代码


1、摘要

在上一篇文章中 使用java加载、调用onnx模型_onnx java-CSDN博客

发现使用Java加载调用模型的分类结果与使用Python加载调用,返回的多分类结果不一致,并且存在较大的误差。

经过与AI工程师通过后,发现。在进行图片加载调用模型的过程中,存在对图片的预处理操作。

主要由 transforms 依次经过 :

  1. Resize  缩放
  2. CenterCrop  中心剪裁
  3. HWC-->CHW,BGR-->RGB  通道、维度变换
  4. continuousMat   维度变换[2, 1, 0]
  5. ToTensor + Normalize   (0-1)区间,正则化

continuousMat函数

NumPy提供了 ascontiguousarray方法。该方法的作用是将输入数组转换为一个连续的内存块中的数组。如果输入数组已经是连续的,则该方法会返回输入数组的视图(即不复制数据);如果输入数组不是连续的,则该方法会返回输入数组的副本,并确保副本是连续的。这样,我们就可以确保在进行后续计算或调用其他库函数时,使用的是连续数组,从而避免潜在的性能问题或错误。

在python的 ToTensor 调用中发现

2、实现过程

2.1、依赖

其中opencv-320.jar 由网上手动下载,并手动导入

 <!-- ONNX Runtime -->
        <dependency>
            <groupId>com.microsoft.onnxruntime</groupId>
            <artifactId>onnxruntime</artifactId>
            <version>1.10.0</version>
        </dependency>

<!-- 加载lib目录下的opencv包 -->
        <dependency>
            <groupId>org.opencv</groupId>
            <artifactId>opencv</artifactId>
            <version>4.8.0</version>
            <scope>system</scope>
            <!--通过路径加载OpenCV480的jar包-->
            <systemPath>${basedir}/libs/opencv-320.jar</systemPath>
        </dependency>

 

<dependency>
            <groupId>cn.hutool</groupId>
            <artifactId>hutool-all</artifactId>
            <version>5.5.8</version>
        </dependency>

2.2、imread

借由 Imgcodecs.imread 函数读取原始图片,并形成Mat对象

:在基础的color255值的比对中,发现,同一个点位的像素值,在java 与python中存在细微的差距(1-3个值),但大部分都能一致。

2.3、contiguous函数

2.3.1、转化示例

上述转化过程表示

维度选择 [2, 1, 0]

1)在原始的【64,64,3】(x, y, z)的数组中

2)先将【0,0】、【0,1】....【0,64】每个中的z, y, x  分别放在 【0,0,0】、【1,0,0】、【2,0,0】变形存储

3)再按2的过程将【1,0】、【1,1】......【1,64】

4)转换结果表示为 【3,64,64】

HWC 转 CHW ,BGR转RGB

2.3.3、核心代码 

由于是三通道 ,所以下方的convert固定3,

public static double[][][] asContiguousArray(Mat mat, int[] dims) {
        int row = mat.rows();
        int col = mat.cols();

        double[][][] convert = new double[3][row][col];
        for (int r = 0; r < row; r++) {
            double[] d0 = new double[col];
            double[] d1 = new double[col];
            double[] d2 = new double[col];
            for (int c = 0; c < col; c++) {
                d0[c] = mat.get(r,c)[dims[0]];
                d1[c] = mat.get(r,c)[dims[1]];
                d2[c] = mat.get(r,c)[dims[2]];
            }
            convert[0][r] = d0;
            convert[1][r] = d1;
            convert[2][r] = d2;
        }
        return convert;
    }

2.4、Flatten拉直

在python中,可以采用多维的Tensor对象,输入到模型中处理,

但实际的Java过程,OnnxTensor需要 Buffer对象,最终导致,只接受一维的Float数组。

2.5、最终结果

1)python:

原始输出:tensor([[-1.19724, 1.18705,-0.73468, 1.9731日, 0.06351, 0.02216,-1.07441]])

SoftMax: tensor([[0.02208,0.23955,0.03506,0.52574,0.07788,0.07473,0.02496]])

2)java

原始输出:[[-1.1788176,1.2926369,-0.72909486,1.9525943,0.07915904,-0.0071295537,-1.0969812]]

SoftMax:[0.02215092526,0.26225931,0.0347299,0.507395,0.077933,0.071490,0.0240]

两组数据在对比后,可以看出存在细微的结果差异,但总体保持一致。

3、完整代码

1)session.run(Collections.singletonMap("images", inputTensor));

其中的imagesv传入模型的入参名,依据实际生成模型的入参决定

2)width =64

由于模型需要【3,64,64】尺寸的图片,所以并未对图片高度做定义,个人需要依据自身模型的图片要求做修改。

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import cn.hutool.json.JSONUtil;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.Rect;
import org.opencv.core.Size;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;

import java.math.BigDecimal;
import java.nio.FloatBuffer;
import java.util.Collections;

/***
 * @author xuancg
 * @date 2024/8/15
 */
public class OnnxTest2 {


    static { System.loadLibrary(Core.NATIVE_LIBRARY_NAME); }

    public static void main(String[] args) throws Exception {
        // 1. 加载 ONNX 模型
        String modelPath = "best.onnx";
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
        OrtSession session = env.createSession(modelPath, opts);

        // 图片尺寸
        int width = 64;

        // 2. 加载并预处理图像
        Mat img = Imgcodecs.imread("5.png");
        System.out.println("--------------------------init--------------------------");
        System.out.println(JSONUtil.toJsonStr(img.get(0,0)));
        // 2.1 Resize 图像
        Size size = new Size(width, width); 
        Imgproc.resize(img, img, size);
        System.out.println("--------------------------after resize--------------------------");
        System.out.println(JSONUtil.toJsonStr(img.get(0,0)));
        System.out.println(JSONUtil.toJsonStr(img.get(0,1)));

        // 2.2 中心裁剪图像
        int cropSize = width;
        int startX = (img.cols() - cropSize) / 2;
        int startY = (img.rows() - cropSize) / 2;
        Rect cropRect = new Rect(startX, startY, cropSize, cropSize);
        Mat croppedImg = new Mat(img, cropRect);

        System.out.println("--------------------------after centerCrop--------------------------");
        System.out.println(JSONUtil.toJsonStr(croppedImg.get(0,0)));

        double[][][] contigArr = asContiguousArray(img, new int[]{2, 1, 0});
        System.out.println("--------------------------after continuousMat--------------------------");
        System.out.println(JSONUtil.toJsonStr(contigArr[0][0]));

        // 2.3 Normalize 图像
        float[] mean = {0.485f, 0.456f, 0.406f};
        float[] std = {0.229f, 0.224f, 0.225f};

        // 2.4 将图像转换为 Tensor (ToTensor) + Normalize
        for (int i = 0; i < contigArr.length; i++) {
            for (int j = 0; j < contigArr[0].length; j++) {
                for (int k = 0; k < contigArr[0][0].length; k++) {
                    double val = contigArr[i][j][k];
                    contigArr[i][j][k] = Double.valueOf((val / 255f - mean[i]) / std[i]).floatValue();
                }
            }
        }

        System.out.println("--------------------------after ToTensor + Normalize--------------------------");
        System.out.println(JSONUtil.toJsonStr(contigArr[0][0]));

        float[] imgData = new float[width * width * 3];
        int idx = 0;
        // 2.5 数据拉直
        for (int i = 0; i < contigArr.length; i++) {
            for (int j = 0; j < contigArr[0].length; j++) {
                for (int k = 0; k < contigArr[0][0].length; k++) {
                    imgData[idx] = Double.valueOf(contigArr[i][j][k]).floatValue();
                    idx++;
                }
            }
        }



        // 3. 创建 ONNX Tensor 并推理
        FloatBuffer tensorBuffer =  FloatBuffer.wrap(imgData);
        long[] shape = {1, 3, width, width}; // CHW format
        OnnxTensor inputTensor = OnnxTensor.createTensor(env, tensorBuffer, shape);

        OrtSession.Result result = session.run(Collections.singletonMap("images", inputTensor));

        // 4. 处理结果
        System.out.println(JSONUtil.toJsonStr(result.get(0).getValue())); // 二维数组 [[0.1,0,222....]]
        System.out.println(JSONUtil.toJsonStr(softmax(((float[][]) result.get(0).getValue())[0])));
    }


    /**
     * 转概率分布
     */
    static double[] softmax(float[] input) {
        double[] exps = new double[input.length];
        double sum = 0.0;

        for (int i = 0; i < input.length; i++) {
            exps[i] = Math.exp(input[i]);
            sum += exps[i];
        }

        double[] probabilities = new double[input.length];
        for (int i = 0; i < input.length; i++) {
            probabilities[i] = exps[i] / sum;
        }

        return probabilities;
    }



    public static double[][][] asContiguousArray(Mat mat, int[] dims) {
        int row = mat.rows();
        int col = mat.cols();

        double[][][] convert = new double[3][row][col];
        for (int r = 0; r < row; r++) {
            double[] d0 = new double[col];
            double[] d1 = new double[col];
            double[] d2 = new double[col];
            for (int c = 0; c < col; c++) {
                d0[c] = mat.get(r,c)[dims[0]];
                d1[c] = mat.get(r,c)[dims[1]];
                d2[c] = mat.get(r,c)[dims[2]];
            }
            convert[0][r] = d0;
            convert[1][r] = d1;
            convert[2][r] = d2;
        }
        return convert;
    }

}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2052160.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算组合数:从n个不同元素中,选k个元素的方式数量math.comb()

【小白从小学Python、C、Java】 【考研初试复试毕业设计】 【Python基础AI数据分析】 计算组合数&#xff1a; 从n个不同元素中&#xff0c; 选k个元素的方式数量 math.comb() 请问关于以下代码表述正确的选项是&#xff1f; import math print("【执行】math.comb(3, 2)…

线性代数证明:把行列式的某一行(列)的k倍加到另一行(列),行列式的值不变

线性代数证明 把行列式的某一行&#xff08;列&#xff09;的k倍加到另一行&#xff08;列&#xff09;&#xff0c;行列式的值不变&#xff1a; 注意五角星的位置要用到另一条性质&#xff1a;若行列式的某一行&#xff08;列&#xff09;的元素都是两数之和&#xff0c;则可以…

Ajax笔记总结(Xmind格式):第一天

Xmind鸟瞰图&#xff1a; 简单文字总结&#xff1a; ajax知识总结&#xff1a; 网络的参考模型&#xff1a; 1.物理层&#xff1a;源设备到目的设备 底层传输就是比特流 2.数据链路层 进行电信号的处理 进行数据的分组 3.网路层 进行数据包的传递 进行不同网络的…

菱形继承和虚继承

菱形继承&#xff08;Diamond Inheritance&#xff09;是指在多重继承的情况下&#xff0c;某个类继承自两个类&#xff0c;而这两个类又都继承自同一个基类的情况。 在这个结构中&#xff0c;D 直接从 A 继承了 A 的所有特性&#xff0c;但通过 B 和 C 继承&#xff0c;这会导…

Avue实现动态查询与数据展示(附Demo)

目录 前言1. 基本知识2. Demo 前言 此框架为Avue-crud&#xff0c;推荐阅读&#xff1a; 【vue】avue-crud表单属性配置&#xff08;表格以及列&#xff09;Avue实现批量删除等功能&#xff08;附Demo&#xff09;Avue实现选择下拉框的多种方式Avue框架实现图表的基本知识 | …

凌晨突发!核心系统瘫痪,通过Signleton单例模式轻松搞定,但还是被裁员了...

&#x1f345; 作者简介&#xff1a;哪吒&#xff0c;CSDN2021博客之星亚军&#x1f3c6;、新星计划导师✌、博客专家&#x1f4aa; &#x1f345; 哪吒多年工作总结&#xff1a;Java学习路线总结&#xff0c;搬砖工逆袭Java架构师 &#x1f345; 技术交流&#xff1a;定期更新…

selenium底层原理详解

目录 1、selenium版本的演变 1.1、Selenium 1.x&#xff08;Selenium RC时代&#xff09; 1.2、Selenium 2.x&#xff08;WebDriver整合时代&#xff09; 1.3、Selenium 3.x 2、selenium原理说明 3、源码说明 3.1、启动webdriver服务建立连接 3.2、发送操作 1、seleni…

flink车联网项目:维表离线同步(第69天)

系列文章目录 3.3 维表离线同步 3.3.1 思路 3.3.2 示例 3.3.3 其他表开发 3.3.4 部署 3.3.1.1 将表提交到生成环境 3.3.1.2 添加虚拟节点 3.3.1.3 配置计算节点 3.3.1.4 添加虚拟结束节点 3.3.1.5 提交到生产环境 3.3.1.6 发布 3.3.1.7 运维中心 3.3.1.8 补数据 3.3.1.9 补数据…

c++进阶------多态

作者前言 &#x1f382; ✨✨✨✨✨✨&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f382; ​&#x1f382; 作者介绍&#xff1a; &#x1f382;&#x1f382; &#x1f382; &#x1f389;&#x1f389;&#x1f389…

机器学习/数据分析--通俗语言带你入门线性回归(结合案例)

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 前言 机器学习是深度学习和数据分析的基础&#xff0c;接下来将更新常见的机器学习算法注意&#xff1a;在打数学建模比赛中&#xff0c;机器学习用的也很多&a…

探索GitLab:从搭建到高效使用的实用指南

企业里为什么喜欢使用GitLab 一、GitLab简介二、搭建GitLab三、GitLab的权限管理3.1、用户注册3.2、创建用户组3.3、为用户组添加用户3.4、为工程添加访问权限 四、GitLab的code review五、团队知识管理六、总结 一、GitLab简介 GitLab是利用 Ruby on Rails 一个开源的版本管理…

Adobe Media Encoder ME 2023-23.6.6.2 解锁版下载安装教程 (专业的视频和音频编码渲染工具)

前言 Adobe Media Encoder&#xff08;简称Me&#xff09;是一款专业的音视频格式转码软件&#xff0c;文件格式转换软件。主要用来对音频和视频文件进行编码转换&#xff0c;支持格式非常多&#xff0c;使用系统预设设置&#xff0c;能更好的导出与相关设备兼容的文件。 一、…

网站怎么做敏感词过滤,敏感词过滤的思路和实践

敏感词过滤是一种在网站、应用程序或平台中实现内容审查的技术&#xff0c;用于阻止用户发布包含不适当、非法或不符合政策的内容。我们在实际的网站运营过程中&#xff0c;往往需要担心某些用户发布的内容中包含敏感词汇&#xff0c;这些词汇往往会导致我们的网站被用户举报&a…

JVM的组成

JVM 运行在操作系统之上 java二进制字节码文件的运行环境 JVM的组成部分 java代码在编写完成后编译成字节码文件通过类加载器 来到运行数据区,主要作用是加载字节码到内存 包含 方法区/元空间 堆 程序计数器,虚拟机栈,本地方法栈等等 随后来到执行引擎,主要作用是翻译字…

系统工程与信息系统(上)

系统工程 概念 【系统工程】是一种组织管理技术。 【系统工程】是为了最好的实现系统的目的&#xff0c;对系统的组成要素、组织结构、信息流、控制机构进行分析研究的科学方法。 【系统工程】从整体出发、从系统观念出发&#xff0c;以求【整体最优】 【系统工程】利用计算机…

信息搜集--敏感文件Banner

免责声明:本文仅做分享参考... git安装: Windows10下安装Git_win10安装git好慢-CSDN博客 git目录结构: Git 仓库目录 .git 详解-CSDN博客 敏感目录泄露 1-git泄露 Git是一个开源的分布式版本控制系统,我们简单的理解为Git 是一个*内容寻址文件系统*&#xff0c;也就是说Gi…

二十四、解释器模式

文章目录 1 基本介绍2 案例2.1 Instruction 接口2.2 StartInstruction 类2.3 PrimitiveInstruction 类2.4 RepeatInstruction 类2.5 InstructionList 类2.6 Context 类2.7 Client 类2.8 Client 类的运行结果2.9 总结 3 各角色之间的关系3.1 角色3.1.1 AbstractExpression ( 抽象…

Nexpose漏扫

免责声明:本文仅做分享参考... nexpose官网: Nexpose On-Premise Vulnerability Scanner - Rapid7 Rapid7的Nexpose是一款非常专业的漏洞扫描软件。有community版本和enterprise版本。 其中community版是免费的&#xff0c;但是功能简单&#xff1b;enterprise版本功能强大.…

适用于 Windows 10 的最佳免费数据恢复软件是什么?

有没有适用于 Windows 10 的真正免费的数据恢复软件&#xff1f; 丢失重要数据&#xff0c;无论是由于硬件问题、软件问题、意外删除、格式化还是病毒和恶意软件&#xff0c;确实很麻烦。当你面临数据丢失时&#xff0c;你可能真心希望找到一款免费的数据恢复软件&#xff0c;…

【C++指南】深入剖析:C++中的引用

&#x1f493; 博客主页&#xff1a;倔强的石头的CSDN主页 &#x1f4dd;Gitee主页&#xff1a;倔强的石头的gitee主页 ⏩ 文章专栏&#xff1a;《C指南》 期待您的关注 目录 引言&#xff1a; 一、引用的基本概念 1. 定义与特性 2. 语法与声明 二、引用的进阶用法 1. 函…