Inference with C# BERT NLP Deep Learning and ONNX Runtime

news2024/9/24 21:24:55

目录

效果

测试一

测试二

测试三

模型信息

项目

代码

下载


Inference with C# BERT NLP Deep Learning and ONNX Runtime

效果

测试一

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :What is his name?

测试二

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :What will he bring home?

测试三

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :Where is Bob?

模型信息

Inputs
-------------------------
name:unique_ids_raw_output___9:0
tensor:Int64[-1]
name:segment_ids:0
tensor:Int64[-1, 256]
name:input_mask:0
tensor:Int64[-1, 256]
name:input_ids:0
tensor:Int64[-1, 256]
---------------------------------------------------------------

Outputs
-------------------------
name:unstack:1
tensor:Float[-1, 256]
name:unstack:0
tensor:Float[-1, 256]
name:unique_ids:0
tensor:Int64[-1]
---------------------------------------------------------------

项目

代码

using BERTTokenizers;
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Windows.Forms;

namespace Inference_with_C__BERT_NLP_Deep_Learning_and_ONNX_Runtime
{
    public struct BertInput
    {
        public long[] InputIds { get; set; }
        public long[] InputMask { get; set; }
        public long[] SegmentIds { get; set; }
        public long[] UniqueIds { get; set; }
    }

    public partial class Form1 : Form
    {
        public Form1()
        {
            InitializeComponent();
        }

        RunOptions runOptions;
        InferenceSession session;
        BertUncasedLargeTokenizer tokenizer;
        Stopwatch stopWatch = new Stopwatch();

        private void Form1_Load(object sender, EventArgs e)
        {
            string modelPath = "bertsquad-10.onnx";
            runOptions = new RunOptions();
            session = new InferenceSession(modelPath);
            tokenizer = new BertUncasedLargeTokenizer();
        }

        int MaxAnswerLength = 30;
        int bestN = 20;

        private void button1_Click(object sender, EventArgs e)
        {
            txt_answer.Text = "";
            Application.DoEvents();

            string question = txt_question.Text.Trim();
            string context = txt_context.Text.Trim();

            // Get the sentence tokens.
            var tokens = tokenizer.Tokenize(question, context);

            // Encode the sentence and pass in the count of the tokens in the sentence.
            var encoded = tokenizer.Encode(tokens.Count(), question, context);

            var padding = Enumerable
              .Repeat(0L, 256 - tokens.Count)
              .ToList();

            var bertInput = new BertInput()
            {
                InputIds = encoded.Select(t => t.InputIds).Concat(padding).ToArray(),
                InputMask = encoded.Select(t => t.AttentionMask).Concat(padding).ToArray(),
                SegmentIds = encoded.Select(t => t.TokenTypeIds).Concat(padding).ToArray(),
                UniqueIds = new long[] { 0 }
            };

            // Create input tensors over the input data.
            var inputIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputIds,
                  new long[] { 1, bertInput.InputIds.Length });

            var inputMaskOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputMask,
                  new long[] { 1, bertInput.InputMask.Length });

