Bert文本分类和命名实体的模型架构剖析

news2024/11/26 19:58:17

文章目录

    • 介绍
    • Bert模型架构
    • 损失计算方式
      • BertForSequenceClassification
      • BertForTokenClassification
    • Bert 输出结果剖析
      • 例子
    • 参考资料

介绍

文本分类:给一句文本分类;
实体识别:从一句文本中,识别出其中的实体;

做命名实体识别,有2种方式:

  1. 基于Bert-lstm-crf 的Token分类;
  2. 生成式的从序列到序列的文本生成方法。比如:T5、UIE、大模型等;

如果你想体验完整命名实体识别教程请浏览:Huggingface Token 分类官方教程:https://huggingface.co/learn/nlp-course/zh-CN/chapter7/2

若实体识别采取Token分类的做法:
 那么文本分类是给一整句话做分类,实体识别是给一整句话中的每个词做分类。从本质上看,两者都是分类任务;

import torch
from transformers import (
    AutoTokenizer, 
    AutoModel,
    AutoModelForSequenceClassification,
    BertForSequenceClassification,
    AutoModelForTokenClassification,
)

Bert模型架构

基本的Bert模型结构:

model_name = "bert-base-chinese"
bert = AutoModel.from_pretrained(model_name)
bert

Output:

...
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

文本分类模型:

seq_cls_model = AutoModelForSequenceClassification.from_pretrained(model_name)
seq_cls_model

Output:

...
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
)

实体识别 token 分类:

token_cls_model = AutoModelForTokenClassification.from_pretrained(model_name)
token_cls_model

Output:

...
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
)

经过观察AutoModelForSequenceClassificationAutoModelForTokenClassification的模型架构一模一样,即分类与实体识别的模型架构一模一样。两者都是在基础的Bert模型尾部添加 dropoutclassifier层。

Q:它们是一模一样Bert模型架构,为何能实现不同的任务?
A:因为它们选取Bert输出不同,损失值计算也不同。

损失计算方式

BertForSequenceClassification

from transformers import BertForSequenceClassification

按住 Ctrl + 鼠标左键,查看源码

forward 函数中可以查看到loss的计算方式。

outputs = self.bert(
    input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids,
    position_ids=position_ids,
    head_mask=head_mask,
    inputs_embeds=inputs_embeds,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict,
)

pooled_output = outputs[1]

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

BertForSequenceClassification 使用 pooled_output = outputs[1]

BertForTokenClassification

from transformers import BertForTokenClassification

forward 函数中可以查看到loss的计算方式。

...
outputs = self.bert(
    input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids,
    position_ids=position_ids,
    head_mask=head_mask,
    inputs_embeds=inputs_embeds,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict,
)

sequence_output = outputs[0]

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

BertForTokenClassification 使用 sequence_output = outputs[0]

Bert 输出结果剖析

下述是BertModel的输出结果,既可以使用字典访问,也可以通过下标访问:

BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )

outputs[0] 是last_hidden_state, outputs[1]是 pooler_output。

last_hidden_state 是输入到Bert模型每一个token的状态,pooler_output[CLS]的last_hidden_state经过pooler处理得到的状态。

在这里插入图片描述

在图片上,用红色字标出了 last_hidden_state 和 pooler_output 在模型架构的位置。

例子

接下来使用一个例子帮助各位读者深入理解Bert输出结果中的last_hidden_statepooler_output的区别。

from transformers import (
    AutoTokenizer, 
    # BertModel,
    AutoModel,
    DataCollatorForTokenClassification
)
model_name = "bert-base-chinese"
tokenizer = AutoTokenizer.from_pretrained(model_name)
seq_cls_model = AutoModel.from_pretrained(model_name)

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
batch = data_collator([
                        tokenizer("今天天气真好,咱们出去放风筝吧!"),
                        tokenizer("起风了,还是在家待着吧!"),
                        ])
for k, v in batch.items():
    print(k, v.shape)

Output:

input_ids torch.Size([2, 18])
token_type_ids torch.Size([2, 18])
attention_mask torch.Size([2, 18])

Bert 模型推理

output = bert(**batch)
print(
    torch.equal(output[0], output["last_hidden_state"]),
    torch.equal(output[1], output["pooler_output"])
)
last_hidden_state = output["last_hidden_state"]
pooler_output = output["pooler_output"]

Output:

