文本分类TextRNN_Att模型(pytorch实现)

news2025/1/20 3:52:47

TextRNN_Att

        • TextRNN-Att简介
        • 模型结构:
        • pytorch代码实现:

TextRNN-Att简介

TextRNN前面已经介绍过了,主体结构就是一个双向/单向的LSTM层,由于LSTM获得每个时间点的输出信息之间的“影响程度”都是一样的,而在关系分类中,为了能够突出部分输出结果对分类的重要性,引入加权的思想。而本篇模型在LSTM层之后引入了attention层,其实就是对lstm每刻的隐层进行加权平均。

在这里插入图片描述

模型结构:
  • 输入层:输入是一个一个的句子,通过对它进行划分batch,sentence,然后进行编码

  • 词嵌入层:将文本中的离散词汇表示(如单词或者字符)转换为连续的实值向量表示,也称为词嵌入(Word Embedding)。这些实值向量具有语义信息,能够捕捉词汇之间的语义关系,从而提供更丰富的特征表示。

  • LSTM层:双向LSTM是RNN的一种改进,其主要包括前后向传播,每个时间点包含一个LSTM单元用来选择性的记忆、遗忘和输出信息。模型的输出包括前后向两个结果,通过拼接作为最终的Bi-LSTM输出。公式如下:

  • 注意力层:对lstm每刻的隐层进行加权平均,将词级别的特征合并到句子级别的特征。

M = tanh ⁡ ( H ) M=\tanh \left(H \right) M=tanh(H)

α = s o f t max ⁡ ( W T M ) \alpha =soft\max \left(W^TM \right) α=softmax(WTM)

r = H α T r=H\alpha ^T r=HαT

  • 输出层:将句子层级的特征用于关系分类。
pytorch代码实现:
  1. 模型输入: [batch_size, seq_len]
  2. 经过embedding层:加载预训练词向量或者随机初始化, 词向量维度为embed_size: [batch_size, seq_len, embed_size]
  3. 双向LSTM:隐层大小为hidden_size,得到所有时刻的隐层状态(前向隐层和后向隐层拼接) [batch_size, seq_len, hidden_size * 2]
  4. 初始化一个可学习的权重矩阵w=[hidden_size * 2, 1]
  5. 对LSTM的输出进行非线性激活后与w进行矩阵相乘,并经行softmax归一化,得到每时刻的分值:[batch_size, seq_len, 1]
  6. 将LSTM的每一时刻的隐层状态乘对应的分值后求和,得到加权平均后的终极隐层值[batch_size, hidden_size * 2]
  7. 对终极隐层值进行非线性激活后送入两个连续的全连接层[batch_size, num_class]
  8. 预测:softmax归一化,将num_class个数中最大的数对应的类作为最终预测[batch_size, 1]
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class Config(object):
    """配置参数"""

    def __init__(self):
        self.model_name = 'TextRNN_Att'
        self.dropout = 0.5  # 随机失活
        self.require_improvement = 1000  # 若超过1000batch效果还没提升,则提前结束训练
        self.num_classes = 10 # 类别数
        self.n_vocab = 10000  # 词表大小,在运行时赋值
        self.num_epochs = 10  # epoch数
        self.batch_size = 128  # mini-batch大小
        self.pad_size = 32  # 每句话处理成的长度(短填长切)
        self.learning_rate = 1e-3  # 学习率
        self.embed =  300  # 字向量维度, 若使用了预训练词向量,则维度统一
        self.hidden_size = 128  # lstm隐藏层
        self.num_layers = 2  # lstm层数
        self.hidden_size2 = 64

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,
                            bidirectional=True, batch_first=True, dropout=config.dropout)
        self.tanh1 = nn.Tanh()
        self.w = nn.Parameter(torch.zeros(config.hidden_size * 2))
        self.tanh2 = nn.Tanh()
        self.fc1 = nn.Linear(config.hidden_size * 2, config.hidden_size2)
        self.fc = nn.Linear(config.hidden_size2, config.num_classes)

    def forward(self, x):
        x, _ = x
        # 词嵌入层
        emb = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]
        # LSTM层
        H, _ = self.lstm(emb)  # [batch_size, seq_len, hidden_size * num_direction]=[128, 32, 256]
        # 注意力层
        M = self.tanh1(H)  # [128, 32, 256]
        alpha = F.softmax(torch.matmul(M, self.w), dim=1).unsqueeze(-1)  # [128, 32, 1]
        out = H * alpha  # [128, 32, 256]
        #输出层
        out = torch.sum(out, 1)  # [128, 256]
        out = F.relu(out)  # [128, 256]
        out = self.fc1(out)  # [128, 64]
        out = self.fc(out)  # [128, 10]
        return out

