【2024】Datawhale AI夏令营 Task4笔记——vllm加速方式修改及llm推理参数调整上分

news2024/9/24 13:20:36

【2024】Datawhale AI夏令营 Task4笔记——vllm加速方式修改及llm推理参数调整上分

本文承接文章【2024】Datawhale AI夏令营 Task3笔记——Baseline2部分代码解读及初步上分思路,对其中vllm加速方式进行修改,推理速度获得了极大提升。另外,在延用多路投票的同时,通过调整大语言模型的参数获得了一些分数的提升。

🔴本文主要的注意点:

1、在使用vllm离线推理时,prompt信息需要装入messages并应用tokenizer的对话模板,否则回答会非常抽象。

2、llm推理参数调整对上分的帮助较小,大概在0.1左右。

一、vLLM加速方式修改

文章【2024】Datawhale AI夏令营 Task3笔记——Baseline2部分代码解读及初步上分思路中使用的vLLM加速方式是类openAI的API服务(vLLM启动的相关参数及解释可参考文章:VLLM参数解释-中文表格形式),本文使用的vLLM加速方式是离线批量推理

vLLM离线批量推理的参考文章:

Qwen-离线推理(仅实现离线推理,未实现批量)

使用vLLM和ChatGLM3-6b批量推理(实现离线批量推理,但不完全适用于本次比赛)

Using VLMs(官方文档,实现与图像相关的离线批量推理,但不完全适用于本次比赛)

本文最终使用的vLLM离线批量推理的代码如下。

1.1 引入相关包,创建LLM模型对象及tokenizer

from vllm import LLM, SamplingParams
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

device = "cuda"
model_path = './merged_model_ana_my'
llm = LLM(model_path) # 使用vllm.LLM()创建LLM对象
tokenizer = AutoTokenizer.from_pretrained(model_path) # 使用AutoTokenizer.from_pretrained()创建tokenizer

🔴注意:

1、只需要提供模型路径即可创建LLM对象。不需要另外使用类似model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float16).eval()的代码创建模型对象,这样可能会导致加载模型权重时程序被Killed或者推理时内存不足(因为创建的模型对象会占用较大的内存空间)。

2、tokenizer还可以通过如下方式创建:

device = "cuda"
model_path = './merged_model_ana_my'
llm = LLM(model_path, model_path) # 第一个model_path表示使用该路径下的model,第二个model_path表示使用该路径下的tokenizer(不再使用AutoTokenizer.from_pretrained()创建tokenizer)

这种方式似乎更加简洁,但为何最终不使用这种方式?原因在后面会提到。

1.2 修改process_datas()函数,实现(多路)离线批量推理

def process_datas(datas, MODEL_NAME):
    prompts = []
    results = []
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 设置使用第1块GPU
    
    # 获取每个问题的prompt,并将prompt信息装入messages,(关键)再应用tokenizer的对话模板
    for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
        problem = data['problem']
        for id, question in enumerate(data['questions']):
            prompt = get_prompt(
                problem, 
                question['question'], 
                question['options'],
            )
            messages = [
                {"role": "user", "content": prompt}
            ]
            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            prompts.append(text) # 将处理完成的prompt添加入prompts列表,准备输入vllm批量推理
    
    # 定义推理参数
    sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=512)
    
    # 开始推理
    # 单路投票推理
    # outputs = llm.generate(prompts, sampling_params)
    # 多路投票推理(这里通过进行三次推理,模仿多路投票的过程)
    outputs1 = llm.generate(prompts, sampling_params)
    outputs2 = llm.generate(prompts, sampling_params)
    outputs3 = llm.generate(prompts, sampling_params)

    '''
    单路投票
    '''
    # i = 0
    # for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
    #     for id, question in enumerate(data['questions']):
    #         generated_text = outputs[i].outputs[0].text
    #         i = i + 1
    #         extract_response= extract(generated_text)
    #         data['questions'][id]['answer'] = extract_response
    #         results.append(data)

    '''
    多路投票
    '''
    i = 0 # 由于outputs中存储的回答序号并不是与datas中的序号一一对应(因为一个问题背景下可能有多个问题),因此使用一个计数变量另外遍历outputs
    for data in tqdm(datas, desc="Extracting answers", total=len(datas)):
        for id, question in enumerate(data['questions']):
            # 获取每一路推理的回答文本
            generated_text1 = outputs1[i].outputs[0].text
            generated_text2 = outputs2[i].outputs[0].text
            generated_text3 = outputs3[i].outputs[0].text
            i = i + 1
            # 从文本中提取答案选项
            extract_response1, extract_response2, extract_response3 = extract(generated_text1),  extract(generated_text2),  extract(generated_text3)
            # 投票选择出现次数最多的选项作为答案
            ans = most_frequent_char(extract_response1, extract_response2, extract_response3)
            data['questions'][id]['answer'] = ans
            results.append(data)

    return results

