LLM - FastAPI 搭建简易问答 Server

news2025/1/10 21:47:54

目录

一.引言

二.辅助函数

1.黑名单

2.清除函数

三.模型函数

1.加载模型

2.生成配置

四.服务部署

1.post - predict

2.get - clean_cache

3.main - run_app

五.总结


一.引言

SFT workflow 微调工作流程 一文中我们介绍了模型微调从数据到最终应用的流程

FastAPI 实现 get、post 请求 一文中我们介绍了如何使用 FastAPI 搭建简易接口

结合以上两者,我们使用 FastAPI 搭建一个简易的问答 Server。

二.辅助函数

1.黑名单

def check_sentence(_sentence, _blacklist=[]):
    """
    检查句子中是否包含黑名单中的单词

    参数:
    sentence (str): 待检查的句子
    blacklist (list): 黑名单单词列表

    返回:
    bool: 如果句子中不包含黑名单中的单词,则返回 True,否则返回 False
    """
    for word in _blacklist:
        if word in _sentence:
            return False
    return True

黑名单的逻辑很简单,遍历 sentence 中的 word 是否在自定义提供的 blacklist 中即可,这里主要是保证服务生成的句子不包含敏感和违法词汇,确保服务的安全性。

2.清除函数

def clean_sentence(_sentence):
    # 删除网页链接
    text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '', _sentence)

    # 删除@提及和#话题标签
    text = re.sub(r'\@\w+|\#', '', text)

    # 删除标点符号和特殊字符
    text = re.sub(r'[%s]' % re.escape(punctuation), '', text)

    # 去除 \r\s\n\t
    text = re.sub(r'\\r|\\s|\\n|\\t|\r|\s|\n|\t', '', text)

    # 合并正文中过多的空格
    text = re.sub(r'\s+', ' ', text)

    # 去除\u200b字符
    text = text.replace('\u200b', '')

    return text

黑名单逻辑保证生成句子的安全性,清除函数保证生成句子的合理性,这里是几个常用逻辑,大家有可以根据自己场景的需求和模型生成句子的特点进行修改:

删除网页链接

删除@与话题词

删除标点符号与特殊字符

删除 \t \n 等转移符号

删除过多空格

三.模型函数

1.加载模型

def load_lora_model(model_path, ckpt_path, compute_type=torch.bfloat16):
    st = time.time()

    # 载入预训练模型与 Tokenizer
    config_kwargs = {
        "trust_remote_code": True,
        "cache_dir": None,
        "revision": 'main',
        "use_auth_token": None,
    }
    # 载入预训练模型
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="left", **config_kwargs)
    config = AutoConfig.from_pretrained(model_path, **config_kwargs)

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        config=config,
        torch_dtype=compute_type,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        revision='main'
    )

    model = PeftModel.from_pretrained(model, ckpt_path)
    model = model.merge_and_unload()

    # 修正模型参数
    model.requires_grad_(False)
    # 精度减半[cast from fp32 to fp16] export 模型预测时不 half
    model = model.half() if model.config.torch_dtype == torch.float32 else model

    if torch.cuda.device_count() > 1:
        from accelerate import dispatch_model
        from accelerate.utils import infer_auto_device_map, get_balanced_memory
        device_map = infer_auto_device_map(model, max_memory=get_balanced_memory(model))
        model = dispatch_model(model, device_map)
        print('multi GPU predict => {}'.format(device_map))
    else:
        model = model.cuda()
        print("single GPU predict")
    print('config = ', model.config)

    end = time.time()
    print('time cost: {}'.format(end - st))

    return model, tokenizer

◆ 加载 LoRA 模型: LoRA 模型合并与保存

由于博主使用 LoRA 微调后的模型,所以涉及到加载 LoRA 参数,如果有完整的模型文件,忽略最后两行 PeftModel 和 merge_and_unload 方法即可。

多卡加载: ​​​​​​多卡加载与推理测试

多卡加载适合单卡内存不足以支持服务部署,或者希望多卡可以加速推理的情况。以 P40 为例,13B 模型使用 2 张 P40 部署服务,而 LLaMA-33B 则需要 4 张 P40。以 A800 为例,可以部署 2 x 13B 模型 + 1 x 7B 模型,或者单独部署一个 33B 模型。这里如果卡资源比较富裕,忽略即可。

量化加载: Model Load_in_8bit

