LSTM简单介绍—然后使用LSTM对FashionMNIST数据集处理

news2024/11/27 0:45:52

文章目录

  • LSTM 简单介绍
    • LSTM的基本结构
    • LSTM的工作原理
      • 输入门
      • 遗忘门
      • 输出门
      • 细胞状态更新
      • 输出计算
    • 总结
    • 代码实例

LSTM 简单介绍

在自然语言处理、语音识别等领域,长短时记忆网络 (Long Short-Term Memory, LSTM) 已经成为了常用的模型之一。本文将介绍 LSTM 的基本结构和工作原理,帮助读者入门 LSTM。

LSTM的基本结构

LSTM 是一种递归神经网络(Recurrent Neural Network, RNN)。与传统 RNN 不同,LSTM 在网络的循环单元中引入了三个门控,以控制输入、输出和记忆的流程。一个 LSTM 单元的基本结构如下图所示:

请添加图片描述

图片来自博客:Pytorch入门之RNN_pytorch rnn_Ton10的博客-CSDN博客

其中, x t x_t xt 表示输入, h t h_t ht 表示输出, c t c_t ct 表示细胞状态(cell state)。 σ \sigma σ 表示 sigmoid 函数, ⊙ \odot 表示逐元素相乘。

从图中可以看到,LSTM 单元有三个门控:输入门、遗忘门和输出门。它们分别控制输入、遗忘和输出的流程。

LSTM的工作原理

下面我们将详细介绍 LSTM 的工作原理。在讲解之前,我们先定义一些符号:

  • x t x_t xt:时刻 t t t 的输入;
  • h t h_t ht:时刻 t t t 的输出;(输出t时刻的隐藏状态,该状态在最终时刻是LSTM系统的输出)
  • c t c_t ct:时刻 t t t 的细胞状态(cell state);
  • i t i_t it:时刻 t t t 的输入门(input gate);
  • f t f_t ft:时刻 t t t 的遗忘门(forget gate);
  • o t o_t ot:时刻 t t t 的输出门(output gate);
  • W W W:权重矩阵;
  • b b b:偏置向量;
  • σ ( x ) \sigma(x) σ(x):sigmoid 函数;
  • tanh ⁡ ( x ) \tanh(x) tanh(x):tanh 函数。

输入门

输入门控制着当前输入 x t x_t xt 对于细胞状态的影响程度。输入门的计算方式如下:
i t = σ ( W i x t + U i h t − 1 + b i ) i_t=\sigma(W_ix_t+U_ih_{t-1}+b_i) it=σ(Wixt+Uiht1+bi)
其中, W i W_i Wi U i U_i Ui 分别是输入 x t x_t xt 和上一时刻输出 h t − 1 h_{t-1} ht1 的权重矩阵, b i b_i bi 是输入门的偏置向量。 σ \sigma σ 表示 sigmoid 函数。

输入门的输出结果将与 c t ~ \tilde{c_t} ct~ 相乘,作为新的细胞状态的一部分。其中, c t ~ \tilde{c_t} ct~ 的计算方式如下:
c t ~ = t a n h ( W c x t + U c h t − 1 + b c ) \tilde{c_t}=\mathrm{tanh}(W_cx_t+U_ch_{t-1}+b_c) ct~=tanh(Wcxt+Ucht1+bc)
公式中的符号代表的含义如下:

  • x t x_t xt:表示当前时刻的输入,即该时刻的输入特征向量。
  • h t − 1 h_{t-1} ht1:表示上一个时刻的输出,即输出特征向量。
  • W i W_i Wi U i U_i Ui:分别是输入 x t x_t xt 和上一时刻输出 h t − 1 h_{t-1} ht1 的权重矩阵,用于计算输入门的输出 i t i_t it
  • b i b_i bi:表示输入门的偏置向量,用于计算输入门的输出 i t i_t it
  • σ \sigma σ:表示 sigmoid 函数,用于将输入门的输入值 W i x t + U i h t − 1 + b i W_ix_t+U_ih_{t-1}+b_i Wixt+Uiht1+bi 映射到一个介于 0 到 1 之间的范围内,表示当前输入 x t x_t xt 对于细胞状态的影响程度。
  • i t i_t it:表示输入门的输出,代表当前输入对于细胞状态的影响程度。
  • c t ~ \tilde{c_t} ct~:表示更新后的细胞状态的一部分,用于更新当前时刻的细胞状态。
  • W c W_c Wc U c U_c Uc:分别是输入 x t x_t xt 和上一时刻输出 h t − 1 h_{t-1} ht1 的权重矩阵,用于计算 c t ~ \tilde{c_t} ct~
  • b c b_c bc:表示 c t ~ \tilde{c_t} ct~ 的偏置向量。
  • t a n h \mathrm{tanh} tanh:表示双曲正切函数,用于将 c t ~ \tilde{c_t} ct~ 的输入值 W c x t + U c h t − 1 + b c W_cx_t+U_ch_{t-1}+b_c Wcxt+Ucht1+bc 映射到介于 -1 到 1 之间的范围内,表示当前时刻的输入和上一时刻的输出对于更新后的细胞状态的影响程度。

