中文句子关系推断

news2025/1/16 12:44:35

本文通过ChnSentiCorp数据集介绍了中文句子关系推断任务过程,主要使用预训练语言模型bert-base-chinese直接在测试集上进行测试,也简要介绍了模型训练流程,不过最后没有保存训练好的模型。

一.任务简介和数据集
通过模型来判断2个句子是否连续,使用ChnSentiCorp数据集,不清楚的可以参考中文情感分类介绍。句子关系推断数据集样例如下所示:

二.准备数据集
1.使用编码工具
编码工具详细介绍可参考使用编码工具。如下所示:

def load_encode_tool(pretrained_model_name_or_path):
    token = BertTokenizer.from_pretrained(Path(f'{pretrained_model_name_or_path}'))
    return token
if __name__ == '__main__':
    # 测试编码工具
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'
    token = load_encode_tool(pretrained_model_name_or_path)
    print(token)
    # 测试编码句子
    out = token.batch_encode_plus(
        batch_text_or_text_pairs=[('不是一切大树,', '都被风暴折断。'),('不是一切种子,', '都找不到生根的土壤。')],
        truncation=True,
        padding='max_length',
        max_length=18,
        return_tensors='pt',
        return_length=True, # 返回长度
    )
    # 查看编码输出
    for k, v in out.items():
        print(k, v.shape)
    print(token.decode(out['input_ids'][0]))
    print(token.decode(out['input_ids'][1]))

输出结果如下所示:

BertTokenizer(name_or_path='L:\20230713_HuggingFaceModel\bert-base-chinese', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)
input_ids torch.Size([2, 18])
token_type_ids torch.Size([2, 18])
length torch.Size([2])
attention_mask torch.Size([2, 18])
[CLS] 不 是 一 切 大 树 , [SEP] 都 被 风 暴 折 断 。 [SEP] [PAD]
[CLS] 不 是 一 切 种 子 , [SEP] 都 找 不 到 生 根 的 土 [SEP]

编码结果如下所示:

2.定义数据集
首先在__init__()中加载ChnSentiCorp数据集,然后过滤掉小于40个字的句子。在__getitem__()中将一句话分隔为各20个字的两句话,并且有50%概率把后半句替换为无关的句子,这样就构建完成本文所需要数据集。

class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\ChnSentiCorp'
        dataset = load_from_disk(pretrained_model_name_or_path)[split]
        # 过滤长度大于40的句子
        self.dataset = dataset.filter(lambda data: len(data['text']) > 40)
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, i):
        text = self.dataset[i]['text']
        # 将一句话切分为前半句和后半句
        sentence1 = text[:20]
        sentence2 = text[20:40]
        # 随机整数,取值为0和1
        label = random.randint(0, 1)
        # 有一半概率把后半句替换为无关的句子
        if label == 1:
            j = random.randint(0, len(self.dataset) - 1) # 随机取出一句话
            sentence2 = self.dataset[j]['text'][20:40] # 取出后半句
        return sentence1, sentence2, label # 返回前半句、后半句和标签
if __name__ == '__main__':
    # 加载数据集
    dataset = Dataset('train')
    sentence1, sentence2, label = dataset[7]
    print(len(dataset), sentence1, sentence2, label)

输出结果如下所示:

8001
地理位置佳,在市中心。酒店服务好、早餐品
种丰富。我住的商务数码房电脑宽带速度满意
0

其中,8001表示训练数据集,每条训练数据包括两句话和一个标签,标签表示这两句话是否有关。

3.定义计算设备

# 定义计算设备
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

4.定义数据整理函数
collate_fn(data)函数中的data表示一个batch的数据,每条记录包括句子对和标签,该函数功能为编码一个batch文本数据。

def collate_fn(data):
    sents = [i[:2] for i in data]
    labels = [i[2] for i in data]
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents, # 输入句子对
                                   truncation=True, # 截断
                                   padding='max_length', # [PAD]
                                   max_length=45, # 最大长度
                                   return_tensors='pt', # 返回pytorch张量
                                   return_length=True, # 返回长度
                                   add_special_tokens=True) # 添加特殊符号
    # input_ids:编码之后的数字
    # attention_mask:补零的位置是0, 其他位置是1
    # token_type_ids:第1个句子和特殊符号的位置是0, 第2个句子的位置是1
    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    token_type_ids = data['token_type_ids'].to(device)
    labels = torch.LongTensor(labels).to(device)
    return input_ids, attention_mask, token_type_ids, labels

输入参数data数据样例如下所示:

