DataWhale AI夏令营-英特尔-阿里天池LLM Hackathon

news2024/9/19 9:27:23

英特尔-阿里天池LLM Hackathon

  • 项目思路
    • 项目背景
    • 项目思路
  • Lora微调Qwen模型
  • 使用ipex_llm推理加速
  • Gradio交互

项目名称:医疗问答助手

项目思路

项目背景

在当今医疗领域,智能问答系统正在逐步成为辅助医疗诊断的重要工具。随着自然语言处理技术的发展,基于大模型的问答系统在处理复杂医疗问题时展现出了巨大的潜力。Qwen2-1.5B模型作为一个大型预训练语言模型,拥有强大的语言理解和生成能力,但在特定领域应用时,往往需要进一步的微调和优化。为了提升医疗问答系统的准确性,本项目采用了LoRA(Low-Rank Adaptation)微调方法,并通过ipex_llm框架在指定的CPU平台上进行推理加速。

项目思路

明确了项目需求之后可以将本次项目分为三个部分:Lora微调Qwen模型、使用ipex_llm在CPU上进行推理加速、使用Gradio交互。

Lora微调Qwen模型

我们本次项目的目的是完成一个医疗问答机器人,训练的首先需要收集数据,我们使用github上开源的医疗问答数据集,数据集包含了2.7w条真实的问答数据(github链接有点久远了时间我忘记了,如果有需要可以私信我我发给您)。
在这里插入图片描述
Qwen的Lora我们在之前的博客中有提到过,在这里就不细说了,详见Qwen2-1.5B微调+推理

import torch
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
from peft import LoraConfig, TaskType, get_peft_model, PeftModel

dataset = load_dataset("csv", data_files="./问答.csv", split="train")
dataset = dataset.filter(lambda x: x["answer"] is not None)
datasets = dataset.train_test_split(test_size=0.1)

tokenizer = AutoTokenizer.from_pretrained("./Qwen2-1.5B-Instruct", trust_remote_code=True)

def process_func(example):
    MAX_LENGTH = 768
    input_ids, attention_mask, labels = [], [], []
    instruction = example["question"].strip()     # query
    instruction = tokenizer(
        f"<|im_start|>system\n你是医学领域的人工助手章鱼哥<|im_end|>\n<|im_start|>user\n{example['question']}<|im_end|>\n<|im_start|>assistant\n",
        add_special_tokens=False,
    )
    response = tokenizer(f"{example['answer']}", add_special_tokens=False)        # \n response, 缺少eos token
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = (instruction["attention_mask"] + response["attention_mask"] + [1])
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

tokenized_ds = datasets['train'].map(process_func, remove_columns=['id', 'question', 'answer'])
tokenized_ts = datasets['test'].map(process_func, remove_columns=['id', 'question', 'answer'])

model = AutoModelForCausalLM.from_pretrained("./Qwen2-1.5B-Instruct", trust_remote_code=True)

config = LoraConfig(target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], modules_to_save=["post_attention_layernorm"])

model = get_peft_model(model, config)

args = TrainingArguments(
    output_dir="./law",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    logging_steps=6,
    num_train_epochs=10,
    learning_rate=1e-4,
    remove_unused_columns=False,
    save_strategy="epoch"
)
model.enable_input_require_grads()

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_ds.select(range(400)),
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

训练结束得到微调后的权重,打包下载即可。
在这里插入图片描述

使用ipex_llm推理加速

导入需要的包
ipex是Intel公司研发优化大语言模型 (LLM) 在其硬件(Intel CPU)上运行而开发的一组扩展库和工具。

import os
import torch
import time
from transformers import AutoTokenizer
from ipex_llm.transformers import AutoModelForCausalLM
from peft import PeftModel

由于实在Cpu推理可以根据核心数设置线程

# 设置OpenMP线程数为8, 优化CPU并行计算性能
os.environ["OMP_NUM_THREADS"] = "8"

