深度学习——LSTM解决分类问题

news2024/11/28 0:33:46

RNN基本介绍

概述

循环神经网络(Recurrent Neural Network,RNN)是一种深度学习模型,主要用于处理序列数据,如文本、语音、时间序列等具有时序关系的数据。

核心思想

RNN的关键思想是引入了循环结构,允许信息在网络内部进行传递。与传统的前馈神经网络(Feedforward Neural Network)不同,RNN在处理序列数据时会保留并利用先前的信息来影响后续的输出。

基本结构

RNN的基本结构是一个被称为“循环单元”(recurrent unit)的模块,它接收输入和先前的隐藏状态,并生成输出和新的隐藏状态。循环单元中的权重参数在时间步之间是共享的,这意味着它可以对序列中的不同位置应用相同的操作。

计算过程

RNN在每个时间步的计算过程如下:
1.接收当前时间步的输入和先前时间步的隐藏状态。
2.使用这些输入和隐藏状态计算当前时间步的输出。
3.更新隐藏状态,以便在下一个时间步使用。

优点

由于RNN具有循环结构,它可以在处理序列数据时保持记忆,并捕捉到序列中的长期依赖关系。这使得RNN在许多任务中表现出色,例如语言建模、机器翻译、语音识别、情感分析等。

缺点

然而,传统的RNN在处理长期依赖时存在梯度消失或梯度爆炸的问题,导致难以捕捉到远距离的依赖关系。

LSTM基本介绍

概述

LSTM(Long Short-Term Memory,长短期记忆网络)是一种循环神经网络(RNN)的改进型结构,用于解决传统RNN中的长期依赖问题。相比于传统的RNN,LSTM引入了门控机制,能够更好地捕捉和处理序列数据中的长期依赖关系。

核心思想

LSTM的核心思想是引入了三个门控单元:输入门(Input Gate)、遗忘门(Forget Gate)和输出门(Output Gate)。这些门控单元允许LSTM网络选择性地保留或丢弃信息,并且在传递信息时能够有效地控制梯度的流动。

基本结构

以下是LSTM中各个门控单元的功能:
1.输入门(Input Gate):决定当前时间步的输入信息中哪些部分需要被记忆。它使用sigmoid函数来产生一个0到1之间的值,描述了每个输入的重要性。
2.遗忘门(Forget Gate):决定之前的隐藏状态中哪些信息需要被遗忘。通过使用sigmoid函数,遗忘门可以控制先前的隐藏状态在当前时间步的重要性。
3.输出门(Output Gate):根据当前时间步的输入和之前的隐藏状态,决定应该输出多少信息到下一个时间步。输出门使用sigmoid函数来控制隐藏状态中的信息量,并使用tanh函数来生成当前时间步的输出。

优点

通过使用这些门控单元,LSTM网络能够在处理序列数据时灵活地控制信息的流动和记忆的保留。这使得LSTM能够更好地处理长期依赖关系,并在各种序列建模任务中表现出色,例如机器翻译、语音识别、文本生成等。

代码与详细注释

import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# 可复现
# torch.manual_seed(1)    # reproducible

# Hyper Parameters
EPOCH = 1               # train the training data n times, to save time, we just train 1 epoch
# 批大小
BATCH_SIZE = 64
TIME_STEP = 28          # rnn time step / image height
INPUT_SIZE = 28         # rnn input size / image width
LR = 0.01               # learning rate
DOWNLOAD_MNIST = True   # set to True if haven't download the data


# Mnist digital dataset
train_data = dsets.MNIST(
    root='./mnist/',
    train=True,                         # this is training data
    transform=transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
    download=DOWNLOAD_MNIST,            # download it if you don't have it
)

# plot one example
print(train_data.train_data.size())     # (60000, 28, 28)
print(train_data.train_labels.size())   # (60000)
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[0])
plt.show()

# Data Loader for easy mini-batch return in training
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

# convert test data into Variable, pick 2000 samples to speed up testing
test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255.   # shape (2000, 28, 28) value in range(0,1)
test_y = test_data.test_labels.numpy()[:2000]    # covert to numpy array