遗忘门

遗忘门控制着上一时刻的细胞状态 c t − 1 c_{t-1} ct1 对于当前细胞状态 c t c_t ct 的影响程度。遗忘门的计算方式如下:
f t = σ ( W f x t + U f h t − 1 + b f ) f_t=\sigma(W_fx_t+U_fh_{t-1}+b_f) ft=σ(Wfxt+Ufht1+bf)
其中, W f W_f Wf U f U_f Uf 分别是输入 x t x_t xt 和上一时刻输出 h t − 1 h_{t-1} ht1 的权重矩阵, b f b_f bf 是遗忘门的偏置向量。 σ \sigma σ 表示 sigmoid 函数。

遗忘门的输出结果将与上一时刻的细胞状态 c t − 1 c_{t-1} ct1 相乘,作为新的细胞状态的一部分。

输出门

输出门控制着当前细胞状态 c t c_t ct 对于输出 h t h_t ht 的影响程度。输出门的计算方式如下:
o t = σ ( W o x t + W o h t − 1 + b o ) o_t=\sigma(W_ox_t+W_oh_{t-1}+b_o) ot=σ(Woxt+Woht1+bo)
其中, W o W_o Wo U o U_o Uo 分别是输入 x t x_t xt 和上一时刻输出 h t − 1 h_{t-1} ht1 的权重矩阵, b o b_o bo 是输出门的偏置向量。 σ \sigma σ 表示 sigmoid 函数。

输出门的输出结果将与 tanh ⁡ ( c t ) \tanh(c_t) tanh(ct) 相乘,作为当前时刻的输出 h t h_t ht。其中, tanh ⁡ \tanh tanh 表示 tanh 函数。

细胞状态更新

在输入门、遗忘门和输出门的计算过程中,LSTM 还需要更新当前的细胞状态 c t c_t ct。细胞状态的更新方式如下:
c t = f t ⊙ c t − 1 + i t ⊙ c t ~ c_t=f_t\odot c_{t-1}+i_t\odot \tilde{c_t} ct=ftct1+itct~
其中, ⊙ \odot 表示逐元素相乘。 f t ⊙ c t − 1 f_t \odot c_{t-1} ftct1 表示上一时刻的细胞状态通过遗忘门传递到当前时刻, i t ⊙ c t ~ i_t \odot \tilde{c_t} itct~ 表示当前时刻的输入通过输入门影响到细胞状态。最终得到的 c t c_t ct 即为当前时刻的细胞状态。

输出计算

最后,LSTM 的输出计算方式如下:
h t = o t ⊙ t a n h ( c t ) h_t=o_t\odot \mathrm{tanh}(c_t) ht=ottanh(ct)
其中, o t o_t ot 表示输出门的输出结果, tanh ⁡ ( c t ) \tanh(c_t) tanh(ct) 表示细胞状态经过 tanh 函数处理后的结果。最终得到的 h t h_t ht 即为当前时刻的输出。

总结

LSTM 是一种递归神经网络,用于处理序列数据。LSTM 在循环单元中引入了三个门控,以控制输入、输出和记忆的流程。LSTM 的基本结构包括输入门、遗忘门、输出门和细胞状态,通过门控机制实现了对于输入、输出和记忆的精细控制。

希望这份 LSTM 的入门教程能够帮助你更好地理解 LSTM 的原理和运作方式。如果你想深入学习 LSTM,可以了解 LSTM 的变种结构,如 peephole LSTM、GRU 等,以及应用场景和实现细节。

同时,如果你想进一步了解深度学习,还可以学习其他类型的神经网络,如卷积神经网络、自编码器、生成对抗网络等。神经网络在计算机视觉、自然语言处理、语音识别等领域都得到了广泛的应用,是人工智能领域中不可或缺的技术之一。

