把Llama2封装为API服务并做一个互动网页

news2024/10/4 14:55:41

最近按照官方例子,把Llama2跑起来了测试通了,但是想封装成api服务,耗费了一些些力气

参考:https://github.com/facebookresearch/llama/pull/147/files

1. 准备的前提如下

  • 按照官方如下命令,可以运行成功
torchrun --nproc_per_node 1 example_chat_completion.py \
   --ckpt_dir llama-2-7b-chat/ \
   --tokenizer_path tokenizer.model \
   --max_seq_len 512 --max_batch_size 6
  • 使用的模型是llama-2-7b-chat

2. 第一步,增加依赖包

fastapi
uvicorn

3. 第二步,增加文件server.pyllama仓库的根目录下

from typing import Tuple
import os
import sys
import argparse
import torch
import time
import json

from pathlib import Path
from typing import List

from pydantic import BaseModel
from fastapi import FastAPI
import uvicorn
import torch.distributed as dist

from fairscale.nn.model_parallel.initialize import initialize_model_parallel

from llama import ModelArgs, Transformer, Tokenizer, Llama


parser = argparse.ArgumentParser()
parser.add_argument('--ckpt_dir', type=str, default='llama-2-7b-chat')
parser.add_argument('--tokenizer_path', type=str, default='tokenizer.model')
parser.add_argument('--max_seq_len', type=int, default=512)
parser.add_argument('--max_batch_size', type=int, default=6)

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
os.environ['WORLD_SIZE'] = '1'

app = FastAPI()

def setup_model_parallel() -> Tuple[int, int]:
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    print("world_size", world_size)
    print("loal_rank", local_rank)
    dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
    initialize_model_parallel(world_size)
    torch.cuda.set_device(local_rank)

    # seed must be the same in all processes
    torch.manual_seed(1)
    return local_rank, world_size


def load(
    ckpt_dir: str,
    tokenizer_path: str,
    local_rank: int,
    world_size: int,
    max_seq_len: int,
    max_batch_size: int,
) -> Llama:
    generator = Llama.build(
        ckpt_dir=ckpt_dir,
        tokenizer_path=tokenizer_path,
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
        model_parallel_size=1
    )
    return generator


def init_generator(
    ckpt_dir: str,
    tokenizer_path: str,
    max_seq_len: int = 512,
    max_batch_size: int = 8,
):
    local_rank, world_size = setup_model_parallel()
    if local_rank > 0:
        sys.stdout = open(os.devnull, "w")

    generator = load(
        ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
    )

    return generator


if __name__ == "__main__":
    args = parser.parse_args()
    generator = init_generator(
        args.ckpt_dir,
        args.tokenizer_path,
        args.max_seq_len,
        args.max_batch_size,
    )

    class Config(BaseModel):
        prompts: List[str]
        system_bg: List[str]
        max_gen_len: int = 510
        temperature: float = 0.6
        top_p: float = 0.9

    if dist.get_rank() == 0:
        @app.post("/llama/")
        def generate(config: Config):
            dialogs: List[Dialog] = [
                [
                    {
                        "role": "system",
                        "content": config.system_bg[0],
                    },
                    {
                        "role": "user",
                        "content": config.prompts[0],
                    }
                ],
            ]
            results = generator.chat_completion(
                dialogs,  # type: ignore
                max_gen_len=config.max_gen_len,
                temperature=config.temperature,
                top_p=config.top_p,
            )

            return {"responses": results}

        uvicorn.run(app, host="0.0.0.0", port=8042)
    else:
        while True:
            config = [None] * 4
            try:
                dist.broadcast_object_list(config)
                generator.generate(
                    config[0], max_gen_len=config[1], temperature=config[2], top_p=config[3]
                )
            except:
                pass

4. 运行测试

直接运行python sever.py即可运行成功
提供了一个post接口,具体信息为

URL:http://localhost:8042/llama