class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.LSTM(         # if use nn.RNN(), it hardly learns
            input_size=INPUT_SIZE,
            hidden_size=64,         # rnn hidden unit
            num_layers=1,           # number of rnn layer
            batch_first=True,       # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
        )

        self.out = nn.Linear(64, 10)

    def forward(self, x):
        # 输入向量的形状
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)
        # h_n shape (n_layers, batch, hidden_size)
        # h_c shape (n_layers, batch, hidden_size)
        r_out, (h_n, h_c) = self.rnn(x, None)   # None represents zero initial hidden state

        # choose r_out at the last time step
        # 选择输出最后一步的r_out
        out = self.out(r_out[:, -1, :])
        return out


rnn = RNN()
print(rnn)

optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)   # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()                       # the target label is not one-hotted

# training and testing
for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader):        # gives batch data
        b_x = b_x.view(-1, 28, 28)              # reshape x to (batch, time_step, input_size)

        output = rnn(b_x)                               # rnn output
        loss = loss_func(output, b_y)                   # cross entropy loss
        optimizer.zero_grad()                           # clear gradients for this training step
        loss.backward()                                 # backpropagation, compute gradients
        optimizer.step()                                # apply gradients

        # 每训练50步之后,测试一下准确度
        if step % 50 == 0:
            test_output = rnn(test_x)                   # (samples, time_step, input_size)
            pred_y = torch.max(test_output, 1)[1].data.numpy()
            accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)

# print 10 predictions from test data
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.numpy()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')



运行结果

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/759603.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaBeans

Code eamples ① Product.java (JavaBean Class) ② Bean.java (Servlet) ③ complie javac -encoding utf-8 -d ..\classes -sourcepath . chapter15\Bean.java ④ Tomcat ⑤ http://localhost:8080/book/chapter15/bean

flink水位线传播及任务事件时间

背景 本文来讲解一下flink的水位线传播及对其对任务事件时间的影响 水位线 首先flink是通过从源头生成水位线记录的方式来实现水位线传播的,也就是说水位线是嵌入在正常的记录流中的特殊记录,携带者水位线的时间戳,以下我们就通过图片的方…

Docker常用命令(三)

1、镜像命令 1、列出本地主机上的镜像 docker images [options]optiins说明: -a:列出本地所有的镜像(包含历史映像层) -q:只显示镜像ID2、搜索某个镜像信息 docker search [options] 镜像名字3、下载镜像 docker …

Kafka第二课-代码实战、参数配置详解、设计原理详解

一、代码实战 一、普通java程序实战 引入依赖 <dependencies><dependency><groupId>org.apache.kafka</groupId><artifactId>kafka-clients</artifactId><version>2.4.1</version></dependency><dependency>&l…

windows环境hadoop报错‘D:\Program‘ 不是内部或外部命令,也不是可运行的程序 或批处理文件。

Hadoop版本为2.7.3&#xff0c;在环境配置好后&#xff0c;检查hadoop安装版本&#xff0c;报如标题所示错误&#xff0c;尝试网上主流的几种方法均无效。 错误&#xff1a;windows环境hadoop报错’D:\Program’ 不是内部或外部命令,也不是可运行的程序 或批处理文件。 错误方…

Jackson使用

导入依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0…

静态输出调节

1.理论部分 15. SISO反馈控制器设计 (6)&#xff1a;输出调节-静态反馈 Output Regulation - 知乎 (zhihu.com) 上述博客已经写的很好了&#xff0c;过多描述 2.仿真实验 3.参考理论 首先通过18式求解出X和U&#xff0c;然后设计一个让原系统初始稳定的控制律Kx&#xff0c;…

第二周习题

2.创建类MyDate,year属性和month属性,编写一个方法totalDays,该方法通过年份和月份判断该月一共有多少天,在主函数中接受用户输入年和月&#xff0c;调用该方法测试它. 这里考虑平年和闰年 “平年2月有28天。闰年的2月有29天 那么就有区别了 只要判断这一点就行了&#xff01;&…

框架开发使用注解处理器APT优雅提效

目录 概述1.什么是注解处理器APT2.应用场景3.如何使用3.1 创建注解API模块3.2 创建注解处理器模块3.3 使用注解 概述 在现在的很多开源框架中&#xff0c;我们经常能在源码中看到注解处理器的影子&#xff0c;比如我们熟悉的阿里的ARouter,Android开发中的替代findViewById神器…

【git】git以及可视化界面下载安装

