如何使用uer做多分类任务

news2024/11/16 4:53:45

如何使用uer做多分类任务

语料集下载
在这里插入图片描述
找到这里点击即可
里面是这有json文件的
在这里插入图片描述
因此我们对此要做一些处理,将其转为tsv格式

# -*- coding: utf-8 -*-
import json
import csv
import chardet

# 检测文件编码
def detect_encoding(file_path):
    with open(file_path, 'rb') as f:
        raw_data = f.read()
    return chardet.detect(raw_data)['encoding']

# 输入文件名
input_file = './datasets/iflytek/train.json'
# 输出文件名
output_file = './datasets/iflytek/train.tsv'

# 检测输入文件的编码格式
file_encoding = detect_encoding(input_file)

# 打开输入的 JSON 文件和输出的 TSV 文件
with open(input_file, 'r', encoding=file_encoding) as json_file, open(output_file, 'w', newline='', encoding='utf-8') as tsv_file:
    # 准备 TSV 写入器
    tsv_writer = csv.writer(tsv_file, delimiter='\t')

    # 写入表头(列表['label', 'label_des', 'sentence']中要注意根据json文件中的键值做更换)
    tsv_writer.writerow(['label', 'label_des', 'sentence'])

    # 逐行读取 JSON 文件
    for line in json_file:
        try:
            # 解析每一行的 JSON 数据
            json_data = json.loads(line.strip())
            # 写入到 TSV 文件中,(列表['label', 'label_des', 'sentence']中要注意根据json文件中的键值做更换)
            tsv_writer.writerow([json_data['label'], json_data['label_des'], json_data['sentence']])
        except json.JSONDecodeError as e:
            print(f"无法解析的行: {line.strip()}")
            print(f"错误信息: {e}")

print(f"JSON 文件已成功转换为 TSV 文件,输入文件编码: {file_encoding}")

接着呢要把所有tsv文件的sentence表头名改成text_a,不然运行uer框架会报错,原因请看源代码逻辑

def read_dataset(args, path):
    dataset, columns = [], {}
    with open(path, mode="r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id == 0:
                for i, column_name in enumerate(line.rstrip("\r\n").split("\t")):
                    columns[column_name] = i
                continue
            line = line.rstrip("\r\n").split("\t")
            tgt = int(line[columns["label"]])
            if args.soft_targets and "logits" in columns.keys():
                soft_tgt = [float(value) for value in line[columns["logits"]].split(" ")]
            if "text_b" not in columns:  # Sentence classification.
                text_a = line[columns["text_a"]]
                src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN])
                seg = [1] * len(src)
            else:  # Sentence-pair classification.
                text_a, text_b = line[columns["text_a"]], line[columns["text_b"]]
                src_a = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN])
                src_b = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(text_b) + [SEP_TOKEN])
                src = src_a + src_b
                seg = [1] * len(src_a) + [2] * len(src_b)

            if len(src) > args.seq_length:
                src = src[: args.seq_length]
                seg = seg[: args.seq_length]
            if len(src) < args.seq_length:
                PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0]
                src += [PAD_ID] * (args.seq_length - len(src))
                seg += [0] * (args.seq_length - len(seg))
            if args.soft_targets and "logits" in columns.keys():
                dataset.append((src, tgt, seg, soft_tgt))
            else:
                dataset.append((src, tgt, seg))

    return dataset

这里规定好了表头名只有label,text_a,text_b
搞完之后进入训练代码,我的显存只有16G,因此

python finetune/run_classifier.py --pretrained_model_path models/cluecorpussmall_roberta_wwm_large_seq512_model.bin --vocab_path models/google_zh_vocab.txt --config_path models/bert/large_config.json --train_path datasets/iflytek/train.tsv --dev_path datasets/iflytek/dev.tsv --output_model_path models/iflytek_classifier_model.bin --epochs_num 3 --batch_size 16 --seq_length 128

在这里插入图片描述
在这里插入图片描述
这里可以看到只有61.49的正确率,其实是因为显存还不够,训练不了那么大的,标准的参数应该设置为batch_size=32 seq_length=256
有能力的可以更改参数进行训练
接着来预测

