PyTorch学习笔记 7.TextCNN文本分类

news2025/1/19 17:12:45

PyTorch学习笔记 7.TextCNN文本分类

  • 一、模型结构
  • 二、文本分词与编码
    • 1. 分词与编码器
    • 2. 数据加载器
  • 二、模型定义
    • 1. 卷积层
    • 2. 池化层
    • 3. 全连接层
  • 三、训练过程
  • 四、测试过程
  • 五、预测过程

一、模型结构

在这里插入图片描述

2014年,Yoon Kim针对CNN的输入层做了一些变形,提出了文本分类模型textCNN。与传统图像的CNN网络相比, textCNN 在网络结构上没有任何变化,包含只有一层卷积,一层最大池化层, 最后将输出外接softmax 来进行n分类。
模型结构:
在这里插入图片描述
本文使用的数据集是 THUCNews 。

二、文本分词与编码

1. 分词与编码器

这里使用bert的预训练模型 bert-base-chinese 实现tokenizer过程。更多与bert分词编码相关知识可以移步到这里查看。

2. 数据加载器

数据加载器使用pytorch 的 dataset,关于DataSet更多知识可以移步到这里查看。

# 定义数据加载器
class Dataset(data.Dataset):
    def __init__(self, data_path):
        super().__init__()
        self.lines = open(data_path, encoding='utf-8').readlines()
        # 如果要指定缓存目录,可以使用 cache_dir='/kaggle/working/tokenizer'
        self.tokenizer = BertTokenizer.from_pretrained(BERT_TOKENIZER_MODEL)

    def __len__(self):
        return len(self.lines)

    # 取每条数据进行编码
    def __getitem__(self, index):
        text, label = self.lines[index].split('\t')
        tokenizer = self.tokenizer(text)

        input_ids = tokenizer['input_ids']
        attention_mask = tokenizer['attention_mask']
        # input_ids 和 attention_mask补全
        if len(input_ids) < TEXT_LEN:
            pad_len = (TEXT_LEN - len(input_ids))
            input_ids += [BERT_PAD_ID] * pad_len
            attention_mask += [0] * pad_len
        target = int(label)
        return torch.tensor(input_ids[:TEXT_LEN]), torch.tensor(attention_mask[:TEXT_LEN]), torch.tensor(target)

二、模型定义

1. 卷积层

模型定义3个卷积层,卷积大小分别是2,3,4。
卷积激活函数使用relu。

2. 池化层

卷积后进行最大池化,池化是在2维上进行,池化后进行降维处理。

3. 全连接层

根据池化层的输出和分类类别数量,构建全连接层,再经过softmax,得到最终的分类结果。
这里使用torch.nn.Linear(input_num, num_class)定义全连接层,其中input_num是池化层输出的维数,即m,num_class是分类任务的类别数量。

def conv_and_pool(conv, input):
    out = conv(input)
    # 第一次out.shape=[2,256,29,1]
    out = F.relu(out)
    # 池化在2维上进行,out.shape是范围大小,最后进行降维
    return F.max_pool2d(out, (out.shape[2], out.shape[3])).squeeze()


class TextCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained(BERT_TOKENIZER_MODEL)
        # 固定bert的参数,只训练下游参数
        for name, param in self.bert.named_parameters():
            param.requires_grad = False
        # 从1 变为 256个通道
        # 这里定义3个层,卷积核大小分别是[2,3,4]
        self.conv1 = nn.Conv2d(1, NUM_FILTERS, (2, EMBEDDING_DIM))
        self.conv2 = nn.Conv2d(1, NUM_FILTERS, (3, EMBEDDING_DIM))
        self.conv3 = nn.Conv2d(1, NUM_FILTERS, (4, EMBEDDING_DIM))
        # 全连接
        self.linear = nn.Linear(NUM_FILTERS * 3, NUM_CLASSES)

    def forward(self, input, mask):
        # self.bert 第0元素 [2,30,768]
        # unsqueeze 进行升维,变成[2,1,30,768]
        out = self.bert(input, mask)[0].unsqueeze(1)
        # 第1层输出 [2,256]
        # 在1维上拼接,输出[256,3],3个层上进行拼接
        out1 = conv_and_pool(self.conv1, out)
        out2 = conv_and_pool(self.conv2, out)
        out3 = conv_and_pool(self.conv3, out)
        out = torch.cat([out1, out2, out3], dim=1)
        # 把3个层拼接,1个层是 out1 = self.conv_and_pool(self.conv1, out)
        # 输出[2,10]
        return self.linear(out)