量化加载对应 QLoRA,博主给出的示例并未使用 QLoRA,所以没有相关量化的步骤,有需要的同学可以参考上文,也可以到对应模型的 Git 界面,一些新模型内置简单易用的量化 API,可以更便捷的实现 LoRA。

2.生成配置

# 获取生成配置
def init_generation_args():
    gen_conf = {
        'do_sample': True,
        'temperature': 0.95,
        'top_p': 0.7,
        'top_k': 50,
        'num_beams': 1,
        'max_new_tokens': 512,
        'repetition_penalty': 1.0,
        'length_penalty': 1.0
    }
    return gen_conf

批量推理: model batch generate 生成文本

服务生成的配置,为了统一写到一个函数里,如果想要动态控制,也可以通过 post 方法传参。更多参数含义与批量生成的方法可以参考上面的链接。

四.服务部署

1.post - predict

@app.post("/predict")
async def predict(request: Request):
    now = datetime.datetime.now()
    print('TIME: {} 开始预测 ...'.format(now.strftime("%Y-%m-%d %H:%M:%S")))
    start = time.time()
    js = await request.json()
    req = js['question']
    template = (
        "{}"
    )
    query = template.format(req)
    response = ''
    i = 0
    while len(re.findall(u"([\u4e00-\u9fa5])", response)) < 6:
        input_ids = tokenizer(query, return_tensors="pt")['input_ids'].to(model.device)
        output_ids = model.generate(input_ids, gen_config)
        input_id_token_num = input_ids[0].shape[0]
        response = tokenizer.decode(output_ids[0][input_id_token_num:], skip_special_tokens=True)
        i += 1
        if i > 5:
            break
    end = time.time()
    cost = end - start
    print('问:{}=>答:{}'.format(req, response))
    if not check_sentence(response):
        response = ''
    print('time cost: {}'.format(cost))
    return {'question': req, 'result': clean_sentence(response)}

文章顶部的链接介绍了如何实现简单的 get 和 post 请求,由于 LLM 模型语言生成时需要传入对应的 query,所以我们的推理方法需要使用 post 请求。

Template

    req = js['question']
    template = (
        "{}"
    )
    query = template.format(req)

上面给了默认的 Template 即模板,这里模板最好和训练时候对应的模板相对应,例如 Baichuan、LLaMA 等模型,官方都应用了不同的 Template,所以如果存在多个模型,需要注意修改正确的 Template。

Sentence Length

    while len(re.findall(u"([\u4e00-\u9fa5])", response)) < 6:
        input_ids = tokenizer(query, return_tensors="pt")['input_ids'].to(model.device)
        output_ids = model.generate(input_ids, gen_config)
        input_id_token_num = input_ids[0].shape[0]
        response = tokenizer.decode(output_ids[0][input_id_token_num:], skip_special_tokens=True)
        i += 1
        if i > 5:
            break

第一个 While 循环条件 '[\u4e00-\u9fa5]' 这个正则表达式是用来匹配所有的中文字符。[\u4e00-\u9fa5] 是一个 Unicode 范围,代表了所有的中文字符。所以 re.findall 这行代码的含义是在 response 字符串中查找并返回所有的中文字符,而循环的要求需要生成的 response 中至少包含 6 个字符,否则持续生成,这里生成的配置根据上面的 init_generation_args 函数。

CheckAndClean

    print('问:{}=>答:{}'.format(req, response))
    if not check_sentence(response):
        response = ''
    print('time cost: {}'.format(cost))
    return {'question': req, 'result': clean_sentence(response)}

如果回答命中 black_list 则此次 response = '',除此之外还需要对 response 执行 clean 操作,去除无关的符号与字符,最终以 json 的形式返回。

2.get - clean_cache

@app.get("/clean_cache")
async def clean_cache():
    import torch
    torch.cuda.empty_cache()
    print('receive clean_cache instruction...')
    return {'flag': 'success'}

empty_cache 是 PyTorch 中的一个函数,它的作用是清理当前 CUDA 设备上未使用的缓存,以释放一些 GPU 内存。在你的程序长时间运行且需要频繁地进行张量创建和移动操作时我们可以调用该方法。然而要注意的是,频繁地调用此函数可能会导致效率下降,因为清理和重新填充缓存的操作本身也需要时间。由于该方法无需传递参数,所以使用 get 请求即可。

3.main - run_app

# -*- coding: utf-8 -*-
from fastapi import FastAPI, Request
import torch
import time
import datetime
import re
from string import punctuation
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from peft import PeftModel

