NLP|LSTM+Attention文本分类

news2025/1/28 1:03:31

目录

一、Attention原理简介

二、LSTM+Attention文本分类实战

1、数据读取及预处理

2、文本序列编码

3、LSTM文本分类

三、划重点

少走10年弯路


        LSTM是一种特殊的循环神经网络(RNN),用于处理序列数据和时间序列数据的建模和预测。而在NLP和时间序列领域上Attention-注意力机制也早已有了大量应用,本文将介绍在LSTM基础上如何添加Attention来优化模型效果。

一、Attention原理简介

        注意力机制通过聚焦于重要的信息,忽略不重要的信息,从而有效地处理输入信息。在神经网络中,注意力机制可以帮助模型更好地关注输入中的重要特征,从而提高模型的性能。

        简单而言,在文本处理任务中,self-attention对每一个词会随机初始化q、k、v三个向量,用每个词的q向量和其他k向量做点积、再归一化得到这个词的权重向量w,用w给v向量加权求和得到z向量(该词attention之后的向量)。再延伸一点,其实可以初始化多组q、k、v矩阵,从而得到多组z矩阵拼接起来(类似于CNN中的多个卷积核、来提取不同信息),再乘上一个矩阵压缩回原来的维度,得到最终的embedding。

        细节原理相对繁琐,推荐大家可以去看一下这篇博客的bert介绍,其中self-attention部分详细且清晰。

https://blog.csdn.net/jiaowoshouzi/article/details/89073944

二、LSTM+Attention文本分类实战

1、数据读取及预处理

import re
import os
from sqlalchemy import create_engine
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve,roc_auc_score
import xgboost as xgb
from xgboost.sklearn import XGBClassifier
import lightgbm as lgb
import matplotlib.pyplot as plt
import gc

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras import optimizers

# 2、数据读取+预处理
data=pd.read_excel('Inshorts Cleaned Data.xlsx')
 
def data_preprocess(data):
    df=data.drop(['Publish Date','Time ','Headline'],axis=1).copy()
    df.rename(columns={'Source ':'Source'},inplace=True)
    df=df[df.Source.isin(['YouTube','India Today'])].reset_index(drop=True)
    df['y']=np.where(df.Source=='YouTube',1,0)
    df=df.drop(['Source'],axis=1)
    return df
 
df=data.pipe(data_preprocess)
print(df.shape)
df.head()

# 导入英文停用词
from nltk.corpus import stopwords  
from nltk.tokenize import sent_tokenize
stop_english=stopwords.words('english')  
stop_spanish=stopwords.words('spanish') 
stop_english

# 4、文本预处理:处理简写、小写化、去除停用词、词性还原
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords  
from nltk.tokenize import sent_tokenize
import nltk
 
def replace_abbreviation(text):
    
    rep_list=[
        ("it's", "it is"),
        ("i'm", "i am"),
        ("he's", "he is"),
        ("she's", "she is"),
        ("we're", "we are"),
        ("they're", "they are"),
        ("you're", "you are"),
        ("that's", "that is"),
        ("this's", "this is"),
        ("can't", "can not"),
        ("don't", "do not"),
        ("doesn't", "does not"),
        ("we've", "we have"),
        ("i've", " i have"),
        ("isn't", "is not"),
        ("won't", "will not"),
        ("hasn't", "has not"),
        ("wasn't", "was not"),
        ("weren't", "were not"),
        ("let's", "let us"),
        ("didn't", "did not"),
        ("hadn't", "had not"),
        ("waht's", "what is"),
        ("couldn't", "could not"),
        ("you'll", "you will"),
        ("i'll", "i will"),
        ("you've", "you have")
    ]
    result = text.lower()
    for word_replace in rep_list:
        result=result.replace(word_replace[0],word_replace[1])
#     result = result.replace("'s", "")
    
    return result
 
def drop_char(text):
    result=text.lower()
    result=re.sub('[^\w\s]',' ',result) # 去掉标点符号、特殊字符
    result=re.sub('\s+',' ',result) # 多空格处理为单空格
    return result
 
def stemed_words(text,stop_words,lemma):
    
    word_list = [lemma.lemmatize(word, pos='v') for word in text.split() if word not in stop_words]
    result=" ".join(word_list)
    return result
 
def text_preprocess(text_seq):
    stop_words = stopwords.words("english")
    lemma = WordNetLemmatizer()
    
    result=[]
    for text in text_seq:
        if pd.isnull(text):
            result.append(None)
            continue
        text=replace_abbreviation(text)
        text=drop_char(text)
        text=stemed_words(text,stop_words,lemma)
        result.append(text)
    return result
 
df['short']=text_preprocess(df.Short)
df[['Short','short']]

# 5、划分训练、测试集
test_index=list(df.sample(2000).index)
df['label']=np.where(df.index.isin(test_index),'test','train')
df['label'].value_counts()