三、训练过程

按批次取训练数据,调用模型进行训练,主要是以下几个步骤:

  1. 获取loss:输入数据和标签,计算得到预测值,计算损失函数;
  2. optimizer.zero_grad() 清空梯度;
  3. loss.backward() 反向传播,计算当前梯度;
  4. optimizer.step() 根据梯度更新网络参数
        for batch, (input, mask, target) in enumerate(train_loader):
            input = input.to(DEVICE)
            mask = mask.to(DEVICE)
            target = target.to(DEVICE)
            # 预测,形状10*10
            pred = model(input, mask)
            loss = loss_fn(pred, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

四、测试过程

测试过程对每次正确率累加,最后打印整体的测试结果:

def test():
    test_dataset = Dataset(TEST_SAMPLE_PATH)
    test_loader = data.DataLoader(test_dataset, batch_size=100, shuffle=False)

    loss_fn = nn.CrossEntropyLoss()

    y_pred = []
    y_true = []

    with torch.no_grad():
        for batch, (input, mask, target) in enumerate(test_loader):
            input = input.to(DEVICE)
            mask = mask.to(DEVICE)
            target = target.to(DEVICE)
            test_pred = model(input, mask)
            loss = loss_fn(test_pred, target)
            print('>> batch:', batch, 'loss:', round(loss.item(), 5))
            test_pred_ = torch.argmax(test_pred, dim=1)
            # 计算整体正确率
            y_pred += test_pred_.data.tolist()
            y_true += target.data.tolist()
    # 打印整体的测试指标
    print(evaluate(y_pred, y_true, id2labels))

五、预测过程

  1. 把输入文本进行分词编码
  2. 输入模型,通过argmax计算预测值
  3. 通过id转标签函数计算标签值
def predict(texts):
    # 分词
    tokenizer = BertTokenizer.from_pretrained(BERT_TOKENIZER_MODEL)

    batch_input_ids = []
    batch_mask = []
    start = time.time()
    for text in texts:
        tokenizers = tokenizer(text)
        input_ids = tokenizers['input_ids']
        attention_masks = tokenizers['attention_mask']
        if len(input_ids) < TEXT_LEN:
            pad_len = (TEXT_LEN - len(input_ids))
            input_ids += [BERT_PAD_ID] * pad_len
            attention_masks += [0] * pad_len
        batch_input_ids.append(input_ids[:TEXT_LEN])
        batch_mask.append(attention_masks[:TEXT_LEN])

    batch_input_ids = torch.tensor(batch_input_ids)
    batch_mask = torch.tensor(batch_mask)
    pred = model(batch_input_ids.to(DEVICE), batch_mask.to(DEVICE))
    pred_ = torch.argmax(pred, dim=1)

    ret = ([id2labels[index] for index in pred_])
    end = time.time()
    runTime = end - start
    print("共", len(texts), '条数据,运行时间:', runTime, '秒,平均每条时间', runTime / len(texts), '秒')
    return ret

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/128302.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis事件循环

Redis事件循环文件事件时间事件事件调度和执行客户端部分关于客户端输出缓冲区限制ServerCron周期函数服务器启动流程小结Redis服务器是一个事件驱动程序, 主要处理两类事件: 文件事件 (File Event) : 对套接字操作的抽象&#xff0c;服务器与客户端的通信过程会产生相应的文件…

Java 中的继承和多态

面向对象的三大特性&#xff1a;封装、继承、多态。在这三个特性中&#xff0c;如果没有封装和继承&#xff0c;也不会有多态。 那么多态实现的途径和必要条件是什么呢&#xff1f;以及多态中的重写和重载在JVM中的表现是怎么样&#xff1f;在Java中是如何展现继承的特性呢&am…

常用密码算法介绍

算法种类 根据技术特征&#xff0c;现代密码学可分为三类&#xff1a; 对称算法 说明&#xff1a;加密密钥和解密密钥相同&#xff0c;对明文、密文长度没有限制 子算法&#xff1a; 流密码算法&#xff1a;每次加密或解密一位或一字节的明文或密文 分组密码算法&#xff…

LiveGBS流媒体平台国标GB/T28181功能-国标流媒体服务平台作为上级接入海康大华华为宇视等下级平台及摄像头

LiveGBS国标流媒体服务平台作为上级接入海康大华华为宇视等下级平台及摄像头1、背景说明2、部署国标平台2.1、安装使用说明2.2、服务器网络环境2.3、信令服务配置3、监控摄像头设备接入3.1、海康GB28181接入示例3.2、大华GB28181接入示例3.3、华为IPC GB28181接入示例4、硬件NV…

mysql 存储过程实现从一张表数据迁移到另一种表

通过存储过程迁移数据&#xff1a; 创建表 CREATE TABLE test1 ( idp varchar(255) DEFAULT NULL, brandIdp varchar(255) DEFAULT NULL, namep varchar(1000) DEFAULT NULL, urlp varchar(1000) DEFAULT NULL ) ENGINEInnoDB DEFAULT CHARSETkeybcs2; INSERT INTO t…

2023美国大学生数学建模竞赛(MCM/ICM)报名流程指南

数模乐园作为国内美赛报名最大官方平台&#xff0c;为参加美赛的同学解决国际支付报名难的问题&#xff0c;为同学们省去大部分繁琐流程的同时还附赠纸质证书打印邮寄、美赛赛题解析、美赛专属礼包、赛题翻译等备赛资料 数模乐园已累计为10万同学完成了美赛辅助报名&#xff0…

Android 音视频编解码(三) -- 视频编码和H264格式原理讲解

Android 音视频编解码(一) – MediaCodec 初探 Android 音视频编解码(二) – MediaCodec 解码(同步和异步) 前面学习了 MediaCodec 的基本原理&#xff0c;以及如何解码&#xff0c;在学习MediaCodec 编码之前&#xff0c;先来学习视频是如何编码的&#xff0c;以及最常用的 H2…

亚马逊vs Starday :做跨境电商生意,从哪里开始?

据有关数据统计&#xff0c;中国跨境电商进出口五年增长近十倍&#xff0c;在一众行业面前脱颖而出&#xff0c;成为我国对外贸易新的增长极&#xff0c;然而也正是这样的趋势&#xff0c;使得许多原本从事电商行业的卖家和资本纷纷闻风而动&#xff0c;想要进入市场分一杯羹&a…

3d打印的翘边问题

如何解决3D打印翘边问题 翘边是3D打印中常见的问题之一。为什么在打印的过程中会遇到翘边呢&#xff1f;主要是因为塑料的热胀冷缩&#xff0c;从喷嘴挤出来的塑料在冷却时候会收缩&#xff0c;进而导致模型边缘或者两头翘了起来与平台出现分离。那么如何避免或解决翘边问题呢…

“消费盲返”爆火,一个月能赚1000w?

寒冬已至&#xff0c;疫情还是在断断续续的复发&#xff0c;很多城市也受到严重的影响&#xff0c;封城的通告一出&#xff0c;无疑是给不少的实体企业增添了相当大的噩耗打击&#xff0c;这时候更为磨炼实体企业和创业人看待事情的立场&#xff0c;有些人会觉得疫情的袭来什么…

SSM框架学习记录-SpringBoot_day01

1.SpringBoot简介 SpringBoot是用来简化Spring应用的初始搭建以及开发过程 先回顾一下SpringMVC的开发过程&#xff1a; 创建工程&#xff0c;并在pom.xml配置文件中配置所依赖的坐标&#xff1a; <dependencies><dependency><groupId>javax.servlet</gro…

阶段性回顾(3)

1. 学习指针必须得了解清楚内存&#xff0c;而内存到底是什么东西呢&#xff1f;内存就是电脑上的存储设备&#xff08;除了内存之外&#xff0c;还有硬盘&#xff0c;寄存器等等&#xff09;&#xff0c;那内存到底是来干啥的呢&#xff1f;程序运行的时候会载入到内存当中&am…

Fast Report .NET 2023.1.7-2022-最后版本

通过使用 Fast Report .NET&#xff0c;用户可以构建和创建本质上独立的应用程序以及报表。网。换句话说&#xff0c;这意味着 Fast Report .NET 可以作为所有用户的独立报告工具独立使用。它可以包括一个强大的可视化报告&#xff0c;用于创建和修改报告的过程。用户应用程序可…

Selenium Webdriver 实现原理详解-手工用Postman调用webdriver执行UI测试

目录 1. Selenium 概述 2. 术语解释&#xff1a; 3. Selenium WebDriver 实现原理 4. 安装selenium 客户端&#xff0c;浏览器&#xff0c;驱动 4.1 安装selenium client lib 4.2 安装浏览器和浏览器驱动 4.3 例子代码 4.4 省略浏览器驱动的方法 4.5 测试代码与Webdr…

Linux近期补充

Linux近期补充Linux命令的近期补充Linux命令的近期补充 1.本地服务器链接远端服务器 命令 ssh 远端服务器ip 如 ssh 121.5.151.236 会弹出 登录框 自己输入密码即可2.当前位置 pwd3.查看网络设备 ifconfig4.查看服务器内存 free -h可以看到还有2.3G内存可以用 5.查看磁盘…

ES学习1~23(ECMAcript相关介绍+ECMASript 6新特性)

1 ECMAcript相关介绍 1.1 什么是ECMA ECMA(European Computer Manufacturers Association)中文名称为欧洲计算机制造商协会&#xff0c;这个组织的目标是评估、开发和认可电信和计算机标准。1994年后该组织改名为Ecma国际。 1.2 什么是ECMScript ECMAScript是由Ecma国际通过…

外网远程访问本地MySQL数据库【cpolar内网穿透】

作为网站运行必备组件之一的数据库&#xff0c;免不了随时对其进行管理维护。若我们没有在安装数据库的电脑旁&#xff0c;但又需要立即对数据库进行管理时&#xff0c;应该如何处理&#xff1f;这时我们可以使用cpolar对内网进行穿透&#xff0c;远程管理和操作MySQL数据库。现…

三叠云甘特图新亮点,可翻页查看数据啦

表单管理 路径 表单 >> 表单设计 功能简介 1.「甘特视图」新增“翻页”功能&#xff0c;用户可以通过翻页查阅更多的数据。 2. 滑动超过显示区域时显示“标记点”&#xff0c;用户可以通过点击标记点快速定位到相应的数据。 3.「列表视图」条件着色功能,修复“系统字…

Linux进程管理

1.什么是程序&#xff1f;具有执行代码和执行权限的文本文件 2.什么是进程&#xff1f;是已启动的可执行程序的运行实例 3.进程的生命周期&#xff1a;由系统程序fork出来的子程序&#xff0c;具备一定的父资源&#xff0c;直到运行完毕 4.进程有哪些组成部分&#xff1f; …

操作系统真相还原_第3章:实模式下跳转指令补充

文章目录数据类型伪指令ret指令call指令jmp指令标志寄存器flags与条件转移数据类型伪指令 byte&#xff1a;字节 word&#xff1a;字 dword&#xff1a;双字 qword&#xff1a;四字 跳转指令指定目标操作数大小 short&#xff1a;字节 near&#xff1a;字 far&#xff1a;双字…