app = FastAPI()

... ...

if __name__ == '__main__':
    # 加载模型与生成参数
    model_path = ""
    ckpt_path = ""

    model, tokenizer = load_lora_model(model_path, ckpt_path)
    print("Finish Load Model...")

    gen_kwargs = init_generation_args()
    gen_config = GenerationConfig(**gen_kwargs)

    print('generating_args = {}'.format(gen_kwargs))
    print('gen_config = {}'.format(gen_config.to_dict()))

    import uvicorn

    uvicorn.run(app, host='0.0.0.0', port=8098)

把上面的函数添加到 ... 处,传入自己对应的模型地址和 LoRA 参数地址,uvicorn.run 运行服务即可。可以使用 post 方法调用获取模型文本生成的回答,也可以 get 方法清除内存。

五.总结

结合前面的文章,我们实现了 LLM 的训练与实践,通过服务的部署,我们可以将自己垂直领域微调得到的模型应用到自己的业务场景中,使用 LLM 的力量助力业务的扩展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1068135.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言 选择(分支)

if 语句&#xff08;分支语句/选择语句&#xff09; 结构&#xff1a; if ( expressio ) statement 如果对 expression 求值为真&#xff08;非0&#xff09;&#xff0c;则执行 statement &#xff1b;否则&#xff0c;跳过 statement 。与 while 循环一样&#xff0c…

推荐两款不错的打字练习网站~

前言 对于写论文或者编程工作&#xff0c; 打字是其中十分耗费体力的一环&#xff0c;如果学会了盲打&#xff0c;那么可以起到事倍功半的作用。为了提高工作效率&#xff0c;我特意在网路上搜寻了大量打字练习的网站&#xff0c;最终发现有两款打字网站十分不错&#xff0c;同…

论文阅读-- A simple transmit diversity technique for wireless communications

一种简单的无线通信发射分集技术 论文信息&#xff1a; Alamouti S M. A simple transmit diversity technique for wireless communications[J]. IEEE Journal on selected areas in communications, 1998, 16(8): 1451-1458. 创新性&#xff1a; 提出了一种新的发射分集方…

搭建环境遇到的坑

office2010装完没法激活&#xff0c;因为没有关闭杀毒软件和防火墙。AWTK designer编译时报这个错&#xff0c;scons按这个方法装之后就好了。 装AWTK designer后&#xff0c;打不开软件&#xff0c;总是闪退&#xff0c;装了VS后就打得开了装IAR时找不到ActivationInfo.txt&am…

abp中iquery类使用orderBy接口功能报错问题

在后端写排序时&#xff0c;当使用如下OrderBy(排序字段)时&#xff0c;只引用System.Linq时如下错误&#xff1a; 只是因为缺少一个引用&#xff1a;System.Linq.Dynamic.Core  在如下类文件中引用 System.Linq.Dynamic.Core  注意&#xff1a;切记不能删掉System.Linq的引…

STM32 CubeMX ADC采集(HAL库)

STM32 CubeMX ADC采集&#xff08;HAL库&#xff09; STM32 CubeMX STM32 CubeMX ADC采集&#xff08;HAL库&#xff09;ADC介绍ADC主要特征一、STM32 CubeMX设置二、代码部分三&#xff0c;单通道轮询采样速度总结 ADC介绍 12位ADC是一种逐次逼近型模拟数字转换器。它有多达1…

springboot整合rabbitmq入门(三)

