从零开发短视频电商 在AWS上SageMaker部署模型自定义日志输入和输出示例

news2025/1/31 8:23:30

从零开发短视频电商 在AWS上SageMaker部署模型自定义日志输入和输出示例

怎么部署自定义模型请看:从零开发短视频电商 在AWS上用SageMaker部署自定义模型

  • 都是huaggingface上的模型或者fine-tune后的。

为了适配jumpstart上部署的模型的http输入输出,我在自定义模型中自定义了适配的输入输出,可以做到兼容适配

code/inference.py

  • 容器的原始代码入口:https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/handler_service.py
  • 默认支持的decode和encode:https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py
  • 可以用这个在sagemaker上使用jupyterlab:https://github.com/huggingface/notebooks/blob/main/sagemaker/17_custom_inference_script/sagemaker-notebook.ipynb

我们自定义的逻辑如下

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import json
import logging
// --------- 这块
logger = logging.getLogger()
logger.setLevel(logging.INFO)
// 自定义http输入,可以适配不同的content_type ,打印输入的日志
// 源码参见下面的 preprocess
def input_fn(input_data, content_type):
    logger.info(f"laker input_data {input_data} and content_type {content_type}")
    if content_type == "application/json":
        request = json.loads(input_data)
    elif content_type == "application/x-text":
        request = {"inputs": input_data.decode('utf-8')}
    else:
        request = {"inputs": input_data} 
    logger.info(f"laker input_fn request {request} ")
    return request
// 自定义输出
def output_fn(prediction, accept):
    return encode_json(prediction)  

// 来自https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py#L102C1-L113C6
    
class _JSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif hasattr(obj, "tolist"):
            return obj.tolist()
        elif isinstance(obj, datetime.datetime):
            return obj.__str__()
        elif isinstance(obj, Image.Image):
            with BytesIO() as out:
                obj.save(out, format="PNG")
                png_string = out.getvalue()
                return base64.b64encode(png_string).decode("utf-8")
        else:
            return super(_JSONEncoder, self).default(obj)
        
def encode_json(content):
    """
    encodes json with custom `JSONEncoder`
    """
    return json.dumps(
        content,
        ensure_ascii=False,
        allow_nan=False,
        indent=None,
        cls=_JSONEncoder,
        separators=(",", ":"),
    )
// --------- 这块  end ---

# Helper: Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def model_fn(model_dir):
  # Load model from HuggingFace Hub
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
  model = AutoModel.from_pretrained(model_dir)
  return model, tokenizer

def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer
    
    # Tokenize sentences
    sentences = data.pop("inputs", data)
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    
    # return dictonary, which will be json serializable
    return {"embedding": sentence_embeddings[0].tolist()}
import logging
from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoder

logger = logging.getLogger(__name__)

 def preprocess(self, input_data, content_type, context=None):
        """
        The preprocess handler is responsible for deserializing the input data into
        an object for prediction, can handle JSON.
        The preprocess handler can be overridden for data or feature transformation.

        Args:
            input_data: the request payload serialized in the content_type format.
            content_type: the request content_type.
            context (obj): metadata on the incoming request data (default: None).

        Returns:
            decoded_input_data (dict): deserialized input_data into a Python dictonary.
        """
        # raises en error when using zero-shot-classification or table-question-answering, not possible due to nested properties
        if (
            os.environ.get("HF_TASK", None) == "zero-shot-classification"
            or os.environ.get("HF_TASK", None) == "table-question-answering"
        ) and content_type == content_types.CSV:
            raise PredictionException(
                f"content type {content_type} not support with {os.environ.get('HF_TASK', 'unknown task')}, use different content_type",
                400,
            )

        decoded_input_data = decoder_encoder.decode(input_data, content_type)
        return decoded_input_data
    
    
    

   logger.info(
        f"param1 {batch_size} and param2 {sequence_length}"
    )