            var segmentIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.SegmentIds,
                  new long[] { 1, bertInput.SegmentIds.Length });

            var uniqueIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.UniqueIds,
                  new long[] { bertInput.UniqueIds.Length });

            var inputs = new Dictionary<string, OrtValue>
              {
                  { "unique_ids_raw_output___9:0", uniqueIdsOrtValue },
                  { "segment_ids:0", segmentIdsOrtValue},
                  { "input_mask:0", inputMaskOrtValue },
                  { "input_ids:0", inputIdsOrtValue }
              };

            stopWatch.Restart();
            // Run session and send the input data in to get inference output. 
            var output = session.Run(runOptions, inputs, session.OutputNames);
            stopWatch.Stop();

            var startLogits = output[1].GetTensorDataAsSpan<float>();

            var endLogits = output[0].GetTensorDataAsSpan<float>();

            var uniqueIds = output[2].GetTensorDataAsSpan<long>();

            var contextStart = tokens.FindIndex(o => o.Token == "[SEP]");

            var bestStartLogits = startLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestEndLogits = endLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestResultsWithScore = bestStartLogits
                .SelectMany(startLogit =>
                    bestEndLogits
                    .Select(endLogit =>
                        (
                            StartLogit: startLogit.Index,
                            EndLogit: endLogit.Index,
                            Score: startLogit.Logit + endLogit.Logit
                        )
                     )
                )
                .Where(entry => !(entry.EndLogit < entry.StartLogit || entry.EndLogit - entry.StartLogit > MaxAnswerLength || entry.StartLogit == 0 && entry.EndLogit == 0 || entry.StartLogit < contextStart))
                .Take(bestN);

            var (item, probability) = bestResultsWithScore
                .Softmax(o => o.Score)
                .OrderByDescending(o => o.Probability)
                .FirstOrDefault();

            int startIndex = item.StartLogit;
            int endIndex = item.EndLogit;

            var predictedTokens = tokens
                          .Skip(startIndex)
                          .Take(endIndex + 1 - startIndex)
                          .Select(o => tokenizer.IdToToken((int)o.VocabularyIndex))
                          .ToList();

            // Print the result.
            string answer = "answer:" + String.Join(" ", StitchSentenceBackTogether(predictedTokens))
                + "\r\nprobability:" + probability
                + $"\r\n推理耗时:{stopWatch.ElapsedMilliseconds}毫秒";

            txt_answer.Text = answer;
            Console.WriteLine(answer);

        }

        private List<string> StitchSentenceBackTogether(List<string> tokens)
        {
            var currentToken = string.Empty;

            tokens.Reverse();

            var tokensStitched = new List<string>();

            foreach (var token in tokens)
            {
                if (!token.StartsWith("##"))
                {
                    currentToken = token + currentToken;
                    tokensStitched.Add(currentToken);
                    currentToken = string.Empty;
                }
                else
                {
                    currentToken = token.Replace("##", "") + currentToken;
                }
            }

            tokensStitched.Reverse();

            return tokensStitched;
        }
    }
}
 

using BERTTokenizers;
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Windows.Forms;

namespace Inference_with_C__BERT_NLP_Deep_Learning_and_ONNX_Runtime
{
    public struct BertInput
    {
        public long[] InputIds { get; set; }
        public long[] InputMask { get; set; }
        public long[] SegmentIds { get; set; }
        public long[] UniqueIds { get; set; }
    }

    public partial class Form1 : Form
    {
        public Form1()
        {
            InitializeComponent();
        }

        RunOptions runOptions;
        InferenceSession session;
        BertUncasedLargeTokenizer tokenizer;
        Stopwatch stopWatch = new Stopwatch();

        private void Form1_Load(object sender, EventArgs e)
        {
            string modelPath = "bertsquad-10.onnx";
            runOptions = new RunOptions();
            session = new InferenceSession(modelPath);
            tokenizer = new BertUncasedLargeTokenizer();
        }

        int MaxAnswerLength = 30;
        int bestN = 20;

        private void button1_Click(object sender, EventArgs e)
        {
            txt_answer.Text = "";
            Application.DoEvents();

            string question = txt_question.Text.Trim();
            string context = txt_context.Text.Trim();

            // Get the sentence tokens.
            var tokens = tokenizer.Tokenize(question, context);

            // Encode the sentence and pass in the count of the tokens in the sentence.
            var encoded = tokenizer.Encode(tokens.Count(), question, context);

            var padding = Enumerable
              .Repeat(0L, 256 - tokens.Count)
              .ToList();

            var bertInput = new BertInput()
            {
                InputIds = encoded.Select(t => t.InputIds).Concat(padding).ToArray(),
                InputMask = encoded.Select(t => t.AttentionMask).Concat(padding).ToArray(),
                SegmentIds = encoded.Select(t => t.TokenTypeIds).Concat(padding).ToArray(),
                UniqueIds = new long[] { 0 }
            };

            // Create input tensors over the input data.
            var inputIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputIds,
                  new long[] { 1, bertInput.InputIds.Length });