# base_model_name = "qwen2chat_int4"
# model = AutoModelForCausalLM.load_low_bit(base_model_name, trust_remote_code=True)

# 加载基础模型和分词器
base_model_name = "Qwen2-1-5B-Instruct"  # 替换为你的基础模型名称
model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)

合并Lora

# 加载LoRA微调后的权重
lora_checkpoint = "./checkpoint-781"
lora_model = PeftModel.from_pretrained(model, lora_checkpoint)

输入Prompt测试

# 定义输入prompt
prompt = "头疼怎么治疗呢"

# 构建符合模型输入格式的消息列表
messages = [{"role": "user", "content": prompt}]

开启推理模式,在这部分其实有一个缺陷,就是合并Lora后的模型推理速度非常慢,大概是普通模型的五倍,欢迎有大佬能指点。

# 使用推理模式,减少内存使用并提高推理速度
with torch.inference_mode():
    # 应用聊天模板,将消息转换为模型输入格式的文本
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    # 将文本转换为模型输入张量,并移至CPU (如果使用GPU,这里应改为.to('cuda'))
    model_inputs = tokenizer([text], return_tensors="pt").to('cpu')

    st = time.time()
    # 生成回答, max_new_tokens限制生成的最大token数
    generated_ids = lora_model.generate(model_inputs.input_ids, max_new_tokens=512)
    end = time.time()

    # 初始化一个空列表,用于存储处理后的generated_ids
    processed_generated_ids = []

    # 使用zip函数同时遍历model_inputs.input_ids和generated_ids
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids):
        # 计算输入序列的长度
        input_length = len(input_ids)
        
        # 从output_ids中截取新生成的部分
        # 这是通过切片操作完成的,只保留input_length之后的部分
        new_tokens = output_ids[input_length:]
        
        # 将新生成的token添加到处理后的列表中
        processed_generated_ids.append(new_tokens)

    # 将处理后的列表赋值回generated_ids
    generated_ids = processed_generated_ids

    # 解码模型输出,转换为可读文本
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

打印推理时间和结果

    # 打印推理时间
    print(f'Inference time: {end-st:.2f} s')
    # 打印原始prompt
    print('-'*20, 'Prompt', '-'*20)
    print(text)
    # 打印模型生成的输出
    print('-'*20, 'Output', '-'*20)
    print(response)

一站式py脚本

import os
import torch
import time
from transformers import AutoTokenizer
from ipex_llm.transformers import AutoModelForCausalLM
from peft import PeftModel

# 设置OpenMP线程数为8, 优化CPU并行计算性能
os.environ["OMP_NUM_THREADS"] = "8"

# 加载基础模型和分词器
base_model_name = "Qwen2-1-5B-Instruct"  # 替换为你的基础模型名称
model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)

# 加载LoRA微调后的权重
lora_checkpoint = "./checkpoint-5000"
lora_model = PeftModel.from_pretrained(model, lora_checkpoint)

# 定义输入prompt
prompt = "头疼怎么治疗呢"

# 构建符合模型输入格式的消息列表
messages = [{"role": "user", "content": prompt}]

# 使用推理模式,减少内存使用并提高推理速度
with torch.inference_mode():
    # 应用聊天模板,将消息转换为模型输入格式的文本
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    # 将文本转换为模型输入张量,并移至CPU (如果使用GPU,这里应改为.to('cuda'))
    model_inputs = tokenizer([text], return_tensors="pt").to('cpu')

    st = time.time()
    # 生成回答, max_new_tokens限制生成的最大token数
    generated_ids = lora_model.generate(model_inputs.input_ids, max_new_tokens=512)
    end = time.time()

    # 初始化一个空列表,用于存储处理后的generated_ids
    processed_generated_ids = []

    # 使用zip函数同时遍历model_inputs.input_ids和generated_ids
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids):
        # 计算输入序列的长度
        input_length = len(input_ids)
        
        # 从output_ids中截取新生成的部分
        # 这是通过切片操作完成的,只保留input_length之后的部分
        new_tokens = output_ids[input_length:]
        
        # 将新生成的token添加到处理后的列表中
        processed_generated_ids.append(new_tokens)

    # 将处理后的列表赋值回generated_ids
    generated_ids = processed_generated_ids

    # 解码模型输出,转换为可读文本
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    # 打印推理时间
    print(f'Inference time: {end-st:.2f} s')
    # 打印原始prompt
    print('-'*20, 'Prompt', '-'*20)
    print(text)
    # 打印模型生成的输出
    print('-'*20, 'Output', '-'*20)
    print(response)