def predict(self, data, model, context=None):
        """The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`
        on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.
        The predict handler can be overridden to implement the model inference.

        Args:
            data (dict): deserialized decoded_input_data returned by the input_fn
            model : Model returned by the `load` method or if it is a custom module `model_fn`.
            context (obj): metadata on the incoming request data (default: None).

        Returns:
            obj (dict): prediction result.
        """

        # pop inputs for pipeline
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", None)

        # pass inputs with all kwargs in data
        if parameters is not None:
            prediction = model(inputs, **parameters)
        else:
            prediction = model(inputs)
        return prediction

    def postprocess(self, prediction, accept, context=None):
        """
        The postprocess handler is responsible for serializing the prediction result to
        the desired accept type, can handle JSON.
        The postprocess handler can be overridden for inference response transformation.

        Args:
            prediction (dict): a prediction result from predict.
            accept (str): type which the output data needs to be serialized.
            context (obj): metadata on the incoming request data (default: None).
        Returns: output data serialized
        """
        return decoder_encoder.encode(prediction, accept)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1329676.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

和数软件:区块链技术重塑未来的分布式账本

区块链技术作为一项革命性的分布式账本技术,正在逐渐改变我们的生活方式和商业模式。它的出现为金融、供应链、物联网等领域带来了巨大的变革。本文将深入探讨区块链技术的原理、特点、应用以及未来发展趋势。 一、区块链技术原理 区块链是一种去中心化的数据库技术…

27.Java程序设计-基于Springboot的在线考试系统小程序设计与实现

1. 引言 随着数字化教育的发展,在线考试系统成为教育领域的一项重要工具。本论文旨在介绍一个基于Spring Boot框架的在线考试系统小程序的设计与实现。在线考试系统的开发旨在提高考试的效率,简化管理流程,并提供更好的用户体验。 2. 系统设…

Java之LinkedList核心源码解读

LinkedList核心源码解读 LinkedList 是一个基于双向链表实现的集合类,经常被拿来和 ArrayList 做比较 LinkedList 插入和删除元素的时间复杂度? 头部插入/删除:只需要修改头结点的指针即可完成插入/删除操作,因此时间复杂度为 O…

洛谷 NOIP2016 普及组 回文日期

这道题目本来是不难想思路的。。。。。。 然而我第一次做的时候改了蛮久才把代码完全改对,主要感觉还是不够细心,敲的时候也没注意见检查一些小错误,那么接下来不说废话,请看题干: 接下来请看输入输出的样例以及数据范…

基于Vite+Vue3 给项目引入Axios

基于ViteVue3 给项目引入Axios,方便与后端进行通信。 系列文章指路👉 系列文章-基于Vue3创建前端项目并引入、配置常用的库和工具类 文章目录 安装依赖新建src/config/config.js 用于存放常用配置进行简单封装解决跨域问题调用尝试 安装依赖 npm install axios …

YOLOv8改进 | 主干篇 | 利用MobileNetV3替换Backbone(轻量化网络结构)

一、本文介绍 本文给大家带来的改进机制是MobileNetV3,其主要改进思想集中在结合硬件感知的网络架构搜索(NAS)和NetAdapt算法,以优化移动设备CPU上的性能。它采用了新颖的架构设计,包括反转残差结构和线性瓶颈层&…

Mac电脑上soucetree账户更改

在开发公司项目的时候遇到一个问题。soucetree提示需要输入已离职员工-张三的密码。 问题:Mac电脑使用souetree,拉取仓库代码提示需要输入其他员工密码。 解决: Mac电脑 SourceTree去掉之前的账户 1、前往文件路径 /Library/Application Su…

linux 中 C++的环境搭建以及测试工具的简单介绍

文章目录 makefleCMakegdb调试 与 coredumpValgrind 内存检测gtest 单元测试 makefile 介绍 安装 : sudo apt install make makefile 的规则: 举例说明 包括:目标文件 、 依赖文件 、 生成规则 使用 : make make clean CMake : CMake是一个…

深度解析LinkedList

LinkedList是Java集合框架中List接口的实现之一,它以双向链表的形式存储元素。与传统的数组相比,链表具有更高的灵活性,特别适用于频繁的插入和删除操作。让我们从底层实现开始深入了解这个强大的数据结构。 底层数据结构 LinkedList的底层数…