在上一篇文章中介绍了rabbitmq的fanout模式。今天继续学习另一种模式——direct模式。这种模式是rabbitmq的最简单一种模式。 首先创建一个名为helloDirect1的对列 Configuration public class DirectRabbitConfig {Beanpublic Queue directA(){return new Queue("hell…

【翻译】NCLS: Neural Cross-Lingual Summarization

Abstract 跨语言摘要&#xff08;CLS&#xff09;是为不同语言的源文件生成特定语言摘要的任务。现有方法通常将此任务分为两个步骤&#xff1a;摘要和翻译&#xff0c;导致错误传播的问题。为了解决这个问题&#xff0c;我们首次提出了一种端到端的CLS框架&#xff0c;我们称…

淘宝商品销量数据接口,淘宝商品销量数据API接口

淘宝商品销量数据接口是淘宝开放平台提供的一种API接口&#xff0c;通过该接口&#xff0c;商家可以获取到淘宝平台上某一商品的销量数据&#xff0c;包括商品的总销量、近期销量、销售趋势等。 该接口的使用方法是&#xff0c;商家先注册淘宝开放平台账号&#xff0c;申请App…

体会jdk17对于空指针的增强

jdk17 // 可以清楚的看出来a.b.c.num中由于c是空指针&#xff0c;所以导致异常 jdk11 // 只报第6行空指针了&#xff0c;但是因为哪个变量&#xff0c;不知道

打exit_hook,如何找exit_hook的偏移

最近在打比赛的时候&#xff0c;一眼就知道咋做了&#xff0c;打exit_hook为one_gadget 但是泄露出libc之后&#xff0c;没法找到exit_hook的偏移&#xff0c;所以没能拿到一血。搜了好多&#xff0c; 总结大概就是exit_hook不是真实存在的&#xff0c;而是函数指针&#xff0c…

ansible - Role

1、简介&#xff1a; Ansible 中的角色&#xff08;Role&#xff09;是一种组织和封装Playbook的方法&#xff0c;用于管理和组织 Ansible代码。它可以将任务和配置逻辑模块化&#xff0c;以便在不同的Playbook中共享和重用。 2、通过 role 远程部署并配置 nginx (1) 准备目…

【FISCO-BCOS】十七、角色的权限控制

目录 一、角色定义 二、账户权限控制 1.委员新增、撤销与查询 2.委员权重修改 3.委员投票生效阈值修改 4. 运维新增、撤销与查询 一、角色定义 分为治理方、运维方、监管方和业务方。考虑到权责分离&#xff0c;治理方、运维方和开发方权责分离&#xff0c;角色互斥。 治理…

pytorch_神经网络构建2(数学原理)

文章目录 深层神经网络多分类深层网络反向传播算法优化算法动量算法Adam 算法 深层神经网络 分类基础理论: 交叉熵是信息论中用来衡量两个分布相似性的一种量化方式 之前讲述二分类的loss函数时我们使用公式-(y*log(y_)(1-y)*log(1-y_)进行概率计算 y表示真实值,y_表示预测值 …

MyBatisPlus(十五)分页查询

说明 MyBatisPlus 提供了分页查询的功能。 MyBatisPlus 的分页功能&#xff0c;是通过分页插件实现的。要使用分页功能&#xff0c;需要配置分页插件的拦截器。 MyBatisPlus 的分页功能&#xff0c;可以通过内置的API接口实现&#xff1b;也可以通过自定义的 mapper#method …

第七章 正交实验法用例评审bug管理流程

一、正交试验法 利用因果图来设计测试用例时,作为输入条件的原因与输出结果之间的因果关系,有时很难从软件需求规格说明中得到。往往因果关系非常庞大,以至于据此因果图得到的测试用例数目多的惊人,给软件测试带来沉重的负担,为了有效地合理地减少测试的工时与费用,可利…

合并不同门店数据-上下合并

项目背景&#xff1a;线下超市分店&#xff0c;统计产品的销售数量和销售额&#xff0c;并用透视表计算求和 merge()函数可以根据链接键横向连接两张不同表&#xff0c;concat()函数可以上下合并和左右合并2种不同的合并方式。merge()函数只能横向连接两张表&#xff0c;而con…

海信电视U8“死磕”技术,家庭影音娱乐的体验突围

最近《奥本海默》上映&#xff0c;导致我无心工作&#xff0c;沉迷抢票。因为北京能播放该片IMAX 1.43画幅的影院只有一家&#xff0c;场场爆满&#xff0c;一票难求。 这次经历也让我对科技产业多了一点思考。流媒体火爆之后&#xff0c;电影发烧友还是会去影院看电影&#xf…

企业信息查询平台:天眼销正式上线!

在团队的不断努力下&#xff0c;天眼销平台终于和大家见面了&#xff01;总所周知我们是一家数据服务提供商&#xff0c;作为西南数据交易所的数据提供商之一&#xff0c;之前主要是面向B端客户提供数据服务。现在&#xff0c;我们上线的天眼销&#xff08;tianyanxiao.)主要面…

Android ncnn-android-yolov8-seg源码解析 : 实现人像分割

1. 前言 上篇文章&#xff0c;我们已经将人像分割的ncnn-android-yolov8-seg项目运行起来了&#xff0c;后续文章我们会抽取出Demo中的核心代码&#xff0c;在自己的项目中&#xff0c;来接入人体识别和人像分割功能。 先来看下效果&#xff0c;整个图像的是相机的原图&#…