ChatGLM2-6B Lora 微调训练医疗问答任务

news2024/10/6 16:22:03

一、ChatGLM2-6B Lora 微调

LoRA 微调技术的思想很简单,在原始 PLM (Pre-trained Language Model) 增加一个旁路,一般是在 transformer 层,做一个降维再升维的操作,模型的输入输出维度不变,来模拟 intrinsic rank,如下图的 AB。训练时冻结 PLM 的参数,只训练 AB ,,输出时将旁路输出与 PLM 的参数叠加,进而影响原始模型的效果。该方式,可以大大降低训练的参数量,而性能可以优于其它参数高效微调方法,甚至和全参数微调(Fine-Tuning)持平甚至超过。

对于 AB 参数的初始化,A 使用随机高斯分布,B 使用 0 矩阵,这样在最初时可以保证旁路为一个 0 矩阵,最开始时使用原始模型的能力。

在这里插入图片描述
对于 lora 微调的实现可以使用 HuggingFace 开源的 PEFT 库,地址如下:

https://github.com/huggingface/peft

下载依赖:

pip install peft -i https://pypi.tuna.tsinghua.edu.cn/simple

使用方式也很简单,例如先查看 ChatGLM2-6B 的模型结构:

from transformers import AutoModel

model_name = "chatglm-6b"
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
print(model)

输出结果:

ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (embedding): Embedding(
      (word_embeddings): Embedding(65024, 4096)
    )
    (rotary_pos_emb): RotaryEmbedding()
    (encoder): GLMTransformer(
      (layers): ModuleList(
        (0-27): 28 x GLMBlock(
          (input_layernorm): RMSNorm()
          (self_attention): SelfAttention(
            (query_key_value): Linear(in_features=4096, out_features=4608, bias=True)
            (core_attention): CoreAttention(
              (attention_dropout): Dropout(p=0.0, inplace=False)
            )
            (dense): Linear(in_features=4096, out_features=4096, bias=False)
          )
          (post_attention_layernorm): RMSNorm()
          (mlp): MLP(
            (dense_h_to_4h): Linear(in_features=4096, out_features=27392, bias=False)
            (dense_4h_to_h): Linear(in_features=13696, out_features=4096, bias=False)
          )
        )
      )
      (final_layernorm): RMSNorm()
    )
    (output_layer): Linear(in_features=4096, out_features=65024, bias=False)
  )
)

可以看出 ChatGLM 主要由 28 层的 GLMBlock 进行提取和理解语义特征,下面借助 PEFT 库将 Lora 旁路层注入到模型中,主要关注下 query_key_value 层的变化:

from transformers import AutoTokenizer, AutoModel, AutoConfig
from peft import LoraConfig, get_peft_model, TaskType

model_name = "chatglm-6b"
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

config = LoraConfig(
    peft_type="LORA",
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    fan_in_fan_out=False,
    bias='lora_only',
    target_modules=["query_key_value"]
)

model = get_peft_model(model, config)
print(model)

其中 r 就是 lora 中秩的大小。

输出结果:

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): ChatGLMForConditionalGeneration(
      (transformer): ChatGLMModel(
        (embedding): Embedding(
          (word_embeddings): Embedding(65024, 4096)
        )
        (rotary_pos_emb): RotaryEmbedding()
        (encoder): GLMTransformer(
          (layers): ModuleList(
            (0-27): 28 x GLMBlock(
              (input_layernorm): RMSNorm()
              (self_attention): SelfAttention(
                (query_key_value): Linear(
                  in_features=4096, out_features=4608, bias=True
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.1, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=4096, out_features=8, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=8, out_features=4608, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
                  (lora_embedding_B): ParameterDict()
                )
                (core_attention): CoreAttention(
                  (attention_dropout): Dropout(p=0.0, inplace=False)
                )
                (dense): Linear(in_features=4096, out_features=4096, bias=False)
              )
              (post_attention_layernorm): RMSNorm()
              (mlp): MLP(
                (dense_h_to_4h): Linear(in_features=4096, out_features=27392, bias=False)
                (dense_4h_to_h): Linear(in_features=13696, out_features=4096, bias=False)
              )
            )
          )
          (final_layernorm): RMSNorm()
        )
        (output_layer): Linear(in_features=4096, out_features=65024, bias=False)
      )
    )
  )
)

可以对比下原始的 ChatGLM 模型结构, query_key_value 层中已经被加入下 loraAB 层,下面可以通过 model.print_trainable_parameters() 打印可训练的参数量:

trainable params: 2,078,720 || all params: 6,245,533,696 || trainable%: 0.03328330453698988

