HuggingFace学习笔记--BitFit高效微调

news2025/2/26 17:45:38

1--BitFit高效微调

        BitFit,全称是 bias-term fine-tuning,其高效微调只去微调带有 bias 的参数,其余参数全部固定;

2--实例代码

from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from transformers import pipeline, TrainingArguments, Trainer

# 分词器
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")

# 函数内将instruction和response拆开分词的原因是:
# 为了便于mask掉不需要计算损失的labels, 即代码labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

if __name__ == "__main__":
    # 加载数据集
    dataset = load_from_disk("./PEFT/data/alpaca_data_zh")
    
    # 处理数据
    tokenized_ds = dataset.map(process_func, remove_columns = dataset.column_names)
    # print(tokenizer.decode(tokenized_ds[1]["input_ids"]))
    # print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))
    
    # 创建模型
    model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)
    # 基于bitfit只训练带有bias的参数
    for name, param in model.named_parameters():
        if "bias" not in name:
            param.requires_grad = False
            
    # 训练参数
    args = TrainingArguments(
        output_dir = "./chatbot",
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 8,
        logging_steps = 10,
        num_train_epochs = 1
    )
    
    # trainer
    trainer = Trainer(
        model = model,
        args = args,
        train_dataset = tokenized_ds,
        data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
    )
    
    # 训练模型
    trainer.train()
    
    # 模型推理
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
    
    ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
    output = pipe(ipt, max_length=256, do_sample=True)
    print(output)

结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1278624.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

牛客算法题【HJ96 表示数字】golang实现

