【如何训练一个中英翻译模型】LSTM机器翻译模型部署之onnx(python)(四)

news2025/1/11 7:38:06

系列文章
【如何训练一个中英翻译模型】LSTM机器翻译seq2seq字符编码(一)

【如何训练一个中英翻译模型】LSTM机器翻译模型训练与保存(二)

【如何训练一个中英翻译模型】LSTM机器翻译模型部署(三)

【如何训练一个中英翻译模型】LSTM机器翻译模型部署之onnx(python)(四)

目录

  • 一、事前准备
  • 二、.h5模型保存为TFSaveModel格式样例
  • 三、模型转换
    • 1、encoder_model的转换
      • 1).h5模型保存为TFSaveModel
      • 2)TFSaveModel格式模型保存为onnx模型
      • 3)onnx模型简化
    • 2、decoder_model的转换
      • 1).h5模型保存为TFSaveModel
      • 2)TFSaveModel格式模型保存为onnx模型
      • 3)onnx模型简化
  • 4、onnx模型推理
      • 1)加载模型数据
      • 2)查看模型输入输出信息
      • 3)模型推理搭建
      • 4)模型推理
      • 5)完整代码

一、事前准备

先把要用到的几个工具说一下:

ncnn:https://github.com/Tencent/ncnn
tf2onnx:https://github.com/onnx/tensorflow-onnx
netron:https://netron.app
onnxsim:https://github.com/daquexian/onnx-simplifier
onnxruntime:https://github.com/microsoft/onnxruntime
以上工具的安装与使用后面会抽空补充一下,在这里先记录下,以免忘记了

有了工具之后,我们还需要以下几个文件:
在这里插入图片描述
这几个文件可以在前面的文章【如何训练一个中译英翻译器】LSTM机器翻译模型训练与保存(二)训练一个模型并保存模型得到,最快的方式就是运行文章最后的kaggle notebook,直接得到文件,然后下载下来即可

二、.h5模型保存为TFSaveModel格式样例

要将tf模型转为onnx模型,我们需要先将格式为.h5的tf模型保存为saved_model的格式,先给出样例:

import tensorflow as tf
from keras.models import load_model

# 加载Keras模型
model = load_model('encoder_model.h5')

# 转换为SavedModel类型
tf.saved_model.save(model, 'TFSaveModel')

三、模型转换

1、encoder_model的转换

1).h5模型保存为TFSaveModel

import tensorflow as tf
from keras.models import load_model

# 加载Keras模型
model = load_model('encoder_model.h5')

# 转换为SavedModel类型
tf.saved_model.save(model, 'TFSaveModel')

2)TFSaveModel格式模型保存为onnx模型

python3 -m tf2onnx.convert --saved-model TFSaveModel --output onnxModel/encoder_model.onnx

3)onnx模型简化

打开https://netron.app/来看下网络结构,主要是先看输入部分的维度(网络结构后面会细讲)
可以看到输入维度:
input_1:[unk__64、unk__65、62]
我们需要将 unk__64、unk__65 这两个改为具体数值,否则在导出ncnn模型时会报一些op不支持的错误,那么问题来了,要怎么改,我也不知道啊!!!
哈哈哈,开完笑的,都写出来了,怎么会不知道,请听我慢慢说来。
在这里插入图片描述[unk__64、unk__65、62]
其实数据第一个unk__64是batch,第二个unk__65是输入句子的最大长度,第三个62是字符总数量,我们在推理时,batch size一般为1,所以这个input_1的shape就是[1,max_encoder_seq_length, num_encoder_tokens](num_encoder_tokens模型已经帮我们填好了)
max_encoder_seq_length, num_encoder_tokens 这两个参数可以在训练的时候获取到了,拿到这个input shape 之后,对onnx模型进行simplify,我训练出来的模型时得到的shape是[1,16,62],因此执行以下命令:

python3 -m onnxsim onnxModel/encoder_model.onnx onnxModel/encoder_model-sim.onnx --overwrite-input-shape 1,16,62

可得到简化后的onnx模型
在这里插入图片描述
这个时候,我们再用https://netron.app打开encoder_model-sim.onnx,可以看到encoder模型的输出了,有两个输出,均为[1,256]的维度
在这里插入图片描述