python inference/run_classifier_infer.py --load_model_path models/iflytek_classifier_model.bin --vocab_path models/google_zh_vocab.txt --config_path models/bert/large_config.json --test_path datasets/iflytek/test.tsv --prediction_path datasets/iflytek/prediction.tsv --seq_length 256 --labels_num 119

在这里插入图片描述
最后自行查看预测效果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1901131.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OceanBase Meetup北京站|跨行业应用场景中的一体化分布式数据库:AI赋能下的探索与实践

随着业务规模的不断扩张和数据处理需求的日益复杂化&#xff0c;传统数据库架构逐渐暴露出业务稳定性波动、扩展性受限、处理效率降低以及运营成本高等一系列问题。众多行业及其业务场景纷纷踏上了数据库现代化升级之路。 为应对这些挑战&#xff0c;7月6日&#xff0c;OceanB…

在Linux环境下搭建Redis服务结合内网穿透实现通过GUI工具远程管理数据库

文章目录 前言1. 安装Docker步骤2. 使用docker拉取redis镜像3. 启动redis容器4. 本地连接测试4.1 安装redis图形化界面工具4.2 使用RDM连接测试 5. 公网远程访问本地redis5.1 内网穿透工具安装5.2 创建远程连接公网地址5.3 使用固定TCP地址远程访问 前言 本文主要介绍如何在Li…

1.Python学习笔记

一、环境配置 1.Python解释器 把程序员用编程语言编写的程序&#xff0c;翻译成计算机可以执行的机器语言 安装&#xff1a; 双击Python3.7.0-选择自定义安装【Customize installation】-勾选配置环境变量 如果没有勾选配置环境变量&#xff0c;输入python就会提示找不到命令…

第四届BPAA算法大赛成功举办!共研算法未来

大家好&#xff0c;我是herosunly。985院校硕士毕业&#xff0c;现担任算法研究员一职&#xff0c;热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名&#xff0c;CCF比赛第二名&#xff0c;科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的…

vue学习笔记(购物车小案例)

用一个简单的购物车demo来回顾一下其中需要注意的细节。 先看一下最终效果 功能&#xff1a; &#xff08;1&#xff09;全选按钮和下面的商品项的选中状态同步&#xff0c;当下面的商品全部选中时&#xff0c;全选勾选&#xff0c;反之&#xff0c;则不勾选。 &#xff08…

最短路:Dijkstra

原始模板&#xff1a; 时间复杂度O() 使用于图很满的情况 struct Node{int y,v;Node(int _y,int _v){y_y;v_v;} };vector<Node> edge[N1]; int n,m,dist[N1]; bool b[N1];int Dijistra(int s,int t){memset(b,false,sizeof(b));memset(dist,127,sizeof(dist));dist[s]…

Windows 11文件资源管理器选项卡的4个高级用法,肯定有你喜欢的

作为一个每天使用文件资源管理器来管理我的工作流程的人,选项卡帮助我为处于不同完成阶段的工作创建了不同的文件夹。以下是我使用选项卡提高工作效率的最佳技巧。 打开和关闭选项卡 假设你的计算机上安装了Windows 11的最新更新,请按Ctrl+E打开文件资源管理器。在我发现“…

软件工程学面向对象

一、面向对象方法学概述 传统的生命周期方法学在消除软件非结构化、促进软件开发工程化方面起了积极的作用&#xff0c;但仍有许多不足&#xff0c;存在的主要问题有&#xff1a;①生产率提高的幅度不能满足需要&#xff1b; ②软件重用程度很低&#xff1b; ③软件很难维护&a…

Canary,三种优雅姿势绕过

Canary&#xff08;金丝雀&#xff09;&#xff0c;栈溢出保护 canary保护是防止栈溢出的一种措施&#xff0c;其在调用函数时&#xff0c;在栈帧的上方放入一个随机值 &#xff0c;绕过canary时首先需要泄漏这个随机值&#xff0c;然后再钩爪ROP链时将其作为垃圾数据写入&…

