HuggingFace模型头的自定义

news2025/1/12 6:01:26

 

在线工具推荐:  Three.js AI纹理开发包 -  YOLO合成数据生成器 -  GLTF/GLB在线编辑 -  3D模型格式在线转换 -  可编程3D场景编辑器

在本文中我们将介绍如何使HuggingFace的模型适应你的任务,在Pytorch中建立自定义模型头并将其连接到HF模型的主体,并端到端地训练系统。

1、HF模型头和模型体

这是典型的HF模型的样子:

为什么我需要单独使用模型头(Model Head)和模型体(Model Body)?

一些HF的模型针对下游任务(例如提问或文本分类)训练,并包含有关其权重培训的数据的知识。

有时,尤其是当我们手头的任务包含很少的数据或领域特定(例如医学或运动特定任务)时,我们可以在HUB上使用其他任务训练的模型(不一定与我们的任务相同的任务 手但属于相同领域,例如运动或药物),并利用一些验证的知识来提高我们模型在我们自己任务的性能表现。

  • 一个非常简单的例子是,如果说我们有一个小数据集,比如分类某些财务报表是积极还是负面的。 但是,我们进入了HF,发现许多模型已经经过与金融相关的问答数据集的训练,那么 我们可以使用这些模型的某些层来改进自己的任务。
  • 另一个简单的示例是,某个特定领域的模型经过巨大数据集的训练学会了将文本从中分为5个类别。 假设我们有类似的分类任务,在同一域中的一个完全不同的数据集,只想将数据分类为2个类别而不是5。 这时我们也可以复用模型主体,添加自己的模型头来增强我们自己任务的特定领域知识。

这就是我们要做的事情的示意图:

2、自定义HF模型头

我们的任务是简单的,从Kaggle上的这个数据集进行讽刺检测。

你可以在此处查看完整的代码。 为了时间的考虑,我没有在下面包括预处理和一些训练的详细信息,因此请确保查看整个代码的笔记本。

我将使用一个在大量推文上训练的模型,有5个分类输出不同的情感类型。我们将提取模型体,在pytorch中添加自定义层(2个标签,讽刺/不讽刺),并训练新的模型。

注意:你可以在此示例中使用任何模型(不一定是对分类训练的模型),因为我们只会使用该模型主体并拆除模型头。

这就是我们的工作流程:

我将跳过数据预处理步骤,然后直接跳到主类,但是你可以在本节开头的链接中查看整个代码。

3、令牌化和动态填充

使用如下代码将文本转化为令牌并进行动态填充:

checkpoint = "cardiffnlp/twitter-roberta-base-emotion"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.model_max_len=512

def tokenize(batch):
  return tokenizer(batch["headline"], truncation=True,max_length=512)

tokenized_dataset = data.map(tokenize, batched=True)
print(tokenized_dataset)

tokenized_dataset.set_format("torch",columns=["input_ids", "attention_mask", "label"])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

结果如下:

DatasetDict({
    train: Dataset({
        features: ['headline', 'label', 'input_ids', 'attention_mask'],
        num_rows: 22802
    })
    test: Dataset({
        features: ['headline', 'label', 'input_ids', 'attention_mask'],
        num_rows: 2851
    })
    valid: Dataset({
        features: ['headline', 'label', 'input_ids', 'attention_mask'],
        num_rows: 2850
    })
})

4、提取模型体并添加我们自己的层

代码如下:

class CustomModel(nn.Module):
  def __init__(self,checkpoint,num_labels): 
    super(CustomModel,self).__init__() 
    self.num_labels = num_labels 

    #Load Model with given checkpoint and extract its body
    self.model = model = AutoModel.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=True,output_hidden_states=True))
    self.dropout = nn.Dropout(0.1) 
    self.classifier = nn.Linear(768,num_labels) # load and initialize weights

  def forward(self, input_ids=None, attention_mask=None,labels=None):
    #Extract outputs from the body
    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

    #Add custom layers
    sequence_output = self.dropout(outputs[0]) #outputs[0]=last hidden state

    logits = self.classifier(sequence_output[:,0,:].view(-1,768)) # calculate losses
    
    loss = None
    if labels is not None:
      loss_fct = nn.CrossEntropyLoss()
      loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    
    return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)

如你所见,我们首先是继承Pytorch中的 nn.Module,使用AutoModel(来自transformers库)提取加载了指定检查点的模型主体。

请注意, forward() 方法返回 TokenClassifierOutput,从而确保我们输出的格式与HF预训练模型一致。

5、端到端训练新的模型

代码如下:

from tqdm.auto import tqdm

progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epochs * len(eval_dataloader)))


for epoch in range(num_epochs):
  model.train()
  for batch in train_dataloader:
      batch = {k: v.to(device) for k, v in batch.items()}
      outputs = model(**batch)
      loss = outputs.loss
      loss.backward()

      optimizer.step()
      lr_scheduler.step()
      optimizer.zero_grad()
      progress_bar_train.update(1)

  model.eval()
  for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])
    progress_bar_eval.update(1)
    
  print(metric.compute())
  model.eval()

test_dataloader = DataLoader(
    tokenized_dataset["test"], batch_size=32, collate_fn=data_collator
)

for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

结果如下:

  0%|          | 0/2139 [00:00<?, ?it/s]
  0%|          | 0/270 [00:00<?, ?it/s]
{'f1': 0.9335347432024169}
{'f1': 0.9360090874668686}
{'f1': 0.9274912756882513}

如你所见,我们使用此方法实现了不错的性能。 请记住,该博客的目的不是分析此特定数据集的性能,而是要学习如何使用预训练的身体并添加自定义头。

6、结束语

在本文中,我们看到了如何在HF预训练模型上添加自定义层。

一些收获:

  • 在我们拥有特定于域的数据集并希望利用在同一域(任务 - 努力的task-agnostic)上训练的模型以增强小型数据集中的性能的情况下,此技术特别有用。
  • 我们可以选择接受过与自己任务不同的下游任务训练的模型,并且仍然使用该模型主体的知识。
  • 如果你的数据集足够大且通用,那么这可能根本不需要,在这种情况下,你可以使用 AutoModeForSequenceCecrification或使用 BERT 解决的任何其他任务。 实际上,如果是这样,我强烈建议不要建立自己的模型头。

原文链接:HF自定义模型头 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1198533.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023数据安全战场回顾:迅软科技助您稳固阵线

随着各行业的数字化转型不断深入&#xff0c;数据安全逐步进入法制化的强监管时代。然而&#xff0c;由于人为攻击、技术漏洞和监管缺位等原因&#xff0c;各种数据泄露事件频繁发生&#xff0c;企业数据安全威胁日益严峻。 以下是我对2023年第三季度安全事件的总结&#xff0c…

Maven Profile组设置

application.properties中xxxx

JS实现数据结构与算法