题目 HJ96 表示数字 golang实现 package mainimport ("fmt""unicode" )func main() {s : ""var s_o stringvar char_pre, r runefor {n, _ : fmt.Scan(&s)if n 0 {break} else {for _, r range s {if unicode.IsDigit(r) {if !unicode.…

Kubernetes技术与架构-策略

Kubernetes集群提供系统支持的策略,也提供开放接口给第三方定义的策略,这些策略用于可定义的配置文件或者Kubernetes集群的运行时环境,其中包括进程ID数量的申请与限制策略,服务器节点Node内的进程ID的数量限制策略,Po…

MySQL 临时数据空间不足导致SQL被killed 的问题与扩展

开头还是介绍一下群,如果感兴趣PolarDB ,MongoDB ,MySQL ,PostgreSQL ,Redis, Oceanbase, Sql Server等有问题,有需求都可以加群群内,可以解决你的问题。加群请联系 liuaustin3 ,(共1730人左右 1 2 3 4 5&#xff0…

网站有必要使用SSL证书吗

随着互联网的快速发展,网络安全问题也变得日益突出,SSL证书的作用日益凸显。 什么是SSL证书? SSL证书(Secure Sockets Layer Certificate),也称为TLS证书(Transport Layer Security Certifica…

【论文阅读】CAN网络中基于时序信道的隐蔽认证算法

文章目录 摘要一、引言和动机A 相关工作 二、背景及实验设置A 以前工作中的时钟偏差和局限性B.最坏到达时间C.安装组件 三、优化流量分配A.问题陈述B.优化帧调度 四、协议和结果A.主协议B.对手模型C. 优化流量和单一发送者的结果D.多发送方情况和噪声信道E.信道数据速率&#x…

HuggingFace学习笔记--Prompt-Tuning高效微调

1--Prompt-Tuning介绍 Prompt-Tuning 高效微调只会训练新增的Prompt的表示层,模型的其余参数全部固定; 新增的 Prompt 内容可以分为 Hard Prompt 和 Soft Prompt 两类; Soft prompt 通常指的是一种较为宽泛或模糊的提示,允许模型在…

ELK配置记录

1. filebeat.yml配置 启动命令: ./filebeat -e -c filebeat.yml # 输入 filebeat.inputs: - type: logenabled: truepaths:- /soft/log/base.*#跨行日志正则,从有时间的开始,到下一个时间之前结束multiline.pattern: ^\[[0-9]{4}-[0-9]{2}…

Concurrent Security of Anonymous Credentials Light, Revisited

目录 摘要引言 摘要 我们重新审视了著名的匿名证书轻(ACL)方案(Baldimtsi和Lysyanskaya,CCS’13)的并发安全保证。该方案最初被证明在按顺序执行时是安全的,其并发安全性仍然是一个悬而未决的问题。Benham…

【PyTorch】线性回归

文章目录 1. 代码实现1.1 一元线性回归模型的训练 2. 代码解读2.1. tensorboardX2.1.1. tensorboardX的安装2.1.2. tensorboardX的使用 1. 代码实现 波士顿房价数据集下载 1.1 一元线性回归模型的训练 import numpy as np import torch import torch.nn as nn from torch.ut…

[计算机网络] 高手常用的几个抓包工具(上)

文章目录 高手常用的抓包工具一览什么是抓包工具优秀抓包工具WiresharkFiddlerTcpdumpCharles 高手常用的抓包工具一览 什么是抓包工具 抓包工具是一种可以捕获、分析和修改网络流量的软件。它可以帮助您进行网络调试、性能测试、安全审计等任务。 抓包工具可以实时地显示网…

使用trigger-forward跨流水线传递参数

参考文档:https://docs.gitlab.com/ee/ci/yaml/#triggerforward 今天给大家介绍一个gitlab CI/CD的关键字 - forward,该关键字是一个比较偏的功能,但同时也是一个很实用的功能,我们通过在gitlab的ci文件中使用forward关键字&#…

Android HCI日志分析案例1

案例1--蓝牙扫描设备过程分析 应用层发起搜索蓝牙设备,Android 官方提供的蓝牙扫描方式有三种,分别如下: BluetoothAdapter.startDiscovery(); //可以扫描经典蓝牙和BLE两种。BluetoothAdapter.startLeScan();//扫描低功耗蓝牙,…

深入理解同源限制:网络安全的守护者(上)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

超简单的node脚本,将xlsx文件转化为json

开发场景,在一个官网中,官网的设计非常简单,就是一个纯静态的页面,全网站仅一个地方调一下接口,发一下用户填写的信息到运营同学的邮箱,这些数据不会记录在数据库,我需要做一个这样的下拉框。 但…

python使用记录

1、VSCode添加多个python解释器 只需要将对应的python.exe的目录,添加到系统环境变量的Path中即可,VSCode会自动识别及添加 2、pip 使用 pip常用命令和一些坑 查看已安装库的版本号 pip show 库名称 通过git 仓库安装第三方库 pip install git仓库地…

SQL Server 2016(分离和附加数据库)

1、实验环境。 基于上一个实验《SQL Server(创建数据库)》 2、需求描述。 class数据库的数据文件和事务日志文件都位于C:\db_class目录下。现在需要把class数据库的数据文件和事务日志文件分开存放,数据文件class.mdf存放于原位置&#xff0…

DQN原理及PyTorch实现【强化学习】

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 欢迎来到我们的强化学习系列的第三部分。 在上两篇博客中,我们介绍了强化学习中…

【Linux】命令行参数

文章目录 前言一、C语言main函数的参数二、环境变量总结 前言 我们在Linux命令行输入命令的时候,一般都会跟上一些参数选项,比如l命令,ls -a -l。以前我总是觉得这是理所当然的,没深究其本质究竟是什么,今天才终于知道…

高压配电室智能运维

高压配电室智能运维是指通过运用先进的物联网、大数据、云计算等技术,对高压配电室进行智能化、远程化的运行维护,实现高压配电室的安全、高效、经济运行。以下是高压配电室智能运维的主要功能和优势: 实时监测:通过传感器和监测设…

分页助手入门以及小bug,报sql语法错误

导入坐标 5版本以上的分页助手 可以不用手动指定数据库语言&#xff0c;它会自动识别 <dependency> <groupId>com.github.pagehelper</groupId> <artifactId>pagehelper</artifactId> <version>5.3.2</version> </dependency&g…