2、decoder_model的转换

然后我们需要对decoder_model.h5也进行转换,

1).h5模型保存为TFSaveModel

import tensorflow as tf
from keras.models import load_model

# 加载Keras模型
model = load_model('decoder_model.h5')

# 转换为SavedModel类型
tf.saved_model.save(model, 'TFSaveModel')

2)TFSaveModel格式模型保存为onnx模型

python3 -m tf2onnx.convert --saved-model TFSaveModel --output onnxModel/decoder_model.onnx

3)onnx模型简化

同样打开模型来看,能看到一共有三个输入:
input_2:[unk__55,unk__56,849]
input_3:[unk__57,256]
input_4:[unk__58,256]
其中,input_3、input_4为encoder的输出,因此可以得到这两个输入维度均为[1,256]
那么,input_2的维度是多少,我们接着往下看。
在这里插入图片描述
我们想一想,解码器除了接受编码器的数据,还有什么数据没给它,没有错,就是target_characters的特征,对于英译中而言就是中文的字符,要解码器解出中文,肯定要把中文数据给它,要不然你让解码器去解空气啊,实际上这个 input_2的维度就是

target_seq = np.zeros((1, 1, num_decoder_tokens))

num_decoder_tokens同样可以在训练的时候获取到(至于不知道怎么来的,可以看这个系列文章的第一、二篇),我这边得到的num_decoder_tokens是849,当然实际上这个模型的 input_2:[unk__55,unk__56,849]已经给了num_decoder_tokens,我们只需要把unk__55,unk__56都改为1就可以了,即[1,1,849],那么对onnx进行simplify

python3 -m onnxsim onnxModel/decoder_model.onnx onnxModel/decoder_model-sim.onnx --overwrite-input-shape input_2:1,1,849 input_3:1,256 input_4:1,256

成功完成simplify可得到:
在这里插入图片描述

4、onnx模型推理

到最后一步了,导出onnx模型后,要试试这个模型怎么样,所以拿过来推理一波,推理代码是从前面文章【如何训练一个中译英翻译器】LSTM机器翻译模型训练与保存(二)的第小6节模型加载与推理里面的代码改过来的,感兴趣的小伙伴可以去看看两者的差异

1)加载模型数据

模型数据的加载主要是加载input_words.txt、target_words.txt、config.json、encoder_model-sim.onnx、decoder_model-sim.onnx 这几个文件

input_words.txt、target_words.txt:为输入输出字符表
config.json:为最长输入长度与最长输出长度
encoder_model-sim.onnx、decoder_model-sim.onnx :为导出的onnx模型

import onnxruntime
import numpy as np
# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:
    input_words = f.readlines()
    input_characters = [line.rstrip('\n') for line in input_words]

# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:
    target_words = [line.strip() for line in f.readlines()]
    target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]

#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])

