深度学习--自动化打标签

news2024/10/3 21:30:58

通过大模型(如羊驼模型)进行自动化打标签的方案,可以按照以下步骤来实现:

方案概要

  1. 初始标注:人工标记一部分数据,作为训练/验证集。
  2. 微调模型:将部分人工标注的数据用于微调大模型,如羊驼模型。
  3. 模型预测:将未标注数据抛给大模型,让其生成自动标签。
  4. 标签对比与迭代:通过与人工标注的标签对比,评估模型性能,直到模型与人工标注的标签非常接近为止。可以使用 F1-score、准确率等指标来衡量。
  5. 大规模标注:使用模型对剩余的大规模数据进行自动化标签预测。

方案详细步骤

  1. 数据准备

    将原始数据分为三部分:训练集(已标注,用于微调)、验证集(已标注,用于评估模型)和 未标注数据集(用于预测)。
  2. 初始模型微调

    通过 训练集 对羊驼模型进行微调,调整模型的参数,使其学习数据的标签分布。
  3. 模型预测与标签生成

    使用微调后的模型对 验证集 进行预测,并与人工标注进行对比,计算性能指标。
  4. 迭代训练

    如果模型的预测结果和人工标签的差异较大,则继续调整模型,直到模型的预测结果与人工标注结果非常接近。
  5. 大规模自动打标签

    当模型的预测结果稳定并达到较高的准确性后,将剩余的未标注数据抛给模型进行打标签。

代码实现(以羊驼模型为例)

Step 1: 数据准备

首先,你需要准备数据,将部分数据进行人工标注并分为训练集和验证集。

import pandas as pd
from sklearn.model_selection import train_test_split

# 假设我们有一个包含文本和标签的数据集
data = pd.read_csv("labeled_data.csv")  # 已经人工标记的部分数据

# 将数据划分为训练集、验证集和未标注的数据
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

# 未标注的数据集
unlabeled_data = pd.read_csv("unlabeled_data.csv")
Step 2: 微调羊驼模型

这里假设你已经有了羊驼模型的本地版本。我们可以使用 Hugging Face 的 transformers 库对羊驼模型进行微调。

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
import torch

# 加载羊驼模型和标记器
model_name = "decapoda-research/llama-7b-hf"  # 以羊驼模型为例
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 将数据预处理为模型输入格式
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

# 准备训练数据集
train_texts = train_data["text"].tolist()
train_labels = train_data["label"].tolist()

# 准备验证数据集
val_texts = val_data["text"].tolist()
val_labels = val_data["label"].tolist()

# 转换为 PyTorch 的 Dataset 格式
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        item = self.texts[idx]
        label = self.labels[idx]
        encodings = tokenizer(item, truncation=True, padding="max_length", max_length=512)
        encodings['labels'] = torch.tensor(label)
        return {key: torch.tensor(val) for key, val in encodings.items()}

train_dataset = TextDataset(train_texts, train_labels)
val_dataset = TextDataset(val_texts, val_labels)

# 设置训练参数
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir='./logs',
    learning_rate=5e-5,
)

# 创建Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

# 开始微调
trainer.train()
Step 3: 对验证集进行预测并与人工标签对比
from sklearn.metrics import accuracy_score, f1_score

# 模型预测
predictions = trainer.predict(val_dataset)
pred_labels = torch.argmax(predictions.predictions, axis=-1)

# 计算准确率和 F1-score
accuracy = accuracy_score(val_labels, pred_labels)
f1 = f1_score(val_labels, pred_labels, average="weighted")

print(f"验证集上的准确率: {accuracy}")
print(f"验证集上的 F1-score: {f1}")
Step 4: 对未标注数据集进行打标签

在验证模型效果后,如果模型表现良好,可以对未标注的数据进行自动打标签。

unlabeled_texts = unlabeled_data["text"].tolist()

# 生成未标注数据的预测标签
unlabeled_dataset = TextDataset(unlabeled_texts, [0]*len(unlabeled_texts))  # 此处的label只是占位
predictions = trainer.predict(unlabeled_dataset)
pred_labels = torch.argmax(predictions.predictions, axis=-1)