git 以及可视化界面下载安装 git下载安装测试功能 sourceTree下载安装 git 下载安装 下载地址 git官网上有多个版本&#xff0c;点击“Click here to download” &#xff0c;下载下来之后&#xff0c;一直下一步安装即可 测试功能 在任意文件夹中右击&#xff0c;看到图中…

Linux三剑客

前言 关于bash&#xff1a; bash&#xff1a;命令处理器&#xff0c;运行在文本窗口&#xff0c;能够执行用户输入的命令。 脚本&#xff1a;从linux文件中读取命令&#xff0c;被称为脚本。 1 命令&#xff1a;alias&#xff1a;起别名 2 快捷键操作&#xff1a; ctrla&#…

浅谈如何提高自动化测试的稳定性和可维护性

目录 前言&#xff1a; 装饰器与出错重试机制 什么是装饰器&#xff1f; 编写一个出错重试装饰器 pytest 里的出错重试机制实现 Allure 里的测试用例分层 为什么要采用分层机制&#xff1f; allure 的装饰器step 前言&#xff1a; 自动化测试在软件开发中扮演着重要的…

Fortran lapack求数组的特征值,特征向量

call zgeev(V, V, n, arr, lda, w, vl, ldvl, vr, ldvr, work, lwork, rwork, info) 这个函数是求矩阵的特征值&#xff0c;且结果是双精度复数的情况&#xff0c;具体可以查MKL的官方文档。 如果是单精度复数就要用cgeev&#xff0c;其中的参数也是将双精度改为单精度即可。…

Hive,FineBI-30W聊天数据分析及可视化-B站黑马学习记录

2023B站黑马Hadoop、Hive、云平台实战项目 目录 1. 清洗数据 2. 计算各指标&#xff0c;并创建表存储结果 3.FineBI连接Hive数据库&#xff0c;将指标结果可视化 1. 清洗数据 1&#xff09;部分数据缺失地理位置信息&#xff08;sender_gps&#xff09;&#xff0c;需要剔…

Linux下的调试器——gdb使用指南

文章目录 一.序二.安装gdb调试器三.进入调试四.调试相关指令 前言&#xff1b; 在VS环境下&#xff0c;我们不仅可以写代码、编译、运行可执行程序&#xff0c;还可以对生成的可执行程序进行调试。本章我们就来学习如何在Linux环境下进行调试。 一.序 要进行调试&#xff0c;首…

EPICS一个示例数据库实例详解

以下是一个示例数据库图表&#xff1a; 以上记录的数据库文件如下&#xff1a; record(ao, "$(P):SET") {field(FLNK, "$(P):ACTIVATE")field(VAL, "2")field(OUT, "$(P):RUN")field(DRVH, "40")field(DRVL, "5"…

摩尔投票算法(Moore‘s Voting Algorithm)及例题

摩尔投票算法&#xff08;Moores Voting Algorithm&#xff09;及例题 摩尔投票算法简介摩尔投票算法算法思想摩尔投票算法经典题目169. 多数元素229. 多数元素 II6927. 合法分割的最小下标 上午打力扣第 354 场周赛最后十五分钟用摩尔投票算法顺利 AC 第三题&#xff0c;以前没…

ViewRootImpl简析

ViewRootImpl简析 如何实现视图和wms沟通桥梁的作用Session的创建获取画布如何实现事件分发的桥梁作用 The top of a view hierarchy, implementing the needed protocol between View and the WindowManager. This is for the most part an internal implementation detail of…

快速搭建Python(Django)+Vue环境并实现页面登录注册功能

文章目录 一. 创建vue项目及环境搭建1. 创建vue项目2. 配置axios3. 创建vue组件login和register4. 设置并引用路由vue-router5. 完成login&#xff0c;register组件代码6. 完成App.vue的代码 二. 创建django项目及环境搭建1. 创建django项目2.配置mysql数据库3. 创建应用app4.创…

学习babylon.js --- [4] 体验WebVR

本文基于babylonjs来创建一个简单的WebVR工程&#xff0c;来体验一下babylonjs带来的VR效果&#xff0c;由于没有VR头显&#xff0c;所以只能使用Win10自带的混合现实模拟器&#xff0c;开启模拟器请参考这篇文章 一 简单工程 本文基于第三篇文章中的工程进行修改&#xff0c;…