代码实例

使用PyTorch框架写的LSTM网络,用于对FashionMNIST数据处理。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

# 设置设备为GPU,只有当有可用GPU时才使用,否则使用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义LSTM模型
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        # 初始化隐藏状态和细胞状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        
        # 前向传播 LSTM
        out, _ = self.lstm(x, (h0, c0))  
        
        # 取最后一个时间步的输出
        out = self.fc(out[:, -1, :])
        return out

# 超参数
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 5

# 加载FashionMNIST数据集并进行预处理
train_dataset = torchvision.datasets.FashionMNIST(root='./data', 
                                                  train=True, 
                                                  transform=transforms.ToTensor(),
                                                  download=True)

test_dataset = torchvision.datasets.FashionMNIST(root='./data', 
                                                 train=False, 
                                                 transform=transforms.ToTensor())

# 数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

# 初始化模型并移动模型到GPU(如果可用)
model = LSTM(input_size, hidden_size, num_layers, num_classes).to(device)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # 将图像数据移动到GPU(如果可用)
        images = images.reshape(-1, 28, 28).to(device)
        labels = labels.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化器步骤
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 每隔100个批次打印一次训练信息
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# 测试模型
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        # 将图像数据移动到GPU(如果可用)
        images = images.reshape(-1, 28, 28).to(device)
        labels = labels.to(device)
        
        # 前向传播
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        
        # 统计预测正确的样本数和总样本数
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/460567.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

gpt在线使用-免费的 GPT在哪下载

免费的 GPT(Generative Pre-trained Transformer) 。现在您可以免费体验我们的 GPT 技术,来让您的业务或项目更加智能。 GPT 是一种基于最前沿的自然语言处理技术,它展现出了令人惊叹的预测能力和交互性能。我们的 GPT 是在世界顶…

警惕读书无用论,要知道一个人最可怕的就是精神世界的贫瘠和荒凉

孔乙已是鲁迅笔下人物,穷困流倒还穿着象征读书人的长衫,迁腐、麻木。最近,大家自我调佩是“当代孔乙己”,学历成为思想负担,找工作时高不成低不就。 一、社会对于学历和职业之间的关系认知是怎样的? 学历不…

Forefront GPT-4免费版:开启无限畅聊时代,乐享人工智能快感,无限制“白嫖”,还能和N多角色一起聊天?赶紧注册,再过些时间估计就要收费了

目录 前言注册登录方式应用体验聊天体验绘图体验 “是打算先免费后收费吗?”建议其它资料下载 前言 近期,人工智能技术迎来重大飞跃,OpenAI的ChatGPT等工具成为全球数亿人探索提高生产力和增强创造力的新方法。人们现在可以使用人工智能驱动…

绩效管理系统有哪些推荐?

绩效管理系统有哪些推荐?市面上的绩效管理系统五花八门,这就来给大家推荐几款优质的! 一、如何选择绩效管理系统 在选择绩效管理系统之前,需要先考虑以下几个问题: 了解你的企业目标和需求:在选择绩效管…

nacos注册中心替换成eureka

背景 项目使用的springcloud、nacos、redis等插件,但是nacos比较重,小项目使用不到,想用一个tomcat部署项目,所以准备用eureka替换nacos; eureka Eureak 是Netflix 开源微服务框架中一系列项目中的一个。Spring Clo…

JVM 垃圾收集器

一,常用的垃圾收集器 如果说收集算法是内存回收的方法论,那么垃圾收集器就是内存回收的具体实现。 如下图为年轻代和老年代的垃圾回收器,划线表示可以同时存在。 1,Serial Serial收集器是最基本、发展历史最悠久的收集器&…

怎么把录音文件转换成mp3格式,3个高效方法

在工作中,我们可能会选择录音来记录会议内容,以便之后整理会议纪要。但是我们知道录音文件的格式千差万别。比如在手机上录制的音频文件通常以M4A、WAV等多种格式存储,然而这些格式可能会存在不兼容的问题,导致我们无法在其他平台…

openEuler Developer Day 2023成功召开!发布嵌入式商业版本及多项成果

【中国,上海,2023年4月21日】openEuler Developer Day 2023于4月20-21日在线上和线下同步举办。本次大会由开放原子开源基金会指导,中国软件行业协会、openEuler社区、边缘计算产业联盟共同主办,以“万涓汇流,奔涌向前…

3DEXPERIENCE MODSIM产品前期概念结构快速开发方案(下) | 达索系统百世慧®