# something readable.
reverse_input_char_index = dict(
    (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
    (i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量

import json
with open('config/config.json', 'r') as file:
    loaded_data = json.load(file)

# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]



encoderSess = onnxruntime.InferenceSession('onnxModel/encoder_model-sim.onnx')
decoderSess = onnxruntime.InferenceSession('onnxModel/decoder_model-sim.onnx')

2)查看模型输入输出信息

查看输入输出信息主要是为了获取输入名称,在进行模型输入的时候,要先知道模型有哪些输入,维度是多少,才能输入正确的数据


print("----------------- 输入部分 -----------------")
input_tensors = encoderSess.get_inputs()  # 该 API 会返回列表
for input_tensor in input_tensors:         # 因为可能有多个输入,所以为列表
    
    input_info = {
        "name" : input_tensor.name,
        "type" : input_tensor.type,
        "shape": input_tensor.shape,
    }
    print(input_info)

print("----------------- 输出部分 -----------------")
output_tensors = encoderSess.get_outputs()  # 该 API 会返回列表
for output_tensor in output_tensors:         # 因为可能有多个输出,所以为列表
    
    output_info = {
        "name" : output_tensor.name,
        "type" : output_tensor.type,
        "shape": output_tensor.shape,
    }
    print(output_info)



print("----------------- 输入部分 -----------------")
input_tensors = decoderSess.get_inputs()  # 该 API 会返回列表
for input_tensor in input_tensors:         # 因为可能有多个输入,所以为列表
    
    input_info = {
        "name" : input_tensor.name,
        "type" : input_tensor.type,
        "shape": input_tensor.shape,
    }
    print(input_info)

print("----------------- 输出部分 -----------------")
output_tensors = decoderSess.get_outputs()  # 该 API 会返回列表
for output_tensor in output_tensors:         # 因为可能有多个输出,所以为列表
    
    output_info = {
        "name" : output_tensor.name,
        "type" : output_tensor.type,
        "shape": output_tensor.shape,
    }
    print(output_info)

3)模型推理搭建


def decode_sequence(input_seq):
    # Encode the input as state vectors.
    states_value = encoderSess.run(None, {'input_1': input_seq})
    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, num_decoder_tokens), dtype=np.float32)
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index['\t']] = 1.
    # this target_seq you can treat as initial state
    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoderSess.run(None, {'input_2': target_seq, 'input_3': states_value[0], 'input_4': states_value[1]})
        # Sample a token
        # argmax: Returns the indices of the maximum values along an axis
        # just like find the most possible char
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        # find char using index
        sampled_char = reverse_target_char_index[sampled_token_index]
        # and append sentence
        decoded_sentence += sampled_char
        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True
        # Update the target sequence (of length 1).
        # append then ?
        # creating another new target_seq
        # and this time assume sampled_token_index to 1.0
        target_seq = np.zeros((1, 1, num_decoder_tokens), dtype=np.float32)
        target_seq[0, 0, sampled_token_index] = 1.
        # Update states
        # update states, frome the front parts
        states_value = [h, c]
    return decoded_sentence


input_text = "Call me."
encoder_input_data = np.zeros(
    (1,max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
for t, char in enumerate(input_text):
    # 3D vector only z-index has char its value equals 1.0
    encoder_input_data[0,t, input_token_index[char]] = 1.

4)模型推理

input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)

5)完整代码

import onnxruntime
import numpy as np
# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:
    input_words = f.readlines()
    input_characters = [line.rstrip('\n') for line in input_words]

# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:
    target_words = [line.strip() for line in f.readlines()]
    target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]

#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])

# something readable.
reverse_input_char_index = dict(
    (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
    (i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量

import json
with open('config/config.json', 'r') as file:
    loaded_data = json.load(file)

# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]



encoderSess = onnxruntime.InferenceSession('onnxModel/encoder_model-sim.onnx')
decoderSess = onnxruntime.InferenceSession('onnxModel/decoder_model-sim.onnx')


print("----------------- 输入部分 -----------------")
input_tensors = encoderSess.get_inputs()  # 该 API 会返回列表
for input_tensor in input_tensors:         # 因为可能有多个输入,所以为列表
    
    input_info = {
        "name" : input_tensor.name,
        "type" : input_tensor.type,
        "shape": input_tensor.shape,
    }
    print(input_info)

print("----------------- 输出部分 -----------------")
output_tensors = encoderSess.get_outputs()  # 该 API 会返回列表
for output_tensor in output_tensors:         # 因为可能有多个输出,所以为列表
    
    output_info = {
        "name" : output_tensor.name,
        "type" : output_tensor.type,
        "shape": output_tensor.shape,
    }
    print(output_info)



print("----------------- 输入部分 -----------------")
input_tensors = decoderSess.get_inputs()  # 该 API 会返回列表
for input_tensor in input_tensors:         # 因为可能有多个输入,所以为列表
    
    input_info = {
        "name" : input_tensor.name,
        "type" : input_tensor.type,
        "shape": input_tensor.shape,
    }
    print(input_info)

print("----------------- 输出部分 -----------------")
output_tensors = decoderSess.get_outputs()  # 该 API 会返回列表
for output_tensor in output_tensors:         # 因为可能有多个输出,所以为列表
    
    output_info = {
        "name" : output_tensor.name,
        "type" : output_tensor.type,
        "shape": output_tensor.shape,
    }
    print(output_info)



def decode_sequence(input_seq):
    # Encode the input as state vectors.
    states_value = encoderSess.run(None, {'input_1': input_seq})
    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, num_decoder_tokens), dtype=np.float32)
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index['\t']] = 1.
    # this target_seq you can treat as initial state
    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoderSess.run(None, {'input_2': target_seq, 'input_3': states_value[0], 'input_4': states_value[1]})
        # Sample a token
        # argmax: Returns the indices of the maximum values along an axis
        # just like find the most possible char
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        # find char using index
        sampled_char = reverse_target_char_index[sampled_token_index]
        # and append sentence
        decoded_sentence += sampled_char
        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True
        # Update the target sequence (of length 1).
        # append then ?
        # creating another new target_seq
        # and this time assume sampled_token_index to 1.0
        target_seq = np.zeros((1, 1, num_decoder_tokens), dtype=np.float32)
        target_seq[0, 0, sampled_token_index] = 1.
        # Update states
        # update states, frome the front parts
        states_value = [h, c]
    return decoded_sentence