Python数据分析案例50——基于EEMD-LSTM的石油价格预测

案例背景 很久没更新时间序列预测有关的东西了。 之前写了很多CNN-LSTM&#xff0c;GRU-attention&#xff0c;这种神经网络之内的不同模型的缝合&#xff0c;现在写一个模态分解算法和神经网络的缝合。 虽然eemd-lstm已经在学术界被做烂了&#xff0c;但是还是很多新手小白或…

Go 中的类型推断

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

HDF4文件转TIF格式

HDF4 HDF4&#xff08;Hierarchical Data Format version 4&#xff09;是一种用于存储和管理机器间数据的库和多功能文件格式。它是一种自描述的文件格式&#xff0c;用于存档和管理数据。 HDF4与HDF5是两种截然不同的技术&#xff0c;HDF5解决了HDF4的一些重要缺陷。因此&am…

【文献解析】Voxelmap——一种自适应体素地图

Efficient and Probabilistic Adaptive Voxel Mapping for Accurate Online LiDAR Odometry 论文地址&#xff1a;https://ieeexplore.ieee.org/stamp/stamp.jsp?tp&arnumber9813516 代码&#xff1a;GitHub - hku-mars/VoxelMap: [RA-L 2022] An efficient and probabili…

DatawhaleAI夏令营2024 Task2

#AI夏令营 #Datawhale #夏令营 赛题解析一、Baseline详解1.1 环境配置1.2 数据处理任务理解2.3 prompt设计2.4 数据抽取 二、完整代码总结 赛题解析 赛事背景 在数字化时代&#xff0c;企业积累了大量对话数据&#xff0c;这些数据不仅是交流记录&#xff0c;还隐藏着宝贵的信…

2000-2022年地级市数字经济指数(含控制变量)

2000-2022年地级市数字经济指数&#xff08;含控制变量&#xff09; 目录 数字经济对区域经济发展的影响实证研究 一、引言 二、文献综述 三、数据来源与变量说明 四、实证模型 五、程序代码与运行结果 数字经济对区域经济发展的影响实证研究 摘要&#xff1a; 本文旨在…

Git-Unity项目版本管理

目录 准备GitHub新建项目并添加ssh密钥Unity文件夹 本文记录如何用git对unity 项目进行版本管理&#xff0c;并可传至GitHub远端。 准备 名称版本windows11Unity2202.3.9.f1gitN.A.githubN.A. GitHub新建项目并添加ssh密钥 GitHub新建一个repositorywindows11 生成ssh-key&…

数字信号处理中的难点

数字信号处理中的难点可以归纳为多个方面&#xff0c;这些难点不仅体现在理论知识的理解和掌握上&#xff0c;还涉及到实际工程应用中的各种问题。以下是对这些难点的详细分析&#xff1a; 一、理论知识的难点 信号与系统的基本概念&#xff1a; 理解和区分连续时间信号与离…

数字时代如果你的企业还未上线B端系统助力则后果很严重

**数字时代如果你的企业还未上线B端系统助力则后果很严重** 数字化浪潮席卷全球&#xff0c;企业对于数字化转型的重视程度日益提高。B端系统&#xff0c;作为企业数字化转型的核心组成部分&#xff0c;其重要性不言而喻。如果你的企业还未上线B端系统助力&#xff0c;那么后果…

3-3 超参数

3-3 超参数 什么是超参数 超参数也是一种参数&#xff0c;它具有参数的特性&#xff0c;比如未知&#xff0c;也就是它不是一个已知常量。是一种手工可配置的设置&#xff0c;需要为它根据已有或现有的经验&#xff0c;指定“正确”的值&#xff0c;也就是人为为它设定一个值&…

美国国家航空航天局(NASA)的载人登月计划:阿耳忒弥斯计划

本文首发于公众号“AntDream”&#xff0c;欢迎微信搜索“AntDream”或扫描文章底部二维码关注&#xff0c;和我一起每天进步一点点 Artemis计划是美国国家航空航天局&#xff08;NASA&#xff09;主导的一项雄心勃勃的月球探索计划&#xff0c;旨在2020年代重新将人类送上月球…