config=Config()
model=Model(config)
print(model)

输出:

Model(
  (embedding): Embedding(10000, 300, padding_idx=9999)
  (lstm): LSTM(300, 128, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
  (tanh1): Tanh()
  (tanh2): Tanh()
  (fc1): Linear(in_features=256, out_features=64, bias=True)
  (fc): Linear(in_features=64, out_features=10, bias=True)
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1682137.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[vue] nvm use时报错 exit status 1:一堆乱码,exit status 5

报错exit status 5:�ܾ����ʡ� 原因:因为当前命令提示符窗口是user权限, 解决:cmd使用管理员方式打开就可以 参考: vm use时报错 exit status 1…

web安全学习笔记(16)

记一下第27-28课的内容。Token 验证 URL跳转漏洞的类型与三种跳转形式;URL跳转漏洞修复 短信轰炸漏洞绕过挖掘 一、token有关知识 什么是token?token是用来干嘛的?_token是什么意思-CSDN博客 二、URL跳转漏洞 我们在靶场中,…

ES 数据写入方式:直连 VS Flink 集成系统

ES 作为一个分布式搜索引擎,从扩展能力和搜索特性上而言无出其右,然而它有自身的弱势存在,其作为近实时存储系统,由于其分片和复制的设计原理,也使其在数据延迟和一致性方面都是无法和 OLTP(Online Transac…

python实现贪吃蛇游戏,python贪吃蛇

欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一.前言 二.代码 三.使用 四.总结 一.前言 贪吃蛇游戏是一款经典的休闲益智类游戏,以下是关于该游戏的详细介绍: 游戏类型与平台:

玩转网络调试利器:深入剖析ip命令的强大功能

欢迎来到我的博客,代码的世界里,每一行都是一个故事 玩转网络调试利器:深入剖析ip命令的强大功能 前言ip命令概述网络接口管理ip地址配置路由管理邻居关系查看 前言 在我们的日常网络使用中,我们经常需要管理和调试网络接口、路由…

【大模型微调】一文掌握7种大模型微调的方法

本篇文章深入分析了大型模型微调的基本理念和多样化技术,细致介绍了LoRA、适配器调整(Adapter Tuning)、前缀调整(Prefix Tuning)等多个微调方法。详细讨论了每一种策略的基本原则、主要优点以及适宜应用场景,使得读者可以依据特定的应用要求和计算资源限…

前端开发切入第三方页面显示不全问题解决方案

前端开发切入第三方页面显示不全问题解决方案 最近做一个电视大屏,大屏分为三个部分,又分为上下结构,下部分分为左右结构布局,第一个部分是设备架构图、第二个部分是本市设备网点告警图,第一、二部分是我自己开发的,采用自动计算、自动适配各种场景缩放,兼容性没有任何…

uni-app 开发准备工作(一次开发,多端部署)

前言 uni-app 是一个使用 Vue.js 开发所有前端应用的框架,开发者编写一套代码,可发布到iOS、Android、Web(响应式)、以及各种小程序(微信/支付宝/百度/头条/飞书/QQ/快手/钉钉/淘宝)、快应用等多个平台。 …

车辆超龄无法注册滴滴司机怎么办理账号

车辆超龄无法注册滴滴司机,别担心这个视频教你如何解决,滴滴司机注册过程中 车辆年限是一个常见的限制条件,如果您的车辆超过了8年,那么注册滴滴可能会遇到困难,但是不要因此而放弃成为滴滴司机的机会,《 …

nestJs链接redis

给大家推荐一个库,地址:Yarn service import { Injectable } from nestjs/common; import { RedisService as RedisServices, DEFAULT_REDIS_NAMESPACE } from liaoliaots/nestjs-redis; import Redis from ioredis;Injectable() export class RedisService {priva…

入职java开发第一天,不会VUE竟然被.........

Vue2 技术栈 第 1 章:Vue 核心1.1. Vue 简介1.1.1. 官网1.1.2. 介绍与描述1.1.3. Vue 的特点1.1.4. 与其它 JS 框架的关联1.1.5. Vue 周边库 1.2. 初识 Vue1.3. 模板语法1.3.1. 效果1.3.2. 模板的理解1.3.3. 插值语法1.3.4. 指令语法 1.4. 数据绑定1.4.1. 效果1.4.2…

012.使用传统.NET事件进行通知操作

Rx的目标是协调和编排来自各种来源的基于事件的异步计算,如社交网络、传感器、UI事件等。例如,建筑物周围的安全摄像头,以及当有人可能在建筑物附近时触发的运动传感器,会向我们发送最近摄像头的照片。Rx还可以统计包含选举候选人…

关爱内向儿童:理解与支持助力成长

引言 每个孩子都是独特的,有些孩子天生性格外向,善于表达,而有些孩子则比较内向,喜欢独处。内向并不是缺点,而是一种性格特质。然而,内向的孩子在社交和学习过程中可能会面临一些挑战。本文将探讨内向儿童…

光伏运维系统在光伏电站的应用

摘要:全球化经济社会的快速发展,加快了传统能源的消耗,导致能源日益短缺,与此同时还带来了严重的环境污染。因此,利用没有环境污染的太阳能进行光伏发电获得了社会的普遍关注。本文根据传统式光伏电站行业的发展背景及其监控系统的技术设备,给出了现代化光伏电站数据…

Python 机器学习 基础 之 监督学习 【分类器的不确定度估计】 的简单说明

Python 机器学习 基础 之 监督学习 【分类器的不确定度估计】 的简单说明 目录 Python 机器学习 基础 之 监督学习 【分类器的不确定度估计】 的简单说明 一、简单介绍 二、监督学习 算法 说明前的 数据集 说明 三、监督学习 之 分类器的不确定度估计 1、决策函数 2、预测…

20232831 袁思承 2023-2024-2 《网络攻防实践》第10次作业

目录 20232831 袁思承 2023-2024-2 《网络攻防实践》第10次作业1.实验内容2.实验过程(1)SEED SQL注入攻击与防御实验①熟悉SQL语句②对SELECT语句的SQL注入攻击③对UPDATE语句的SQL注入攻击④SQL对抗 (2)SEED XSS跨站脚本攻击实验…

github新手用法

目录 1,github账号注册2,github登录3,新建一个仓库4,往仓库里面写入东西或者上传东西5, 下载Git软件并安装6 ,获取ssh密钥7, 绑定ssh密钥8, 测试本地和github是否联通9,从…

防火请技术基础篇:令牌桶机制的剖析与应用

防火墙中的令牌桶机制:深度剖析与应用 在现代网络通信中,防火墙技术发挥着至关重要的作用,它不仅能够实现网络安全防御,还能通过诸如令牌桶算法等机制来有效管理网络流量,保证网络服务的质量。本文将全面深入地探讨防…

Linux(十) 线程,线程控制

目录 一、认识线程 1.1 线程是什么 1.2 为啥要有线程 并行与并发 为什么要有线程(线程的优点) 为什么线程的切换成本更低 1.3 线程的缺点 1.4 线程和进程的区别 二、线程控制 2.1 线程创建 进程ID和线程ID 2.2 线程终止 2.3 线程等待 2.4 线程分离 三、注意 一、…