data = [('酒店还是非常的不错,我预定的是套间,服务', '非常好,随叫随到,结账非常快。',
0),
('外观很漂亮,性价比感觉还不错,功能简', '单,适合出差携带。蓝牙摄像头都有了。',
0),
('《穆斯林的葬礼》我已闻名很久,只是一直没', '怎能享受4星的服务,连空调都不能
用的。', 1)]

5.定义数据集加载器

# 数据集加载器
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=8, collate_fn=collate_fn, shuffle=True, drop_last=True)
# print(len(loader))

三.定义模型
1.加载预训练模型

pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\bert-base-chinese'
pretrained = BertModel.from_pretrained(Path(f'{pretrained_model_name_or_path}'))
pretrained.to(device)

2.定义下游任务模型
下游任务模型为线性神经网络,权重矩阵为768×2,即把一个768维度的向量转换到2维空间(相关或无关)。其中,out.last_hidden_state[:, 0, :]表示只使用了第1个词([CLS])的特征用于相关或无关的判断。如下所示:

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        # 使用预训练模型抽取数据特征
        with torch.no_grad():
            out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # 对抽取的特征只取第1个字的结果进行分类即可
        out = self.fc(out.last_hidden_state[:, 0, :])
        out = out.softmax(dim=1)
        return out #torch.Size([16, 2])

四.训练和测试
1.训练

def train():
    # 定义优化器
    optimizer = AdamW(model.parameters(), lr=5e-5)
    # 定义1oss函数
    criterion = torch.nn.CrossEntropyLoss()
    # 定义学习率调节器
    scheduler = get_scheduler(name='linear', num_warmup_steps=0, num_training_steps=len(loader), optimizer=optimizer)
    # 将模型切换到训练模式
    model.train()
    # 按批次遍历训练集中的数据
    for epoch in range(5):
        # 按批次遍历训练集中的数据
        for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
            # 模型计算
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
            # 计算loss并使用梯度下降法优化模型参数
            loss = criterion(out, labels)
            loss.backward() # 反向传播
            optimizer.step() # 梯度下降法优化模型参数
            scheduler.step() # 学习率调节器
            optimizer.zero_grad() # 清空梯度
            # 输出各项数据的情况,便于观察
            if i % 20 == 0:
                out = out.argmax(dim=1) # 取出最大值的索引
                accuracy = (out == labels).sum().item() / len(labels) # 计算准确率
                lr = optimizer.state_dict()['param_groups'][0]['lr'] # 获取当前学习率
                print(epoch, 1, loss.item(), lr, accuracy)

2.测试

def test():
    # 定义测试数据集加载器
    dataset = Dataset('test')
    loader_test = torch.utils.data.DataLoader(dataset=dataset,  batch_size=32, collate_fn=collate_fn, shuffle=True, drop_last=True)
    # 将下游任务模型切换到运行模式
    model.eval()
    correct = 0
    total = 0
    # 按批次遍历测试集中的数据
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        # 计算5个批次即可,不需要全部遍历
        if i == 5:
            break
        print(i)
        # 计算
        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # 统计正确率
        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)
    print(correct / total)

参考文献:
[1]HuggingFace自然语言处理详解:基于BERT中文模型的任务实战
[2]代码链接:https://github.com/ai408/nlp-engineering/blob/main/20230625_HuggingFace自然语言处理详解/第9章:中文句子关系推断.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/960611.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Gazebo仿真环境下的强化学习实现

Gazebo仿真环境下的强化学习实现 主体源码参照《Goal-Driven Autonomous Exploration Through Deep Reinforcement Learning》 文章目录 Gazebo仿真环境下的强化学习实现1. 源码拉取2. 强化学习实现2.1 环境2.2 动作空间2.3 状态空间2.4 奖励空间2.5 TD3训练 3. 总结 1. 源码…

【LeetCode算法系列题解】第41~45题