Gradio交互

Gradio是一个功能强大的Web交互页面,Gradio的特点是可以非常简单的使用几行代码实现前端的页面,在这里我只是简单的使用了比赛baseline提供的一个简单的Gradio,后续有时间我也会专门补一篇gradio使用教程。

import os
import torch
import time
from transformers import AutoTokenizer
from ipex_llm.transformers import AutoModelForCausalLM
from peft import PeftModel
import gradio as gr
from threading import Event

# 设置OpenMP线程数为8, 优化CPU并行计算性能
os.environ["OMP_NUM_THREADS"] = "8"

# 加载基础模型和分词器
base_model_name = "Qwen2-1-5B-Instruct"  # 替换为你的基础模型名称
model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)

# 加载LoRA微调后的权重
lora_checkpoint = "./checkpoint-781"
lora_model = PeftModel.from_pretrained(model, lora_checkpoint)

# 创建一个停止事件,用于控制生成过程的中断
stop_event = Event()

# 定义用户输入处理函数
def user(user_message, history):
    return "", history + [[user_message, None]]

# 定义机器人回复生成函数
def bot(history):
    stop_event.clear()
    prompt = history[-1][0]
    messages = [{"role": "user", "content": prompt}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to('cpu')

    print(f"\n用户输入: {prompt}")
    print("模型输出: ", end="", flush=True)
    start_time = time.time()

    with torch.inference_mode():
        generated_ids = lora_model.generate(model_inputs.input_ids, max_new_tokens=512)

        processed_generated_ids = []
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids):
            input_length = len(input_ids)
            new_tokens = output_ids[input_length:]
            processed_generated_ids.append(new_tokens)
        generated_ids = processed_generated_ids

        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    history[-1][1] = response
    end_time = time.time()
    print(f"\n生成完成,用时: {end_time - start_time:.2f} 秒")

    return history

def stop_generation():
    stop_event.set()

with gr.Blocks() as demo:
    gr.Markdown("# Qwen 聊天机器人")
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("清除")
    stop = gr.Button("停止生成")

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)
    stop.click(stop_generation, queue=False)

if __name__ == "__main__":
    print("启动 Gradio 界面...")
    demo.queue()
    demo.launch(root_path='/dsw-607012/proxy/7860/')

运行代码即可启动本次项目的界面,测试界面如下:
请添加图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1986984.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于STM32的智能家居灯光控制系统

目录 引言环境准备工作 硬件准备软件安装与配置系统设计 系统架构硬件连接代码实现 初始化代码灯光控制代码应用场景 智能家居灯光控制办公环境智能照明常见问题及解决方案 常见问题解决方案结论 1. 引言 随着智能家居技术的发展&#xff0c;灯光控制系统在提升家居生活品质…

尝鲜 HarmonyOS NEXT 开发环境搭建

申请好 HarmonyOS NEXT的开发套件白名单后&#xff0c;就可以下载最的开发套件了&#xff0c;最新的开发工具更新时间是2024-06-17&#xff0c;DevEcoStudio5.0-API12-x86-402。下载后是这样的&#xff1a; 我用的是 MAC PRO&#xff0c;所以下载的是 MAC 版&#xff0c;这里有…

