中文完形填空

news2025/1/12 10:09:15

本文通过ChnSentiCorp数据集介绍了完型填空任务过程,主要使用预训练语言模型bert-base-chinese直接在测试集上进行测试,也简要介绍了模型训练流程,不过最后没有保存训练好的模型。

一.完形填空
完形填空应该大家都比较熟悉,就是把句子中的词挖掉,根据上下文推测挖掉的词是什么。

二.准备数据集
本文使用ChnSentiCorp数据集,不清楚的可以参考中文情感分类介绍。一些样例如下所示:

本文做法为将每句话截断为固定的30个词,同时将第15个词替换为[MASK],模型任务为根据上下文预测第15个词。

1.使用编码工具

def load_encode_tool(pretrained_model_name_or_path):
    token = BertTokenizer.from_pretrained(Path(f'{pretrained_model_name_or_path}'))
    return token
if __name__ == '__main__':
  # 测试编码工具
  pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'
  token = load_encode_tool(pretrained_model_name_or_path)
  print(token)

输出结果如下所示:

BertTokenizer(name_or_path='L:\20230713_HuggingFaceModel\bert-base-chinese', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

测试编码句子如下所示:

if __name__ == '__main__':
    # 测试编码工具
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'
    token = load_encode_tool(pretrained_model_name_or_path)
    # 测试编码句子
    out = token.batch_encode_plus(
        batch_text_or_text_pairs=[('不是一切大树,', '都被风暴折断。'),('不是一切种子,', '都找不到生根的土壤。')],
        truncation=True,
        padding='max_length',
        max_length=18,
        return_tensors='pt',
        return_length=True, # 返回长度
    )
    # 查看编码输出
    for k, v in out.items():
        print(k, v.shape)
    print(token.decode(out['input_ids'][0]))
    print(token.decode(out['input_ids'][1]))

结果输出如下所示:

input_ids torch.Size([2, 18])
token_type_ids torch.Size([2, 18])
length torch.Size([2])
attention_mask torch.Size([2, 18])
[CLS] 不 是 一 切 大 树 , [SEP] 都 被 风 暴 折 断 。 [SEP] [PAD]
[CLS] 不 是 一 切 种 子 , [SEP] 都 找 不 到 生 根 的 土 [SEP]

第1个句子长度为17,补了1个[PAD],第2个句子长度为18。return_length=True表示返回句子真实长度,即不包括[PAD]填充部分长度。如下所示:

编码结果如下所示:

2.定义数据集

def load_dataset_from_disk():
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\ChnSentiCorp'
    dataset = load_from_disk(pretrained_model_name_or_path)
    # batched=True表示批量处理
    # batch_size=1000表示每次处理1000个样本
    # num_proc=8表示使用8个线程操作
    # remove_columns=['text']表示移除text列
    dataset = dataset.map(f1, batched=True, batch_size=1000, num_proc=8, remove_columns=['text', 'label'])
    return dataset
if __name__ == '__main__':
    # 加载数据集
    dataset = load_dataset_from_disk()
    print(dataset)

结果输出如下所示:

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'length'],
        num_rows: 9600
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'length'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'length'],
        num_rows: 1200
    })
})

3.定义计算设备

# 定义计算设备
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
# print(device)

4.定义数据整理函数
本质是将每个句子第15个词替换为[MASK],同时将第15个词作为标签,即根据上下文要预测的词。如下所示:

# 数据整理函数
def collate_fn(data):
    # 取出编码结果
    input_ids = [i['input_ids'] for i in data]
    attention_mask = [i['attention_mask'] for i in data]
    token_type_ids = [i['token_type_ids'] for i in data]
    # 转换为Tensor格式
    input_ids = torch.LongTensor(input_ids)
    attention_mask = torch.LongTensor(attention_mask)
    token_type_ids = torch.LongTensor(token_type_ids)
    # 把第15个词替换为MASK
    labels = input_ids[:, 15].reshape(-1).clone()
    input_ids[:, 15] = token.get_vocab()[token.mask_token]
    # 移动到计算设备
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    token_type_ids = token_type_ids.to(device)
    labels = labels.to(device)
    return input_ids, attention_mask, token_type_ids, labels