True True
# output[0] == output["last_hidden_state"] 为真
# 这意味着Bert的输出,既可以用下标访问,也可以用字典的键访问
print(
    f"last_hidden_state.shape: {last_hidden_state.shape}",
    f"pooler_output.shape: {pooler_output.shape}"
)

Output:

last_hidden_state.shape: torch.Size([2, 18, 768]) pooler_output.shape: torch.Size([2, 768])

仅仅看它们的shape,也能看出它们的区别

last_hidden_state:包括每一个token的状态;(所以用来做命名实体识别)
pooler_output:只有[CLS]的状态;([CLS]的输出向量被认为是整个序列的聚合表示,故用于分类任务。)

# [CLS] 在第一个token的位置,通过下标获取 `[CLS]`的tensor,再经过pooler处理
# 判断其与output["pooler_output"]是否相等
CLS_tensor = last_hidden_state[:,0,:].reshape(2, 1, -1)
torch.equal(
    bert.pooler(CLS_tensor),
    pooler_output
)

Output:

True

输出为True,这验证了 [CLS]的tensor经过pooler层后,便是output[“pooler_output”]。

参考资料

  • Huggingface Token 分类官方教程:https://huggingface.co/learn/nlp-course/zh-CN/chapter7/2 若你想使用Bert做命名实体识别,非常推荐浏览这篇官方教程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1952436.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

万界星空科技灯具行业MES系统:点亮生产管理的未来

在快速迭代的灯具行业中,高效、精准的生产管理是企业保持竞争力的关键。万界星空科技推出的灯具行业MES(制造执行系统)系统,以其强大的功能和完善的管理体系,正成为众多灯具生产企业的首选解决方案。本文将重点介绍万界…

构建高并发Web服务:Gunicorn与Flask在Docker中的完美融合

1. 引言 在数字化时代,Web服务的性能和可靠性对于任何在线业务的成功至关重要。随着用户基数的增长和业务需求的扩展,高并发处理能力成为了衡量一个Web服务质量的关键指标。高并发Web服务不仅能够确保用户体验的流畅性,还能在流量激增时保持…

抖音矩阵管理系统开发:全面解析与推荐

在数字时代,短视频平台如抖音已经成为人们生活中不可或缺的一部分。随着内容创作者数量的激增,如何高效地管理多个抖音账号,实现内容矩阵化运营,成为了众多创作者关注的焦点。今天,我们就来全面解析抖音矩阵管理系统的…

Android 生成Excel并导出全流程

前言 最近接到需求,要在安卓上离线完成根据数据生成Excel文件,但搜到了都不是能立马使用 例如 // implementation org.apache.poi:poi:3.17 // implementation com.alibaba:easyexcel:4.0.1 这两最大的问题是专用于java的,如果And…

【SpringBoot】7 数据库(MySQLMyBatis)

MySQL 前提:本地有安装 MySQL 。 连接 使用工具 Navicat Premium ,或者 IDEA 自带的 DB 工具,或者其他能连接 MySQL 数据库的工具都可以。 1)创建 MySQL Data Source 2)根据本地配置连接上 MySQL,点击…

移动UI:排行榜单页面如何设计,从这五点入手,附示例。

移动UI的排行榜单页面设计需要考虑以下几个方面: 1. 页面布局: 排行榜单页面的布局应该清晰明了,可以采用列表的形式展示排行榜内容,同时考虑到移动设备的屏幕大小,应该设计合理的滚动和分页机制,确保用户…

Android 软键盘挡住输入框

Android原生输入法软键盘挡住输入框,网上各种解法,但不起效。 输入框都是被挡住了,第二张图的小点,实际就是输入法的光标。 解法: packages\inputmethods\LatinIME\java\res\values-land config.xml <!-- <fraction name="config_min_keyboard_height"&g…

2024年国际高校数学建模大赛(IMMCHE)问题A:金字塔石的运输成品文章分享(仅供学习)

2024 International Mathematics Molding Contest for Higher Education Problem A: Transportation of Pyramid Stones&#xff08;2024年国际高校数学建模大赛&#xff08;IMMCHE&#xff09;问题A&#xff1a;金字塔石的运输&#xff09; 古埃及金字塔石材运输优化模型研究…

【单片机毕业设计选题24084】-基于嵌入式的16位AD采集系统设计

系统功能: 系统上电后显示“欢迎使用数模转换系统请稍后”后两秒后进入正常显示。 第一行显示ADS1115第一通道采集到的电压值 第二行显示ADS1115第二通道采集到的电压值 第一行显示ADS1115第三通道采集到的电压值 第二行显示ADS1115第四通道采集到的电压值 手动调节四个电…