2、文本序列编码

        按照词频排序,创建长度为6000的高频词词典、来对文本进行序列化编码。

from tensorflow.keras.preprocessing.text import Tokenizer
def word_dict_fit(train_text_list,num_words):
    '''
        train_text_list: ['some thing today ','some thing today2']
    '''
    tok_params={
        'num_words':num_words,  # 词典的长度,仅保留词频top的num_words个词
        'filters':'!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
        'lower':True, 
        'split':' ', 
        'char_level':False, 
        'oov_token':None, # 设定词典外的词编码
    }
    tok = Tokenizer(**tok_params) # 分词
    tok.fit_on_texts(train_text_list)
    
    return tok

def word_dict_apply_sequences(tok_model,text_list,len_vec):
    '''
        text_list: ['some thing today ','some thing today2']
    '''
    list_tok = tok_model.texts_to_sequences(text_list) # 编码映射
    
    pad_params={
        'sequences':list_tok,
        'maxlen':len_vec,  # 补全后向量长度
        'padding':'pre', # 'pre' or 'post',在前、在后补全
        'truncating':'pre', # 'pre' or 'post',在前、在后删除长度多余的部分
        'value':0, # 补全0
    }
    seq_tok = pad_sequences(**pad_params) # 补全编码向量,返回二维array
    return seq_tok

num_words,len_vec = 6000,40
tok_model= word_dict_fit(df[df.label=='train'].short,num_words)
tok_train = word_dict_apply_sequences(tok_model,df[df.label=='train'].short,len_vec)
tok_test = word_dict_apply_sequences(tok_model,df[df.label=='test'].short,len_vec)
tok_test

图片

3、LSTM文本分类

        LSTM层的输入是三维张量(batch_size, timesteps, input_dim),所以使用的数据可以是时间序列、也可以是文本数据的embedding;输出设置return_sequences为False,返回尺寸为 (batch_size, units) 的 2D 张量。

'''
LSTM层核心参数
    units:输出维度
    activation:激活函数
    recurrent_activation: RNN循环激活函数
    use_bias: 布尔值,是否使用偏置项
    dropout:0~1之间的浮点数,神经元失活比例
    recurrent_dropout:0~1之间的浮点数,循环状态的神经元失活比例
    return_sequences: True时返回RNN全部输出序列(3D),False时输出序列的最后一个输出(2D)
'''
def init_lstm_model(max_features, embed_size):
    model = Sequential()
    model.add(Embedding(input_dim=max_features, output_dim=embed_size))
    model.add(Bidirectional(LSTM(units=32,activation='relu', recurrent_dropout=0.1)))
    model.add(Dropout(0.25,seed=1))
    model.add(Dense(64))
    model.add(Dropout(0.3,seed=1))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

def model_fit(model, x, y,test_x,test_y):
    return model.fit(x, y, batch_size=100, epochs=2, validation_data=(test_x,test_y))

embed_size = 128
lstm_model=init_lstm_model(num_words, embed_size)
model_train=model_fit(lstm_model,tok_train,np.array(df[df.label=='train'].y),tok_test,np.array(df[df.label=='test'].y))
lstm_model.summary()

def model_fit(model, x, y,test_x,test_y):
    return model.fit(x, y, batch_size=100, epochs=2, validation_data=(test_x,test_y))

embed_size = 128
lstm_model=init_lstm_model(num_words, embed_size)
model_train=model_fit(lstm_model,tok_train,np.array(df[df.label=='train'].y),tok_test,np.array(df[df.label=='test'].y))
lstm_model.summary()

 

def ks_auc_value(y_value,y_pred):
    fpr,tpr,thresholds= roc_curve(list(y_value),list(y_pred))
    ks=max(tpr-fpr)
    auc= roc_auc_score(list(y_value),list(y_pred))
    return ks,auc

print('train_ks_auc',ks_auc_value(df[df.label=='train'].y,lstm_model.predict(tok_train)))
print('test_ks_auc',ks_auc_value(df[df.label=='test'].y,lstm_model.predict(tok_test)))

'''
    train_ks_auc (0.7223217797649937, 0.922939132379851)
    test_ks_auc (0.7046603930606234, 0.9140880065296716)
'''

4、LSTM+Attention文本分类

        在LSTM层之后添加Attention层优化效果。

from tensorflow.keras.models import Model
def init_lstm_model(max_features, embed_size ,embedding_matrix):
    input_=layers.Input(shape=(40,))
    x=Embedding(input_dim=max_features, output_dim=embed_size,weights=[embedding_matrix],trainable=False)(input_)
    x=Bidirectional(layers.LSTM(units=32,activation='relu', recurrent_dropout=0.1,return_sequences=True))(x)
    x=layers.Attention(40)([x,x])
    x=Dropout(0.25)(x)
    x=layers.Flatten()(x)
    x=Dense(64)(x)
    x=Dropout(0.3)(x)
    x=Dense(1,activation='sigmoid')(x)
    model = Model(inputs=input_, outputs=x)

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