可以看到可训练的参数量只有 0.03328330453698988

下面依然借助前面文章使用的医疗问答数据集,在 ChatGLM2 lora 微调下的效果。

对该数据集不了解的小伙伴可以参考下面这篇文章:

ChatGLM2-6B P-Tuning v2 微调训练医疗问答任务

二、ChatGLM2-6B Lora 微调

解析数据,构建 Dataset 数据集 qa_dataset.py

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
import torch
import json
import numpy as np


class QADataset(Dataset):
    def __init__(self, data_path, tokenizer, max_source_length, max_target_length) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.max_seq_length = self.max_source_length + self.max_target_length

        self.data = []
        with open(data_path, "r", encoding='utf-8') as f:
            for line in f:
                if not line or line == "":
                    continue
                json_line = json.loads(line)
                content = json_line["content"]
                summary = json_line["summary"]
                self.data.append({
                    "question": content,
                    "answer": summary
                })
        print("data load , size:", len(self.data))
    def preprocess(self, question, answer):
        prompt = self.tokenizer.build_prompt(question, None)

        a_ids = self.tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
                                      max_length=self.max_source_length)

        b_ids = self.tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
                                      max_length=self.max_target_length)

        context_length = len(a_ids)
        input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]
        labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]

        pad_len = self.max_seq_length - len(input_ids)
        input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
        labels = labels + [self.tokenizer.pad_token_id] * pad_len
        labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]
        return input_ids, labels

    def __getitem__(self, index):
        item_data = self.data[index]

        input_ids, labels = self.preprocess(**item_data)

        return {
            "input_ids": torch.LongTensor(np.array(input_ids)),
            "labels": torch.LongTensor(np.array(labels))
        }

    def __len__(self):
        return len(self.data)

构造 Lora 结构,微调训练 train_lora.py

# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from qa_dataset import QADataset
from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm
import torch
import os, time, sys


def train(epoch, model, device, loader, optimizer, gradient_accumulation_steps):
    model.train()
    time1 = time.time()
    for index, data in enumerate(tqdm(loader, file=sys.stdout, desc="Train Epoch: " + str(epoch))):
        input_ids = data['input_ids'].to(device, dtype=torch.long)
        labels = data['labels'].to(device, dtype=torch.long)

        outputs = model(
            input_ids=input_ids,
            labels=labels,
        )
        loss = outputs.loss
        # 反向传播,计算当前梯度
        loss.backward()
        # 梯度累积步数
        if (index % gradient_accumulation_steps == 0 and index != 0) or index == len(loader) - 1:
            # 更新网络参数
            optimizer.step()
            # 清空过往梯度
            optimizer.zero_grad()

        # 100轮打印一次 loss
        if index % 100 == 0 or index == len(loader) - 1:
            time2 = time.time()
            tqdm.write(
                f"{index}, epoch: {epoch} -loss: {str(loss)} ; each step's time spent: {(str(float(time2 - time1) / float(index + 0.0001)))}")


def validate(tokenizer, model, device, loader, max_length):
    model.eval()
    predictions = []
    actuals = []
    with torch.no_grad():
        for _, data in enumerate(tqdm(loader, file=sys.stdout, desc="Validation Data")):
            input_ids = data['input_ids'].to(device, dtype=torch.long)
            labels = data['labels'].to(device, dtype=torch.long)
            generated_ids = model.generate(
                input_ids=input_ids,
                max_length=max_length,
                do_sample=False,
                temperature=0
            )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
                     generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in labels]
            predictions.extend(preds)
            actuals.extend(target)
    return predictions, actuals


def main():
    model_name = "chatglm-6b"
    train_json_path = "./data/train.json"
    val_json_path = "./data/val.json"
    max_source_length = 128
    max_target_length = 512
    epochs = 5
    batch_size = 1
    lr = 1e-4
    lora_rank = 8
    lora_alpha = 32
    gradient_accumulation_steps = 16
    model_output_dir = "output"
    # 设备
    device = torch.device("cuda:0")

    # 加载分词器和模型
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

    # setup peft
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=lora_rank,
        lora_alpha=lora_alpha,
        lora_dropout=0.1
    )
    model = get_peft_model(model, peft_config)
    model.is_parallelizable = True
    model.model_parallel = True
    model.print_trainable_parameters()
    # 转为半精度
    model = model.half()
    model.float()

    print("Start Load Train Data...")
    train_params = {
        "batch_size": batch_size,
        "shuffle": True,
        "num_workers": 0,
    }
    training_set = QADataset(train_json_path, tokenizer, max_source_length, max_target_length)
    training_loader = DataLoader(training_set, **train_params)
    print("Start Load Validation Data...")
    val_params = {
        "batch_size": batch_size,
        "shuffle": False,
        "num_workers": 0,
    }
    val_set = QADataset(val_json_path, tokenizer, max_source_length, max_target_length)
    val_loader = DataLoader(val_set, **val_params)

    optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)
    model = model.to(device)
    print("Start Training...")
    for epoch in range(epochs):
        train(epoch, model, device, training_loader, optimizer, gradient_accumulation_steps)
        print("Save Model To ", model_output_dir)
        model.save_pretrained(model_output_dir)
    # 验证
    print("Start Validation...")
    with torch.no_grad():
        predictions, actuals = validate(tokenizer, model, device, val_loader, max_target_length)
        # 验证结果存储
        final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})
        val_data_path = os.path.join(model_output_dir, "predictions.csv")
        final_df.to_csv(val_data_path)
        print("Validation Data To ", val_data_path)