这样修改后,在与前一篇文章同样的环境下,模型推理完成全部问题只需使用约3min30s,相较于原先的7h提升很多。造成这种差异的原因可能是原先每推理一个问题就需要启动一次vllm,启动耗时较大,因此整体速度慢。现在能够将所有问题的prompt一次性传入vllm进行离线批量推理,速度更快。

🔴注意:prompt的内容影响模型的性能。在进行推理时,如果传入的prompt没有经过messages包装、没有应用tokenizer的对话模板,推理出来的文本会非常抽象,例如对于如下问题:

{"problem": "有一群人和一些食物类型。下列是关于这些个体和食物的已知信息:\n\n1. 鸡肉是一种食物。\n2. 苹果是一种食物。\n3. 如果X吃了Y,且X活着,则Y是一种食物。\n4. Bill存活。\n5. Bill吃了花生。\n6. John吃所有食物。\n7. Sue吃所有Bill吃的食物。\n8. John喜欢所有食物。\n\n根据以上信息,回答以下选择题:", "questions": [{"question": "选择题 1:\n谁喜欢吃花生?", "options": ["Bill", "Sue", "John", "None of the above"]}], "id": "round1_test_data_000"}

它的回答是这样的(无中生有了更多选择题):
在这里插入图片描述

对其他问题,回答甚至可能是这样的:

在这里插入图片描述

可以说是非常抽象、已读乱回。

prompt经过messages包装、应用tokenizer的对话模板后就正常多了(但是这一步为什么这么关键,我也还不是很懂):
在这里插入图片描述

这也是为什么在前面要单独创建tokenizer,就是为了在后面能够对prompt应用tokenizer的对话模板。

二、llm推理参数调整上分

其实这只是一个比较低级的trick,还不涉及微调、数据集等技术(时间较短,还未来得及学习应用其他技术)。主要调整llm参数的地方就在process_datas函数中sampling_params定义的位置。

sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=512)

关于SamplingParams参数的解释可以查看文档Sampling Parameters,这里这设置了一部分推理参数:temperaturetop_prepetition_penalty

SamplingParams参数的解释可以查看文档Sampling Parameters,这里这设置了一部分推理参数:temperaturetop_prepetition_penalty

这部分是否真的能够提分还没有做对比实验(毕竟验证会消耗提交次数),但是与前一篇文章中的最高分相比,使用此篇文章的代码再次推理出答案后,得到的分数提升了0.1。而本文代码与前一篇文章的代码相比,与推理准确度有关的部分只做了这一方面的改动,vllm加速方式的改动应该不影响推理准确度,所以暂且认为这部分参数的调整有助于微小提分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1976204.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JAVA入门】Day17 - GUI

【JAVA入门】Day17 - GUI 文章目录 【JAVA入门】Day17 - GUI一、组件二、事件 GUI 即图形化界面。 一、组件 一个 Java 的图形化界面项目主要用到了下面几种组件。 Java 中最外层的窗体叫做 JFrame。Java 中最上层的菜单叫做 JMenuBar。Java 中管理文字和图片的容器叫做 JLab…

快速体验LLaMA3模型微调(超算互联网平台国产异构加速卡DCU)

序言 本文以 LLaMA-Factory 为例,在超算互联网平台SCNet上使用异构加速卡AI 显存64GB PCIE,对 Llama3-8B-Instruct 模型进行 LoRA 微调、推理和合并。 超算互联网平台 异构加速卡AI 显存64GB PCIE 一、参考资料 github仓库代码:LLaMA-Fac…

C#中的Winform基础

program 每个Windows应用程序都会有一个Program类——程序入口点 [STAThread] ----指示应用程序的COM线程模型是单线程单元(如果无此特性,无法工作) static voidMain() —— 入口 System.Windows.Forms.Application类提供一系列静态方法和…

【C++】一堆数组案例 元素逆置

所谓元素逆置就是把一堆数组的元素顺序反过来 例如一堆数组的为 1,2,3,4 那么它的逆置为 4,3,2,1 逆置过程运用赋值存储的思想,先把第一个数组存贮到一个变量中,然后把末尾数组…

开源LivePortrait,快速实现表情包自定义

最近可灵AI很火,看到网上生成的效果也很赞啊,之前发现快手可灵开源了LivePortrait,今天去玩了一下,很有意思。 比如下图官方展示效果: 这些图片开始自带表情了,主要就是通过LivePortrait来实现。 LivePor…

浏览器用户文件夹详解 - Top Sites(七)

1. TopSites简介 1.1 什么是TopSites文件? TopSites文件是Chromium浏览器中用于存储用户访问频率最高的网站信息的一个重要文件。每当用户在浏览器中访问网站时,这些信息都会被记录在TopSites文件中。通过这些记录,浏览器可以为用户提供个性…

校园抢课助手【7】-抢课接口限流

