【如何训练一个中英翻译模型】LSTM机器翻译模型部署之ncnn(python)(五)

news2025/1/13 6:13:49

系列文章
【如何训练一个中英翻译模型】LSTM机器翻译seq2seq字符编码(一)
【如何训练一个中英翻译模型】LSTM机器翻译模型训练与保存(二)
【如何训练一个中英翻译模型】LSTM机器翻译模型部署(三)
【如何训练一个中英翻译模型】LSTM机器翻译模型部署之onnx(python)(四)

目录

  • 一、事情准备
  • 二、模型转换
  • 三、ncnn模型加载与推理(python版)

一、事情准备

这篇是在【如何训练一个中译英翻译器】LSTM机器翻译模型部署之onnx(python)(四)的基础上进行的,要用到文件为:

input_words.txt
target_words.txt
config.json
encoder_model-sim.onnx
decoder_model-sim.onnx

其中的onnx就是用来转为ncnn模型的,这里借助了onnx这个中间商,所以前面我们需要先通过onnxsim对模型进行simplify,要不然在模型转换时会出现op不支持的情况(模型转换不仅有中间商这个例子,目前还可以通过pnnx直接将pytorch模型转为ncnn,感兴趣的小伙伴可以去折腾下)
老规矩,先给出工具:

onnx2ncnn:https://github.com/Tencent/ncnn
netron:https://netron.app

二、模型转换

这里进行onnx转ncnn,通过命令进行转换

onnx2ncnn onnxModel/encoder_model-sim.onnx ncnnModel/encoder_model.param ncnnModel/encoder_model.bin
onnx2ncnn onnxModel/decoder_model-sim.onnx ncnnModel/decoder_model.param ncnnModel/decoder_model.bin

转换成功可以看到:
在这里插入图片描述
转换之后可以对模型进行优化,但是奇怪的是,这里优化了不起作用,去不了MemoryData这些没用的op

ncnnoptimize ncnnModel/encoder_model.param ncnnModel/encoder_model.bin ncnnModel/encoder_model.param ncnnModel/encoder_model.bin 1
ncnnoptimize ncnnModel/decoder_model.param ncnnModel/decoder_model.bin ncnnModel/decoder_model.param ncnnModel/decoder_model.bin 1

三、ncnn模型加载与推理(python版)

跟onnx的推理比较类似,就是函数的调用方法有点不同,这里先用python实现,验证下是否没问题,方面后面部署到其它端,比如android。
主要包括:模型加载、推理模型搭建跟模型推理,但要注意的是这里的输入输出名称需要在param这个文件里面获取。

采用netron分别查看encoder与decoder的网络结构,获取输入输出名称:

encoder:
输入输出分别如图
在这里插入图片描述
decoder:

输入
在这里插入图片描述
输出:
在这里插入图片描述

推理代码如下,推理过程感觉没问题,但是推理输出结果相差很大(对比过第一层ncnn与onnx的推理结果了),可能问题出在模型转换环节的精度损失上,而且第二层模型转换后网络输出结果不一致了,很迷,还没找出原因,但是以下的推理是能运行通过,只不过输出结果有问题

import numpy as np
import ncnn


# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:
    input_words = f.readlines()
    input_characters = [line.rstrip('\n') for line in input_words]

# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:
    target_words = [line.strip() for line in f.readlines()]
    target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]

#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])

