基于 Qwen2-1.5B Lora 微调训练医疗问答任务

news2025/1/16 4:02:20

一、Qwen2 Lora 微调

Qwen是阿里巴巴集团Qwen团队研发的大语言模型和大型多模态模型系列。Qwen2Qwen1.5 的重大升级。无论是语言模型还是多模态模型,均在大规模多语言和多模态数据上进行预训练,并通过高质量数据进行后期微调以贴近人类偏好。Qwen具备自然语言理解、文本生成、视觉理解、音频理解、工具使用、角色扮演、作为AI Agent进行互动等多种能力。

Qwen2有以下特点:

  • 5种模型规模,包括0.5B、1.5B、7B、57B-A14B72B
  • 针对每种尺寸提供基础模型和指令微调模型,并确保指令微调模型按照人类偏好进行校准;
  • 基础模型和指令微调模型的多语言支持;
  • 所有模型均稳定支持32K长度上下文;Qwen2-7B-InstructQwen2-72B-Instruct可支持128K上下文(需额外配置)
  • 支持工具调用、RAG(检索增强文本生成)、角色扮演、AI Agent等;

更多详细的介绍可以参考官方文档:

https://qwen.readthedocs.io/zh-cn/latest/

下面实验所使用的核心依赖版本如下:

torch==1.13.1+cu116
peft==0.12.0
transformers==4.37.0
tensorboard==2.17.1

二、构建 Qwen2-1.5B Lora 模型

LoRA 微调技术的思想很简单,在原始 PLM (Pre-trained Language Model) 增加一个旁路,一般是在 transformer 层,做一个降维再升维的操作,模型的输入输出维度不变,来模拟 intrinsic rank,如下图的 AB。训练时冻结 PLM 的参数,只训练 AB ,,输出时将旁路输出与 PLM 的参数叠加,进而影响原始模型的效果。该方式,可以大大降低训练的参数量,而性能可以优于其它参数高效微调方法,甚至和全参数微调(Fine-Tuning)持平甚至超过。

对于 AB 参数的初始化,A 使用随机高斯分布,B 使用 0 矩阵,这样在最初时可以保证旁路为一个 0 矩阵,最开始时使用原始模型的能力。

在这里插入图片描述

在构建 Qwen2-1.5B Lora 结构模型前,先了解下现在 Qwen2-1.5B 的结构:

这里直接使用 PyTorch 的模型打印方式,主要看模型的组成:

from transformers import AutoModelForCausalLM

model_path = "model/Qwen2-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
print(model)

输出结果:

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0): Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
      .
      . 省略中间结构
      .
      (27): Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
    )
    (norm): Qwen2RMSNorm()
  )
  (lm_head): Linear(in_features=1536, out_features=151936, bias=False)
)

从上面的结构可以看出 Qwen2-1.5B 的结构其实并不复杂,由 27DecoderLayer 构成,每个 Decoder 主要的核心是 self_attentionmlp,因此可以尝试在 q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj 层添加 Lora 结构,下面使用 PEFT 库实现,这里 r 使用 8lora_alpha 使用 32

from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType

model_path = "model/Qwen2-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
print(model)

输出结果:

