通义千问 1.5 -7B fine-tune验证

news2025/1/16 13:51:46

尝试对对中文数据进行finetune验证,测试模型的可优化方向。下面是代码的详细情况

代码实现

from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    GenerationConfig
)
from tqdm import tqdm
from trl import SFTTrainer
import torch
import time
import pandas as pd
import numpy as np
from huggingface_hub import interpreter_login
from datasets import Dataset, DatasetDict
from functools import partial
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import os
# 禁用权重和偏差
os.environ['WANDB_DISABLED']="true"

中文摘要相关数据

git clone https://www.modelscope.cn/datasets/DAMO_NLP/lcsts_test_set.git

data_train_pth ='../Fine-tune/data/lcsts_test_set/{}'.format('train.csv')
data_train = pd.read_csv(data_train_pth)

data_test_pth = '../Fine-tune/data/lcsts_test_set/{}'.format('test.csv')
data_test = pd.read_csv(data_test_pth)

print(data_train.shape)
print(data_test.shape)
data_train.head()

在这里插入图片描述

数据加载

#这里看到原本的训练集合很大
data_train = data_train.head(3000)

data_train = Dataset.from_pandas(data_train)
data_test = Dataset.from_pandas(data_test)

print(data_train)

模型加载

compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=False,
    )
model_name=r'G:\hugging_fase_model2\Qwen1.5-7B-Chat'
device_map = {"": 0}
original_model = AutoModelForCausalLM.from_pretrained(model_name, 
                                                      device_map=device_map,
                                                      quantization_config=bnb_config,
                                                      trust_remote_code=True,
                                                      use_auth_token=True)

tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True,padding_side="left",add_eos_token=True,add_bos_token=True,use_fast=False)
tokenizer.pad_token = tokenizer.eos_token                                                      

数据预处理

def create_prompt_formats(sample):
    """
    格式化示例的各个字段('instruction','output')
    然后使用两个换行符将它们连接起来
    :参数sample:样本字典
    """
    ROLE_PROMPT = "### 你是一个新闻工作者。"#校色说明
    INTRO_BLURB = " ###需要给文章起一个合适的标题,这里会给到已有的文章和标题,需要学习如何给文章起标题名称"#任务简介
    INSTRUCTION_KEY = "###文章内容:以下是文章的内容,"
    RESPONSE_KEY =  "### 标题名称:"
    END_KEY = "### 结束"

    role= f"\n{ROLE_PROMPT}"
    blurb = f"\n{INTRO_BLURB}"
    instruction = f"{INSTRUCTION_KEY}"
    input_context = f"{sample['text1']}" if sample["text1"] else None
    response = f"{RESPONSE_KEY}\n{sample['text2']}"
    end = f"{END_KEY}"
    
    parts = [part for part in [role,blurb, instruction, input_context, response, end] if part]

    formatted_prompt = "\n\n".join(parts)
    sample["text"] = formatted_prompt

    return sample

def get_max_length(model):
    conf = model.config
    max_length = None
    for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
        max_length = getattr(model.config, length_setting, None)
        if max_length:
            print(f"Found max lenth: {max_length}")
            break
    if not max_length:
        max_length = 1024
        print(f"Using default max length: {max_length}")
    return max_length


def preprocess_batch(batch, tokenizer, max_length):
    """
    token处理
    """
    return tokenizer(
        batch["text"],
        max_length=max_length,
        truncation=True,
    )
    
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int,seed, dataset):
    """
    格式化并标记它,以便为培训做好准备
    参数标记器(AutoTokenizer):模型标记器
    :param max_length (int):从标记器发出的标记的最大数量
    """
    
    # 在每个示例中添加提示
    print("开始数据预处理...")
    dataset = dataset.map(create_prompt_formats)#, batched=True)
    
    # 对每一批数据集&进行预处理
    _preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
    dataset = dataset.map(
        _preprocessing_function,
        batched=True,
        remove_columns=['text1', 'text2'],
    )

    # 过滤掉input_ids超过max_length的样本
    dataset = dataset.filter(lambda sample: len(sample["input_ids"]) < max_length)
    
    # 打乱数据
    dataset = dataset.shuffle(seed=seed)

    return dataset

## 预处理
max_length = get_max_length(original_model)
print(max_length)

seed=123
train_data = preprocess_dataset(tokenizer, max_length,seed, data_train)
eval_data = preprocess_dataset(tokenizer, max_length,seed, data_test)

模型训练