# something readable.
reverse_input_char_index = dict(
    (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
    (i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量

import json
with open('config/config.json', 'r') as file:
    loaded_data = json.load(file)

# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]





# Load the ncnn models for the encoder and decoder
encoderNet = ncnn.Net()
encoderNet.load_param("ncnnModel/encoder_model.param")
encoderNet.load_model("ncnnModel/encoder_model.bin")

decoderNet = ncnn.Net()
decoderNet.load_param("ncnnModel/decoder_model.param")
decoderNet.load_model("ncnnModel/decoder_model.bin")





def decode_sequence(input_seq):
    # Encode the input as state vectors.
    # print(input_seq)
    ex_encoder = encoderNet.create_extractor()
    ex_encoder.input("input_1", ncnn.Mat(input_seq))
    states_value = []

    _, LSTM_1 = ex_encoder.extract("lstm")
    _, LSTM_2 = ex_encoder.extract("lstm_1")


    states_value.append(LSTM_1)
    states_value.append(LSTM_2)


    # print(ncnn.Mat(input_seq))
    # print(vgdgd)
    
    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, 849))

    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index['\t']] = 1.
    # this target_seq you can treat as initial state



    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    ex_decoder = decoderNet.create_extractor()
    while not stop_condition:
        
        
        #print(ncnn.Mat(target_seq))
        
        print("---------")

        
        ex_decoder.input("input_2", ncnn.Mat(target_seq))
        ex_decoder.input("input_3", states_value[0])
        ex_decoder.input("input_4", states_value[1])
        _, output_tokens = ex_decoder.extract("dense")
        _, h = ex_decoder.extract("lstm_1")
        _, c = ex_decoder.extract("lstm_1_1")

        print(output_tokens)


        tk = []
        for i in range(849):
            tk.append(output_tokens[849*i])

        tk = np.array(tk)
        output_tokens = tk.reshape(1,1,849)

        print(output_tokens)



        # print(fdgd)
        
        print(h)
        print(c)
        
        
        # output_tokens = np.array(output_tokens)
        # output_tokens = output_tokens.reshape(1, 1, -1)


        # # h = np.array(h)
        # # c = np.array(c)
        # print(output_tokens.shape)
        # print(h.shape)
        # print(c.shape)
        
        
        #output_tokens, h, c = decoder_model.predict([target_seq] + states_value)

        # Sample a token
        # argmax: Returns the indices of the maximum values along an axis
        # just like find the most possible char
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        # find char using index
        sampled_char = reverse_target_char_index[sampled_token_index]
        # and append sentence
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        # append then ?
        # creating another new target_seq
        # and this time assume sampled_token_index to 1.0
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        print(sampled_token_index)

        # Update states
        # update states, frome the front parts
        
        states_value = [h, c]

    return decoded_sentence
    

import numpy as np

input_text = "Call me."
encoder_input_data = np.zeros(
    (1,max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
for t, char in enumerate(input_text):
    print(char)
    # 3D vector only z-index has char its value equals 1.0
    encoder_input_data[0,t, input_token_index[char]] = 1.


input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)

decoder的模型输出为849*849,感觉怪怪的,然后我们把模型的输入固定下来看看是不是模型的问题。
打开decoder_model.param,把输入层固定下来,0=w 1=h 2=c,那么:
input_2:0=849 1=1 2=1
input_3:0=256 1=1
input_4:0=256 1=1

运行以下命令进行优化

ncnnoptimize ncnnModel/decoder_model.param ncnnModel/decoder_model.bin ncnnModel/decoder_model.param ncnnModel/decoder_model.bin 1

结果如下:
在这里插入图片描述
打开网络来看一下:
可以看到输出确实是849849(红色框),那就是模型转换有问题了
在这里插入图片描述
仔细看,能够看到有两个shape(蓝色框)分别为849跟849
1,这两个不同维度的网络进行BinaryOP之后,就变成849849了,那么,我们把Reshape这个网络去掉试试(不把前面InnerProduct的输入维度有849reshape为8491),下面来看手术刀怎么操作。

我们需要在没经过固定维度并ncnnoptimize的模型上操作(也就是没经过上面0=w 1=h 2=c修改的模型上操作)
根据名字我们找到Reshape那一层:
在这里插入图片描述
然后找到与reshape那一层相连接的上一层(红色框)与下一层(蓝色框)
在这里插入图片描述
通过红色框与蓝色框里面的名字我们找到了上层与下层分别为InnerProduct与BinaryOp
在这里插入图片描述
这时候,把InnerProduct与BinaryOp接上,把Reshape删掉
在这里插入图片描述
再改一下最上面的层数,把19改为18,因为我们删掉了一层
在这里插入图片描述保存之后再次执行

ncnnoptimize ncnnModel/decoder_model.param ncnnModel/decoder_model.bin ncnnModel/decoder_model.param ncnnModel/decoder_model.bin 1

执行后可以看到网络层数跟blob数都更新了
在这里插入图片描述

这时候改一下固定一下输入层数,并运行ncnnoptimize,再打开netron看一下网络结构,可以看到输出维度正常了
在这里插入图片描述
但是通过推理结果还是不对,没找到原因,推理代码如下:

import numpy as np
import ncnn


# 加载字符
# 从 input_words.txt 文件中读取字符串
with open('config/input_words.txt', 'r') as f:
    input_words = f.readlines()
    input_characters = [line.rstrip('\n') for line in input_words]

# 从 target_words.txt 文件中读取字符串
with open('config/target_words.txt', 'r', newline='') as f:
    target_words = [line.strip() for line in f.readlines()]
    target_characters = [char.replace('\\t', '\t').replace('\\n', '\n') for char in target_words]

#字符处理,以方便进行编码
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])

