【Finetune】(三)、transformers之P-Tuning微调

news2024/9/23 9:44:34

文章目录

  • 0、P-Tuning基本原理
  • 1、代码实战
    • 1.1、导包
    • 1.2、加载数据集
    • 1.3、数据集预处理
    • 1.4、创建模型
    • 1.5、P-tuning*
      • 1.5.1、配置文件
      • 1.5.2、创建模型
    • 1.6、配置训练参数
    • 1.7、创建训练器
    • 1.8、模型训练
    • 1.9、模型推理

0、P-Tuning基本原理

 P-Tuning的基本思想是在prompt-tuning的基础上,对prompt部分进行进一步的编码计算,加速收敛。具体来说,PEFT支持两种编码方式,一种是LSTM,一种是MLP,与Prompt-tuning不同的是,Prompt的形式只有Soft Prompt。
在这里插入图片描述

1、代码实战

1.1、导包

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

1.2、加载数据集

ds = Dataset.load_from_disk("../Data/alpaca_data_zh/")
ds

1.3、数据集预处理

tokenizer = AutoTokenizer.from_pretrained("../Model/bloom-389m-zh")
tokenizer
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds

1.4、创建模型

model = AutoModelForCausalLM.from_pretrained("../Model/bloom-389m-zh", low_cpu_mem_usage=True)

1.5、P-tuning*

1.5.1、配置文件

from peft import PromptEncoderConfig, TaskType, get_peft_model, PromptEncoderReparameterizationType

config = PromptEncoderConfig(
    task_type=TaskType.CAUSAL_LM, 
    num_virtual_tokens=10,
    encoder_reparameterization_type=PromptEncoderReparameterizationType.MLP,
    # encoder_reparameterization_type=PromptEncoderReparameterizationType.LSTM,
    encoder_dropout=0.1, 
    encoder_num_layers=5, 
    encoder_hidden_size=1024)
print(config)

1.5.2、创建模型

model = get_peft_model(model, config)
model.print_trainable_parameters()

1.6、配置训练参数

args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=1
)

1.7、创建训练器

args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=1
)

1.8、模型训练

args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=1
)

1.9、模型推理

