Tensorflow Serving部署推荐模型

news2025/1/4 17:08:51

Tensorflow Serving部署推荐模型

1、找到当前模型中定义的variables,并在此定义一个saver用于保存模型参数

def saveVariables(self):
    variables_dict = {}
    variables_dict[self.user_embedding.op.name] = self.user_embedding
    variables_dict[self.item_embedding.op.name] = self.item_embedding

    for v in self.reduce_dimension_layer.variables:
        variables_dict[v.op.name] = v

    self.saver = tf.train.Saver(variables_dict)

在模型的输入和输出的地方,尽量自行定义name,这样在之后的部署的时候会方便很多!

self.item_input = tf.placeholder("int32", [None, 1],name="gat_iteminput") 
self.user_input = tf.placeholder("int32", [None, 1],name="gat_userinput") 
...
self.prediction = tf.sigmoid(tf.reduce_sum(self.predict_vector, 1, keepdims=True),name="gat_predict")

2、在需要保存模型参数的地方调用save方法,一般建议在模型取到最高指标处保存模型

#此处的saver为上面模型中定义的saver
#sess即为session;weights_save_path为自定义的文件路径;global_step表示当前为第几次epoch
model.saver.save(sess, weights_save_path + '/weights', global_step=epoch)

3、最终会保存为如下图所示的文件

在这里插入图片描述

110、111、112、119、127是最近5次模型指标最高的5次记录,可以根据自己需要选择最高的记录,也可以在self.saver = tf.train.Saver(variables_dict)这里指定好保存的次数,例如保留最多两次:self.saver = tf.train.Saver(variables_dict,max_to_keep=2)

此处训练中,我得到的最高指标epoch为127,所以我使用127的weights

在这里插入图片描述

因为tensorflow serving需要saved_model的格式,所以我们需要将ckpt的格式转成savedModel格式,转换的代码如下:

import tensorflow as tf

#两个参数都是文件夹的名称,一个是ckpt文件所在文件夹,一个是之后导出的文件夹
def restore_and_save(input_checkpoint, export_path):
    checkpoint_file = tf.train.latest_checkpoint(input_checkpoint)
    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        sess = tf.Session(config=session_conf)

        with sess.as_default():
            # 载入保存好的meta graph,恢复图中变量,通过SavedModelBuilder保存可部署的模型
            saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)
            print("graph.get_name_scope()=",graph.get_name_scope())
            # for node in graph.as_graph_def().node:
            #     print(node.name)
            builder = tf.saved_model.builder.SavedModelBuilder(export_path)

            # 建立签名映射,需要包括计算图中的placeholder(ChatInputs, SegInputs, Dropout)和
            # 我们需要的结果(project/logits,crf_loss/transitions)
            """
            build_tensor_info
            建立一个基于提供的参数构造的TensorInfo protocol buffer,
            输入:tensorflow graph中的tensor;
            输出:基于提供的参数(tensor)构建的包含TensorInfo的protocol buffer

            get_operation_by_name
            通过name获取checkpoint中保存的变量,能够进行这一步的前提是在模型保存的时候给对应的变量赋予name
            """
            user_inputs = tf.saved_model.utils.build_tensor_info(
                graph.get_tensor_by_name("gat_userinput:0"))
            item_inputs = tf.saved_model.utils.build_tensor_info(
                graph.get_tensor_by_name("gat_iteminput:0"))
            prediction = tf.saved_model.utils.build_tensor_info(
                graph.get_tensor_by_name("gat_predict:0"))
            print("user_inputs=",user_inputs)
            print("item_inputs=",item_inputs)
            print("prediction=",prediction)

            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            """
            signature_constants
            SavedModel保存和恢复操作的签名常量。
            在序列标注的任务中,这里的method_name是"tensorflow/serving/predict"
            """
            # 定义模型的输入输出,建立调用接口与tensor签名之间的映射
            labeling_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={
                        "user_inputs": user_inputs,
                        "item_inputs": item_inputs
                    },
                    outputs={
                        "prediction": prediction
                    },
                    method_name="tensorflow/serving/predict"
                ))

            """
            tf.group
            创建一个将多个操作分组的操作,返回一个可以执行所有输入的操作
            """
            # legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

            """
            add_meta_graph_and_variables
            建立一个Saver来保存session中的变量,输出对应的原图的定义,这个函数假设保存的变量已经被初始化;
            对于一个SavedModelBuilder,这个API必须被调用一次来保存meta graph;
            对于后面添加的图结构,可以使用函数 add_meta_graph()来进行添加
            """
            # 建立模型名称与模型签名之间的映射
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                # 保存模型的方法名,与客户端的request.model_spec.signature_name对应
                signature_def_map={
                    tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:labeling_signature
                })

            builder.save()
            print("Build Done")


