深度学习--LSTM网络、使用方法、实战情感分类问题

news2024/11/17 9:48:14

1.LSTM基础

长短期记忆网络(Long Short-Term Memory,简称LSTM),是RNN的一种,为了解决RNN存在长期依赖问题而设计出来的。

LSTM的基本结构:

网络图

2.LSTM的具体说明

LSTM与RNN的结构相比,在参数更新的过程中,增加了三个门,由左到右分别是遗忘门(也称记忆门)、输入门、输出门。

图片来源:

LSTM的工作原理究竟是什么?深入了解LSTM-电子发烧友网

1.点乘操作决定多少信息可以传送过去,当为0时,不传送;当为1时,全部传送。

2.1 遗忘门

对于输入xt和ht-1,遗忘门会输出一个值域为[0, 1]的数字,放进Ct−1中。当为0时,全部删除;当为1时,全部保留。

遗忘门

2.2 输入门

对于对于输入xt和ht-1,输入门会选择信息的去留,并且通过tanh激活函数更新临时Ct

输入门

通过遗忘门和输入门输出累加,更新最终的Ct

更新Ct

2.3输出门

通过Ct和输出门,更新memory

输出门

3.PyTorch的LSTM使用方法

  1. __ init __(input _ size, hidden_size,num _layers)

  2. LSTM.foward():

​ out,[ht,ct] = lstm(x,[ht-1,ct-1])

​ x:[一句话单词数,batch几句话,表示的维度]

​ h/c:[层数,batch,记忆(参数)的维度]

​ out:[一句话单词数,batch,参数的维度]

 
import torch
import torch.nn as nn
lstm = nn.LSTM(input_size = 100,hidden_size = 20,num_layers = 4)
print(lstm)
#LSTM(100, 20, num_layers=4)
x = torch.randn(10,3,100)
out,(h,c)=lstm(x)
print(out.shape,h.shape,c.shape)
#torch.Size([10, 3, 20]) torch.Size([4, 3, 20]) torch.Size([4, 3, 20])

单层使用方法:

 
cell = nn.LSTMCell(input_size = 100,hidden_size=20)
x = torch.randn(10,3,100)
h = torch.zeros(3,20)
c = torch.zeros(3,20)
for xt in x:
h,c = cell(xt,[h,c])
print(h.shape,c.shape)
#torch.Size([3, 20]) torch.Size([3, 20])

LSTM实战--情感分类问题

Google CoLab环境,需要魔法。

 
import torch
from torch import nn, optim
from torchtext import data, datasets
print('GPU:', torch.cuda.is_available())
torch.manual_seed(123)
TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField(dtype=torch.float)
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
print('len of train data:', len(train_data))
print('len of test data:', len(test_data))
print(train_data.examples[15].text)
print(train_data.examples[15].label)
# word2vec, glove
TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')
LABEL.build_vocab(train_data)
batchsz = 30
device = torch.device('cuda')
train_iterator, test_iterator = data.BucketIterator.splits(
(train_data, test_data),
batch_size = batchsz,
device=device
)
class RNN(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
"""
"""
super(RNN, self).__init__()
# [0-10001] => [100]
self.embedding = nn.Embedding(vocab_size, embedding_dim)
# [100] => [256]
self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2,
bidirectional=True, dropout=0.5)
# [256*2] => [1]
self.fc = nn.Linear(hidden_dim*2, 1)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
"""
x: [seq_len, b] vs [b, 3, 28, 28]
"""
# [seq, b, 1] => [seq, b, 100]
embedding = self.dropout(self.embedding(x))
# output: [seq, b, hid_dim*2]
# hidden/h: [num_layers*2, b, hid_dim]
# cell/c: [num_layers*2, b, hid_di]
output, (hidden, cell) = self.rnn(embedding)
# [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]
hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
# [b, hid_dim*2] => [b, 1]
hidden = self.dropout(hidden)
out = self.fc(hidden)
return out
rnn = RNN(len(TEXT.vocab), 100, 256)
pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.')
optimizer = optim.Adam(rnn.parameters(), lr=1e-3)
criteon = nn.BCEWithLogitsLoss().to(device)
rnn.to(device)
import numpy as np
def binary_acc(preds, y):
"""
get accuracy
"""
preds = torch.round(torch.sigmoid(preds))
correct = torch.eq(preds, y).float()
acc = correct.sum() / len(correct)
return acc
def train(rnn, iterator, optimizer, criteon):
avg_acc = []
rnn.train()
for i, batch in enumerate(iterator):
# [seq, b] => [b, 1] => [b]
pred = rnn(batch.text).squeeze(1)
#
loss = criteon(pred, batch.label)
acc = binary_acc(pred, batch.label).item()
avg_acc.append(acc)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i%10 == 0:
print(i, acc)
avg_acc = np.array(avg_acc).mean()
print('avg acc:', avg_acc)
def eval(rnn, iterator, criteon):
avg_acc = []
rnn.eval()
with torch.no_grad():
for batch in iterator:
# [b, 1] => [b]
pred = rnn(batch.text).squeeze(1)
#
loss = criteon(pred, batch.label)
acc = binary_acc(pred, batch.label).item()
avg_acc.append(acc)
avg_acc = np.array(avg_acc).mean()
print('>>test:', avg_acc)
for epoch in range(10):
eval(rnn, test_iterator, criteon)
train(rnn, train_iterator, optimizer, criteon)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/920674.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