# something readable.
reverse_input_char_index = dict(
    (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
    (i, char) for char, i in target_token_index.items())
num_encoder_tokens = len(input_characters) # 英文字符数量
num_decoder_tokens = len(target_characters) # 中文文字数量

import json
with open('config/config.json', 'r') as file:
    loaded_data = json.load(file)

# 从加载的数据中获取max_encoder_seq_length和max_decoder_seq_length的值
max_encoder_seq_length = loaded_data["max_encoder_seq_length"]
max_decoder_seq_length = loaded_data["max_decoder_seq_length"]





# Load the ncnn models for the encoder and decoder
encoderNet = ncnn.Net()
encoderNet.load_param("ncnnModel/encoder_model.param")
encoderNet.load_model("ncnnModel/encoder_model.bin")

decoderNet = ncnn.Net()
decoderNet.load_param("ncnnModel/decoder_model.param")
decoderNet.load_model("ncnnModel/decoder_model.bin")





def decode_sequence(input_seq):
    # Encode the input as state vectors.
    # print(input_seq)
    ex_encoder = encoderNet.create_extractor()
    ex_encoder.input("input_1", ncnn.Mat(input_seq))
    states_value = []

    _, LSTM_1 = ex_encoder.extract("lstm")
    _, LSTM_2 = ex_encoder.extract("lstm_1")


    states_value.append(LSTM_1)
    states_value.append(LSTM_2)


    # print(ncnn.Mat(input_seq))
    # print(vgdgd)
    
    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, 849))

    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index['\t']] = 1.
    # this target_seq you can treat as initial state



    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    ex_decoder = decoderNet.create_extractor()
    while not stop_condition:
        
        
        #print(ncnn.Mat(target_seq))
        
        print("---------")

        
        ex_decoder.input("input_2", ncnn.Mat(target_seq))
        ex_decoder.input("input_3", states_value[0])
        ex_decoder.input("input_4", states_value[1])
        _, output_tokens = ex_decoder.extract("dense")
        _, h = ex_decoder.extract("lstm_1")
        _, c = ex_decoder.extract("lstm_1_1")

        print(output_tokens)


        # print(ghfhf)


        # tk = []
        # for i in range(849):
        #     tk.append(output_tokens[849*i])

        # tk = np.array(tk)
        # output_tokens = tk.reshape(1,1,849)

        # print(output_tokens)



        # print(fdgd)
        
        print(h)
        print(c)
        
        
        output_tokens = np.array(output_tokens)
        output_tokens = output_tokens.reshape(1, 1, -1)


        # # h = np.array(h)
        # # c = np.array(c)
        # print(output_tokens.shape)
        # print(h.shape)
        # print(c.shape)
        
        
        #output_tokens, h, c = decoder_model.predict([target_seq] + states_value)

        # Sample a token
        # argmax: Returns the indices of the maximum values along an axis
        # just like find the most possible char
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        # find char using index
        sampled_char = reverse_target_char_index[sampled_token_index]
        # and append sentence
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        # append then ?
        # creating another new target_seq
        # and this time assume sampled_token_index to 1.0
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        print(sampled_token_index)

        # Update states
        # update states, frome the front parts
        
        states_value = [h, c]

    return decoded_sentence
    

import numpy as np