CONTENTS LeetCode 41. 缺失的第一个正数(困难)LeetCode 42. 接雨水(困难)LeetCode 43. 字符串相乘(中等)LeetCode 44. 通配符匹配(困难)LeetCode 45. 跳跃游戏 II(中等&…

字符和字符串的库函数模拟与实现

前言: 相信大家平常在写代码的时候,用代码解决实际问题时苦于某种功能的实现,而望而止步,这个时候库函数的好处就体现出来了,当然个人代码编写能力强的可以自己创建一个函数,不过相当于库函数来说却是浪费了…

SSM 基于注解的整合实现

一、pom.xml <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"><modelV…

C++的基类和派生类构造函数

基类的成员函数可以被继承&#xff0c;可以通过派生类的对象访问&#xff0c;但这仅仅指的是普通的成员函数&#xff0c;类的构造函数不能被继承。构造函数不能被继承是有道理的&#xff0c;因为即使继承了&#xff0c;它的名字和派生类的名字也不一样&#xff0c;不能成为派生…

CSC2121A

半桥架构的栅极驱动电路CSC2121A CSC2121系列是一款高性价比的半桥架构的栅极驱动专用电路&#xff0c;用于大功率MOS管、IGBT管栅极驱动。IC内部集成了逻辑信号处理电路、死区时间控制电路、欠压保护电路、电平位移电路、脉冲滤波电路及输出驱动电路&#xff0c;专用于无刷电…

ClickHouse配置Hdfs存储数据

文章目录 背景配置单机配置HA高可用Hdfs集群参考文档 背景 由于公司初始使用Hadoop这一套&#xff0c;所以希望ClickHouse也能使用Hdfs作为存储 看了下ClickHouse的文档&#xff0c;拿Hdfs举例来说&#xff0c;有两种方式来完成&#xff0c;一种是直接关联Hdfs上的数据文件&am…

【python爬虫案例】用python爬豆瓣音乐TOP250排行榜!

文章目录 一、爬虫对象-豆瓣音乐TOP250二、python爬虫代码讲解三、同步视频四、获取完整源码 一、爬虫对象-豆瓣音乐TOP250 您好&#xff0c;我是 马哥python说 &#xff0c;一名10年程序猿。 今天我们分享一期python爬虫案例讲解。爬取对象是&#xff0c;豆瓣音乐TOP250排行…

汇报下卢松松自媒体收入情况

我是卢松松&#xff0c;点点上面的头像&#xff0c;欢迎关注我哦&#xff01; 9月1日了&#xff0c;2023年还剩最后4个月了&#xff0c;怎么样?梦想是不是越来越远了? 我看很多人都在涌入做自媒体行业&#xff0c;那今天卢松松就给大家汇报下近期卢松松自媒体帐号的收益情况…

2023开学礼《乡村振兴战略下传统村落文化旅游设计》南京农大许少辉八一新书

2023开学礼《乡村振兴战略下传统村落文化旅游设计》南京农大许少辉八一新书

MVC、MVP、MVVM的成本角度结合业务,如何考虑选型?一文了解方方面面

大家都知道&#xff0c;使用架构的目的是使程序模块化&#xff0c;做到模块内部的高聚合和模块之间的低耦合&#xff0c;使得程序在开发的过程中&#xff0c;开发人员只需要专注于一点&#xff0c;提高程序开发的效率。那么MVC、MVP、MVVM&#xff0c;该怎么选&#xff1f;在什…

百万级并发IM即时消息系统(2)

1.用户model type UserBasic struct {gorm.ModelName stringPassWord stringPhone string valid:"matches(^1[3-9]{1}\\d{9}$)"Email string valid:"email"Avatar string //头像Identity stringClientIp s…

Java注解和反射

注解(Java.Annotation) 什么是注解&#xff08;Annotation&#xff09;&#xff1f; Annotation是从JDK5.0开始引入的新技术 Annotation的作用: 不是程序本身&#xff0c;可以对程序作出解释(这一点和注释(comment)没什么区别)可以被其他程序(比如:编译器等)读取Annotation的…

解决uniapp下拉框 内容被覆盖的问题

1. 下拉框 内容被覆盖的问题 场景: 现在是下拉框被表格覆盖了 解决办法: 在表格上添加css 样式来解决这个问题 .add-table{display: static;overflow: visible; } display: static: 将元素会按照默认的布局方式进行显示&#xff0c;不会分为块状或行内元素。 overflow: vi…

【力扣每日一题】2023.9.1 买钢笔和铅笔的方案数

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 题目给我们三个数&#xff0c;一个是我们拥有的钱&#xff0c;一个是钢笔的价格&#xff0c;另一个是铅笔的价格。 问我们一共有几种买笔…

需求管理系统大盘点,谁能问鼎行业王者宝座?

“需求管理系统有哪些&#xff1f;这些品牌引领行业新风潮&#xff1a;Zoho Projects、SAP SuccessFactors、Oracle NetSuite、Microsoft Dynamics 365、Infor CloudSuite、JDA Software。” 需求管理系统是一种专门用于收集、分析和跟踪客户需求的工具&#xff0c;可以帮助企业…

9.1 消息 字体 颜色 文件对话框 发布软件

保存 void Widget::on_savebtn_clicked() {QString filename QFileDialog::getSaveFileName(this, "保存", "C:/Users/yc/Desktop/", "图片 (*.png *.xpm *.jpg);;文本 (*.txt);;所有文件 (*.*)");if(filename.isNull()){QMessageBox::informa…

砍价活动制作秘籍,打造抢购风潮

砍价活动作为一种吸引用户参与、提高销售量的营销手段&#xff0c;已经成为了电商行业的热门选择。在如今竞争激烈的市场环境下&#xff0c;如何制作出成功的砍价活动&#xff0c;成为了每位电商从业者亟需解决的问题。在本文中&#xff0c;我们将为大家揭秘一种制作成功砍价活…

十分钟学会三种管理交换机的方法

一、eNSP脚本配置 拓扑 cloud配置 配置两个ip地址&#xff0c;可以相同 测试通信 二、Console 登录方式 Console 就是用串口线连上去直接可以访问 比如 Please Press ENTER.<Huawei>sy Enter system view, return user view with CtrlZ. [Huawei]dis cu # sysname Hu…

CSC7203S 应用注意事项

CSC7203S 为高性能电流模式 PWM 开关电源功率转换器&#xff0c;满足绿色环保标准&#xff1b;广泛适用于经济型开关电源&#xff0c;如 DVD、机顶盒、传真机、打印机、LCD 显示器等。CSC7203S采用SOP-8封装。  内置 700V 高压功率开关管  输入电压&#xff08;85V~265V&a…