# 模型格式转换
restore_and_save("ckpt","saved_model")

最后转换成功后的文件夹目录如下:

在这里插入图片描述

下面是一些查看节点和张量的代码:

import tensorflow as tf
# from parser_predict import parse_args
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import graph_util

import os


def get_pretrain_path():
    pretrain_path= "ckpt"
    print("get_pretrain_path="+str(pretrain_path))
    return pretrain_path


def get_ckpt_file():
    return tf.train.get_checkpoint_state(os.path.dirname(get_pretrain_path() + '/checkpoint'))


def get_tensors_name():
    ckpt = get_ckpt_file()
    if ckpt and ckpt.model_checkpoint_path:
        reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in var_to_shape_map:
            print('tensor_name: ', key)
    else:
        print('wrong path')


def get_nodes_name():
    ckpt = get_ckpt_file()
    print("ckpt.model_checkpoint_path="+str(ckpt.model_checkpoint_path))
    if ckpt and ckpt.model_checkpoint_path:
        saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True)
        with tf.Session() as sess:
            saver.restore(sess, ckpt.model_checkpoint_path)
            graph_def = tf.get_default_graph().as_graph_def(add_shapes=True)
            node_list = [n.name for n in graph_def.node]
            for node in node_list:
                print('node_name: ', node)


def get_operations_name():
    ckpt = get_ckpt_file()
    if ckpt and ckpt.model_checkpoint_path:
        with tf.Session() as sess, open('../log/gat_operation_log', 'w', encoding='utf-8') as log:
            sess.run(tf.global_variables_initializer())
            tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True)
            for operation in tf.get_default_graph().get_operations():
                log.write(operation.name + '\n')


def freeze_graph( pb_file_path):
    ckpt = get_ckpt_file()
    if ckpt and ckpt.model_checkpoint_path:
        saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True)
        graph = tf.get_default_graph()
        input_graph_def = graph.as_graph_def()

        with tf.Session() as sess:
            # 恢复图并得到数据
            print(sess.run(tf.global_variables_initializer()))
            print(sess.run(tf.local_variables_initializer()))
            saver.restore(sess, ckpt.model_checkpoint_path)
            # 模型持久化,将变量值固定
            output_graph_def = graph_util.convert_variables_to_constants(
                sess=sess,
                input_graph_def=input_graph_def,
                output_node_names=['gat_predict'])  # 注意是节点名称不是张量名称

            # 保存模型
            with tf.gfile.GFile(pb_file_path, 'wb') as f:
                # 序列化输出
                f.write(output_graph_def.SerializeToString())
            # 得到当前图的操作节点
            nodes = output_graph_def.node
            print('%d ops in the final graph.' % len(nodes))


def load_freeze_graph(pb_file_path):
    with tf.Session() as sess:
        graph = load_pb(pb_file_path)

        # 定义输入的张量名称,对应网络结构的输入张量
        input_user_id = graph.get_tensor_by_name('user_embedding')

        # 定义输出的张量名称
        output_tensor_name = graph.get_tensor_by_name('item_embedding')

        # 测试读出来的模型是否正确
        # 注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
        out = sess.run(output_tensor_name, feed_dict={input_user_id: [0]})
        print('out:', out)