def model_fit(model, x, y,test_x,test_y):
    return model.fit(x, y, batch_size=100, epochs=5, validation_data=(test_x,test_y))

num_words,embed_size = 6000,128
lstm_model2=init_lstm_model(num_words, embed_size ,embedding_matrix)
model_train=model_fit(lstm_model2,tok_train,np.array(df[df.label=='train'].y),tok_test,np.array(df[df.label=='test'].y))

print('train_ks_auc',ks_auc_value(df[df.label=='train'].y,gru_model.predict(tok_train)))
print('test_ks_auc',ks_auc_value(df[df.label=='test'].y,gru_model.predict(tok_test)))
'''
    train_ks_auc (0.7126925954159541, 0.9199721561742299)
    test_ks_auc (0.7239373279559567, 0.917086274086166)
'''

三、划重点

少走10年弯路

        关注威信公众号 Python风控模型与数据分析,回复 文本分类5 获取本篇数据及代码

        还有更多理论、代码分享等你来拿

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1372775.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

66、python - 代码仓库介绍

上一节,我们可以用自己手写的算法以及手动搭建的神经网络完成预测了,不知各位同学有没有自己尝试来预测一只猫或者一只狗,看看准确度如何? 本节应一位同学的建议,来介绍下 python 代码仓库的目录结构,以及每一部分是做什么? 我们这个小课的代码实战仓库链接为:cv_lea…

springboot医院信管系统源码和论文

随着信息技术和网络技术的飞速发展&#xff0c;人类已进入全新信息化时代&#xff0c;传统管理技术已无法高效&#xff0c;便捷地管理信息。为了迎合时代需求&#xff0c;优化管理效率&#xff0c;各种各样的管理系统应运而生&#xff0c;各行各业相继进入信息管理时代&#xf…

【Linux 内核源码分析笔记】系统调用

在Linux内核中&#xff0c;系统调用是用户空间程序与内核之间的接口&#xff0c;它允许用户空间程序请求内核执行特权操作或访问受保护的内核资源。系统调用提供了一种安全可控的方式&#xff0c;使用户程序能够利用内核功能而不直接访问底层硬件。 系统调用&#xff1a; 通过…

代理IP连接不上/网速过慢?如何应对?

当您使用代理时&#xff0c;您可能会遇到不同的代理错误代码显示代理IP连不通、访问失败、网速过慢等种种问题。 在本文中中&#xff0c;我们将讨论您在使用代理IP时可能遇到的常见错误、发生这些错误的原因以及解决方法。 一、常见代理服务器错误 当您尝试访问网站时&#…

关于Geek软件的下载

直接百度搜geek出来的前几条似乎都是广告&#xff1a; 点进去之后是这个界面&#xff1a; 然后安装到最后一步提示要付费才能安装成功&#xff1a; 然后如果是用谷歌搜索&#xff1a; 有free版和pro版&#xff1a; free版下载之后压缩包解压就是exe不需要安装 综上&#xff0c…

金蝶EAS pdfviewlocal 任意文件读取漏洞

产品简介 金蝶EAS 为集团型企业提供功能全面、性能稳定、扩展性强的数字化平台&#xff0c;帮助企业链接外部产业链上下游&#xff0c;实现信息共享、风险共担&#xff0c;优化生态圈资源配置&#xff0c;构筑产业生态的护城河&#xff0c;同时打通企业内部价值链的数据链条&a…

【leetcode】力扣热门算法之K个一组翻转链表【困难】

题目描述 给你链表的头节点 head &#xff0c;每 k 个节点一组进行翻转&#xff0c;请你返回修改后的链表。 k 是一个正整数&#xff0c;它的值小于或等于链表的长度。如果节点总数不是 k 的整数倍&#xff0c;那么请将最后剩余的节点保持原有顺序。 你不能只是单纯的改变节…

JS-基础语法(一)

JavaScript简单介绍 变量 常量 数据类型 类型转换 案例 1.JavaScript简单介绍 JavaScript 是什么&#xff1f; 是一种运行在客户端&#xff08;浏览器&#xff09;的编程语言&#xff0c;可以实现人机交互效果。 JS的作用 JavaScript的组成 JSECMAScript( 基础语法 )…

JavaSE 反射、枚举及Lambda的使用

目录 1 反射1.1 定义1.2 用途1.3 反射基本信息1.4 反射相关的类1.4.1 Class类(反射机制的起源 )1.4.1.1 Class类中的相关方法 1.4.2 反射示例1.4.2.1 获得Class对象的三种方式1.4.2.2 反射的使用 1.5 优缺点 2 枚举2.1 背景及定义2.2 使用2.3 优缺点2.4 枚举和反射2.5 总结2.6 …