input_text = "Call me."
encoder_input_data = np.zeros(
    (1,max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
for t, char in enumerate(input_text):
    print(char)
    # 3D vector only z-index has char its value equals 1.0
    encoder_input_data[0,t, input_token_index[char]] = 1.


input_seq = encoder_input_data
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_text)
print('Decoded sentence:', decoded_sentence)



参考文献:https://github.com/Tencent/ncnn/issues/2586

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/794723.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MTK系统启动流程

MTK系统启动流程 boot rom -> preloader ->lk ->kernel ->Native -> Android 1、Boot rom:系统开机,最先执行的是固化在芯片内部的bootrom,其作用主要有 a.初始化ISRAM和EMMC b.当系统全擦后 ,也会配置USB,用来仿…

Android:RecyclerView封装,打造列表极简加载

前言 mBinding.recycler.linear().divider().set<OrdinaryListBean> {addLayout(R.layout.layout_ordinary_item)}.setList(getList()) 如果我要说&#xff0c;除了数据和布局之外&#xff0c;以上的几行代码&#xff0c;就实现了一个列表加载&#xff0c;有老铁会相信…

T形积木(T puzzle)

目录 积木绘制 积木拼接 练习 1. 停止标志 2. 跳跃旋转 3. 小步平移 积木绘制 &#xff08;1&#xff09;复数欧拉公式&#xff1a; &#xff08;2&#xff09;复数的极坐标形式&#xff1a; 其中 &#xff08;3&#xff09;T形积木问题利用了复数乘以将该复数值旋转b角的…

Spring源码解析(五):循环依赖

Spring源码系列文章 Spring源码解析(一)&#xff1a;环境搭建 Spring源码解析(二)&#xff1a;bean容器的创建、默认后置处理器、扫描包路径bean Spring源码解析(三)&#xff1a;bean容器的刷新 Spring源码解析(四)&#xff1a;单例bean的创建流程 Spring源码解析(五)&…

Ubuntu 20.04.4 LTS安装Terminator终端(Linux系统推荐)

Terminator终端可以在一个窗口中创建多个终端&#xff0c;并且可以水平、垂直分割&#xff0c;运行ROS时很方便。 sudo apt install terminator这样安装完成后&#xff0c;使用快捷键Ctrl Alt T打开的就是新安装的terminator终端&#xff0c;可以使用以下方法仍然打开ubuntu默…

IAR for STM8L标准库基于DMP库驱动MPU6050

IAR for STM8L标准库基于DMP库驱动MPU6050 ✨移植到STM8上&#xff0c;主要对接的是I2C对应的接口函数&#xff0c;也没有什么难度&#xff0c;该型号目前不属于新设计推荐的型号了&#xff0c;如果使用DMP库最好还是需要配合磁力计才能输出比较稳定的数据&#xff0c;使用MPU9…

【App管理04-Bug修正 Objective-C语言】

一、咱们刚才已经把这个给大家做完了吧 1.这个Label怎么显示到上面去了, 我们现在是把它加到我们的控制器的View里面吧 我们看一下这个坐标是怎么算的,来,我们找一个坐标, 咱们的坐标,是不是用这个View的frame,减的吧 来,咱们在这里,输出一下这个Frame,看一下吧 在…

idea的Plugins中搜索不到插件

1、ctrlalts 打开设置 ; 2、搜索框输入plugins; 3、点击plugins; 4、点齿轮按钮&#xff0c;选择HTTP Proxy settings; 如下操作&#xff1a; 5、刷新DNS&#xff0c;ipconfig /flushdns 6、重新打开idea 的plugins 插件列表出来了

EC200U-CN学习(五)

预留抓取CP日志 U15-B引脚学习 LOUDSPK_P: 一个连接到一个功率放大器的音频输出引脚。 MIC_P: MIC_P是一个用于连接麦克风的阳性引脚。通常情况下&#xff0c;MIC_P引脚用于连接麦克风的正极&#xff0c;而MIC_N引脚则用于连接麦克风的负极。麦克风的正极和负极之间通常需要连…

springboot+mybatis-plus+vue+element+vant2实现短视频网站,模拟西瓜视频移动端

目录 一、前言 二、管理后台 1.登录 2.登录成功&#xff0c;进入欢迎页 ​编辑 3.视频分类管理 4. 视频标签管理 5.视频管理 6.评论管理 ​编辑 7.用户管理 8.字典管理 &#xff08;类似于后端的枚举&#xff09; 9.参数管理&#xff08;富文本录入&#xff09; 10.管…

Flask get post请求

Flask get &post请求 一、环境描述二、初始化flask 程序三、get请求3.1 代码3.2 分析3.3 验证3.4 请求结果 四、post请求4.1 代码4.2 分析4.3 验证4.3.1 postman 请求头application/json参数4.3.2 postman 请求头application/x-www-form-urlencoded参数4.3.3 postman 请求头…

【ESP32】Espressif-IDE及ESP-IDF安装

一、下载Espressif-IDE 2.10.0 with ESP-IDF v5.0.2 1.打开ESP-IDF 编程指南 2.点击快速入门–>安装–>手动安装–>Windows Installer–>Windows Installer Download 3.点击下载Espressif-IDE 2.10.0 with ESP-IDF v5.0.2 二、安装Espressif-IDE 2.10.0 wit…

Docker 全栈体系(七)

Docker 体系&#xff08;高级篇&#xff09; 五、Docker-compose容器编排 1. 是什么 Compose 是 Docker 公司推出的一个工具软件&#xff0c;可以管理多个 Docker 容器组成一个应用。你需要定义一个 YAML 格式的配置文件docker-compose.yml&#xff0c;写好多个容器之间的调…

4、Linux驱动开发:设备-设备号设备号注册

目录 &#x1f345;点击这里查看所有博文 随着自己工作的进行&#xff0c;接触到的技术栈也越来越多。给我一个很直观的感受就是&#xff0c;某一项技术/经验在刚开始接触的时候都记得很清楚。往往过了几个月都会忘记的差不多了&#xff0c;只有经常会用到的东西才有可能真正记…

Unity游戏源码分享-ARPG游戏Darklight.rar

Unity游戏源码分享-ARPG游戏Darklight.rar 玩法 项目地址&#xff1a;https://download.csdn.net/download/Highning0007/88105464

1.jquery遍历数组2.layui框架的理解

1.jquery遍历数组 2.layui框架的理解 layui.use(["form", "laydate"], 是使用 layui 框架中的模块加载方法 use 来加载并使用 form 和 laydate 这两个模块。 在 layui 框架中&#xff0c;可以使用 use 方法来加载所需的模块&#xff0c;然后使用这些模块…

[CrackMe]Chafe.1.exe的逆向及注册机编写

上手先试一下, 发现其没有对话框, 只有字符串, 搜索"Your serial is not valid"字符串 \ 上来就直接发现关键跳转, 难道这题这么简单吗? 仔细一看实际上远远要复杂 往上一翻发现没有生成serial key的代码, 而是看到了一个SetTimer, 时间间隔设置成了1ms, 之前输入…

访问:http://localhost:8070/actuator/bus-refresh 问题

1、请求发送不出去 原因&#xff1a; 自己 config-server端 application.yml 配置的端口号是8888&#xff0c;访问server修改为配置的端口号 2、请求报错405 几个解决办法&#xff1a; 1、版本问题变为busrefresh 2、bus-refresh加单引号或双引号尝试 3、加配置尝试&#xff1a…

嵌入式_GD32看门狗配置

嵌入式_GD32独立看门狗配置与注意事项 文章目录 嵌入式_GD32独立看门狗配置与注意事项前言一、什么是独立看门狗定时器&#xff08;FWDGT&#xff09;二、独立看门狗定时器原理三、独立看门狗定时器配置过程与注意事项总结 前言 使用GD3单片机时&#xff0c;为了提供了更高的安…

服务器数据恢复-Windows服务器RAID5数据恢复案例

服务器数据恢复环境&#xff1a; 一台服务器挂载三台IBM某型号存储设备&#xff0c;共64块SAS硬盘&#xff0c;组建RAID5磁盘阵列&#xff1b; 服务器操作系统&#xff1a;Windows Server&#xff1b;文件系统&#xff1a;NTFS。 服务器故障&#xff1a; 一台存储中的一块硬盘离…