output_dir = f'./QWEn_peft-dialogue-summary-training-{str(int(time.time()))}'
peft_config = LoraConfig(
        lora_alpha=16, 
        lora_dropout=0.1,
        r=64,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
        #inplace=False,
        #target_modules=[
        #    'q_proj',
        #    'k_proj',
        #    'v_proj',
        #    'dense'
    #],
)




training_arguments = TrainingArguments(
    output_dir=output_dir,                    # 保存训练日志和检查点的目录
    num_train_epochs=5,                       # 为其训练模型的历元数。一个epoch通常指的是通过整个训练数据集一次的前向传播和反向传播过程。
    #num_train_epochs 被设置为3,意味着模型将完整地遍历训练数据集3次。
    per_device_train_batch_size=1,            # 每个设备上每个批次的样本数。
    gradient_accumulation_steps=8,            #  执行向后/更新过程之前的步骤数
    gradient_checkpointing=True,              # 使用渐变检查点保存内存
    optim="paged_adamw_8bit",                 #"paged_adamw_8bit"/"paged_adamw_32bit" 用于训练模型的优化器
    save_steps=50,
    logging_steps=50,                         # 记录训练指标的步骤数。它被设置为50,意味着每50个训练步骤,训练指标将被记录一次。
    learning_rate=2e-4,                       # 学习率
    weight_decay=0.001,
    fp16=True,
    bf16=False,
    max_grad_norm=0.3,                        # 基于QLoRA的最大梯度范数
    max_steps=500, #1000,                     #这个建议设置上,不然会出现很多次的训练轮
    warmup_ratio=0.03,                        # 基于QLoRA的预热比
    group_by_length=True,
    lr_scheduler_type="cosine",               # 使用余弦学习率调度
    report_to="tensorboard",                  # 向tensorboard报告指标  可选"none"
    evaluation_strategy="epoch",               # 每个纪元保存检查点 可选"steps" 这个参数设置了评估策略。
    #代码中设置为"epoch",意味着评估将在每个epoch结束后进行。由于eval_steps也设置为50,这可能意味着评估将在每50个训练步骤或每个epoch
    #warmup_steps = 1
    #logging_dir="./logs",
    save_strategy="steps",
    eval_steps=50,#意味着每50个训练步骤,模型将在验证集上进行一次评估。
    do_eval=True,
    overwrite_output_dir  =True
)
"""
上述参数,模型将在以下情况下停止训练:

完成3个epoch的训练,无论eval_steps条件是否满足。
如果训练数据集的大小导致在3个epoch内无法达到50个训练步骤,那么模型将在完成所有训练步骤后停止。
至于评估输出,由于logging_steps和eval_steps都设置为50,这意味着:

每50个训练步骤,训练指标将被记录一次。
每50个训练步骤,模型将在验证集上进行一次评估。
"""


#training_arguments.config.use_cache = False
#transformers.Trainer
"""
transformers.Trainer:如果你有一个大的数据集,并且需要为你的培训循环或复杂的培训工作流程进行广泛的定制。
使用SFTTrainer:如果你有一个预训练的模型和相对较小的数据集,并且想要更简单、更快的微调体验和高效的内存使用。

如果训练数据集较小,可能导致在每个epoch中训练步骤数少于50步,那么eval_steps条件可能不会触发,评估将在每个epoch结束后根据evaluation_strategy参数的设置进行。

另外,max_steps参数通常用于设置训练的最大步骤数,以防止训练超过预定的epoch数。
代码中,max_steps被设置为1000,这个值远大于由num_train_epochs和per_device_train_batch_size参数隐式定义的训练步骤数。
因此,除非训练数据集非常大,否则这个参数在上下文中可能不会起作用。
如果max_steps的值大于完成所有epoch所需步骤数的总和,训练将在完成所有epoch后停止,而不是在达到max_steps指定的步骤数时停止。
"""
trainer = SFTTrainer(
    model=original_model,
    args=training_arguments,#
    train_dataset=train_data,
    eval_dataset=eval_data,
    peft_config=peft_config,# 模型配置文件
    dataset_text_field="text",
    tokenizer=tokenizer,
    max_seq_length=1024,
    packing=False,
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": False,
    }
)

"""
#可选
trainer = transformers.Trainer(
    model=peft_model,
    train_dataset=train_data,
    eval_dataset=eval_data,
    args=training_arguments,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
"""

trainer.train()
#模型保存
trainer.save_model()
#保存token
tokenizer.save_pretrained(output_dir)

运行结果:

在这里插入图片描述
这里看到loss 没有很好的出现下降的情况

