【Finetune】(二)、transformers之Prompt-Tuning微调

news2024/9/19 10:24:15

文章目录

  • 0、prompt-tuning基本原理
  • 1、实战
    • 1.1、导包
    • 1.2、加载数据
    • 1.3、数据预处理
    • 1.4、创建模型
    • 1.5、Prompt Tuning*
      • 1.5.1、配置文件
      • 1.5.2、创建模型
    • 1.6、配置训练参数
    • 1.7、创建训练器
    • 1.8、模型训练
    • 1.9、推理:加载预训练好的模型

0、prompt-tuning基本原理

 prompt-tuning的基本思想就是冻结主模型的全部参数,在训练数据前加入一小段Prompt,只训练Prompt的表示向量,即一个Embedding模块。其中,prompt又存在两种形式,一种是hard prompt,一种是soft prompt。

在这里插入图片描述

1、实战

1.1、导包

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

1.2、加载数据

ds = Dataset.load_from_disk("../Data/alpaca_data_zh/")

1.3、数据预处理

tokenizer = AutoTokenizer.from_pretrained("../Model/bloom-389m-zh")
tokenizer

def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds

1.4、创建模型

model = AutoModelForCausalLM.from_pretrained("../Model/bloom-389m-zh",low_cpu_mem_usage=True)

1.5、Prompt Tuning*

1.5.1、配置文件

#soft prompt

# config = PromptTuningConfig(
#     task_type=TaskType.CAUSAL_LM,
#     num_virtual_tokens=10,
#     )
# config
#hard prompt
config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init = PromptTuningInit.TEXT,
    prompt_tuning_init_text = '下面是一段机器人的对话:',
    num_virtual_tokens=len(tokenizer('下面是一段机器人的对话:')['input_ids']),
    tokenizer_name_or_path='../Model/bloom-389m-zh',

    )
config

1.5.2、创建模型

model= get_peft_model(model,config)
model

打印模型训练参数

model.print_trainable_parameters()

1.6、配置训练参数

args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=1
)

1.7、创建训练器

trainer = Trainer(
    args=args,
    model=model,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True, )
)

1.8、模型训练

trainer.train()

1.9、推理:加载预训练好的模型

