[oneAPI] 使用Bert进行中文文本分类

news2024/11/23 20:58:05

[oneAPI] 使用Bert进行中文文本分类

  • Intel® Optimization for PyTorch
  • 基于BERT的文本分类模型
    • 数据预处理
    • 数据集
      • 定义tokenize
      • 建立词表
      • 转换为Token序列
      • padding处理与mask
    • 模型
  • 结果
  • OneAPI
  • 参考资料

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

Intel® Optimization for PyTorch

在本次实验中,我们利用PyTorch和Intel® Optimization for PyTorch的强大功能,对PyTorch进行了精心的优化和扩展。这些优化举措极大地增强了PyTorch在各种任务中的性能,尤其是在英特尔硬件上的表现更加突出。通过这些优化策略,我们的模型在训练和推断过程中变得更加敏捷和高效,显著地减少了计算时间,提高了整体效能。我们通过深度融合硬件和软件的精巧设计,成功地释放了硬件潜力,使得模型的训练和应用变得更加快速和高效。这一系列优化举措为人工智能应用开辟了新的前景,带来了全新的可能性。
在这里插入图片描述

基于BERT的文本分类模型

基于BERT的文本分类模型就是在原始的BERT模型后再加上一个分类层即可,同时,对于分类层的输入(也就是原始BERT的输出),默认情况下取BERT输出结果中[CLS]位置对于的向量即可,当然也可以修改为其它方式,例如所有位置向量的均值等。因此,对于基于BERT的文本分类模型来说其输入就是BERT的输入,输出则是每个类别对应的logits值。

数据预处理

在构建数据集之前,我们首先需要知道的是模型到底应该接收什么样的输入,然后才能构建出正确的数据形式。在上面我们说到,基于BERT的文本分类模型的输入就等价于BERT模型的输入,同时BERT模型的输入如图1所示:
在这里插入图片描述

数据集