其他方案调整

调整promt 工程

告诉大模型 任务是文本的摘要
在这里插入图片描述

尝试调整

调整学习率1e-4
在这里插入图片描述
只是延缓了过拟合的情况发生。

增加训练集
在这里插入图片描述
后期过拟合的只有更快了。
以上是文本的全部内容,有好的方法希望一起讨论。感谢。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1674894.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

hadoop学习---基于Hive的数据仓库相关函数机制及其优化方案

Hive相关函数&#xff08;部分&#xff09;&#xff1a; if函数: 作用: 用于进行逻辑判断操作 语法: if(条件, true返回信息,false返回信息) 注意: if函数支持嵌套使用 select if(aa,’bbbb’,111) fromlxw_dual; bbbb select if(1<2,100,200) fromlxw_dual; 200nvl函数:…

【Python】理解WOE(Weight of Evidence)和IV(Information Value)

忠孝东路走九遍 脚底下踏著曾经你我的点点 我从日走到夜 心从灰跳到黑 我多想跳上车子离开伤心的台北 忠孝东路走九遍 穿过陌生人潮搜寻你的脸 有人走的匆忙 有人爱的甜美 谁会在意擦肩而过的心碎 &#x1f3b5; 动力火车《忠孝东路走九遍》 在信用评分和…

教育型内容的制胜秘诀:Kompas.ai如何结合知识与营销

在数字化营销的浪潮中&#xff0c;教育型内容已经成为品牌建立权威性和提供价值的重要手段。通过分享专业知识和见解&#xff0c;品牌不仅能够吸引目标受众&#xff0c;还能够在潜在客户心中建立起专业和可信赖的形象。本文将深入分析教育型内容的重要性&#xff0c;详细介绍Ko…

VALSE 2024合合信息 | 文档解析与向量化技术加速多模态大模型训练与应用

第十四届视觉与学习青年学者研讨会&#xff08;VALSE 2024&#xff09;近期在重庆悦来国际会议中心圆满举行&#xff0c;由中国人工智能学会&#xff08;CAAI&#xff09;、中国图象图形学会&#xff08;CSIG&#xff09;、中国民族贸易促进会主办&#xff0c;重庆邮电大学承办…

数据库系统概论(个人笔记)(第二部分)

数据库系统概论&#xff08;个人笔记&#xff09; 文章目录 数据库系统概论&#xff08;个人笔记&#xff09;2、关系模型简介2.1 关系数据库的结构2.2 数据库模式2.3 键2.4 模式图2.5 关系查询语言2.6 关系代数 2、关系模型简介 2.1 关系数据库的结构 Structure of Relational…

【目标检测】YOLOv5|YOLOv8模型QT界面可视化部署

YOLO-Deploy-QT_Interface 最近笔者做了YOLO系列算法的部署工作,现做一个总结。主要工作是做了用于部署YOLOv5和YOLOv8的可视化QT界面,可实现图片、文件夹、视频、摄像头的ONNX与OpenVino部署,具体效果如下: 代码链接:https://github.com/Zency-Sun/YOLO-Deploy-QT_Inte…

开源之夏又有新项目发布!快来认领!¥12,000 奖金等你来!

又有新项目加入开源之夏啦&#xff01;KubeBlocks 的亲兄弟 WeScale 也加入开源之夏啦&#xff01;一起来了解下 WeScale 社区及它带来的项目吧&#xff01; WeScale 是什么&#xff1f; WeScale 社区源自对更好的数据库访问体验的追求。在现代应用程序开发中&#xff0c;数据…

一文扫盲(13):电商管理系统的功能模块和设计要点

电商管理系统是一种用于管理和运营电子商务平台的软件系统。它提供了一系列功能模块&#xff0c;帮助企业进行商品管理、订单管理、会员管理、营销推广、数据分析等工作。本文将从以下四个方面介绍电商管理系统。 一、什么是电商管理系统 电商管理系统是一种集成了各种功能模块…

Android Q - 音频通路调试

对于当前模块不是很清楚&#xff0c;刚好有个项目这方面有点问题&#xff0c;根据展锐支持文档一步步检查就可以了。首先得先弄清楚硬件具体是怎么连接的&#xff0c;比如文档提到的案例&#xff1a;sprd codec speaker output 连接外部 PA。 耳机接的是什么&#xff0c;speake…

国产化开源鸿蒙系统智能终端RK3568主板在电子班牌项目的应用