【产品应用】一体化伺服电机在AGV小车中的应用

随着自动化技术的快速发展&#xff0c;自动引导车&#xff08;AGV&#xff0c;Automated Guided Vehicle&#xff09;在物流、仓储和生产等领域的应用日益广泛。 作为智能物流体系中的重要设备&#xff0c;AGV小车通过先进的控制技术、传感器技术和导航系统&#xff0c;实现了…

潜水通信定位系统的功能概述_鼎跃安全

水域救援是一项极具挑战性的救援行动&#xff0c;其特点鲜明&#xff0c;集突发性、时间敏感性、技术精密性、难度系数高及潜在危险性之大成。这类救援任务往往要求在极短的时间内迅速响应&#xff0c;面对复杂多变的水域环境&#xff0c;救援人员必须具备高超的专业技能和冷静…

23万一张的天价卡牌,如何撑起一个港股IPO?

23万&#xff0c;可以买到什么&#xff1f; 是拿下一辆涨价后的宝马i3&#xff1f;还是在三线城市全款盘下一套房&#xff1f;又或是来一次环球旅行&#xff1f;这些都已经过时了&#xff0c;对于现在的年轻人来说&#xff0c;他们或许会选择拿这些钱去二手市场&#xff0c;收…

pycharm关闭项目时,页面卡住了,怎么办?

问题 在关闭pycharm时&#xff0c;有时会遇到卡在退出进度条的界面&#xff0c;很讨厌&#xff0c;那我们要怎么办才能退出呢&#xff1f; 说明&#xff1a;本篇文章不是从根源上解决这个问题&#xff0c;无法避免这种情况。 解决方法 方法一&#xff1a; 在卡住时&#xf…

Golang | Leetcode Golang题解之第287题寻找重复数

题目&#xff1a; 题解&#xff1a; func findDuplicate(nums []int) int {slow, fast : 0, 0for slow, fast nums[slow], nums[nums[fast]]; slow ! fast; slow, fast nums[slow], nums[nums[fast]] { }slow 0for slow ! fast {slow nums[slow]fast nums[fast]}return s…

STM32-寄存器时钟配置指南

目录 启动 SystemInit SetSysClock 总结 启动 从startup_stm32f0xx.s内的开头的Description可以看到 ;* Description : STM32F051 devices vector table for EWARM toolchain. ;* This module performs: ;* - Set the in…

【中项第三版】系统集成项目管理工程师 | 第 11 章 规划过程组⑤ | 11.13 - 11.14

前言 第11章对应的内容选择题和案例分析都会进行考查&#xff0c;这一章节属于10大管理的内容&#xff0c;学习要以教材为准。本章上午题分值预计在15分。 目录 11.13 制定预算 11.13.1 主要输入 11.13.2 主要输出 11.14 规划质量管理 11.14.1 主要输入 11.14.2 主要工…

MySQL查询优化 limit 100000,10加载很慢该怎么优化

需求&#xff1a;查询19年以后发布的商品 数据库表结构如下&#xff1a; 目前数据量&#xff1a;264751 优化前执行时间&#xff1a;0.790s 优化后执行时间&#xff1a;0.467s select id,no,title,cart_title,cid_name from tb_item where id > (select id from tb_item …

Redis 缓存

安装 安装 Redis 下载&#xff1a; Releases tporadowski/redis (github.com) winr ----services.msc-----将redis 设置为手动(只是学习&#xff0c;如果经常用可以设置为自动) 安装 redis-py 库 pip install redis-py Redis 和 StrictRedis redis-py 提供 Redis 和 Str…

记忆注意力用于多模态情感计算!

记忆注意力用于多模态情感计算&#xff01; 目录 情感计算 一、概述 二、研究背景 三、模型结构和代码 六、数据集介绍 七、性能展示 八、复现过程 九、运行过程 模型总结 本文所涉及所有资源均在传知代码平台可获取。 情感计算 近年来&#xff0c;社交媒体的快速扩张推动了用户…

跨境电商独立站:Shopify/Wordpress/店匠选哪个?

在面对不断增加的平台运营压力时&#xff0c;不少跨境电商的商家逐渐将注意力转向建立自己的独立站。据《中国跨境出口电商发展报告&#xff08;2022&#xff09;》所示&#xff0c;中国拥有的独立站数量在2022年已接近20万个&#xff0c;这表明独立站已成为卖家拓展海外市场的…