# 将打的标签添加到未标注数据中
unlabeled_data["predicted_label"] = pred_labels.numpy()

# 保存打好标签的数据
unlabeled_data.to_csv("labeled_unlabeled_data.csv", index=False)
Step 5: 迭代过程

如果验证集上的标签与人工标注的差距仍较大,可以调整训练超参数或微调模型,直到模型的性能稳定。

总结

  1. 数据准备:人工标记一部分数据,并分成训练集和验证集。
  2. 模型微调:通过羊驼模型微调,让模型学习数据的标签分布。
  3. 预测与对比:模型在验证集上进行预测,并与人工标签对比,使用准确率和 F1-score 等指标评估模型性能。
  4. 自动打标签:模型通过自动化方式对未标注数据进行打标签。

通过这种方法,能够有效利用大模型进行大规模数据的标签生成,同时减少人工标注的成本和工作量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2186892.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

flutter_鸿蒙next(win)环境搭建

第一步 拉取鸿蒙版本flutterSDK仓库 仓库地址:OpenHarmony-SIG/flutter_flutter 第二步 找到拉取的仓库中的README.md 并根据说明配置环境 第三步 配置好环境变量之后 用管理员开启cmd 输入:flutter dcotor 并查看此时flutter所支持的系统 包括&…

Cpp::STL—string类的模拟实现(12)

文章目录 前言一、string类各函数接口总览二、默认构造函数string(const char* str "");string(const string& str);传统拷贝写法现代拷贝写法 string& operator(const string& str);传统赋值构造现代赋值构造 ~string(); 三、迭代器相关函数begin &…

leetcode打卡001-约瑟夫问题

约瑟夫问题 其背景故事是关于一组人站成一个圈,从某个人开始报数,每数到特定数字的人将被淘汰出圈,然后从被淘汰人的下一个人重新开始报数,直到最后剩下一个人。问题的目标是确定最后剩下的那个人在最初的位置。 关键词 递归&a…

HCIP-HarmonyOS Application Developer 习题(四)

1、以下哪个Harmonyos的AI能力可以提供文档翻拍过程中的辅助增强功能? A.文档检测矫正 B.通用文字识别 C.分词 D.图像超分辨率 答案:A 分析:文档校正提供了文档翻拍过程的辅助增强功能,包含两个子功能: 文档检测:能够…

基于单片机人体反应速度测试仪系统

** 文章目录 前言概要设计思路 软件设计效果图 程序文章目录 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师,一名热衷于单片机技术探索与分享的博主、专注于 精通51/STM32/MSP430/AVR等单片机设计 主要对象是咱们…

kubernetes基础操作(pod生命周期)