5.定义数据集加载器

# 数据集加载器
loader = torch.utils.data.DataLoader(dataset=dataset['train'], batch_size=16, collate_fn=collate_fn, shuffle=True, drop_last=True)
print(len(loader)) #600=9600/16

查看样例数据如下所示:

# 查看数据样例
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break
print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels)

输出结果如下所示:

torch.Size([16, 30])
torch.Size([16, 30])
torch.Size([16, 30])
tensor([4638, 8024, 3198, 6206, 6392, 4761, 3449, 2128, 3341,  119, 3315, 2697,
        2523, 2769, 6814, 1086], device='cuda:0')

三.定义模型
1.加载预训练模型
加载模型并移动到device(CPU或GPU)中,如下所示:

pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'
# 加载预训练模型
pretrained = BertModel.from_pretrained(Path(f'{pretrained_model_name_or_path}'))
pretrained.to(device) 

2.定义下游任务模型
下游任务模型将BERT提取第15个词的特征(16×768),输入到全连接神经网络(768×21128),得到16×21128,即把第15个词的特征投影到全体词表空间中,还原为词典中的某个词。

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = torch.nn.Linear(in_features=768, out_features=token.vocab_size, bias=False)
        # 重新将decode中的bias参数初始化为全o
        self.bias = torch.nn.Parameter(data=torch.zeros(token.vocab_size))
        self.decoder.bias = self.bias
        # 定义 Dropout层,防止过拟合
        self.Dropout = torch.nn.Dropout(p=0.5)

    def forward(self, input_ids, attention_mask, token_type_ids):
        # 使用预训练模型抽取数据特征
        with torch.no_grad():
            out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # 把第15个词的特征投影到全字典范围内
        out = self.Dropout(out.last_hidden_state[:, 15])
        out = self.decoder(out)
        return out

四.训练和测试
1.训练
定义了AdamW优化器、loss损失函数(交叉熵损失函数)和线性学习率调节器,如下所示:

def train():
    # 定义优化器
    optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=1.0)
    # 定义1oss函数
    criterion = torch.nn.CrossEntropyLoss()
    # 定义学习率调节器
    scheduler = get_scheduler(name='linear', num_warmup_steps=0, num_training_steps=len(loader) * 5, optimizer=optimizer)
    # 将模型切换到训练模式
    model.train()
    # 共训练5个epoch
    for epoch in range(5):
        # 按批次遍历训练集中的数据
        for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
            # 模型计算
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
            # 计算loss并使用梯度下降法优化模型参数
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            # 输出各项数据的情况,便于观察
            if i % 50 == 0:
                out = out.argmax(dim=1)
                accuracy = (out == labels).sum().item() / len(labels)
                lr = optimizer.state_dict()['param_groups'][0]['lr']
                print(epoch, 1, loss.item(), lr, accuracy)

输出部分结果如下所示:

0 1 10.123428344726562 0.0004998333333333334 0.0
0 1 8.659417152404785 0.0004915 0.0625
0 1 7.431852340698242 0.0004831666666666667 0.0625
0 1 7.261701583862305 0.00047483333333333335 0.0625
0 1 6.693362236022949 0.0004665 0.125
0 1 4.0811614990234375 0.00045816666666666667 0.375
0 1 7.034963607788086 0.00044983333333333334 0.1875

2.测试
使用测试数据集进行测试,如下所示:

def test():
    # 定义测试数据集加载器
    loader_test = torch.utils.data.DataLoader(dataset=dataset['test'],  batch_size=32, collate_fn=collate_fn, shuffle=True, drop_last=True)
    # 将下游任务模型切换到运行模式
    model.eval()
    correct = 0
    total = 0
    # 按批次遍历测试集中的数据
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        # 计算15个批次即可,不需要全部遍历
        if i == 15:
            break
        print(i)
        # 计算
        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # 统计正确率
        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)
    print(correct / total)