队列 1、普通队列 利用数组push和shif 就可以简单实现 2、利用链表的方式实现队列 class MyQueue {constructor(){this.head nullthis.tail nullthis.length 0}add(value){let node {value}if(this.length 0){this.head nodethis.tail node}else{this.tail.next no…

hosts文件地址

Hosts是一个没有扩展名的系统文件&#xff0c;可以用记事本等工具打开&#xff0c;其作用就是将一些常用的网址域名与其对应的IP地址建立一个关联“数据库”&#xff0c;当用户在浏览器中输入一个需要登录的网址时&#xff0c;系统会首先自动从Hosts文件中寻找对应的IP地址&…

Typescript -尚硅谷

基础 1.ts是以js为基础构建的语言&#xff0c;是一个js的超集(对js进行了扩展)&#xff1b; 2.ts(type)最主要的功能是在js的基础上引入了类型的概念; Js的类型是只针对于值而言&#xff0c;ts的类型是针对于变量而言 Ts可以被编译成任意版本的js&#xff0c;从而进一步解决了…

企业邮箱本地私有化部署解决方案

随着互联网化进程不断深入&#xff0c;加快推进企业信息化系统建设&#xff0c;已经成为提高企业核心竞争力的重要途径。企业对企业邮箱系统的需求越来越大&#xff0c;企业邮箱系统作为企业级通讯工具中的利器&#xff0c;在协同办公和内外业务交流上发挥着无可替代的巨大作用…

LLM代码生成器的挑战【GDELT早期观察】

越来越多的研究开始对LLM大模型生成的代码的质量提出质疑&#xff0c;尽管科技行业不断推出越来越多的旨在增强甚至取代人类编码员的工具。 随着我们&#xff08;GDELT&#xff09;继续探索和评估越来越多的此类工具&#xff0c;以下是我们的一些早期观察结果。 在线工具推荐&a…

将复数中的虚部取反 即对复数求共轭 numpy.conjugate()

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 将复数中的虚部取反 即对复数求共轭 numpy.conjugate() [太阳]选择题 请问以下代码中执行语句输出结果是&#xff1f; import numpy as np a np.array([1 2j, 3 - 4j]) print("【显示…

Linux学习教程(第二章 Linux系统安装)1

第二章 Linux系统安装 学习 Linux&#xff0c;首先要学会搭建 Linux 系统环境&#xff0c;也就是学会在你的电脑上安装 Linux 系统。 很多初学者对 Linux 望而生畏&#xff0c;多数是因为对 Linux 系统安装的恐惧&#xff0c;害怕破坏电脑本身的系统&#xff0c;害怕硬盘数据…

第二十五节——Vuex--历史遗留

文档地址 Vuex 是什么&#xff1f; | Vuex version V4.x 一、概念 Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式 库。它采用集中式存储管理应用的所有组件的状态&#xff0c;并以相应的规则保证状态以一种可预测的方式发生变化。一个状态自管理应用包含以下几个部…

大数据Doris(二十一):数据导入演示

文章目录 数据导入演示 一、启动zookeeper集群(三台节点都启动) 二、启动hdfs集群

计算机网络——物理层-传输方式(串行传输、并行传输,同步传输、异步传输,单工、半双工和全双工通信)

目录 串行传输和并行传输 同步传输和异步传输 单工、半双工和全双工通信 串行传输和并行传输 串行传输是指数据是一个比特一个比特依次发送的。因此在发送端和接收端之间&#xff0c;只需要一条数据传输线路即可。 并行传输是指一次发送n个比特&#xff0c;而不是一个比特&…

图论13-最小生成树-Kruskal算法+Prim算法

文章目录 1 最小生成树2 最小生成树Kruskal算法的实现2.1 算法思想2.2 算法实现2.2.1 如果图不联通&#xff0c;直接返回空&#xff0c;该图没有mst2.2.2 获得图中的所有边&#xff0c;并且进行排序2.2.2.1 Edge类要实现Comparable接口&#xff0c;并重写compareTo方法 2.2.3 取…

强化学习中蒙特卡罗方法

一、蒙特卡洛方法 这里将介绍一个学习方法和发现最优策略的方法&#xff0c;用于估计价值函数。与前文不同&#xff0c;这里我们不假设完全了解环境。蒙特卡罗方法只需要经验——来自实际或模拟与环境的交互的样本序列的状态、动作和奖励。从实际经验中学习是引人注目的&#x…

如何使用pngPackerGUI_V2.0,将png图片打包成plist的工具

pngPackerGUI_V2.0&#xff0c;此软件是在pngpacker_V1.1软件基础之后&#xff0c;开发的界面化操作软件&#xff0c;方便不太懂命令行的小白快捷上手使用。 具体的使用步骤如下&#xff1a; 1.下载并解压缩软件&#xff0c;得到如下目录&#xff0c;双击打开 pngPackerGUI.e…

SFTP远程终端访问

远程终端访问 当服务器部署好以后&#xff0c;除了直接在服务器上操作&#xff0c;还可以通过网络进行远程连接访问CentOS 7默认支持SSH(Secure Shell, 安全Shell 协议),该协议通过高强度的加密算法提高了数据在网络传输中的安全性&#xff0c;可有效防止中间人攻击(Man-in-th…

AI绘画神器DALLE 3的解码器:一步生成的扩散模型之Consistency Models

前言 关于为何写此文&#xff0c;说来同样话长啊&#xff0c;历程如下 我司LLM项目团队于23年11月份在给一些B端客户做文生图的应用时&#xff0c;对比了各种同类工具&#xff0c;发现DALLE 3确实强&#xff0c;加之也要在论文100课上讲DALLE三代的三篇论文&#xff0c;故此文…

web:[网鼎杯 2018]Fakebook

题目 点进页面&#xff0c;页面显示为 查看源代码 用dirsearch扫一下&#xff0c;看一下有什么敏感信息泄露 扫出另一个flag.php和robots.txt&#xff0c;访问flag.php回显内容为空 请求robots.txt 网页提示/user.php.bak&#xff0c;直接访问会自动下载.bak备份文件 进行代码…

Flink在汽车行业的应用【面试加分系列】

很多同学问我为什么要发这些大数据前沿汇报&#xff1f; 一方面是自己学习完后觉得非常好&#xff0c;然后总结发出来方便大家阅读&#xff1b;另外一方面&#xff0c;看这些汇报对你的面试帮助会很大&#xff0c;特别是面试前可以看看即将面试公司在大数据前沿的发展动向&…

DevOps平台两种实现模式

我们需要一个DevOps平台 要讨论DevOps平台的实现模式&#xff0c;似乎就必须讨论它们的概念定义。然而&#xff0c;当大家要讨论它们的定义时&#xff0c;就像在讨论薛定谔的猫。 A公司认为它不过是自动化执行Shell脚本的平台&#xff0c;有些人认为它是一场运动&#xff0c;另…