Body:
{
    "prompts":["你好,你是谁?"],
    "system_bg":["你需要用中文回答问题"]
}

其中prompts为输入内容,system_bg为给提前设定的背景

5. 做一个互动的网页

想做一个类似OpenAI那样子的对话框,继续添加依赖

streamlit

添加如下文件chatbot.py

import streamlit as st
import requests
import json


st.title("llama-2-7b-chat Bot")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# React to user input
if prompt := st.chat_input("What is up?"):
    # Display user message in chat message container
    st.chat_message("user").markdown(prompt)
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": prompt})

    url = 'http://localhost:8042/llama'
    d = {"prompts": [prompt], "system_bg": [""]}
    print(d)
    r_resp_txt = requests.post(url, data=json.dumps(d))
    r_resp_dict = json.loads(r_resp_txt.text)

    response = r_resp_dict['responses'][0]['generation']['content']

    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        st.markdown(response)
    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": response})

运行streamlit run chatbot.py,即可有如下效果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1457249.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【EI会议征稿通知】第五届信息科学与并行、分布式处理国际学术会议(ISPDS 2024)

第五届信息科学与并行、分布式处理国际学术会议(ISPDS 2024) 2024 5th International Conference on Information Science, Parallel and Distributed Systems 第五届信息科学与并行、分布式处理国际学术会议(ISPDS 2024)定于20…

查看 PyCharm 代码文件目录位置

查看 PyCharm 代码文件目录位置 1. Show in Files2. Copy PathReferences 1. Show in Files right click -> Show in Files / Show in Explorer 即可打开目录 2. Copy Path right click -> Copy Path 即可复制目录或文件路径 References [1] Yongqiang Cheng, http…

Android逆向学习(七)绕过root检测与smali修改学习

Android逆向学习(七)绕过root检测与smali修改学习 一、写在前面 这是吾爱破解正己大大教程的第五个作业,然后我的系统还是ubuntu, 这个是剩下作业的完成步骤。 二、任务目标 现在我们已经解决了一些问题,现在剩下的…

【C语言】长篇详解,字符系列篇2-----受长度限制的字符串函数,字符串函数的使用和模拟实现【图文详解】

欢迎来CILMY23的博客喔,本期系列为【【C语言】长篇详解,字符系列篇2-----“混杂”的字符串函数,字符串函数的使用和模拟实现【图文详解】,图文讲解各种字符串函数,带大家更深刻理解C语言中各种字符串函数的应用&#x…

小型医院医疗设备管理系统|基于springboot小型医院医疗设备管理系统设计与实现(源码+数据库+文档)

小型医院医疗设备管理系统目录 目录 基于springboot小型医院医疗设备管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、职员信息管理 2、设备信息管理 3、库房信息管理 4、公告信息管理 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、…

【性能测试入门必看】性能测试理论知识

一、性能测试理论知识 1、常用的七种性能测试方法 (1) 后端性能测试:其实,你平时听到的性能测试,大多数情况下指的是后端性能测试,也就是服务器端性能测试。后端性能测试,是通过性能测试工具模拟大量的并发用户请求&…

20240219画图程序

1. PTZ在惯性态时,不同视场角下的【发送】角速度和【理论响应】角速度 1.1 优化前 import numpy as np import matplotlib.pyplot as plt# PTZ在惯性态时,不同视场角下的【发送】角速度和【理论响应】角速度 ATROffset_x np.linspace(0, 60, 120) y2 …

OpenAI 全新发布文生视频模型 Sora,支持 60s 超长长度,有哪些突破?将带来哪些影响?

Sora大模型简介 OpenAI 的官方解释了在视频数据基础上进行大规模训练生成模型的方法。 我们下面会摘取其中的关键部分罗列让大家快速get重点。 喜欢钻研的伙伴可以到官网查看技术报告: https://openai.com/research/video-generation-models-as-world-simulator…

【图论经典题目讲解】CF786B - Legacy 一道线段树优化建图的经典题目