VMware Linux 虚拟机设置了共享文件夹找不到如何解决?

如果在‌虚拟机中设置了‌共享文件夹但找不到&#xff0c;可能是因为没有正确执行挂载操作。挂载操作是将主机上的共享文件夹与虚拟机中的某个目录关联起来的步骤。 目前已经设置了共享文件夹&#xff0c;但是在Linux 上并没有找到 执行以下操作&#xff1a; mkdir /mnt/hgf…

云原生第一次作业

一、实验准备 1、准备一台rhel7的主机,并开启主机的图形 2、配置好可用IP 3、做kickstart自动安装脚本后面需要用到DHCP&#xff0c;关闭VMware DHCP功能 一、kickstart的安装和配置 安装 yum install system-config-kickstart 配置 安装httpd yum install httpd -y\n\n…

投资充电桩源码 共享充电桩投资理财源码 金融理财源码 最新理财投资源码php 投资理财网站源码

海外共享项目投资源码&#xff0c;投资充电桩源码 共享充电桩投资理财源码 金融理财源码 最新理财投资源码php 投资理财网站源码 源码下载&#xff1a;https://download.csdn.net/download/m0_66047725/89612921 更多资源下载&#xff1a;关注我。

软件测试学习笔记

测试学习 1. 测试流程2. Bug的提出什么是bugbug 的描述bug 级别 3. 测试用例的设计什么是测试用例测试用例应如何设计基于需求的设计方法等价类边界值场景法正交表法判定表法错误猜测法 4. 自动化测试回归测试自动化分类 5. 安装 webdriver-manager 和 selenium第一个web自动化…

SAP MM学习笔记 - 豆知识05 - Customer Exit 实例,MM01上定义Customer Exit 来Check评估Class

上一章讲了一些MM模块的豆知识。 - MM01中设定的安全在库和最小安全在库 - MM01/MMSC/Customize自动 扩张物料的保管场所 - MM01中定义生产订单的默认入库保管场所 - VA01受注票中设定禁止贩卖某个物料 SAP MM学习笔记 - 豆知识03 - 安全在库和最小安全在库&#xff0c;扩…

java使用opencv

一、windows安装opencv 下载地址&#xff1a;https://opencv.org/releases/ 下载后安装 本人安装目录 目录说明&#xff1a; build&#xff1a;基于windows构建 java&#xff1a;开发关注 x64、x86对应windows操作系统位数 sources&#xff1a;开源源码 二、java使用ope…

java之多线程篇

一、基本概念 1.什么是线程&#xff1f; 线程就是&#xff0c;操作系统能够进行运算调度的最小单位。它被包含在进程之中&#xff0c;是进程中的实际运作单位。简单理解就是&#xff1a;应用软件中互相独立&#xff0c;可以同时运行的功能 2.什么是多线程&#xff1f; 有了多线…

高清无水印视频素材哪里找?分享几个热门的高清无水印素材网站

一个好的短视频离不开精彩的素材&#xff0c;但高清视频素材哪里找&#xff1f;今天小编就跟大家分享五个可以下载高清无水印短视频素材的网站&#xff0c;如果你还不知道从哪里可以下载高清视频素材&#xff0c;赶紧进来看看吧&#xff01;&#xff5e; 1、稻虎网 首推的是稻…

leetcode数论(​3044. 出现频率最高的质数)-质数判断

前言 经过前期的基础训练以及部分实战练习&#xff0c;粗略掌握了各种题型的解题思路。现阶段开始专项练习。 描述 给你一个大小为 m x n 、下标从 0 开始的二维矩阵 mat 。在每个单元格&#xff0c;你可以按以下方式生成数字&#xff1a; 最多有 8 条路径可以选择&#xff1…

宝兰德JVM参数查看及优化