input_text = "Call me."
encoder_input_data = np.zeros(
    (1,max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
for t, char in enumerate(input_text):
    # 3D vector only z-index has char its value equals 1.0
    encoder_input_data[0,t, input_token_index[char]] = 1.


input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)


可以看到运行结果:
在这里插入图片描述
代码比较简单,然后也有加一些注释,就不再细讲了,要不然就显得有点啰嗦,有疑问的可以留言,欢迎交流!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/806173.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

极简每周计划应用程序WeekToDo

什么是 WeekToDo ? WeekToDo 是一款免费的极简每周计划应用程序,专注于隐私。使用待办事项列表和日历安排您的任务和项目。适用于 Windows、Mac、Linux 或在线。 WeekToDo 是一个免费且开源的极简每周计划程序。借助 WeekToDo,您可以以简单直观的方式定…

Matplotlib_绘制柱状图

绘制柱状图 🧩bar方法 bar()是Matplotlib.pyplot库中用于绘制条形图(bar chart)的函数。条形图是一种常见的数据可视化图表,用于显示不同类别之间的比较。 函数签名: matplotlib.pyplot.bar(x, height, width0.8, …

ICMP协议(网际报文控制协议)详解

ICMP协议(网际报文控制协议)详解 ICMP协议的功能ICMP的报文格式常见的ICMP报文差错报文目的站不可达数据报超时 查询报文回送请求或回答 ICMP协议是一个网络层协议。 一个新搭建好的网络,往往需要先进行一个简单的测试,来验证网络…

JDBC Some Templates

JDBCTemplate 是Spring对JDBC的封装&#xff0c;使用JDBCTemplate方便实现对数据的操作。 <!-- orm:Object relationship mapping m对象 关系 映射-->引入依赖 <!-- 基于Maven依赖的传递性&#xff0c;导入spring-content依赖即可导入当前所需的所有…

Spring项目启动报错无法访问org.springframework.boot.SpringApplication:6

当springBoot项目启动后报错如下 解决办法如下&#xff1a;将jdk版本调为11,springBoot版本降低为2.7.12。然后clean&#xff0c;再package重新打包。最后重新启动项目

存储论——经济订货批量的R实现

存储论又称库存理论&#xff0c;是运筹学中发展较早的分支。早在 1915 年&#xff0c;哈李斯&#xff08;F.Harris&#xff09;针对银行货币的储备问题进行了详细的研究&#xff0c;建立了一个确定性的存储费用模型&#xff0c;并求得了最佳批量公式。1934 年威尔逊&#xff08…

第五章 HL7 架构和可用工具 - 创建新的自定义架构

文章目录 第五章 HL7 架构和可用工具 - 创建新的自定义架构创建新的自定义架构定义新段 第五章 HL7 架构和可用工具 - 创建新的自定义架构 创建新的自定义架构 要从管理门户启动自定义架构编辑器&#xff0c;请从主页选择互操作性 > 互操作 > HL7 v2.x >HL7 v2.x 架…

单机和集群以及分布式的浅析

假设一个大系统分为A、B、C、D、E五个模块&#xff0c;也可以认为是五个基本的服务&#xff0c;该系统靠这五个模块协同工作&#xff0c;共同为用户提供服务。 单机 单机&#xff1a;显然&#xff0c;单机表名该系统完完全全的部署在该台机器上&#xff0c;拥有完整的服务&am…

算法38:反转链表【O(n)方案】

一、需求 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1] 示例 2&#xff1a; 输入&#xff1a;head [1,2] 输出&#xff1a;[2,1] 示例3&#xff…

监听镜像版本变化触发 GitOps工作流

文章目录 前言工作流总览安装和配置 ArgoCD Image Updater创建 Image Pull Secret&#xff08;可选&#xff09;创建 Helm Chart 仓库创建 ArgoCD Application删除旧应用&#xff08;可选&#xff09;配置仓库访问权限创建 ArgoCD 应用 体验 GitOps 工作流总结 前言 在【GitOps…

FastDeploy的方式在OK3588上部署yolov7-- C++

FastDeploy介绍 ⚡️FastDeploy是一款全场景、易用灵活、极致高效的AI推理部署工具&#xff0c; 支持云边端部署。提供超过 &#x1f525;160 Text&#xff0c;Vision&#xff0c; Speech和跨模态模型&#x1f4e6;开箱即用的部署体验&#xff0c;并实现&#x1f51a;端到端的…

附录1-将uni-app运行到微信开发者工具

目录 1 在manifest.json写入AppID 2 配置微信开发者工具的安装路径 3 微信开发者工具的安全设置 4 运行 5 修改一些配置项 1 在manifest.json写入AppID 2 配置微信开发者工具的安装路径 如果你忘了安装在哪里了&#xff0c;可以右键快捷方式看一下属性 在运行设置…

邻接矩阵与邻接表

文章目录 0 前面几种数据结构的回顾1 图1.1 图的定义1.2 常见术语1.3 图的抽象数据类型定义1.4 表示一个图1.4.1 邻接矩阵表示法1.4.2 邻接表 1.5 图的构建1.5.1 邻接矩阵法1.5.2 邻接表法 0 前面几种数据结构的回顾 1 图 1.1 图的定义 图&#xff1a; G (V,E) // Graph (V…

Moke 一百万条 Mysql 的数据

文章目录 前言创建数据库创建表结构生成数据 前言 想研究一下&#xff0c;数据量大的情况下&#xff0c;如何优化前端分页&#xff0c;所以需要 Moke 一些数据 创建数据库 在 Mysql的基础上&#xff0c;可以写个语句执行 CREATE DATABASE test_oneMillion; USE test_oneMi…

Jmeter —— 录制脚本

1. 第一步&#xff1a;添加http代理服务器&#xff0c;在测试计划--》添加--》非测试元件--》http代理服务器 2. 第二步&#xff1a;添加线程组&#xff08;这个线程组是用来放录制的脚本&#xff0c;不添加也可以&#xff0c;就直接放在代理服务器下&#xff09; 测试计划--》…

【Linux】sed修改文件指定内容

sed修改文件指定内容&#xff1a; 参考&#xff1a;(5条消息) Linux系列讲解 —— 【cat echo sed】操作读写文件内容_shell命令修改文件内容_星际工程师的博客-CSDN博客

如何连接远程服务器?快解析内内网穿透可以吗?

如今我们迎来了数字化转型的时代&#xff0c;众多企业来为了更好地推动业务的发展&#xff0c;常常需要在公司内部搭建一个远程服务器。然而&#xff0c;对于企业员工来说&#xff0c;在工作过程中经常需要与这个服务器进行互动&#xff0c;而服务器位于公司的局域网中&#xf…

活动目录(Active Directory) 管理工具

每个IT管理员几乎每天都在Active Directory管理中面临许多挑战&#xff0c;尤其是在管理Active Directory用户帐户方面。手动配置用户属性非常耗时、令人厌烦且容易出错&#xff0c;尤其是在大型、复杂的 Windows 网络中。Active Directory管理员和IT经理大多必须执行重复和世俗…

20.3 HTML 表格

1. table表格 table标签是HTML中用来创建表格的元素. table标签通常包含以下子标签: - th标签: 表示表格的表头单元格(table header), 用于描述列的标题. - tr标签: 表示表格的行(table row). - td标签: 表示表格的单元格(table data), 通常位于tr标签内, 用于放置单元格中的…

进阶高级测试专项,Pytest自动化测试框架总结(二)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、pyets种有四种…