pod生命周期 一、Pod生命周期 我们一般将pod对象从创建至终的这段时间范围称为pod的生命周期,它主要包含下面的过程: ◎pod创建过程 ◎运行初始化容器(init container)过程 ◎运行主容器(main container&#xff…

【Redis入门到精通九】Redis中的主从复制

目录 主从复制 1.配置主从复制 2.主从复制中的拓扑结构 3.主从复制原理 4.主从复制总结 主从复制 在分布式系统中为了解决单点问题,通常会把数据复制多个副本部署到其他服务器,满⾜故障恢复和负载均衡等需求。Redis 也是如此,它为我们提…

kafka基本概念以及用法

kafka基本概念以及用法目录 文章目录 kafka基本概念以及用法目录一、什么是kafka?二、为什么要使用kafka?三、kafka的基本概念四、安装kafka(windows版本)五、命令行控制kafka生产消费数据,创建 删除topic六、java操作kafka消费生产 提示:以…

Ubuntu操作系统版本服务支持时间(更新到24.04)

文章参考链接 以下是解释: 开发代号:Ubuntu的每个版本都有一个开发代号,例如“Mantic Minotaur”。 版本命名:Ubuntu的版本号是根据发布年份和月份来命名的。例如,Ubuntu 23.10是在2023年10月发布的。 LTS版本&…

Windows 11 24H2正式发布

微软最近正式发布了Windows 11 24H2,这是Windows 11的最新功能更新,带来了多项新特性和改进。 主要新功能: 人工智能增强:此次更新特别强调AI能力,推出了如Windows Copilot的增强版本。Copilot的界面得到了改善&#…

【微服务】注册中心 - Eureka(day3)

CAP理论 P是分区容错性。简单来说,分区容错性表示分布式服务中一个节点挂掉了,并不影响其他节点对外提供服务。也就是一台服务器出错了,仍然可以对外进行响应,不会因为某一台服务器出错而导致所有的请求都无法响应。综上所述&…

关于Mybatis框架操作时注意的细节,常见的错误!(博主亲生体会的细节!)

目录 1.在对DB进行CRUD时,除了查,其余的操作都要进行事务的提交否则不成功。 2.用sqlSession原生方法时,第一个参数方法名,是xml文件中定义的id名,底层找的是你这个接口所定义的方法名。 3.以包为单位引入映射文件 …

第三节-类与对象(2)默认成员函数详解

1.类的6个默认成员函数 如果一个类中什么成员都没有,简称为空类(空类大小为1)。 空类中真的什么都没有吗?并不是,任何类在什么都不写时,编译器会自动生成以下6个默认成员函数。 默认成员函数:…

DOM树(下) -- 第八课

文章目录 前言一、DOM属性操作1. 获取属性值2. 设置属性值3. 移除属性值 二、节点1.什么是节点?2. 节点层级1. 获取父级节点2. 获取兄弟节点3. 获取子节点 3. 节点操作1. 创建节点2. 添加和删除节点 三、事件进阶1. 注册事件1. 传统方式2. 监听方式 2. 删除事件3. 事件流 四、…

第4篇:MSSQL日志分析----应急响应之日志分析篇

常见的数据库攻击包括弱口令、SQL注入、提升权限、窃取备份等。对数据库日志进行分析,可以发现攻击行为,进一步还原攻击场景及追溯攻击源。 0x01 MSSQL日志分析 首先,MSSQL数据库应启用日志记录功能,默认配置仅限失败的登录&…

Veritus netbackup 管理控制台无法连接:未知错误

节假日停电,netbackup服务器意外停机后重新开机,使用netbackup管理控制台无法连接,提示未知错误。 ssh连接到服务器,操作系统正常,那应该是应用有问题,先试一下重启服务器看看。重新正常关机,重…

【Ubuntu】使用阿里云apt源来更新apt源

1.前言 我在京东云买了一个云服务器,但是我第一次使用apt的时候,发现遇到了下面这些情况 后面听老师讲,还需要执行下面这个 但是我再次使用apt下载软件的时候,还是出现了下面这个情况 后面问了老师才知道是apt源的问题&#x…

解决Github打不开或速度慢的问题

一、原因 我们先分析一下Github在国内访问慢或有时候登陆不上去的问题原因:其实这都是因为我们访问github官网时是直接访问域名即github.com,那么中间有个域名通过DNS解析的过程,将域名解析为对应的ip地址,其实主要时间都是花在了…

【寻找one piece的算法之路】——双指针算法!他与她是否会相遇呢?

💐个人主页:初晴~ 📚相关专栏:寻找one piece的刷题之路 什么是双指针算法 双指针算法是一种常用的编程技巧,尤其在处理数组和字符串问题时非常有效。这种方法的核心思想是使用两个指针来遍历数据结构,这两…

学习记录:js算法(五十二):验证二叉搜索树

文章目录 验证二叉搜索树我的思路网上思路 总结 验证二叉搜索树 给你一个二叉树的根节点 root ,判断其是否是一个有效的二叉搜索树。 有效 二叉搜索树定义如下: 节点的 左子树 只包含 小于 当前节点的数。 节点的 右子树 只包含 大于 当前节点的数。 所有…