if __name__ == '__main__':
    main()

开始训练:

在这里插入图片描述

等待训练结束后,可以在输出目录看到保存的模型,仅只有 lora 层的参数,所以模型比较小:

在这里插入图片描述

此时可以查看下 predictions.csv 中验证集的效果。

三、模型测试

from transformers import AutoTokenizer, AutoModel, AutoConfig
from peft import PeftConfig, PeftModel, LoraConfig, get_peft_model, TaskType
import torch


def load_lora_config(model):
    config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=["query_key_value"]
    )
    return get_peft_model(model, config)

device = torch.device("cuda:0")

model_name = "chatglm-6b"
lora_dir = "output"

model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

config = PeftConfig.from_pretrained(lora_dir)
model = PeftModel.from_pretrained(model, lora_dir)

model = model.to(device)
model.eval()

response, history = model.chat(tokenizer, "5月至今上腹靠右隐痛,右背隐痛带酸,便秘,喜睡,时有腹痛,头痛,腰酸症状?", history=[])
print("回答:", response)

输出:

在这里插入图片描述

回答: 你好,根据你的叙述,考虑是胃炎引来的。建议你平时留意饮食规律,不要吃辛辣刺激性食物,多喝热水,可以口服奥美拉唑肠溶胶囊和阿莫西林胶囊实施救治,如果效果不好,建议去医院做胃镜仔细检查。除了及时救治胃痛外,患者朋友理应始终保持愉快的心态去直面疾病,只有这样才能令得患者及时对症救治,同时要多看重自身饮食护理,多观注自身的症状变动,认为这样一定能将胃痛撵走。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1015428.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【简单教程】利用Net2FTP构建免费个人网盘,实现便捷的文件管理

文章目录 1.前言2. Net2FTP网站搭建2.1. Net2FTP下载和安装2.2. Net2FTP网页测试 3. cpolar内网穿透3.1.Cpolar云端设置3.2.Cpolar本地设置 4.公网访问测试5.结语 1.前言 文件传输可以说是互联网最主要的应用之一,特别是智能设备的大面积使用,无论是个人…

深入理解右值引用与移动语义

文章目录 写在前面1. 什么是右值,什么是左值?1.1右值引用可以引用左值吗1.2 左值引用、右值引用本身是左值还是右值?1.3 特殊的 const 左值引用 2. 右值引用与移动构造的意义3. 移动构造函数的使用4. move的实现原理5. 完美转发 写在前面 本…

Ocoya:快速创建社交媒体上的文字

【产品介绍】 名称 Ocoya 具体描述 Ocoya是一个人工智能文字创建平台,速度能提高10倍,节省高达80%的时间,在几分钟内完成内容营销、文案写作和社交媒体!可用于创建和安排社交媒体内容,减轻你的团队负担。 【团队介绍】…

KMP算法(C++)

KMP算法与BF算法不一样的在于,当主串与子串不匹配时,主串不回溯,选择了子串回溯,大大提高了运算效率。 借用了next1【】数组,让子串回溯。get_next函数求next1【】数组,get_next函数的实现难点在于下列几行…

数据库开发-MySQL基础DQL和多表设计

1. 数据库操作-DQL DQL英文全称是Data Query Language(数据查询语言),用来查询数据库表中的记录。 1.1 介绍 查询关键字:SELECT 查询操作是所有SQL语句当中最为常见,也是最为重要的操作。在一个正常的业务系统中,查询操作的使…

MC-4/11/01/400 ELAU 软件允许用户完全访问相机设置