国产化开源鸿蒙系统智能终端主板AIoT-3568A、人脸识别算法的的电子班牌方案可支持校园信息发布、人脸识别考勤、考场管理、查询互动等多项功能&#xff0c;助力学校在硬件上实现信息化、网络化、数字化&#xff0c;构建“学校、教师、学生”三个维度的智慧教育空间。 方案优势 …

mysql5和mysql8同时存在

Win10安装两个不同版本MySQL数据库&#xff08;一个5.7&#xff0c;一个8.0.17&#xff09;_两个数据库的版本不同(我本地的是mysql5.7,线上是mysql8),怎么进行数据的同步?-CSDN博客 安装两个版本的mysql出现的问题和解决_mysql安装两个版本其中一个不适用-CSDN博客 一台电脑…

增程SUV价格即将崩盘?买车一定要再等等!

文 | AUTO芯球 作者 | 雷歌​ 真是“离谱”啊&#xff0c;车圈真是逗比欢乐多&#xff0c; 我这两天看一个博主连续40多小时开车直播&#xff0c;充电口、油箱盖全部封死&#xff0c;全程视频直播没断过&#xff0c; 就为了测试这两天刚上市的星际元ET续航有多远。 另一个…

drippingblues 靶机实战

信息收集&#xff1a; Nmap: 存活&#xff1a; 靶机ip&#xff1a;192.168.10.110 端口&#xff1a; 服务&#xff1a; 发现ftp服务可以匿名登录。且用户名是FTP。 发现一个压缩包&#xff0c;下载并爆破。 得到密码 072528035。发现关键字 drip。里面还有一个 secret.zip(…

C语言写扫雷游戏(数组和函数实践)

目录 最后是代码啦&#xff01; 手把手教你用C语言写一个扫雷游戏&#xff01; 1.我们搭建一下这个多文件形式的扫雷游戏文件结构 2.在主函数里面设置一个包含游戏框架的菜单 菜单可以方便游戏玩家选择要进行的动作和不断地进行下一局。 3.switch语句连接不同的结果 菜单可…

49.乐理基础-拍号的类型-单拍子、复拍子

当前写的东西&#xff0c;如果只是想要看懂乐谱的话&#xff0c;它是没什么意义的&#xff0c;就像我们要把 0&#xff0c;1&#xff0c;2&#xff0c;3&#xff0c;4&#xff0c;5。。。称为自然数&#xff0c;1&#xff0c;2&#xff0c;3&#xff0c;4&#xff0c;5称为正整…

【提示学习论文】TCP:Textual-based Class-aware Prompt tuning for Visual-Language Model

TCP:Textual-based Class-aware Prompt tuning for Visual-Language Model&#xff08;CVPR2024&#xff09; 基于文本的类感知提示调优的VLMKgCoOp为baseline&#xff0c;进行改进&#xff0c;把 w c l i p w_{clip} wclip​进行投影&#xff0c;然后与Learnable prompts进行…

考研数学|强化《660》+《880》这样刷,太丝滑了❗️

660题880题需要大概两个月才能做完 660题和880题都是很高质量的题集&#xff0c;所以做起来一点也不轻松。 每年都会有学生暑假两个月只做了一本660题的情况&#xff0c;因为题目实在是太难&#xff0c;有点做不下去的感觉。 不过不要担心&#xff0c;暑假就是刷题发现问题的…

Ubuntu安装k8s集群

文章目录 Ubuntu安装k8s3台主机前置操作&#xff1a;3台主机k8s前置安装命令&#xff1a;k8s安装命令&#xff1a; 节点加入 Ubuntu安装k8s 官方文档&#xff1a;https://kubernetes.io/zh-cn/docs/setup/production-environment/tools/kubeadm/install-kubeadm/ 默认3台机子 注…

如何为域名生成证书签发请求CSR

最近我们在Hostease购买了服务器产品&#xff0c;为了保障我们网站的安全&#xff0c;我们额外还购买了SSL证书产品。在Hostease技术客服的帮助下&#xff0c;我们成功签发了SSL证书。 在签发证书前需要生成一个证书签名请求CSR&#xff0c;证书签名请求(CSR)是一个包含有关你…

轻松拿下指针(5)

文章目录 一、回调函数是什么二、qsort使用举例三、qsort函数的模拟实现 一、回调函数是什么 回调函数就是⼀个通过函数指针调⽤的函数。 如果你把函数的指针&#xff08;地址&#xff09;作为参数传递给另⼀个函数&#xff0c;当这个指针被⽤来调⽤其所指向的函数 时&#x…