基于3DEXPERIENCE单一数据源、实时多专业协同平台、附加全新CATIA建模方法与MODSIM建模仿真一体化技术,助力产品设计与仿真效率提升,产品多学科性能提升,产品轻量化减重等,全方位赋能产品前期概念结构高效高质开发。 目录 达索系…

利用css实现视差滚动和抖动效果

背景: 前端的设计效果,越来越炫酷,而这些炫酷的效果,利用css3的动画效果和js就可以实现,简单的代码就能实现非常炫酷的效果。 原理: 利用 js监控scrollTop的位置,通过 top定位图片的位置&#x…

halcon灰度积分投影/垂直积分投影

简介:关于灰度投影积分可以用到的场合很多,例如分割字符,分割尺子上的刻度等,适用于有规律的变化这些内容的检测。本文复现了论文《基于深度学习和灰度纹理特征的铁路接触网绝缘子状态检测》中灰度积分投影实现了对绝缘子缺陷位置的检测。见(图1)灰度积分垂直方向投影获得…

JAVAWeb09-WEB 工程路径专题

1. 工程路径问题 先看一个问题 index.html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>base 标签</title> </head> <body> <h1>注册用户~~</h1> <!--解读:1.…

创建型模式-建造者模式

建造者模式 概述 将一个复杂对象的构建与表示分离&#xff0c;使得同样的构建过程可以创建不同的表示 这个模式适用于&#xff1a;某个对象的构建过程复杂的情况 将部件的构造与装配分离&#xff0c;由 Builder 负责构造&#xff0c;Director 进行装配&#xff0c;实现了构…

LeetCode——新手村

目录 前言 一、一维数组的动态和 1、题目 2、代码 二、将数字变成 0 的操作次数 1、题目 2、代码 三、最富有客户的资产总量 1、题目 2、代码 四、Fizz Buzz 1、题目 2、代码 五、链表的中间结点 1、题目 2、代码 六、赎金信 1、题目 2、代码 前言 注册了一个LeetCode的…

10、Mysql常见面试题

Mysql常见面试题 文章目录 Mysql常见面试题一 Mysql索引001 Mysql如何实现的索引机制&#xff1f;002 InnoDB索引与MyISAM索引实现的区别是什么&#xff1f;003 一个表中如果没有创建索引&#xff0c;那么还会创建B树吗&#xff1f; 004 说一下B树索引实现原理&#xff08;数据…

毕业5年的同学突然告诉我,他已经是年薪30W的自动化测试工程师....

作为一名程序员&#xff0c;都会对自己未来的职业发展而焦虑。一方面是因为IT作为知识密集型的行业&#xff0c;知识体系复杂且知识更新速度非常快&#xff0c;“一日不学就会落后”。 另外一方面&#xff0c;IT又是劳动密集型的行业&#xff0c;不仅业人员多&#xff0c;而且个…

8个你可能不知道的令人震惊的 HTML 技巧

大厂面试题分享 面试题库 前后端面试题库 &#xff08;面试必备&#xff09; 推荐&#xff1a;★★★★★ 地址&#xff1a;前端面试题库 web前端面试题库 VS java后端面试题库大全 1. 捕获属性打开你的设备摄像头 正如 input 标记具有 email、 text 和 password 属性一样&…

Unity音量滑块沿弧形移动

一、音量滑块的移动 1、滑块在滑动的时候&#xff0c;其运动轨迹沿着大圆的弧边展开 2、滑块不能无限滑动&#xff0c;而是两端各有一个挡块&#xff0c;移动到挡块位置&#xff0c;则不能往下移动&#xff0c;但可以折回 3、鼠标悬停滑块时&#xff0c;给出音量值和操作提示 …

JMeter 获取登录接口的token

1、登录接口为POST请求方式&#xff0c;添加请求登录接口的消息体数据 添加HTTP信息头管理器&#xff0c;配置content-type值为application/json 2、给登录接口“添加监听器-查看结果树”和“后置处理器-正则表达式处理器” 先运行一次登录接口&#xff0c;通过查看结果树返回内…

C++三大特性—继承 “访问控制”

本文主要阐述关于C继承中基类与派生类之间的访问关系 继承方式与访问方式 继承定义格式&#xff1a; 派生类可以继承定义在基类的成员&#xff0c;但是派生类的成员函数不一定有权访问从基类继承来的成员    访问限定符的作用&#xff1a;控制派生类从基类继承而来的成员是否…