在上一节中,该接口已经接受过风控的处理,过滤掉了机器人脚本请求,剩下都是人为的下单请求。为了防止用户短时间内高频率点击抢课链接,海量请求造成服务器过载,这里使用接口限流算法。 先介绍下几种常用的接口限流策略…

脚拉脚模型笔记

脚拉脚模型 ⌈♪⌋例题: 辅助线(中点)做法: 倍长中线Rt △ △ △ 斜边中线等腰 △ △ △ 三线合一中位线 需要:两个等腰三角形,顶角互补 共__底点__ 底角需要连接 解: ∵ D Q 1 / 2 A B O…

中国人工智能最好50所大学排名-2024年最强学校名单

人工智能最强的学校包含:清华大学、上海交通大学、南京大学、西安电子科技大学、电子科技大学、中国科学技术大学、哈尔滨工业大学、华中科技大学、东南大学、浙江大学等学校。这些都是人工智能专业排名全国前十的名牌大学。 圆梦小灯塔将在下文继续为2024年高考生…

鸿蒙应用开发 DevEcoStudio 汉化

步骤 DevEcoStudio 是默认支持中文的,只是默认是关闭的,需要在已安装的插件中搜索 Chinese 关键字,然后启用并重启即可(注意:是在已安装的插件中搜索)。 1. 2. 3. 重启就行

滚珠花键:新能源汽车传动系统的核心动力传递者

在日常生活中,汽车已经成为了必不可少的交通工具,尤其是新能源汽车。而滚珠花键作为传动系统中的重要组成部分,在传动系统方面的作用不容忽视。 随着科技的不断发展,汽车行业也在不断进步,滚珠花键作为高精度的机械传动…

PE安装win11原版系统“无法创建新的分区,也找不到现有的分区”和“windows无法对计算机进行启动到下一个安装阶段”的解决办法

问题1 针对“无法创建新的分区,也找不到现有的分区”: 解决办法: 用Diskgenius等分区工具删除整个分区,不要在分区工具里新建分区,而是在安装系统选择安装磁盘的时候,直接选择这个磁盘,从而完成…

五. TensorRT API的基本使用-build-model-from-scratch

目录 前言0. 简述1. 案例运行2. 代码分析2.1 main.cpp2.2 model.cpp 3. 案例3.1 sample_conv3.2 sample_permute3.3 sample_reshape3.4 sample_batchNorm3.5 sample_cbr 4. 补充说明总结下载链接参考 前言 自动驾驶之心推出的 《CUDA与TensorRT部署实战课程》,链接。…

《学会 SpringMVC 系列 · 写入拦截器 ResponseBodyAdvice》

📢 大家好,我是 【战神刘玉栋】,有10多年的研发经验,致力于前后端技术栈的知识沉淀和传播。 💗 🌻 CSDN入驻不久,希望大家多多支持,后续会继续提升文章质量,绝不滥竽充数…

3.4数组和特殊矩阵

3.4.1数组的定义 数组是由n个相同类型的数据元素构成的有序序列 数组是线性表的推广,一个数组可以视为一个线性表 数组一旦被定义,其长度不会再改变,所以数组只会有存取元素和修改元素的操作 3.4.2数组的存储结构 多维数组 有两种映射方法:按行优先和按列优先 按行优先 …

2024 年最值得阅读的 10 个外国技术网站

从网络上数以千计的博客中挑选出最好的技术网站,并根据相关性、权威性、社交媒体关注者和新鲜度进行排名。 1. TechCrunch TechCrunch 是一家领先的科技媒体,致力于深入分析初创公司、评论新的互联网产品和发布科技新闻。该网站是科技专业人士和爱好者…

【传知代码】实体关系抽取(论文复现)

当谈论信息提取领域的最前沿时,实体关系抽取无疑是其中一颗耀眼的明星。从大数据时代的信息海洋中提炼出有意义的关系,不仅是科技进步的体现,更是人类对知识管理和智能决策迫切需求的响应。本文将探索实体关系抽取的核心技术、应用场景及其在…

域控搭建(windows 2012 R2和win10)

域控搭建 环境准备 两台windows虚拟机 主域控为:windows server2012 子域为:win10 虚拟机设置网段 Win10网络设置 Windows server2012网络设置 Windows server2012网络适配器 设置 识别成功 更改计算机名字 等待重启 Win10网络适配器 设置 识别成功 …

opencv-图像透视变换

透射变换是视角变化的结果,是指利用透视中心,像点,目标点共线的条件,按透视旋转定律使承影面(透视面)绕迹线(透视轴旋转某一角度,破坏原有的投影光束,仍能保持承影面上投影几何图形不变的变化) 它的本质将图…

QT实现步进电机控制和IMU数据读取显示

实现功能: 1.两步进电机分别使能和循环运动,可以设置循环次数、循环里分别运行的角度、旋转的速度和加减速度等等,在最下方的表格里显示发送和接收的CAN报文 2.读取水平电机当前位置和速度并画图显示,示波器暂停、缩放、滑动等功…