            var inputMaskOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputMask,
                  new long[] { 1, bertInput.InputMask.Length });

            var segmentIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.SegmentIds,
                  new long[] { 1, bertInput.SegmentIds.Length });

            var uniqueIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.UniqueIds,
                  new long[] { bertInput.UniqueIds.Length });

            var inputs = new Dictionary<string, OrtValue>
              {
                  { "unique_ids_raw_output___9:0", uniqueIdsOrtValue },
                  { "segment_ids:0", segmentIdsOrtValue},
                  { "input_mask:0", inputMaskOrtValue },
                  { "input_ids:0", inputIdsOrtValue }
              };

            stopWatch.Restart();
            // Run session and send the input data in to get inference output. 
            var output = session.Run(runOptions, inputs, session.OutputNames);
            stopWatch.Stop();

            var startLogits = output[1].GetTensorDataAsSpan<float>();

            var endLogits = output[0].GetTensorDataAsSpan<float>();

            var uniqueIds = output[2].GetTensorDataAsSpan<long>();

            var contextStart = tokens.FindIndex(o => o.Token == "[SEP]");

            var bestStartLogits = startLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestEndLogits = endLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestResultsWithScore = bestStartLogits
                .SelectMany(startLogit =>
                    bestEndLogits
                    .Select(endLogit =>
                        (
                            StartLogit: startLogit.Index,
                            EndLogit: endLogit.Index,
                            Score: startLogit.Logit + endLogit.Logit
                        )
                     )
                )
                .Where(entry => !(entry.EndLogit < entry.StartLogit || entry.EndLogit - entry.StartLogit > MaxAnswerLength || entry.StartLogit == 0 && entry.EndLogit == 0 || entry.StartLogit < contextStart))
                .Take(bestN);

            var (item, probability) = bestResultsWithScore
                .Softmax(o => o.Score)
                .OrderByDescending(o => o.Probability)
                .FirstOrDefault();

            int startIndex = item.StartLogit;
            int endIndex = item.EndLogit;

            var predictedTokens = tokens
                          .Skip(startIndex)
                          .Take(endIndex + 1 - startIndex)
                          .Select(o => tokenizer.IdToToken((int)o.VocabularyIndex))
                          .ToList();

            // Print the result.
            string answer = "answer:" + String.Join(" ", StitchSentenceBackTogether(predictedTokens))
                + "\r\nprobability:" + probability
                + $"\r\n推理耗时:{stopWatch.ElapsedMilliseconds}毫秒";

            txt_answer.Text = answer;
            Console.WriteLine(answer);

        }

        private List<string> StitchSentenceBackTogether(List<string> tokens)
        {
            var currentToken = string.Empty;

            tokens.Reverse();

            var tokensStitched = new List<string>();

            foreach (var token in tokens)
            {
                if (!token.StartsWith("##"))
                {
                    currentToken = token + currentToken;
                    tokensStitched.Add(currentToken);
                    currentToken = string.Empty;
                }
                else
                {
                    currentToken = token.Replace("##", "") + currentToken;
                }
            }

            tokensStitched.Reverse();

            return tokensStitched;
        }
    }
}

下载

源码下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1285135.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码随想录刷题题Day5

刷题的第五天&#xff0c;希望自己能够不断坚持下去&#xff0c;迎来蜕变。&#x1f600;&#x1f600;&#x1f600; 刷题语言&#xff1a;C / Python Day5 任务 ● 哈希表理论基础 ● 242.有效的字母异位词 ● 349. 两个数组的交集 ● 202. 快乐数 ● 1. 两数之和 1 哈希表理…