在这里,我们使用到的数据集是今日头条开放的一个新闻分类数据集(https://github.com/aceimnorstuvwxz/toutiao-text-classfication-dataset),一共包含有382688条数据,15个类别,经过处理后数据集格式为:

千万不要乱申请网贷,否则后果很严重_!_4
10年前的今天,纪念5.12汶川大地震10周年_!_11
怎么看待杨毅在一NBA直播比赛中说詹姆斯的球场统治力已经超过乔丹、伯德和科比?_!_3
戴安娜王妃的车祸有什么谜团?_!_2

其中_!_左边为新闻标题,也就是后面需要用到的分类文本,右边为类别标签。

定义tokenize

将输入进来的文本序列tokenize到字符级别。对于中文语料来说就是将每个字和标点符号都给切分开。在这里,我们可以借用transformers包中的BertTokenizer方法来完成,如下所示:

1 if __name__ == '__main__':
2     model_config = ModelConfig()
3     tokenizer = BertTokenizer.from_pretrained(model_config.pretrained_model_dir).tokenize
4     print(tokenizer("青山不改,绿水长流,我们月来客栈见!"))
5     print(tokenizer("10年前的今天,纪念5.12汶川大地震10周年"))
6 
7 # ['青', '山', '不', '改', ',', '绿', '水', '长', '流', ',', '我', '们', '月', '来', '客', '栈', '见', '!']
8 # ['10', '年', '前', '的', '今', '天', ',', '纪', '念', '5', '.', '12', '汶', '川', '大', '地', '震', '10', '周', '年']

建立词表

将vocab.txt中的内容读取进来形成一个词表即可

1 class Vocab:
 2     UNK = '[UNK]'
 3     def __init__(self, vocab_path):
 4         self.stoi = {}
 5         self.itos = []
 6         with open(vocab_path, 'r', encoding='utf-8') as f:
 7             for i, word in enumerate(f):
 8                 w = word.strip('\n')
 9                 self.stoi[w] = i
10                 self.itos.append(w)
11 
12     def __getitem__(self, token):
13         return self.stoi.get(token, self.stoi.get(Vocab.UNK))
14 
15     def __len__(self):
16         return len(self.itos)

转换为Token序列

在得到构建的字典后,便可以通过如下函数来将训练集、验证集和测试集转换成Token序列:

 1 def data_process(self, filepath):
 2     raw_iter = open(filepath, encoding="utf8").readlines()
 3     data = []
 4     max_len = 0
 5     for raw in tqdm(raw_iter, ncols=80):
 6         line = raw.rstrip("\n").split(self.split_sep)
 7         s, l = line[0], line[1]
 8         tmp = [self.CLS_IDX] + [self.vocab[token] for token in self.tokenizer(s)]
 9         if len(tmp) > self.max_position_embeddings - 1:
10             tmp = tmp[:self.max_position_embeddings - 1]  # BERT预训练模型只取前512个字符
11         tmp += [self.SEP_IDX]
12         tensor_ = torch.tensor(tmp, dtype=torch.long)
13         l = torch.tensor(int(l), dtype=torch.long)
14         max_len = max(max_len, tensor_.size(0))
15         data.append((tensor_, l))
16     return data, max_len

padding处理与mask

对原始文本序列tokenize转换为Token ID后还需要对其进行padding处理。对于这一处理过程可以通过如下代码来完成:

 1 def pad_sequence(sequences, batch_first=False, max_len=None, padding_value=0):
 2     if max_len is None:
 3         max_len = max([s.size(0) for s in sequences])
 4     out_tensors = []
 5     for tensor in sequences:
 6         if tensor.size(0) < max_len:
 7             tensor = torch.cat([tensor, torch.tensor(
 8               [padding_value] * (max_len - tensor.size(0)))], dim=0)
 9         else:
10             tensor = tensor[:max_len]
11         out_tensors.append(tensor)
12     out_tensors = torch.stack(out_tensors, dim=1)
13     if batch_first:
14         return out_tensors.transpose(0, 1)
15     return out_tensors

模型

class BertModel(nn.Module):
    """

    """

    def __init__(self, config):
        super().__init__()
        self.bert_embeddings = BertEmbeddings(config)
        self.bert_encoder = BertEncoder(config)
        self.bert_pooler = BertPooler(config)
        self.config = config
        self._reset_parameters()

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None):
        """
        ***** 一定要注意,attention_mask中,被mask的Token用1(True)表示,没有mask的用0(false)表示
        这一点一定一定要注意
        :param input_ids:  [src_len, batch_size]
        :param attention_mask: [batch_size, src_len] mask掉padding部分的内容
        :param token_type_ids: [src_len, batch_size]  # 如果输入模型的只有一个序列,那么这个参数也不用传值
        :param position_ids: [1,src_len] # 在实际建模时这个参数其实可以不用传值
        :return:
        """
        embedding_output = self.bert_embeddings(input_ids=input_ids,
                                                position_ids=position_ids,
                                                token_type_ids=token_type_ids)
        # embedding_output: [src_len, batch_size, hidden_size]
        all_encoder_outputs = self.bert_encoder(embedding_output,
                                                attention_mask=attention_mask)
        # all_encoder_outputs 为一个包含有num_hidden_layers个层的输出
        sequence_output = all_encoder_outputs[-1]  # 取最后一层
        # sequence_output: [src_len, batch_size, hidden_size]
        pooled_output = self.bert_pooler(sequence_output)
        # 默认是最后一层的first token 即[cls]位置经dense + tanh 后的结果
        # pooled_output: [batch_size, hidden_size]
        return pooled_output, all_encoder_outputs

    def _reset_parameters(self):
        r"""Initiate parameters in the transformer model."""
        """
        初始化
        """
        for p in self.parameters():
            if p.dim() > 1:
                normal_(p, mean=0.0, std=self.config.initializer_range)

    @classmethod
    def from_pretrained(cls, config, pretrained_model_dir=None):
        model = cls(config)  # 初始化模型,cls为未实例化的对象,即一个未实例化的BertModel对象
        pretrained_model_path = os.path.join(pretrained_model_dir, "pytorch_model.bin")
        if not os.path.exists(pretrained_model_path):
            raise ValueError(f"<路径:{pretrained_model_path} 中的模型不存在,请仔细检查!>\n"
                             f"中文模型下载地址:https://huggingface.co/bert-base-chinese/tree/main\n"
                             f"英文模型下载地址:https://huggingface.co/bert-base-uncased/tree/main\n")
        loaded_paras = torch.load(pretrained_model_path)
        state_dict = deepcopy(model.state_dict())
        loaded_paras_names = list(loaded_paras.keys())[:-8]
        model_paras_names = list(state_dict.keys())[1:]
        if 'use_torch_multi_head' in config.__dict__ and config.use_torch_multi_head:
            torch_paras = format_paras_for_torch(loaded_paras_names, loaded_paras)
            for i in range(len(model_paras_names)):
                logging.debug(f"## 成功赋值参数:{model_paras_names[i]},形状为: {torch_paras[i].size()}")
                if "position_embeddings" in model_paras_names[i]:
                    # 这部分代码用来消除预训练模型只能输入小于512个字符的限制
                    if config.max_position_embeddings > 512:
                        new_embedding = replace_512_position(state_dict[model_paras_names[i]],
                                                             loaded_paras[loaded_paras_names[i]])
                        state_dict[model_paras_names[i]] = new_embedding
                        continue
                state_dict[model_paras_names[i]] = torch_paras[i]
            logging.info(f"## 注意,正在使用torch框架中的MultiHeadAttention实现")
        else:
            for i in range(len(loaded_paras_names)):
                logging.debug(f"## 成功将参数:{loaded_paras_names[i]}赋值给{model_paras_names[i]},"
                              f"参数形状为:{state_dict[model_paras_names[i]].size()}")
                if "position_embeddings" in model_paras_names[i]:
                    # 这部分代码用来消除预训练模型只能输入小于512个字符的限制
                    if config.max_position_embeddings > 512:
                        new_embedding = replace_512_position(state_dict[model_paras_names[i]],
                                                             loaded_paras[loaded_paras_names[i]])
                        state_dict[model_paras_names[i]] = new_embedding
                        continue
                state_dict[model_paras_names[i]] = loaded_paras[loaded_paras_names[i]]
            logging.info(f"## 注意,正在使用本地MyTransformer中的MyMultiHeadAttention实现,"
                         f"如需使用torch框架中的MultiHeadAttention模块可通过config.__dict__['use_torch_multi_head'] = True实现")
        model.load_state_dict(state_dict)
        return model