调用导致堆栈不对称。原因可能是托管的 PInvoke 签名与非托管的目标签名不匹配。请检查 PInvoke 签名的调用约定和参数与非托管的目标签名是否匹配

作者推荐 【动态规划】C算法312 戳气球 关键字&#xff1a; 函数调用约定 混合编程 __stdcall c WINAPI APIENTRY _cdecl 调用方出错提示如下&#xff1a; 调用导致堆栈不对称。原因可能是托管的 PInvoke 签名与非托管的目标签名不匹配。请检查 PInvoke 签名的调用约定和参…

C++qt-信号-信号槽

1、概念 信号和槽是两种函数&#xff0c;这是Qt在C基础上新增的特性&#xff0c;类似于其他技术中的回调的概念。 信号和槽通过程序员提前设定的“约定”&#xff0c;可以实现对象之间的通信&#xff0c;有两个先决的条件&#xff1a; 通信的对象必须都是从QObject类中派生出来…

threejs 光带扩散动画

目录 一、创建光带 (1) 设置光带顶点 (2) 设置光带顶点透明度属性 二、光带动画 完整代码 html文件代码 js文件代码 最后展示一下项目里的效果&#xff1a; 最近项目中要求做一段光带效果动画&#xff0c;尝试着写了一下&#xff0c;下面是本次分享光带扩散动画的效果预…

地铁判官(外包)

到处都是说外包不好不好的&#xff0c;从没有想过自身问题。 例如&#xff1a; 技术人员动不动就是说&#xff0c;进了外包三天&#xff0c;一年&#xff0c;三年之后技术退步很多。就算你这样的人进了甲方&#xff0c;也是个渣渣。(声明一下&#xff0c;我也是外包&#xff0…

CMU15-445-Spring-2023-Project #2 - 前置知识(lec07-010)

Lecture #07_ Hash Tables Data Structures Hash Table 哈希表将键映射到值。它提供平均 O (1) 的操作复杂度&#xff08;最坏情况下为 O (n)&#xff09;和 O (n) 的存储复杂度。 由两部分组成&#xff1a; Hash Function和Hashing Scheme&#xff08;发生冲突后的处理&…

阿里云99元一年2核2G3M云服务器值得买吗?

阿里云作为国内领先的云服务提供商&#xff0c;一直致力于为用户提供优质、高效的服务。目前&#xff0c;阿里云推出的99元一年2核2G3M云服务器&#xff0c;更是引发了广大用户的关注。本文将详细解析这款云服务器的特点、优势以及适用场景&#xff0c;为大家上云提供参考。 一…

Android逆向学习(六)绕过app签名校验,通过frida,io重定向(上)

Android逆向学习&#xff08;六&#xff09;绕过app签名校验&#xff0c;通过frida&#xff0c;io重定向&#xff08;上&#xff09; 一、写在前面 这是吾爱破解正己大大教程的第五个作业&#xff0c;然后我的系统还是ubuntu&#xff0c;建议先看一下上一个博客&#xff0c;关…

阿赵UE学习笔记——8、贴图导入设置

阿赵UE学习笔记目录 大家好&#xff0c;我是阿赵。   继续学习虚幻引擎的用法&#xff0c;这次来说一下贴图的导入设置。   在内容浏览器里面可以看到纹理类型的资源&#xff0c;就是贴图了&#xff0c;鼠标悬浮在上面可以看到这个纹理贴图的信息&#xff1a; 双击纹理贴图…

使用Scikit Learn 进行识别手写数字

使用Scikit Learn 进行识别手写数字 作者&#xff1a;i阿极 作者简介&#xff1a;数据分析领域优质创作者、多项比赛获奖者&#xff1a;博主个人首页 &#x1f60a;&#x1f60a;&#x1f60a;如果觉得文章不错或能帮助到你学习&#xff0c;可以点赞&#x1f44d;收藏&#x1f…

虽迟但到!MySQL 可以用 JavaScript 写存储过程了!

任何能用 JavaScript 来干的事情&#xff0c;最终都会用 JavaScript 来干 背景 不久前&#xff0c;Oracle 在 MySQL 官方博客官宣了在 MySQL 中支持用 JavaScript 来写存储过程。 最流行的编程语言 最流行的数据库。程序员不做选择&#xff0c;当然是全都要。 使用方法 用 J…

压测必经之路,Jmeter分布式压测教程

01、分布式压测原理 Jemter分布式压测是选择其中一台作为调度机&#xff08;master&#xff09;&#xff0c;其他机器作为执行机&#xff08;slave&#xff09;&#xff1b;当然一台机器也可以既做调度机&#xff0c;也做执行机。 调度机执行脚本的时候&#xff0c;master将会…