LTO-3 磁带机种草终于是用上了

跑来跑去,买了不少配件,终于是把这磁带机给用上了,已经备份好了300 多 GB 的数据。 我们用了 NAS 的数据压缩功能,把需要备份的文件用 NAS 压缩成一个 Zip 文件,如果你可以 tar 的话也行。 这样传输速度更快&#xf…

Android Studio各种Gradle常见报错问题及解决方案

大家好,我是咕噜铁蛋!在开发Android应用程序时,我们可能会遇到各种Gradle错误。这些错误可能来自不同的原因,例如依赖项问题、配置错误、版本冲突等。今天我通过搜索整理了一下,在这篇文章中,我将分享一些常…

SpringIOC之BeanFactoryResolver

博主介绍:✌全网粉丝5W,全栈开发工程师,从事多年软件开发,在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战,博主也曾写过优秀论文,查重率极低,在这方面有丰富的经验…

Canal使用详解

Canal介绍 Canal是阿里巴巴开发的MySQL binlog增量订阅&消费组件,Canal是基于MySQL二进制日志的高性能数据同步系统。在阿里巴巴集团中被广泛使用,以提供可靠的低延迟增量数据管道。Canal Server能够解析MySQL Binlog并订阅数据更改,而C…

python3 数据分析项目案例,用python做数据分析案例

本篇文章给大家谈谈python3 数据分析项目案例,以及用python做数据分析案例,希望对各位有所帮助,不要忘了收藏本站喔。 目录 一丶可视化绘图案例 1.曲线图 2.柱形图 3.点线图 4.3D散点图 5. 绘制漏斗图 6. 绘制词云图 二丶包/模块使用示例 (1)…

用Python处理PDF:拆分与合并PDF文档

PDF文档在信息共享和数据保存方面被广泛使用,处理PDF文档也成为常见需求。其中,合并和拆分PDF文档能够帮助我们更有效地管理PDF文档,使文档内容分布更合理。通过合并,可以将相关文档整合成一个文件,以便更好地组织和提…

蓝桥杯2019年10月青少组Python程序设计省赛真题

1:有n个人围成一个圈,按顺序排好号然后从第一个人开始报数(从1到3报数),报到3的人退出圈子,然后继续从1到3报数,直到最后留下一个人游戏结束,问最后留下的是原来第几号输人描迹:输人一个正整数n 输出描迹:输出最后留下的是原来的第几号 [样例输人] [样例输出] 2: 3、 […

P1883 函数

题目链接 P1883 函数 思路 举例 题目中的 F ( x ) F(x) F(x) 看起来很复杂,但由于每个 f ( x ) f(x) f(x) 的二次项系数 a a a 都不是负数,故 F ( x ) F(x) F(x) 是一个单谷函数。直接说出结论可能有些令人难以接受,不妨举出两个例子…

mysql忘记了密码

1.查找mysql的配置文件 find / -name my.cnf 2.编辑my.cnf vim /etc/my.cnf 3. 在最后一行添加skip-grant-tables跳过密码校验 4.检查mysql服务是否已正常启动 service mysqld status 5.修改完配置重启服务 systemctl restart msyqld 6.键入 msyql直接进入mysql mysql 7.进入my…

Qt制作定时关机小程序

文章目录 完成效果图ui界面ui样图 main函数窗口文件头文件cpp文件 引言 一般定时关机采用命令行模式&#xff0c;还需要我们计算在多久后关机&#xff0c;我们可以做一个小程序来定时关机 完成效果图 ui界面 <?xml version"1.0" encoding"UTF-8"?>…

Linux--Shell脚本应用实战

实验环境 随着业务的不断发展&#xff0c;某公司所使用的Linux服务器也越来越多。在系统管理和维护过程中&#xff0c;经 常需要编写一些实用的小脚本&#xff0c;以辅助运维工作&#xff0c;提高工作效率。 需求描述 > 编写一个名为getarp.sh的小脚本&#xff0c;记录局域…