最近生产环境宝兰德服务总是莫名奇妙的宕掉&#xff0c;很是搞人心态&#xff08;幸好是集群服务器多&#xff0c;总有一台提供服务&#xff09;&#xff0c;初步排查是内存溢出导致&#xff0c;需要进行宝兰德JVM进行调整 调整宝兰德&#xff08;BES&#xff09;JVM参数通常涉…

Spring源码解析(29)之AOP动态代理对象创建过程分析

一、前言 在上一节中我们已经介绍了在createBean过程中去执行AspectJAutoProxyCreator的after方法&#xff0c;然后去获取当前bean适配的advisor&#xff0c;如果还不熟悉的可以去看下之前的博客&#xff0c;接下来我们分析Spring AOP是如何创建代理对象的&#xff0c;在此之前…

38. 115.不同的子序列,583. 两个字符串的删除操作,72. 编辑距离,编辑距离总结篇

确定dp数组以及下标的含义。dp[i][j]&#xff1a;以i-1为结尾的s子序列中出现以j-1为结尾的t的个数为dp[i][j]。确定递推公式。这一类问题&#xff0c;基本是要分析两种情况&#xff1a;s[i - 1] 与 t[j - 1]相等&#xff1b;s[i - 1] 与 t[j - 1] 不相等。当s[i - 1] 与 t[j -…

【屏驱MCU】RT-Thread 文件系统接口解析

本文主要介绍【屏驱MCU】基于RT-Thread 系统的文件系统原理介绍与代码接口梳理 目录 0. 个人简介 && 授权须知1. 文件系统架构1.1 虚拟文件系统目录架构 2. menuconfig 分析3. 代码接口分析3.1 DFS框架挂载目录3.2 【FAL抽象层】分区表和设备表3.3 如何将【文件路径】挂…

计算机毕业设计PySpark+Django考研推荐系统 考研分数线预测 中公考研爬虫 混合神经网络推荐算法 考研可视化 机器学习 深度学习 大数据毕业设计

《PySparkDjango考研推荐系统》开题报告 一、研究背景与意义 1.1 研究背景 随着社会对高学历人才需求的不断增加&#xff0c;研究生入学考试&#xff08;考研&#xff09;已成为众多大学毕业生追求深造的重要途径。然而&#xff0c;考研涉及的知识面广泛且复杂&#xff0c;考…

Unity补完计划 之Tilemap

本文仅作笔记学习和分享&#xff0c;不用做任何商业用途 本文包括但不限于unity官方手册&#xff0c;unity唐老狮等教程知识&#xff0c;如有不足还请斧正 1.Tilemap 是什么 Q&#xff1a;和 SpriteShape有什么区别&#xff1f; A&#xff1a;tilemap强项在于做重的复背景&…

产品经理-​桌面端、手机端、电视端、平板端在设计上的异同(29)

在互联网产品当中,产品形态,pc网页端,客户端,安卓,苹果端,小程序端等 不同的设备,交互设计、产品设计是不一样的,面对的用户群体不一样,产品的设计,规则也是不一样的 这个考查的是PM的交互设计知识&#xff0c;需要知道一般性的交互设计原则与各端设计差异 互联网的各端产品&am…

树莓派新版本在interface options中找不到camera选项

文章目录 问题原因&#xff1a; 操作方法&#xff1a; 1.系统升级 2. 安装libcamera 3. 测试拍照 4. 拍照和视频 5. 查看图片 问题原因&#xff1a; 版本问题&#xff0c;自2023.10之后的新版本中&#xff0c;树莓派去除了原先使用的picamera库&#xff0c;所以不能通过…

Unity补完计划之 Tile Palette

1.Tile Palette Creating a Tile Palette - Unity 手册 瓦片调色板&#xff08;Tile Palette&#xff09;是 Unity 引擎中用于在瓦片地图上进行绘制的工具。它允许您选择和管理颜色、纹理和瓦片&#xff0c;以便在游戏场景中创建地图、背景和其他2D元素 说白了&#xff0c;Ti…