trainable params: 9,232,384 || all params: 1,552,946,688 || trainable%: 0.5945
PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen2ForCausalLM(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 1536)
        (layers): ModuleList(
          (0): Qwen2DecoderLayer(
            (self_attn): Qwen2Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=1536, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=256, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=256, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (v_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=256, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=256, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (o_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=1536, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (rotary_emb): Qwen2RotaryEmbedding()
            )
            (mlp): Qwen2MLP(
              (gate_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=8960, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=8960, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (up_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=8960, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=8960, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (down_proj): lora.Linear(
                (base_layer): Linear(in_features=8960, out_features=1536, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=8960, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (act_fn): SiLU()
            )
            (input_layernorm): Qwen2RMSNorm()
            (post_attention_layernorm): Qwen2RMSNorm()
          )
	      .
	      . 省略中间结构
	      .  
          (27): Qwen2DecoderLayer(
            (self_attn): Qwen2Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=1536, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=256, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=256, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (v_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=256, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=256, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (o_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=1536, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (rotary_emb): Qwen2RotaryEmbedding()
            )
            (mlp): Qwen2MLP(
              (gate_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=8960, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=8960, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (up_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=8960, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=8960, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (down_proj): lora.Linear(
                (base_layer): Linear(in_features=8960, out_features=1536, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=8960, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (act_fn): SiLU()
            )
            (input_layernorm): Qwen2RMSNorm()
            (post_attention_layernorm): Qwen2RMSNorm()
          )
        )
        (norm): Qwen2RMSNorm()
      )
      (lm_head): Linear(in_features=1536, out_features=151936, bias=False)
    )
  )
)


从结果可以看出,Lora 之后在每一层都增加了一个 lora_Alora_B 结构来实现降维升维的作用。

三、准备训练数据集

数据集采用 GitHub 上的 Chinese-medical-dialogue-data 中文医疗对话数据集。

GitHub 地址如下:

https://github.com/Toyhom/Chinese-medical-dialogue-data

数据分了 6 个科目类型:

在这里插入图片描述

数据格式如下所示:

在这里插入图片描述

其中 ask 为病症的问题描述,answer 为病症的回答。

该数据集在本专栏的前面文章中,已经被使用在 ChatGLM2ChatYuan-large 模型上做过微调实验,感兴趣的小伙伴可以参考一下:

ChatGLM2-6B Lora 微调训练医疗问答任务

基于第二代 ChatGLM2-6B P-Tuning v2 微调训练医疗问答任务

ChatYuan-large-v2 微调训练 医疗问答 任务

由于整体数据比较多,这里为了演示效果,选取 内科、肿瘤科、儿科、外科 四个科目的数据进行实验,并且每个科目取前 10000 条数据进行训练、2000 条数据进行验证。

首先将数据集转为 json 格式方便后续读取:

import json
import pandas as pd

data_path = [
    "./data/Chinese-medical-dialogue-data-master/Data_数据/IM_内科/内科5000-33000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Oncology_肿瘤科/肿瘤科5-10000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Pediatric_儿科/儿科5-14000.csv",
    "./data/Chinese-medical-dialogue-data-master/Data_数据/Surgical_外科/外科5-14000.csv",
]

train_json_path = "./data/train.json"
val_json_path = "./data/val.json"
# 每个数据取 10000 条作为训练
train_size = 10000
# 每个数据取 2000 条作为验证
val_size = 2000


def main():
    train_f = open(train_json_path, "a", encoding='utf-8')
    val_f = open(val_json_path, "a", encoding='utf-8')
    for path in data_path:
        data = pd.read_csv(path, encoding='ANSI')
        train_count = 0
        val_count = 0
        for index, row in data.iterrows():
            question = row["ask"]
            answer = row["answer"]
            line = {
                "question": question,
                "answer": answer
            }
            line = json.dumps(line, ensure_ascii=False)
            if train_count < train_size:
                train_f.write(line + "\n")
                train_count = train_count + 1
            elif val_count < val_size:
                val_f.write(line + "\n")
                val_count = val_count + 1
            else:
                break
    print("数据处理完毕!")
    train_f.close()
    val_f.close()


if __name__ == '__main__':
    main()

处理之后可以看到两个生成的文件:

在这里插入图片描述

四、微调训练

解析数据,构建 Dataset 数据集

qa_dataset.py:

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
import torch
import json
import numpy as np


class QADataset(Dataset):
    def __init__(self, data_path, tokenizer, max_source_length, max_target_length) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.max_seq_length = self.max_source_length + self.max_target_length

        self.data = []
        if data_path:
            with open(data_path, "r", encoding='utf-8') as f:
                for line in f:
                    if not line or line == "":
                        continue
                    json_line = json.loads(line)
                    question = json_line["question"]
                    answer = json_line["answer"]
                    self.data.append({
                        "question": question,
                        "answer": answer
                    })
        print("data load , size:", len(self.data))

    def preprocess(self, question, answer):
        messages = [
            {"role": "system", "content": "你是一个医疗方面的专家,可以根据患者的问题进行解答。"},
            {"role": "user", "content": question}
        ]
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        instruction = self.tokenizer(prompt, add_special_tokens=False, max_length=self.max_source_length)
        response = self.tokenizer(answer, add_special_tokens=False, max_length=self.max_target_length)
        input_ids = instruction["input_ids"] + response["input_ids"] + [self.tokenizer.pad_token_id]
        attention_mask = (instruction["attention_mask"] + response["attention_mask"] + [1])
        labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [self.tokenizer.pad_token_id]
        if len(input_ids) > self.max_seq_length:
            input_ids = input_ids[:self.max_seq_length]
            attention_mask = attention_mask[:self.max_seq_length]
            labels = labels[:self.max_seq_length]
        return input_ids, attention_mask, labels

    def __getitem__(self, index):
        item_data = self.data[index]

        input_ids, attention_mask, labels = self.preprocess(**item_data)

        return {
            "input_ids": torch.LongTensor(np.array(input_ids)),
            "attention_mask": torch.LongTensor(np.array(attention_mask)),
            "labels": torch.LongTensor(np.array(labels))
        }

    def __len__(self):
        return len(self.data)

训练:

# -*- coding: utf-8 -*-
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
import pandas as pd
from qa_dataset import QADataset
from tqdm import tqdm
import os, time, sys


def train_model(model, train_loader, val_loader, optimizer, gradient_accumulation_steps,
                device, num_epochs, model_output_dir, writer):
    batch_step = 0
    for epoch in range(num_epochs):
        time1 = time.time()
        model.train()
        for index, data in enumerate(tqdm(train_loader, file=sys.stdout, desc="Train Epoch: " + str(epoch))):
            input_ids = data['input_ids'].to(device, dtype=torch.long)
            attention_mask = data['attention_mask'].to(device, dtype=torch.long)
            labels = data['labels'].to(device, dtype=torch.long)
            # 前向传播
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            # 反向传播,计算当前梯度
            loss.backward()
            # 梯度累积步数
            if (index % gradient_accumulation_steps == 0 and index != 0) or index == len(train_loader) - 1:
                # 更新网络参数
                optimizer.step()
                # 清空过往梯度
                optimizer.zero_grad()
                writer.add_scalar('Loss/train', loss, batch_step)
                batch_step += 1
            # 100轮打印一次 loss
            if index % 100 == 0 or index == len(train_loader) - 1:
                time2 = time.time()
                tqdm.write(
                    f"{index}, epoch: {epoch} -loss: {str(loss)} ; each step's time spent: {(str(float(time2 - time1) / float(index + 0.0001)))}")
        # 验证
        model.eval()
        val_loss = validate_model(model, val_loader, device)
        writer.add_scalar('Loss/val', val_loss, epoch)
        print(f"val loss: {val_loss} , epoch: {epoch}")
        print("Save Model To ", model_output_dir)
        model.save_pretrained(model_output_dir)


def validate_model(model, device, val_loader):
    running_loss = 0.0
    with torch.no_grad():
        for _, data in enumerate(tqdm(val_loader, file=sys.stdout, desc="Validation Data")):
            input_ids = data['input_ids'].to(device, dtype=torch.long)
            attention_mask = data['attention_mask'].to(device, dtype=torch.long)
            labels = data['labels'].to(device, dtype=torch.long)
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            running_loss += loss.item()
    return running_loss / len(val_loader)


def main():
    # 基础模型位置
    model_name = "model/Qwen2-1.5B-Instruct"
    # 训练集
    train_json_path = "./data/train.json"
    # 验证集
    val_json_path = "./data/val.json"
    max_source_length = 128
    max_target_length = 256
    epochs = 10
    batch_size = 1
    lr = 1e-4
    gradient_accumulation_steps = 16
    lora_rank = 8
    lora_alpha = 32
    model_output_dir = "output"
    logs_dir = "logs"
    # 设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 加载分词器和模型
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
    # setup peft
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        inference_mode=False,
        r=lora_rank,
        lora_alpha=lora_alpha,
        lora_dropout=0.1
    )
    model = get_peft_model(model, peft_config)
    model.is_parallelizable = True
    model.model_parallel = True
    model.print_trainable_parameters()
    print("Start Load Train Data...")
    train_params = {
        "batch_size": batch_size,
        "shuffle": True,
        "num_workers": 0,
    }
    training_set = QADataset(train_json_path, tokenizer, max_source_length, max_target_length)
    training_loader = DataLoader(training_set, **train_params)
    print("Start Load Validation Data...")
    val_params = {
        "batch_size": batch_size,
        "shuffle": False,
        "num_workers": 0,
    }
    val_set = QADataset(val_json_path, tokenizer, max_source_length, max_target_length)
    val_loader = DataLoader(val_set, **val_params)
    # 日志记录
    writer = SummaryWriter(logs_dir)
    # 优化器
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)
    model = model.to(device)
    # 开始训练
    print("Start Training...")
    train_model(
        model=model,
        train_loader=training_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        gradient_accumulation_steps=gradient_accumulation_steps,
        device=device,
        num_epochs=epochs,
        model_output_dir=model_output_dir,
        writer=writer
    )
    writer.close()


if __name__ == '__main__':
    main()

训练过程:

在这里插入图片描述

训练结束后,可以在 output 中看到 lora 模型:

在这里插入图片描述

五、模型测试

# -*- coding: utf-8 -*-
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

model_path = "model/Qwen2-1.5B-Instruct"
lora_dir = "output"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = PeftModel.from_pretrained(model, lora_dir)
model.to(device)

prompt = """
5月至今上腹靠右隐痛,右背隐痛带酸,便秘,喜睡,时有腹痛,头痛,腰酸症状?
"""
messages = [
    {"role": "system", "content": "你是一个医疗方面的专家,可以根据患者的问题进行解答。"},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=258)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)

在这里插入图片描述

模型回答:根据你的叙述,胃炎胆汁反流性胃炎的可能性大,建议口服奥美拉唑,吗丁啉救治,清淡易消化饮食,忌辛辣打击食物,留意歇息,不要加班除了正规救治胃痛外,患者还需要有看重护理方面,比如恰当饮食,始终保持心情愉快。与此同时患者还要留意决定一家专业医院诊病,这样才能获得良好的治疗效果。

六、模型合并

上面测试还是分开加载的基础模型和lora模型,可以将两个合并为一个,方便后续部署:

# -*- coding: utf-8 -*-
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = "model/Qwen2-1.5B-Instruct"
lora_dir = "output"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
model = PeftModel.from_pretrained(model, lora_dir).to(device)
print(model)
# 合并model, 同时保存 token
model = model.merge_and_unload()
model.save_pretrained("lora_output")
tokenizer.save_pretrained("lora_output")

合并后的结构:

在这里插入图片描述

后面就不需要再通过 PeftModel 直接加载模型既可使用:

# -*- coding: utf-8 -*-
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_path = "lora_output"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.to(device)

prompt = """
5月至今上腹靠右隐痛,右背隐痛带酸,便秘,喜睡,时有腹痛,头痛,腰酸症状?
"""
messages = [
    {"role": "system", "content": "你是一个医疗方面的专家,可以根据患者的问题进行解答。"},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=258)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2154682.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redisson分布式锁分析,可重入、可续锁(看门狗)

前言 在此说明&#xff0c;本文章不只是讲一些抽象的概念&#xff0c;而是可落地的&#xff0c;在日常工作中基本上进行修改一下便可以使用。书接上回&#xff0c;上篇自研分布式锁的文章使用是一个自己手写的一个分布式锁&#xff0c;按照JUC里面java.util.concurrent.locks.L…

Linux根文件系统构建

直接参考【正点原子】I.MX6U嵌入式Linux驱动开发指南V1.81 本文仅作为个人笔记使用&#xff0c;方便进一步记录自己的实践总结。 Linux“三巨头”已经完成了 2 个了&#xff0c;就剩最后一个 rootfs(根文件系统)了&#xff0c;本章我们就来学习一下根文件系统的组成以及如何构建…

苹果叶片病理分类数据集

苹果叶片病理分类数据集&#xff0c;数据集包括&#xff08;a&#xff09;健康叶片; (b) 苹果链格孢叶斑病&#xff1b; (c) 褐斑病&#xff1b; (d) 蛙叶斑病&#xff1b; (e) 灰斑&#xff1b; (f) 苹果花叶病&#xff1b; (g) 白粉病&#xff1b; (h) 叶片锈病&#xff1b; …

Redis数据结构之哈希表

这里的哈希表说的是value的类型是哈希表 一.相关命令 1.hset key field value 一次可以设置多个 返回值是设置成功的个数 注意&#xff0c;哈希表中的键值对&#xff0c;键是唯一的而值可以重复 所以有下面的结果&#xff1a; key中原来已经有了f1&#xff0c;所以再使用hse…

AI智能跟踪技术核心!

1. 目标检测技术 在视频序列的第一帧中&#xff0c;通过目标检测算法确定要追踪的目标对象的位置和大小。 技术实现&#xff1a;目标检测算法可以基于传统的图像处理技术&#xff0c;如颜色、纹理、形状等特征&#xff0c;也可以基于深度学习方法&#xff0c;如卷积神经网络&…

利用人工智能改变视频智能

人工智能视频分析正在将安全摄像头变成强大的传感器&#xff0c;可以改善您监控站点安全的方式。借助人工智能 (AI)&#xff0c;摄像头可以独立准确地检测威胁&#xff0c;而无需人工不断观看视频。 这并不奇怪——过去几年&#xff0c;这一直是安全行业协会 (SIA) 提出的几大…

Linux命令--05----find 日志

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 find1.语法语法&#xff1a;find 目标目录(路径) <选项> 参数 2.示例3.find 结合 xargs4.案例.* 模糊匹配 绝对路径 find 在 Linux 命令中&#xff0c;fin…

nginx架构篇(三)

文章目录 一、Nginx实现服务器端集群搭建1.1 Nginx与Tomcat部署1. 环境准备(Tomcat)2. 环境准备(Nginx) 1.2. Nginx实现动静分离1.2.1. 需求分析1.2.2. 动静实现步骤 1.3. Nginx实现Tomcat集群搭建1.4. Nginx高可用解决方案1.4.1. Keepalived1.4.2. VRRP介绍1.4.3. 环境搭建环境…

银河麒麟V10系统崩溃后的处理

银河麒麟V10系统崩溃后的处理 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 当银河麒麟桌面操作系统V10崩溃无法启动时&#xff0c;直接使用备份还原工具不可行。此时&#xff0c;应采取以下步骤&#xff1a; 进入救援模式或LiveCD&#x…

基于Java springboot+mybatis 网上商城系统

基于Java springbootmybatis 网上商城系统 一、系统介绍二、功能展示1.主页(客户)2.登陆&#xff08;客户&#xff09;3.注册&#xff08;客户&#xff09;4.购物车(客户)5.我的订单&#xff08;客户&#xff09;6.用户管理&#xff08;管理员&#xff09;7.分类管理&#xff0…

ICM20948 DMP代码详解(35)

接前一篇文章&#xff1a;ICM20948 DMP代码详解&#xff08;34&#xff09; 上一回终于解析完了inv_icm20948_initialize_lower_driver函数&#xff0c;本回回到icm20948_sensor_setup函数&#xff0c;继续往下进行解析。为了便于理解和回顾&#xff0c;再次贴出icm20948_senso…

缓存中间件Redis进阶之路二(快速安装Redis)

一、快速安装Redis “工欲善其事&#xff0c;必先利其器”&#xff0c;在代码实战之前&#xff0c;需要在本地开发环境安装好Redis。对于不同的开发环境&#xff0c;Redis的安装及配置方式也不尽相同&#xff0c;本文将以Windows开发环境为例&#xff0c;介绍Redis的快速安装与…

Kafka 为什么这么快?

Kafka 是一款性能非常优秀的消息队列&#xff0c;每秒处理的消息体量可以达到千万级别。今天来聊一聊 Kafka 高性能背后的技术原理。 1 批量发送 Kafka 收发消息都是批量进行处理的。我们看一下 Kafka 生产者发送消息的代码&#xff1a; private Future<RecordMetadata>…

【数据结构C语言】【入门】【首次万字详细解析】入门阶段数据结构可能用到的C语言知识,一章让你看懂数据结构!!!!!!!

前言&#xff1a;欢迎各位光临本博客&#xff0c;这里小编带你直接手撕入门阶段的数据结构的C语言知识&#xff0c;让你不再看见数据结构就走不动道。文章并不复杂&#xff0c;愿诸君耐其心性&#xff0c;忘却杂尘&#xff0c;道有所长&#xff01;&#xff01;&#xff01;&am…

论文阅读 - MDFEND: Multi-domain Fake News Detection

https://arxiv.org/pdf/2201.00987 目录 ABSTRACT INTRODUCTION 2 RELATED WORK 3 WEIBO21: A NEW DATASET FOR MFND 3.1 Data Collection 3.2 Domain Annotation 4 MDFEND: MULTI-DOMAIN FAKE NEWS DETECTION MODEL 4.1 Representation Extraction 4.2 Domain Gate 4.…

【Stm32】从零建立一个工程

这里我们创建“STM32F103”系列的文件&#xff0c;基于“固件库” 1.固件库获取 https://www.st.com.cn/zh/embedded-software/stm32-standard-peripheral-libraries.html 2.使用Keil创建.uvprojx文件 前提是已经下载好了“芯片对应的固件” 3.复制底层驱动代码 将固件库下的…

【图表如何自动排序】

前提&#xff1a;有这样一个表&#xff1a;第1列为姓名&#xff08;用字母A~I代替&#xff09;&#xff0c;第2列假设为销量 用到的函数&#xff1a; Sort(自动排序的区域&#xff0c;以哪一列为基础来排序&#xff0c;降序or升序&#xff09; Choosecols(选择的区域&#xf…

Spring在不同类型之间也能相互拷贝?

场景还原 日常开发中&#xff0c;我们会定义非常多的实体&#xff0c;例如VO、DTO等&#xff0c;在涉及实体类的相互转换时&#xff0c;常使用Spring提供的BeanUtils.copyProperties&#xff0c;该类虽好&#xff0c;可不能贪用。 这不在使用过程中就遇到一个大坑&#xff0c…

二十、功率放大电路

功率放大电路 1、乙类功率放大器的工作过程以及交越失真; 2、复合三极管的复合规则 3、甲乙类功率放大器的工作原理、自举过程

智谱AI:CogVideoX-2b——视频生成模型的得力工具

智谱AI&#xff1a;CogVideoX-2b——视频生成模型的得力工具 文章目录 CogVideoX 简介——它是什么&#xff1f;CogVideoX 具体部署与实践指南一、创建丹摩实例二、配置环境和依赖三、上传模型与配置文件四、开始运行五、Web UI 演示 CogVideoX 简介——它是什么&#xff1f; …