聊一下最近有个网红铁头惩恶扬善举报新东方校外补课引起争议

最近有个网红,铁头打假,举报新东方校外补课上了热搜,引起了争议 最近他自己在一次直播带货当中,翻车了的 铁头敢端了学生的课桌,家长就敢掀了他的直播间 而因自己,我不用读书,我有社会经验&…

双碳目标下DNDC模型教程

详情点击链接:双碳目标下DNDC模型建模方法及在土壤碳储量、温室气体排放、农田减排、土地变化、气候变化中的实践技术应用教程 前沿 碳循环的精确模拟是实现“双碳”行动的关键。DNDC(Denitrification-Decomposition,反硝化-分解模型&#…

前端三部曲之一HTML

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

同样都是手机卡,为什么线下的手机卡和线上的手机卡差距这么大?

大家好,我是搜卡之家,今天这篇文章就带大家了解一下线上流量卡和线下流量卡有哪些区别? ​ 众所周知,如果我们在营业厅办理流量业务,30G的流量不管是哪个运营商可能就需要将近100块钱,是为什么线上申请的流…

简单聊聊uniapp和uview组件库一起开发

简单的聊聊uniapp和uview组件库的开发 uniapp是一个基于Vue.js的跨平台开发框架,可以同时开发H5、微信小程序、App等多个平台的应用。这样可以减少开发人员的工作量,提高开发效率。 官网:https://uniapp.dcloud.net.cn/ uView是uni-app生态…

hive-列转行

转成 select customer_code,product_type from temp.temp_xx LATERAL VIEW explode(SPLIT(product_types,,)) table_tmp AS product_type where customer_code K100515182

DNS解析中的A记录、AAAA记录、CNAME记录、MX记录、NS记录、TXT记录、SRV记录、URL转发等

DNS解析中的A记录、AAAA记录、CNAME记录、MX记录、NS记录、TXT记录、SRV记录、URL转发等 1. DNS域名解析中添加的各项解析记录2. DNS解析中一些问题简要的介绍DNS 的 SOA记录:参考资料 域名注册完成后首先需要做域名解析,域名解析就是把域名指向网站所在…

数组习题答案

基础题目 第一题:需求实现 模拟大乐透号码: 一组大乐透号码由10个1-99之间的数字组成定义方法,打印大乐透号码信息 代码实现,效果如图所示: 开发提示: 使用数组保存录入的号码 参考答案: p…

浅析三维模型OBJ格式轻量化处理常见问题与处理措施

浅析三维模型OBJ格式轻量化处理常见问题与处理措施 在三维模型OBJ格式轻量化处理过程中,可能会遇到一些问题。以下是一些常见问题以及相应的解决方法: 1、文件大小过大: OBJ格式的三维模型文件通常包含大量的顶点、面片和纹理信息&#xff0…

【Windows iTunes】Windows 10 下如何不通过 Microsoft Store 下载 iTunes,Apple 官网直链下载,图文教程

目录 写在前头(解决办法)图文教程  第一步 搜索  第二步 下载 写在前头(解决办法) 在 Apple 官网(https://www.apple.com.cn/)搜索“ iTunes 下载 ”,进入下载页面(https://www.…

【深入理解Linux内核锁】四、自旋锁

我的圈子: 高级工程师聚集地 我是董哥,高级嵌入式软件开发工程师,从事嵌入式Linux驱动开发和系统开发,曾就职于世界500强企业! 创作理念:专注分享高质量嵌入式文章,让大家读有所得! 文章目录 1、什么是自旋锁?2、自旋锁思想3、自旋锁的定义及实现3.1 API接口3.2 API实…

LAMP配置与应用

目录 一、LAMP架构的组成 1、WEB资源类型 2、LAMP架构的组成 二、编译安装LAMP 编译安装apache 1、环境准备 2、导入apache相关压缩安装包,然后安装编译环境 3、解压软件包,并移动apr包与apr-util包到安装目录中,并切换到http解压出…

小米汽车开启工人招聘:年产能30万辆,雷军造车目标又迈进一步

新浪科技报道指出,小米汽车近日已经开启了工人招聘,包括涂装操作工、电池车间操作工等多个岗位。这表明小米汽车即将进入生产阶段,也进一步证实了此前关于小米获得造车资质的传闻。 根据小米此前给出的时间表,在2024年上半年&…

Java——一个简单的油耗计算机程序

该代码是一个简单的油耗计算机程序,使用了Java的图形化界面库Swing。具体分析如下: 导入必要的类和包: import javax.swing.*; import java.awt.*;代码中导入了用于创建图形界面的类和其他必要的类。 定义main类: public class f…

基于java+swing贪吃蛇小游戏

基于javaswing贪吃蛇小游戏 一、系统介绍二、效果展示三、其他系统实现四、获取源码 一、系统介绍 项目类型:Java SE项目 项目名称:基于Java的贪吃蛇小游戏(snake_game) 当前版本:V1.0.0版本 运行工具:Eclipse/MyE…

JDBC详解

文章目录 一、引言1.1 如何操作数据库1.2 实际开发中,会采用客户端操作数据库吗? 二、JDBC(Java Database Connectivity)2.1 什么是 JDBC?2.2 JDBC 核心思想2.2.1 MySQL 数据库驱动2.2.2 JDBC API 2.3 环境搭建 三、JD…

【Unity自制手册】游戏基础API大全

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:Uni…

无人机空管电台-中大型无人机远程VHF语音电台系统

方案背景 中大型无人机在执行飞行任务时,特别是在管制空域飞行时地面航管人员需要通过语音与无人机通信。按《无人驾驶航空器飞行管理暂行条例》规定,中大型无人机应当进行适航管理。物流无人机和载人eVTOL都将进行适航管理,所以无人机也要有…

水经微图网页版基础名词

水经微图网页版,可轻松将关注的地点制作成您的个人地图。 您可以在任意位置添加标注点或绘制地图,查找地点并将其保存到您的地图中,或导入地图数据迅速制作地图并保存,您还可以运用图标和颜色展示个性风采,从而可让每…

ACM模式(基础输入输出)

import java.lang.*; import java.util.*; public class Main{public static void main(String[] args){Scanner in new Scanner(System.in);while(in.hasNextInt()){//下一行是否有数据int ain.nextInt();int bin.nextInt();System.out.println(ab);}} }Java方法间的调用 http…