MC-4/11/01/400 ELAU 软件允许用户完全访问相机设置 一个完整的Sentinel模具保护解决方案包括一到四台冲击式摄像机、专用红外LED照明和镜头、Sentinel软件以及所有与模压机连接的必要互连组件。摄像机支架基于磁性,可快速、安全、灵活地部署。此外,一个…

Kotlin Android中错误及异常处理最佳实践

Kotlin Android中错误及异常处理最佳实践 Kotlin在Android开发中的错误处理机制以及其优势 Kotlin具有强大的错误处理功能:Kotlin提供了强大的错误处理功能,使处理错误变得简洁而直接。这个特性帮助开发人员快速识别和解决错误,减少了调试代…

在SpringSecurity + SpringSession项目中如何实现当前在线用户的查询、剔除登录用户等操作

1、前言 在前一篇《在SpringBoot项目中整合SpringSession,基于Redis实现对Session的管理和事件监听》笔记中,已经实践了在SpringBoot SpringSecurity 项目中整合SpringSession,这里我们继续尝试如何统计当前在线用户,思路如下&am…

看好多人都在劝退学计算机,可是张雪峰又 推荐过计算机,所以计算机到底是什么样 的?

张雪峰高考四百多分,但是他现在就瞧不起400多分的学生。说难听点,六七百分的 热门专业随便报谁不会啊? 计算机专业全世界都是过剩的,今年桂林电子科技,以前还是华为的校招大学,今年 计算机2/3待业。这个世…

程序员兼职社区招募(内含技术指导)

👨‍💻作者简介:大数据专业硕士在读,CSDN人工智能领域博客专家,阿里云专家博主,专注大数据与人工智能知识分享。公众号: GoAI的学习小屋,免费分享书籍、简历、导图等资料&#xff0c…

flex:1的大坑

一、问题描述 整个类名为roomList 的大盒子设置了flex为1,与它同级的其他盒子都已经设置了宽高,但roomList 依然被内容撑开了,没有自适应 .roomList { flex: 1; } 二、原因分析 roomList的整个高度溢出,对于包裹roomList的父盒子…

pycharm安装(windows)

一、下载及安装 1.下载进入PyCharm官方下载地址: https://www.jetbrains.com/pycharm/download/ 下拉一下,直接下载社区版就行,是免费的,功能足够用了。 2.安装 (1) 找到你下载PyCharm的路径,双击.exe文件进行安装…

每日一题~合并二叉树

题目链接:617. 合并二叉树 - 力扣(LeetCode) 题目描述: 思路分析: 由图可知,当两个位置都有节点的时候,直接将两个节点的 val 相加就是结果,如果在一个位置两棵树只有一棵在此位置上…

Vim的基础操作

前言 本文将向您介绍关于vim的基础操作 基础操作 在讲配置之前,我们可以新建一个文件 .vimrc,并用vim打开在里面输入set nu 先给界面加上行数,然后shift ;输入wq退出 默认打开:命令模式 在命令模式中&#xff1a…

06乐观锁与悲观锁

乐观锁与悲观锁 悲观锁: 悲观锁比较适合插入数据,简单粗暴但是性能一般 乐观锁: 比较适合更新数据, 性能好但是成功率低(多个线程同时执行时只有一个可以执行成功),还需要访问数据库造成数据库压力过大 模拟乐观锁实现流程 第一步: 数据库中增加商品表t_product并插入一条数…

【MySQL】基础SQL语句——表的操作

文章目录 一. 创建表二. 查看表结构三. 修改表3.1 修改表名或列名3.2 插入数据3.3 添加列3.4 修改列类型3.5 删除列 四. 删除表结束语 一. 创建表 create table table_name(field1 datatype,field2 datatype...) charset 字符集 collate 校验规则 engine 存储引擎; 创建表 fiel…

07通用枚举和表的代码生成器

通用枚举 通用枚举 如果表中的有些字段值是固定的例如性别(男或女),此时我们可以使用MyBatis-Plus的通用枚举来为属性赋值 需求: 在数据库表添加字段sex 第一步: 设置枚举类型,使用EnumValue注解将注解所标识的属性值存储到数据库中 // 枚举类型只要设置getter方法 Getter …

2023/09/15 qt day1

代码实现图形化界面 #include "denglu.h" #include "ui_denglu.h" #include <QDebug> #include <QIcon> #include <QLabel> #include <QLineEdit> #include <QPushButton> denglu::denglu(QWidget *parent): QMainWindow(p…

JavaScript-promise使用+状态

Promise 什么是PromisePromise对象就是异步操作的最终完成和失败的结果&#xff1b; Promise的基本使用&#xff1a; 代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compati…