Rust 枚举和模式匹配

目录 1、枚举的定义 1.1 Option 枚举和其相对于空值的优势 2、match 控制流结构 2.1 匹配 Option 2.2 匹配是穷尽的 2.3 通配模式和 _ 占位符 3、if let 简洁控制流 1、枚举的定义 枚举&#xff08;enumerations&#xff09;&#xff0c;也被称作 enums。枚举允许你通过…

微信小程序 纯css画仪表盘

刚看到设计稿的时候第一时间想到的就是用canvas来做这个仪表盘&#xff0c;虽然本人的画布用的不是很好但还可以写一写&#x1f600;。话不多说直接上代码。最后有纯css方法 <!--wxml--> <canvas canvas-id"circle" class"circle" >// js dat…

文章解读与仿真程序复现思路——中国电机工程学报EI\CSCD\北大核心《考虑富氧燃烧技术的电–气–热综合能源系统低碳经济调度》

这个标题涉及到一个关于能源系统和经济调度的复杂主题。让我们逐步解读&#xff1a; 电–气–热综合能源系统&#xff1a; 指的是一个综合的能源系统&#xff0c;包括了电力、气体&#xff08;可能是天然气等&#xff09;、热能等多个能源形式。这种系统的设计和优化旨在使不同…

【数据结构(六)】排序算法介绍和算法的复杂度计算(1)

文章目录 1. 排序算法的介绍1.1. 排序的分类 2. 算法的时间复杂度2.1. 度量一个程序(算法)执行时间的两种方法2.2. 时间频度2.2.1. 忽略常数项2.2.2. 忽略低次项2.2.2. 忽略系数 2.3. 时间复杂度2.4. 常见的时间复杂度2.5. 平均时间复杂度和最坏时间复杂度 3. 算法的空间复杂度…

编码器-解码器(seq-seq)

1. 背景 encoder-decoder和seq-seq模型可以解决输入与输出都是不定长序列的问题。它们都用到了两个循环NN&#xff0c;分别叫做编码器(用来分析输入序列)与解码器(用来生成输出序列)。 2. 编码器 把一个不定长的输入序列变换成一个定长的背景变量c&#xff0c;并在其中编码输入…

分享72个节日PPT,总有一款适合您

分享72个节日PPT&#xff0c;总有一款适合您 72个节日PPT下载链接&#xff1a;https://pan.baidu.com/s/17Lmlvd_xN-xRSKu3FZUS9w?pwd6666 提取码&#xff1a;6666 Python采集代码下载链接&#xff1a;采集代码.zip - 蓝奏云 学习知识费力气&#xff0c;收集整理更不易…

驱动开发--内核添加新功能

Ubuntu下这个文件为开发板ls命令的结果 内核的内容&#xff1a; mm&#xff1a;内存管理 fs&#xff1a;文件系统 net&#xff1a;网络协议栈 drivers&#xff1a;驱动设备 arch与init&#xff1a;跟启动相关 kernel与ipc&#xff1a;任务&#xff0c;进程相关 向内核增…

java项目日常运维需要的文档资料

一、前言 java项目开发完成&#xff0c;部署上线&#xff0c;进入项目运维阶段&#xff0c;日常工作需要准备哪些资料和文档?当项目上线后&#xff0c;运行一段时间&#xff0c;或多或少会遇到一些运维上的问题&#xff0c;比如服务器磁盘饱满&#xff0c;服务器CPU&#xff0…

分享 | 顶刊高质量论文插图配色(含RGB值及16进制HEX码)(第一期)

我在很早之前出过一期高质量论文绘图配色&#xff0c;但当时觉得搜集太麻烦于是就没继续做&#xff0c;后来用MATLAB爬了上万张顶刊绘图&#xff0c;于是又想起来做这么一个系列&#xff0c;拿了一个多小时写了个提取论文图片颜色并得出RGB值和16进制码并标注在原图的代码&…