model = model.cuda()
ipt = tokenizer("Human: {}\n{}".format("数学考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**ipt, max_length=256, do_sample=True)[0], skip_special_tokens=True))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2157223.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Boot管理用户数据

目录 学习目标前言Thymeleaf 模板JSON 数据步骤 1: 创建 Spring Boot 项目使用 Spring Initializr 创建项目使用 IDE 创建项目 步骤 2: 添加依赖步骤 3: 创建 Controller步骤 4: 新建index页面步骤 5: 运行应用程序 表单提交步骤 1: 添加 Thymeleaf 依赖在 Maven 中添加依赖 步…

Vue3.3新特性defineModel

defineModel的使用: defineModel选项可以帮我们省去很多麻烦 不仅需要上述操作,还需要进行一定的配置: 在vite.config.js中进行配置 defineModel是一个宏,所以不需要从vue中import导入,直接使用就可以了。这个宏可以用来声明一个…

Java之线程篇六

目录 CAS CAS伪代码 CAS的应用 实现原子类 实现自旋锁 CAS的ABA问题 ABA问题导致BUG的例子 相关面试题 synchronized原理 synchronized特性 加锁过程 相关面试题 Callable 相关面试题 JUC的常见类 ReentrantLock ReentrantLock 和 synchronized 的区别: 原…

Android 新增目录怎么加入git

工作中会遇到android系统源码系统增加第三方功能支持,需要增加目录,那么这个目录怎么提交到服务器上去呢?接下来我们就看下这个问题的解决 一服务器创建仓库 一个新的目录增加到仓库中,需要仓库有对应的仓库地址,这个让…

小记编程语言浮点精度问题

注意: 本文内容于 2024-09-15 20:21:12 创建,可能不会在此平台上进行更新。如果您希望查看最新版本或更多相关内容,请访问原文地址:小记编程语言浮点精度问题。感谢您的关注与支持! 浮点数在计算机中不能精确表示所有…

docker启动mysql未读取my.cnf配置文件问题

描述 在做mysql主从复制配置两台mysql时,从节点的my.cnf配置为: [mysqld] datadir /usr/local/mysql/slave1/data character-set-server utf8 lower-case-table-names 1 # 主从复制-从机配置# 从服务器唯一 ID server-id 2 # 启用中继日志 relay-l…

【小沐学CAD】3ds Max常见操作汇总

文章目录 1、简介2、二次开发2.1 C 和 3ds Max C SDK2.2 NET 和 3ds Max .NET API2.3 3ds Max 中的 Python 脚本2.4 3ds Max 中的 MAXScript 脚本 3、快捷键3.1 3Dmax键快捷键命令——按字母排序3.2 3dmax快捷键命令——数字键3.3 3dmax功能键快捷键命令3.4 3Dmax常用快捷键——…

对网页聊天项目进行性能测试, 使用JMeter对于基于WebSocket开发的webChat项目的聊天功能进行测试

登录功能 包括接口的设置和csv文件配置 ​​​​​​ 这里csv文件就是使用xlsx保存数据, 然后在浏览器找个网址转成csv文件 注册功能 这里因为需要每次注册的账号不能相同, 所以用了时间函数来当用户名, 保证尽可能的给正确的注册数据, 时间函数使用方法如下 这里输入分钟, 秒…

肝内胆管癌中三级淋巴结构分布与临床预后的相关性研究|文献精析·24-09-22

小罗碎碎念 这篇文章是关于肝内胆管癌(intrahepatic cholangiocarcinoma, iCCA)中三级淋巴结构(tertiary lymphoid structures, TLSs)的分布、密度及其对临床结果的预测价值的研究。 作者类型作者姓名单位名称(中文&a…

数据结构——串的模式匹配算法(BF算法和KMP算法)

算法目的: 确定主串中所含子串(模式串)第一次出现的位置(定位) 算法应用: 搜索引擎、拼写检查、语言翻译、数据压缩 算法种类: BF算法(Brute-Force,又称古典的…

【洛谷】P10417 [蓝桥杯 2023 国 A] 第 K 小的和 的题解

【洛谷】P10417 [蓝桥杯 2023 国 A] 第 K 小的和 的题解 题目传送门 题解 CSP-S1 补全程序,致敬全 A 的答案,和神奇的预言家。 写一下这篇的题解说不定能加 CSP 2024 的 RP 首先看到 k k k 这么大的一个常数,就想到了二分。然后写一个判…

《深入理解JAVA虚拟机(第2版)》- 第13章 - 学习笔记【终章】

第13章 线程安全与锁优化 13.1 概述 面向过程的编程思想 将数据和过程独立分开,数据是问题空间中的客体,程序代码是用来处理数据的,这种站在计算机角度来抽象和解决问题的思维方式,称为面向对象的编程思想。 面向对象的编程思想…

一劳永逸:用脚本实现夸克网盘内容自动更新

系统环境:debian/ubuntu 、 安装了python3 原作者项目:https://github.com/Cp0204/quark-auto-save 感谢 缘起 我喜欢看电影追剧,会经常转存一些资源到夸克网盘,电影还好,如果是电视剧,麻烦就来了。 对于一…

Kettle的安装及简单使用

Kettle的安装及简单使用一、kettle概述二、kettle安装部署和使用Windows下安装案例1:MySQL to MySQL案例2:使用作业执行上述转换,并且额外在表stu2中添加一条数据案例3:将hive表的数据输出到hdfs案例4:读取hdfs文件并将…

Jboss常⻅中间件漏洞

一.CVE-2015-7501 环境搭建 cd vulhub-master/jboss/JMXInvokerServlet-deserialization docker-compose up -d 1.POC,访问地址 172.16.1.4:8080/invoker/JMXInvokerServlet 返回如下,说明接⼝开放,此接⼝存在反序列化漏洞 2.下载 ysose…

7.C++程序中的基本数据类型-数据类型之间的转换

在C中,类型转换是将一个数据类型转为另外一个数据类型,其转换过程比较复杂,目前只讨论基本数据类型之间的转换。 类型转换分为两部分:隐式转换和显示转换 隐式转换又称为自动转换,显示转换又称为强制转换。 隐式转换…

[Linux] Linux进程PCB内部信息的深入理解

标题:[Linux] Linux进程PCB内部信息的深入理解 个人主页:水墨不写bug (图片来自网络) 目录 一.查看进程 二.认识并了解进程的关键信息 I,PID/PPID II,exe III,cwd 三、fork(&…

vue源码分析(九)—— 合并配置

文章目录 前言1.vue cli 创建一个基本的vue2 项目2.将mian.js文件改成如下3. 运行结果及其疑问? 一、使用 new Vue 创建过程的 2 种场景二、margeOption的详细说明1.margeOption的方法地址2.合并策略的具体使用3.defaultStrat 默认策略方法 三:以生命周期…

OpenResty安装及使用

🍓 简介:java系列技术分享(👉持续更新中…🔥) 🍓 初衷:一起学习、一起进步、坚持不懈 🍓 如果文章内容有误与您的想法不一致,欢迎大家在评论区指正🙏 🍓 希望这篇文章对你有所帮助,欢…

调用本地大模型服务出现PermissionDeniedError的解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…