C F 786 B − L e g a c y \mathrm{CF786B - Legacy} CF786B−Legacy D e s c r i p t i o n \mathrm{Description} Description 给定 1 1 1 张 n n n 个点的有向图,初始没有边,接下来有 q q q 次操作,形式如下: 1 u v w 表示…

GO和KEGG富集分析

写在前面 我们《复现SCI文章系列教程》专栏现在是免费开放,推出这个专栏差不多半年的时间,但是由于个人的精力和时间有限,只更新了一部分。后续的更新太慢了。因此,最终考虑后还是免费开放吧,反正不是什么那么神秘的东…

关于数据结构的定义以及基本的数据结构

在计算机科学中,数据结构是指用于组织和存储数据的方式或方法。它涉及到在计算机内存中存储、管理和操作数据的技术和原则。数据结构不仅仅是简单地存储数据,还可以提供高效的数据访问和操作方式,以满足特定的需求。 以下是每个数据结构的详细…

mkcert安装教程

1、下载 官方文档:https://github.com/FiloSottile/mkcert#mkcert 下载链接:https://github.com/FiloSottile/mkcert/releases 2、安装,该文件目录下打开cmd(可以把文件复制到别的文件夹),执行命令 //命令…

开源模型应用落地-工具使用篇-向量数据库进阶(四)

一、前言 通过学习"开源模型应用落地"系列文章,我们成功地建立了一个完整可实施的AI交付流程。现在,我们要引入向量数据库,作为我们AI服务的二级缓存。本文将继续基于上一篇“开源模型应用落地-工具使用篇-向量数据库(三…

FreeRTOS移植到GD32

目录 一、GD32基础工程创建: 1、创建如下文件夹 2、在keil5创建工程 3、在工程添加相关.c文件和头文件路径 4、实例:实现LED闪烁功能 二、在基础工程添加FreeRTOS: 1、FreeRTOS中的文件: 2、添加的源文件: 3、添加的头文件路径: 4、…

机器人常用传感器分类及一般性要求

机器人传感器的分类 传感技术是先进机器人的三大要素(感知、决策和动作)之一。根据用途不同,机器人传感器可以分为两大类:用于检测机器人自身状态的内部传感器和用于检测机器人相关环境参数的外部传感器。 内部传感器 内部传感…

【JavaEE】_HTML常用标签

目录 1.HTML结构 2. HTML常用标签 2.1 注释标签 2.2 标题标签:h1~h6 2.3 段落标签:p 2.4 换行标签:br 2.5 格式化标签 2.6 图片标签:img 2.7 超链接标签:a 2.8 表格标签 2.9 列表标签 2.10 表单标签 2.10…

航班进出港|航班进出港管理系统|基于springboot航班进出港管理系统设计与实现(源码+数据库+文档)

航班进出港管理系统目录 目录 基于springboot航班进出港管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 5、航班信息管理 (1) 航班信息管理 (2)起飞降落申请管理 (3)公告管理 &…

辽宁博学优晨教育科技有限公司视频剪辑培训专业之选

随着数字时代的到来,视频剪辑技术已成为各行各业不可或缺的一项技能。为了满足市场需求,辽宁博学优晨教育科技有限公司(以下简称“博学优晨”)推出了专业的视频剪辑培训课程,旨在为广大学员提供系统、高效的学习机会。…

AMD FPGA设计优化宝典笔记(4)复位桥

高亚军老师的这本书《AMD FPGA设计优化宝典》,他主要讲了两个东西: 第一个东西是代码的良好风格; 第二个是设计收敛等的本质。 这个书的结构是一个总论,加上另外的9个优化,包含的有:时钟网络、组合逻辑、触…

面试系列之《Spark》(持续更新...)

1.job&stage&task如何划分? job:应用程序中每遇到一个action算子就会划分为一个job。 stage:一个job任务中从后往前划分,分区间每产生了shuffle也就是宽依赖则划分为一个stage,stage这体现了spark的pipeline思…