def load_pb(pb_file_path):
    with tf.gfile.GFile(pb_file_path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph


if __name__ == '__main__':
    # get_tensors_name()
    # get_nodes_name()
    #freeze_graph('gat.pb') #此处功能不正确,但是可以参考下
    # get_operations_name()
    # load_freeze_graph('gat.pb')#此处功能不正确,但是可以参考下

4、安装好docker,并下载tensorflow serving的镜像

在这里插入图片描述

5、编写dockerfile

from:表示使用tensorflow/serving:latest的docker基础镜像

RUN:在镜像变成容器时,可以执行的命令(直接docker run进行/挂载也是一样的,这里为了方便);此处的命令意思为创建文件夹models,并在models下创建gat文件夹,并在gat文件夹下面创建1文件夹

ADD:有两个参数,前一个参数指的是当前机器上的文件夹名称,后一个参数指的是创建出来的容器的文件夹名称。此处的命令意思是将当前机器上的gat(相对路径)文件夹下的所有东西,都复制到docker创建出来的容器中的/models/gat/1/文件夹下

ENV:设置系统变量

from tensorflow/serving:latest

RUN mkdir -p /models/gat/1
ADD gat  /models/gat/1/
ENV MODEL_NAME=gat

当前dockerfile文件名称为:GatDockerFile

6、构建镜像

#-f后面跟着dockerfile文件名称;-t后面跟着自定义的创建出来的镜像名称;冒号后面的1.0代表自定义的版本号;最后的那个点号一定要加上
docker build -f GatDockerFile -t gat:1.0 .

构建完成后,会出现下面的提示:

在这里插入图片描述

然后使用docker image ls可以看到已经成功构建了镜像:

在这里插入图片描述

7、创建容器实例:

#8508:指的是本机的端口号(此处可以更改为其他,只要本机上的这个端口号没有被占用就可以)
#8501:指的是容器中的端口号,因为tensorflow serving是通过8501提供服务的,所以此处最好不要修改
#--name:后面跟着的是自己创建的容器名称,自定义
#-d:指的是后台运行
#gat:1.0是刚才我们创建的镜像名称+版本号
docker run -p 8508:8501 --name gatContainer -d gat:1.0

输入命令后,如果出现一串很长的字符,并且没有报错,则代表创建成功,可以使用docker ps 查看

在这里插入图片描述

STATUS为Up… 说明已经启动成功

http://主机ip地址:8508/v1/models/gat
http://主机ip地址:8508/v1/models/gat/metadata

整体格式如下:

在这里插入图片描述

细节:

此处就是模型的名称:

在这里插入图片描述

此处是模型的输入和输出:

  • 输入就是:user_inputsitem_inputs
  • 输出就是:prediction

参数名是上面定义的,具体的数据格式如图所示了

在这里插入图片描述
在这里插入图片描述

可以使用postman进行测试:

http://主机IP地址:8508/v1/models/gat:predict

在这里插入图片描述

参考:

Tensorflow Serving部署推荐模型_默默然咯的博客-CSDN博客_tensorflow 推荐模型

整套流程+多模型部署,强烈建议参考这篇:https://blog.csdn.net/tianyunzqs/article/details/103842894

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/100141.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LeetCode】1971. 寻找图中是否存在路径

题目描述 有一个具有 n 个顶点的 双向 图,其中每个顶点标记从 0 到 n - 1(包含 0 和 n - 1)。图中的边用一个二维整数数组 edges 表示,其中 edges[i] [ui, vi] 表示顶点 ui 和顶点 vi 之间的双向边。 每个顶点对由 最多一条 边连…

犀牛插件开发-基础核心-技术概览-总体架构-教程

文章目录1.概述2.基础核心2.1.C Rhino 核心2.2.openNURBS2.3.C SDK3.C Stack3.1.C Plugins3.2.RhinoScript4.NET Stack4.1.C API4.2.NET Framework4.3.RhinoCommon4.4.Eto4.5.net插件4.6.Grasshopper组件4.7.Python脚本5.相关主题1.概述 《Rhinoceros》由许多层组成——用多种…

细说OA系统的繁荣发展

改革开放以来,科技发展突飞猛进,我们生活的方方面面都受到了巨大影响。随着信息化时代的到来,企业的办公方式也发生了巨大的改变,OA系统随之走进了大众的视野。细数这四十几年,OA办公系统已经由一个异想天开的想法变成…

centos7.8离线安装pg和postgis

安装包下载地址: 链接:https://pan.baidu.com/s/1MxJc-5Ws6OPTRAoC-2srJw 提取码:is2q 1.centos7.8 离线安装pg操作步骤 这里基于centos7.8空白系统操作实践写的文档,系统一致的情况下可以照搬教程操作安装,镜像为…

1.0、Hibernate-快速入门初体验

1.0、Hibernate-快速入门初体验 Hibernate 和 mybatis 一样是 ORM (Object Relation Mapping) 对象关系映射框架,将面向对象映射成面向关系。 如何使用呢? 1、导入依赖; 2、创建 Hibernate 配置文件; 3、创建实体类; 4…

Allegro172版本多人协同在线设计操作指导

Allegro172版本多人协同在线设计操作指导 Allegro升级到172版本,可以支持多人协同设计,并且实时同步,具体操作如下 首先用户需要在同一个局域网下,并且Allegro172的版本必须一致,比如都是S082的版本 第一个用户打开PCB,选择Symphony Team design 选择 Start Symphony …

2022年度总结

自我介绍 大家好,我又回来了!我在一年之前在 CSDN 写了第一篇文章,到现在也有一年时间了。这次回来呢,也是因为 CSDN 官方发的消息,让写一篇年度总结的文章。在离开的这几个月里,主要是因为工作繁忙&#…

ASO优化:总结APP被下架的5点原因

随着苹果的App Store的监管力度的不断加强,每个APP都会有被下架的风险,而对于开发者来说,APP被下架是一件很严重的事情,不仅会造成用户的流失,还会降低用户对APP 的信任。所以,我们要了解APP被下架的原因&a…

【大数据技术Spark】Spark SQL操作Dataframe、读写MySQL、Hive数据库实战(附源码)

需要源码和依赖请点赞关注收藏后评论区留言私信~~~ 一、Dataframe操作 步骤如下 1)利用IntelliJ IDEA新建一个maven工程,界面如下 2)修改pom.XML添加相关依赖包 3)在工程名处点右键,选择Open Module Settings 4&a…