GPTs每日推荐--生化危机【典藏版】

今天给大家推荐一个游戏性质的GPTs&#xff0c;叫做生化危机典藏版&#xff0c;国内点击可玩。 开篇&#xff1a;玩家从末日中醒来。 选择&#xff1a;玩家会遇到各种资源、任务、剧情&#xff0c;需要自行选择相关的分支剧情&#xff0c;一旦选错&#xff0c;无法重选。 结局…

一次北斗接收机调试总结

作者&#xff1a;朱金灿 来源&#xff1a;clever101的专栏 为什么大多数人学不会人工智能编程&#xff1f;>>> 最近项目中要用到北斗接收机&#xff0c;它的样子是长这样的&#xff1a; 这部机器里面是没有操作系统的&#xff0c;由单片机控制。最近我们要根据协议…

Linux socket编程(10):UDP详解、聊天室实现及进阶知识

首先来回顾以下TCP的知识&#xff0c;TCP是一种面向连接的、可靠的传输协议&#xff0c;具有以下特点&#xff1a; TCP通过三次握手建立连接&#xff0c;确保通信的可靠性和完整性使用流控制和拥塞控制机制&#xff0c;有效地调整数据传输的速率&#xff0c;防止网络拥塞TCP提…

使用 PyTorch 进行 K 折交叉验证

一、说明 中号机器学习模型在训练后必须使用测试集进行评估。我们这样做是为了确保模型不会过度拟合&#xff0c;并确保它们适用于现实生活中的数据集&#xff0c;与训练集相比&#xff0c;现实数据集的分布可能略有偏差。 但为了使您的模型真正稳健&#xff0c;仅仅通过训练/测…

OneNote for Windows10 彻底删除笔记本

找了超多方法&#xff0c;都没有用&#xff0c;我的OneNote都没有文件选项&#xff0c;要在OneDrive中删除&#xff0c;但是一直登不进&#xff0c;然后又找到一个方法&#xff1a; 在网页中打开Office的控制面板 "Sign in to your Microsoft account" 在“最近”一…

k8s volumes and data

Overview 传统上&#xff0c;容器引擎(Container Engine)不提供比容器寿命更长的存储。由于容器被认为是瞬态(transient)的&#xff0c;这可能会导致数据丢失或复杂的外部存储选项。Kubernetes卷共享 Pod 生命周期&#xff0c;而不是其中的容器。如果容器终止&#xff0c;数据…

ctfhub技能树_web_信息泄露

目录 二、信息泄露 2.1、目录遍历 2.2、Phpinfo 2.3、备份文件下载 2.3.1、网站源码 2.3.2、bak文件 2.3.3、vim缓存 2.3.4、.DS_Store 2.4、Git泄露 2.4.1、log 2.4.2、stash 2.4.3、index 2.5、SVN泄露 2.6、HG泄露 二、信息泄露 2.1、目录遍历 注&#xff1…

POI Excel导入导出(下)

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 上一篇通过四个简单的小…

力扣刷题day1(两数相加,回文数,罗马数转整数)

题目1&#xff1a;1.两数之和 思路1和解析&#xff1a; //1.暴力枚举解法(历遍两次数组&#xff0c;时间复杂度O&#xff08;N^2)&#xff0c;空间复杂度O&#xff08;1&#xff09; int* twoSum(int* nums, int numsSize, int target, int* returnSize) {for (int i 0; i &…

短波红外相机的原理及应用场景

短波红外 (简称SWIR&#xff0c;通常指0.9~1.7μm波长的光线) 是一种比可见光波长更长的光。这些光不能通过“肉眼”看到&#xff0c;也不能用“普通相机”检测到。由于被检测物体的材料特性&#xff0c;一些在可见光下无法看到的特性&#xff0c;却能在近红外光下呈现出来&…