结果

在这里插入图片描述

OneAPI

import intel_extension_for_pytorch as ipex

model = model.to(config.device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

参考资料

基于BERT预训练模型的中文文本分类任务: https://www.ylkz.life/deeplearning/p10979382/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/902742.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

标速高就是好?不看4K随机就别买SSD!

游戏玩家心心念念的SSD终于降到了白菜价&#xff0c;1TB的固态硬盘甚至比机械硬盘都便宜了&#xff0c;不过如果只看到动辄3000MB/s的读速&#xff0c;那你下单的时候还真的会被骗。 之所以这么说&#xff0c;是因为商品页面标注的速度都是连续读写速度&#xff0c;也就是直接向…

面试官:JVM是如何判定对象已死的?学JVM必会的知识!

本文已收录至GitHub&#xff0c;推荐阅读 &#x1f449; Java随想录 文章目录 引用计数算法可达性分析算法引用类型Dead Or Alive永久代真的"永久"吗&#xff1f;垃圾收集算法标记-清除算法标记-复制算法标记-整理算法标记-清除 VS 标记-整理 作为一名Java程序员&…

Autosar存储入门系列02_NVM之CRC校验及显隐式同步机制

本文框架 0.前言1. NVM中CRC校验2. NVM的显隐式同步机制2.1 隐式同步2.2 显式同步 0.前言 本系列是Autosar存储入门系列&#xff0c;希望能从学习者的角度把存储相关的知识点梳理一遍&#xff0c;这个过程中如果大家觉得有讲得不对或者不够清晰的地方&#xff0c;还请一定指出…

Linus对AMD的fTPM 漏洞表示”沮丧” 呼吁禁用该功能

导读AMD 的 fTPM 问题在业内众所周知&#xff0c;经常导致系统崩溃和卡死。Linux 的创建者 Linus Torvalds 对该功能表示失望&#xff0c;称其为内核的”瘟疫”。 简单回顾一下&#xff0c;可信平台模块&#xff08;Trusted Platform Module 或 TPM&#xff09;是一种安全检查…

抖音火山引擎推出免费域名DNS和公共DNS服务

抖音旗下的云计算服务火山引擎最近推出了"TrafficRoute DNS 套件"服务&#xff0c;其中包括两款产品&#xff0c;对软希网来说非常有用。 1.域名DNS&#xff1a; 这是一个用于网站域名的DNS服务&#xff0c;可以加速域名解析速度&#xff0c;从而提升网站的速度。如…

【100天精通python】Day42:python网络爬虫开发_HTTP请求库requests 常用语法与实战

目录 1 HTTP协议 2 HTTP与HTTPS 3 HTTP请求过程 3.1 HTTP请求过程 3.2 GET请求与POST请求 3.3 常用请求报头 3.4 HTTP响应 4 HTTP请求库requests 常用语法 4.1 发送GET请求 4.2 发送POST请求 4.3 请求参数和头部 4.4 编码格式 4.5 requests高级操作-文件上传 4.6 …

线性代数的学习和整理4: 求逆矩阵的多种方法汇总

目录 原始问题&#xff1a;如何求逆矩阵&#xff1f; 1 EXCEL里&#xff0c;直接可以用黑盒表内公式 minverse() 数组公式求A- 2 非线性代数方法&#xff1a;解方程组的方法 3 增广矩阵的方法 4 用行列式的方法计算&#xff08;未验证&#xff09; 5 A-1/|A|*A* &…

构建 NodeJS 影院微服务并使用 docker 部署【01/4】

图片来自谷歌 — 封面由我制作 一、说明 构建一个微服务的电影网站&#xff0c;需要Docker、NodeJS、MongoDB&#xff0c;这样的案例您见过吗&#xff1f;如果对此有兴趣&#xff0c;您就继续往下看吧。 在本系列中&#xff0c;我们将构建一个 NodeJS 微服务&#xff0c;并使用…

【排序】插入排序 希尔排序(改进)

文章目录 插入排序时间复杂度空间复杂度 代码希尔排序时间复杂度空间复杂度 代码 以从小到大排序为例进行说明。 插入排序 插入排序就是从前向后&#xff08;i1开始&#xff09;进行选择&#xff0c;如果找到在i之前&#xff08;分配一个j下标进行寻找&#xff09;有比array[i…

第 7 章 排序算法(2)(冒泡排序)

7.5冒泡排序 7.5.1基本介绍 冒泡排序&#xff08;Bubble Sorting&#xff09;的基本思想是&#xff1a;通过对待排序序列从前向后&#xff08;从下标较小的元素开始&#xff09;,依次比较相邻元素的值&#xff0c;若发现逆序则交换&#xff0c;使值较大的元素逐渐从前移向后部…

手写Promise一:结构的设计

手写Promise 这里写目录标题 手写Promise手写Promise的规范手册promisesaplus官网手写Promise-结构的设计 手写Promise的规范手册promisesaplus官网 链接: 官网链接 手写Promise-结构的设计 // 手写Promsie const PROMISE_STATUS_PENDING pending //等待状态 const PROMIS…

自动驾驶仿真:基于Carsim开发的加速度请求模型

文章目录 前言一、加速度输出变量问题澄清二、配置Carsim动力学模型三、配置Carsim驾驶员模型四、添加VS Command代码五、Run Control联合仿真六、加速度模型效果验证 前言 1、自动驾驶行业中&#xff0c;算法端对于纵向控制的功能预留接口基本都是加速度&#xff0c;我们需要…

【LeetCode75】第三十四题 叶子相似的树

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 给我们两棵二叉树&#xff0c;让我们判断这两棵二叉树的从左到右的叶子节点组成的叶子序列是否一致&#xff0c;即从左到右的叶子节点的数…

python爬虫实战零基础(2)——网页图片

网页图片的批量爬取保存 分析思路预备知识xpath用法response.text和 response.content两者的区别 代码实战请求网页内容批量图片保存 分析思路 还是基于request和xpath的爬虫代码 定位目标网址&#xff08;里面图片还是很好看的 https://pic.netbian.com/4kdongman/index.html&…

漏洞指北-VulFocus靶场专栏-入门

漏洞指北-VulFocus靶场01-入门 VulFocus靶场前置条件&#xff1a;入门001 命令执行漏洞step1&#xff1a; 输入默认index的提示step2&#xff1a; 入门002 目录浏览漏洞step1&#xff1a;进入默认页面&#xff0c;找到tmp目录step2 进入tmp目录获取flag文件 VulFocus靶场前置条…

Linux 线程库中的接口介绍

1.pthread_create()创建线程 pthread_create()的语法形式&#xff1a; 参数解释&#xff1a; 第一个参数thread&#xff1a;事先创建好的pthread_t类型的参数。成功时thread指向的内存单元被设置为新创建线程的线程ID。 第二个参数attr&#xff1a;用于定制各种不同的线程属性…

三角形添加数--夏令营

题目 tips&#xff1a; 1.本题不要求正三角形输出&#xff0c;只要输出左下三角即可 2.这种输入三角形的&#xff0c;都是可以理解为左下三角形的模型&#xff0c;然后去写f[i][j]f[i-1][j]f[i-1][j1]&#xff0c;写行列 3.还有双重for循环输入输出三角形&#xff0c;注意第二…

数据处理与统计分析——MySQL与SQL

这里写目录标题 1、初识数据库1.1、什么是数据库1.2、数据库分类1.3、相关概念1.4、MySQL及其安装1.5、基本命令 2、基本命令2.1、操作数据库2.2、数据库的列类型2.3、数据库的字段属性2.4 创建和删除数据库表2.5、数据库存储引擎2.6、修改数据库 3、MySQL数据管理3.1、外键 My…

YOLOv5+deepsort实现目标追踪。(附有各种错误解决办法)

一、YOLOv5算法相关配置 🐸这里如果是自己只想跑一跑YOLOV5的话,可以参考本章节。只想跑通YOLOv5+deepsort的看官移步到下一章节。 1.1 yolov5下载 🐸yolov5源码在github下载地址上或者Gitee上面都有。需要注意的是由于yolov5的代码库作者一直在维护,所以下载的时候需…

【前端】vscode javascript 代码片段失效问题解决

1. 文件--首选项--用户代码片段-vue.json : 添加 // { // // Place your global snippets here. Each snippet is defined under a snippet name and has a scope, prefix, body and // // description. Add comma separated ids of the languages where the snippet is app…