参考文献:
[1]HuggingFace自然语言处理详解:基于BERT中文模型的任务实战
[2]https://github.com/ai408/nlp-engineering/blob/main/20230625_HuggingFace自然语言处理详解/第8章:完形填空.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/961559.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文件上传漏洞-upload靶场5-12关

文件上传漏洞-upload靶场5-12关通关笔记(windows环境漏洞) 简介 ​ 在前两篇文章中,已经说了分析上传漏的思路,在本篇文章中,将带领大家熟悉winodws系统存在的一些上传漏洞。 upload 第五关 (大小写绕过…

如何用VMware虚拟机连上Xshell

目录 前言废话1.1设置虚拟机设置1.2 设置虚拟网络编辑器方法一:方法二: 1.3 配置静态IP地址1.4 Xshell连接虚拟机2.1 解决可能出现的一些问题2.1.1 虚拟机Ping不通网络2.1.2 我可以Ping通百度了,但是宿主机和虚拟机互相Ping不通。2.1.3 更离谱…

【8 排序】简单选择排序。

顺序表&#xff1a; void Swap(int &a,int &b){int temp;tempa;ab;btemp; } void SelectSort(int A[],int n){int min,i,j;for(i0;i<n-1;i){mini;for(ji1;j<n;j)if(A[j]<A[min])minj;if(min!i)Swap(A[i],A[min]);} } 单链表&#xff1a; void SelectSort…

【leetcode 力扣刷题】数学题之除法:哈希表解决商的循环节➕快速乘求解商

两道和除法相关的力扣题目 166. 分数到小数29. 两数相除快速乘解法一&#xff1a;快速乘变种解法二&#xff1a; 二分查找 快速乘 166. 分数到小数 题目链接&#xff1a;166. 分数到小数 题目内容&#xff1a; 题目是要求我们把一个分数变成一个小数&#xff0c;并以字符串的…

go锁-waitgroup

如果被等待的协程没了&#xff0c;直接返回 否则&#xff0c;waiter加一&#xff0c;陷入sema add counter 被等待协程没做完&#xff0c;或者没人在等待&#xff0c;返回 被等待协程都做完&#xff0c;且有人在等待&#xff0c;唤醒所有sema中的协程 WaitGroup实现了一组协程…

【MySQL】基础语法总结

MySQL 基础语句 一、DDL 数据库定义语言 1.1CREATE 创建 1.1.1 创建数据库 语法结构 CREATE DATABASE database_name;示例 CREATE DATABASE demo;1.1.2 创建表 语法结构 CREATE TABLE 表名 (列1 数据类型,列2 数据类型,... );示例 CREATE TABLE new_user (id INT PRIMARY KE…

python爬虫数据解析xpath

一、环境配置 1、安装xpath 下载地址&#xff1a;百度网盘 请输入提取码 第一步&#xff1a; 下载好文件后会得到一个没有扩展名的文件&#xff0c;重命名该文件将其改为.rar或者.zip等压缩文件&#xff0c;解压之后会得到一个.crx文件和一个.pem文件。新建一个文件夹&…

AI工人操作行为流程规范识别算法

AI工人操作行为流程规范识别算法通过yolov7python网络模型框架&#xff0c;AI工人操作行为流程规范识别算法对作业人员的操作行为进行实时分析&#xff0c;根据设定算法规则判断操作行为是否符合作业标准规定的SOP流程。Yolo意思是You Only Look Once&#xff0c;它并没有真正的…

怎样免费在公司访问家中的树莓派

最近拿起了大学时买的树莓派&#xff0c;刚好看到了一篇文章写到无公网IP&#xff0c;从公网SSH远程访问家中的树莓派 便来试试&#xff1a; 我的树莓派之前装过ssh&#xff0c;所以插上电就能用了。其实过程很简单&#xff0c;只需要在树莓派中下载一个cpolar即可。 curl -…

CSS3常用的新功能总结

CSS3常用的新功能包括圆角、阴渐变、2D变换、3D旋转、动画、viewpor和媒体查询。 圆角、阴影 border-redius 对一个元素实现圆角效果&#xff0c;是通过border-redius完成的。属性为两种方式&#xff1a; 一个属性值&#xff0c;表示设置所有四个角的半径为相同值&#xff…

UE5 实现Niagara粒子特效拖尾效果

文章目录 前言实现效果闪现示例疾跑示例实现新建Niagara系统应用Niagara系统实现拖尾效果应用拖尾颜色前言 本文采用虚幻5.2.1版本,对角色粒子特效拖尾效果进行讲解,从零开始,来实现此效果。此效果可以在角色使用某一技能时触发,比如使用闪现、疾跑等等。 实现效果 闪现示…

深入剖析 Golang 程序启动原理 - 从 ELF 入口点到GMP初始化到执行 main!

大家好&#xff0c;我是飞哥&#xff01; 在过去的开发工作中&#xff0c;大家都是通过创建进程或者线程来工作的。Linux进程是如何创建出来的&#xff1f; 、聊聊Linux中线程和进程的联系与区别&#xff01; 和你的新进程是如何被内核调度执行到的&#xff1f; 这几篇文章就是…

每日一题(链表中倒数第k个节点)

每日一题&#xff08;链表中倒数第k个节点&#xff09; 链表中倒数第k个结点_牛客网 (nowcoder.com) 思路: 如下图所示&#xff1a;此题仍然定义两个指针&#xff0c;fast指针和slow指针&#xff0c;假设链表的长度是5&#xff0c;k是3&#xff0c;那么倒数第3个节点就是值为…

解决WebSocket通信:前端拿不到最后一条数据的问题

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

最新智能AI系统ChatGPT网站程序源码+详细图文搭建教程/支持GPT4/WEB-H5端+微信公众号版源码

一、AI系统 如何搭建部署AI创作ChatGPT系统呢&#xff1f;小编这里写一个详细图文教程吧&#xff01;SparkAi使用Nestjs和Vue3框架技术&#xff0c;持续集成AI能力到AIGC系统&#xff01; 1.1 程序核心功能 程序已支持ChatGPT3.5/GPT-4提问、AI绘画、Midjourney绘画&#xf…

MySQL高阶语句(三)

一、NULL值 在 SQL 语句使用过程中&#xff0c;经常会碰到 NULL 这几个字符。通常使用 NULL 来表示缺失 的值&#xff0c;也就是在表中该字段是没有值的。如果在创建表时&#xff0c;限制某些字段不为空&#xff0c;则可以使用 NOT NULL 关键字&#xff0c;不使用则默认可以为空…

Vue中过滤器如何使用?

过滤器是对即将显示的数据做进⼀步的筛选处理&#xff0c;然后进⾏显示&#xff0c;值得注意的是过滤器并没有改变原来 的数据&#xff0c;只是在原数据的基础上产⽣新的数据。过滤器分全局过滤器和本地过滤器&#xff08;局部过滤器&#xff09;。 目录 全局过滤器 本地过滤器…

Python之父加入微软三年后,Python嵌入Excel!

近日&#xff0c;微软传发布消息&#xff0c;Python被嵌入Excel&#xff0c;从此Excel里可以平民化地进行机器学习了。只要直接在单元格里输入“PY”&#xff0c;回车&#xff0c;调出Python&#xff0c;马上可以轻松实现数据清理、预测分析、可视化等等等等任务&#xff0c;甚…

好马配好鞍:Linux Kernel 4.12 正式发布

Linus Torvalds 在内核邮件列表上宣布释出 Linux 4.12&#xff0c;Linux 4.12 的主要特性包括&#xff1a; BFQ 和 Kyber block I/O 调度器&#xff0c;livepatch 改用混合一致性模型&#xff0c;信任的执行环境框架&#xff0c;epoll 加入 busy poll 支持等等&#xff0c;其它…

从零开始,探索C语言中的字符串

字符串 1. 前言2. 预备知识2.1 字符2.2 字符数组 3. 什么是字符串4. \04.1 \0是什么4.2 \0的作用4.2.1 打印字符串4.2.2 求字符串长度 1. 前言 大家好&#xff0c;我是努力学习游泳的鱼。你已经学会了如何使用变量和常量&#xff0c;也知道了字符的概念。但是你可能还不了解由…