from peft import PeftModel
peft_model =  PeftModel.from_pretrained(model=model,model_id='./chat_bot/checkpoint500/')
from transformers import pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
pipe(ipt, max_length=256, do_sample=True, temperature=0.5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2146087.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【机器学习】任务五:葡萄酒和鸢尾花数据集分类任务

目录 1.实验基础知识 1.1 集成学习 (1)随机森林 (2)梯度提升决策树(GBDT) (3)XGBoost (4)LightGBM 1.2 参数优化 (1)网格搜索…

编写第一个hadoop3.3.6的mapreduce程序

hadoop还是用的上个伪分布环境。 hadoop安装在龙蜥anolis8.9上,开发是在windows下。 1、windows下首先要下载hadoop的包,hadoop-3.3.6.tar.gz,比如我的解压到d:\java\hadoop-3.3.6中。 配置环境:HADOOP_HOME,内容为&am…

ava总结篇系列:Java泛型Java sort用法详解

一. 泛型概念的提出(为什么需要泛型)? 首先,我们看下下面这段简短的代码: 1 public class GenericTest { 2 3 public static void main(String[] args) { 4 List list new ArrayList(); 5 list.add(&q…

【Elasticsearch系列四】ELK Stack

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

mysql事务的隔离级别学习

事务的隔离级别: ⅰ. 读未提交 ⅱ. 对已提交 (解决 脏读) ⅲ. 可重复读 (解决 不可重复读) ⅳ. 串行化 (解决 脏读 不可重复读 幻读 问题 ) 隔离级别分类如下,在不同的隔离级别下可能产生不…

网络安全。

文章目录 目录 文章目录 一. 网络安全概述 二. 密码学原理 三. 报文完整性和数字签名 密码散列函数 报文鉴别码 数字签名 公钥认证 四. HTTPS通信 总结 一. 网络安全概述 网络安全是保护计算机网络及其数据免受各种威胁和攻击的实践和技术。随着互联网的普及和数字化…

Linux shell编程学习笔记81:zcat命令——快速查看压缩文件内容

0 引言 在 Linux shell编程学习笔记80:gzip命令——让文件瘦身-CSDN博客https://blog.csdn.net/Purpleendurer/article/details/141862213?spm1001.2014.3001.5501中,我们使用gzip命令可以创建压缩文件。那么,我们可以使用zcat命令来查看压…

传输层协议——udp/tcp

目录 再谈端口号 udp 协议 理解报头 udp特点 缓冲区 udp使用的注意事项 tcp协议 TCP的可靠性与提高效率的策略 序号/确认序号 窗口大小 ACK: PSH URG RST 保活机制 重传 三次握手(SYN) 四次挥手(FIN) 流量控制 滑动窗口 拥塞控制 延迟应答 捎带应答 面…

JavaScript match() 方法

match() 方法可在字符串内检索指定的值,或找到一个或多个正则表达式的匹配。 如果想了解更多正则表达式教程请查看: RegExp 教程 和我们的 RegExp 对象参考手册。 注意: match() 方法将检索字符串 String Object,以找到一个或多个…

Vue3 项目引入阿里 iconfont 图标和字体的多种方式

🚀 个人简介:某大型国企资深软件研发工程师,信息系统项目管理师、CSDN优质创作者、阿里云专家博主,华为云云享专家,分享前端后端相关技术与工作常见问题~ 💟 作 者:码喽的自我修养&#x1f9…

计算机人工智能前沿进展-大语言模型方向-2024-09-19

计算机人工智能前沿进展-大语言模型方向-2024-09-19 1. SAM4MLLM: Enhance Multi-Modal Large Language Model for Referring Expression Segmentation Authors: Yi-Chia Chen, Wei-Hua Li, Cheng Sun, Yu-Chiang Frank Wang, Chu-Song Chen SAM4MLLM: 增强多模态大型语言模型…

Java面试篇基础部分-Java线程生命周期

线程的生命周期分别为 新建(New)、就绪(Runnable)、运行(Running)、阻塞(Blocked)和死亡(Dead)这五种状态。   在系统运行过程中有线程不断地被创建,而旧的线程在执行完毕之后被清理,线程通过排队的方式获取共享资源或者锁的时候被阻塞,所以运行中的线程就会在…

基于yolov8的红外小目标无人机飞鸟检测系统python源码+onnx模型+评估指标曲线+精美GUI界面

【算法介绍】 基于YOLOv8的红外小目标无人机与飞鸟检测系统是一项集成了前沿技术的创新解决方案。该系统利用YOLOv8深度学习模型的强大目标检测能力,结合红外成像技术,实现了对小型无人机和飞鸟等低空飞行目标的快速、准确检测。 YOLOv8作为YOLO系列的…

supermap iclient3d for cesium中的平移,旋转

昨天写的模型机头不是速度的方向 基础知识 屏幕坐标系,笛卡尔空间直角坐标系,大地坐标系 平移和旋转都是基于笛卡尔空间直角坐标系,也就是基于地心。但是我们想实现模型的旋转是基于模型的局部坐标系,那么就要坐标转换。 向量归…

C++调用C# DLL之踩坑记录

C是非托管代码,C#则是托管代码,无法直接调用 CLR的介绍见CLR简介 MSDN提到了两种非托管-托管的交互技术:CLR Interop和COM Interop 后者要将C# 类库注册为COM组件,本文只探讨CLR,要通过C CLR写中间层代码 方式一&…

全新 HLOB 模型:预测限价订单簿中间价格变化方向的利器

作者:老余捞鱼 原创不易,转载请标明出处及原作者。 写在前面的话: 本文介绍了一个名为HLOB的新型大规模深度学习模型,用于预测限价订单簿中间价格的变化。该模型利用信息过滤网络(特别是三角最大化过滤图)来揭示订单簿中不同成交量水平间的深层和非平凡依赖结构,…

[Meachines] [Medium] Bart Server Monitor+Internal Chat+UA投毒+Winlogon用户密码泄露权限提升

信息收集 IP AddressOpening Ports10.10.10.81TCP:80 $ nmap -p- 10.10.10.81 --min-rate 1000 -sC -sV PORT STATE SERVICE VERSION 80/tcp open http Microsoft IIS httpd 10.0 | http-methods: |_ Potentially risky methods: TRACE |_http-server-header: Micros…

【java系】记录一次ClassLoader.getResourceAsStream获取不到文件流

问题描述反馈,开发同事在本地获取对应文件流是可以正常业务操作,发布到linux服务器对应环境就会出现异常。 源码部分截图 看到这里,我猜想是否和window底层文件操作系统不区分大小写有关呢?而服务器linux是严格区分大小写这个应该…

ZLMediaKit Windows编译以及使用

1.运行ZLMediaKit 2.通过ffmpeg把视频源推流给ZLMediaKit 执行以下命令,将本地视频通过RTSP协议推流给ZLMediaKit。 ffmpeg -re -stream_loop -1 -i "D:\workplace\armgb\public\1.fileh264" -vcodec h264 -f rtsp rtsp://127.0.0.1/live/test 若想将本…

网络安全:建筑公司会计软件遭受暴力攻击

黑客正在暴力破解基金会会计服务器上高权限账户的密码,这些账户广泛用于建筑行业,从而侵入企业网络。 这一恶意活动最先被 Huntress 发现,其研究人员于 2024 年 9 月 14 日检测到了此次攻击。 Huntress 已经发现这些攻击对管道、暖通空调、…