整数的大小端序

在存储整数时,一般按字节为逻辑单位进行存储,有“小端序”和“大端序”之分。小端序(little-endian) 是指将表示整数的低位字节存储在内存地址的低位,高位字节存储在内存地址的高位。如果将整数 1982062410 存储至内存…

【CANN训练营第三季】2022年度第三季新手班之昇腾AI入门课

本次参加CANN训练营,本来我报名的是进阶班课程,再看一遍新手班,学习一下目前CANN的最新进展也是不错的,巩固一下。 视频课程大家可以从这里看到 (1)【CANN训练营第三季】- 昇腾AI入门课(上&am…

使用Keepalived工具实现集群节点的高可用

GreatSQL社区原创内容未经授权不得随意使用,转载请联系小编并注明来源。GreatSQL是MySQL的国产分支版本,使用上与MySQL一致。作者:蟹黄瓜子文章来源:社区投稿 1.前言 在集群当中离不开的一个词就是是高可用,用本文来…

OpenWrt + 每步科技DDNS 实现ipv6动态域名解析方法

其实好几个月前我就已经把这个动态域名设置好了,后面重新刷了系统,忘记保存,又得重新再来,这次把过程记录一下,免得下次再从头百度。 工具 刷好openWrt的路由器一个每步科技注册的域名(我为什么选择这个&…

数字电子技术(八)D/A和A/D转换

D/A和A/D转换概述D/A转换A/D转换例题练习模拟信号:在时间与数值上都连续 数字信号:在时间与数值上都离散 概述 D/A转换:数字信号——模拟信号 (D/A转换器简称DAC)A/D转换:模拟信号——数字信号 &#xff0…

修改物料编号格式及长度

修改物料编号格式及长度(OMSL) 路径:IMG--后勤常规--物料主数据--基本设置--定义物料编号的输出格式

毕业设计 - 基于JSP的合同信息管理系统【源码+论文】

文章目录前言一、项目设计1. 模块设计数据库设计2. 实现效果二、部分源码项目源码前言 今天学长向大家分享一个 java web jsp 项目: 基于JSP的合同信息管理系统 适合用于毕业设计、课程设计 一、项目设计 1. 模块设计 需求分析是从客户的需求中提取出软件系统能够帮助用户…

java互联网医院系统HIS源码带本地搭建教程

技术架构 技术框架:SpringBoot MySql MyBatis nginx Vue2.6 原生APP 运行环境:jdk8 IntelliJ IDEA maven 宝塔面板 Android Studio 文字本地搭建教程 下载源码,小皮面板安装mysql5.7数据库,创建一个新数据库,…

引力波探测,冷冻电镜研究:两项诺奖GPU功不可没

我们的日常工作固然重要,但并非每一份重要的工作都能够助力他人获得诺贝尔奖。然而,就在2017年10月,GPU 计算便两度成为了助力获得诺贝尔奖的幕后英雄。 三名美国物理学家Rainer Weiss、Barry Barish和Kip Thorne因探测到了爱因斯坦百年前预测…

从“跨域融合”到“中央计算”,这家Tier1如何率先抢跑?

全球汽车产业已经进入以智能化为主旋律的下半场竞赛,同时整车电子电气架构也在加速跨入集中式电子电气架构时代。 在这样的背景之下,智能驾驶域控制器成为了当前最大的增量市场之一,由此也带动了上游芯片、OS、中间件等域控相关软硬件产品的…

第13讲:Python列表对象中元素的删操作

文章目录1.列表元素删操作的方法2.调用remove方法一次删除一个指定的元素3.调用pop方法一次只删除一个指定索引的元素3.1.使用pop方法删除列表中索引为2的元素3.2.使用pop方法不指定索引3.3.使用pop方法指定的索引不存在时同样也会抛出错误4.使用del语句一次至少删除一个元素4.…