如何构建LSTM神经网络模型

news2025/1/24 13:17:48

一、了解LSTM

1. 核心思想

        首先,LSTM 是 RNN(循环神经网络)的变体。它通过引入细胞状态 C(t) 贯穿于整个网络模型,达到长久记忆的效果,进而解决了 RNN 的长期依赖问题。

2. 思维导图

        每个LSTM层次都有三个重要的门结构,从前往后依次是遗忘门(forget gate layer)、输入门(input gate layer)、输出门(output gate layer)。

        还有两个重要的状态,分别是细胞状态(cell state)、隐藏状态(hidden state),即图示中的 C(t) 和 h(t) 。其中细胞状态不仅记忆某个时间步的信息,而是对整个时间序列保持较为稳定的记忆,是一种长期 “记忆信息” 。对于隐藏状态来说,它更多地关注当前时间步以及上一个时间步的输出,是一种短期 “记忆信息”

        具体内容如下面思维导图所示:


二、利用pytorch构建LSTM

1. 构造神经网络模型

1.1 LSTM层
self.lstm = nn.LSTM(
    input_size=28,  # 每次输入特征数量为28
    hidden_size=64,  # 表示每个时间步的输出会有 64 个特征
    num_layers=1,  # LSTM隐藏层的层数
    batch_first=True  # 输入数据的格式是“批次在第一位”
)
  • input_size: 这告诉模型,每次输入的数据有多少个特征(比如一张28x28像素的图像,每一行就是一个时间步)。也就是图示中的 x(t) 。
  • hidden_size:这是模型的“记忆”大小。即细胞状态C(t) 和隐藏状态 h(t) 的容量。
  • num_layers:等于1则代表只使用一层 LSTM 网络。
  • batch_first:这个参数表示输入数据的维度格式是(批次,时间步、特征数),即批次在第一维。
1.2 全连接层
self.out = nn.Linear(
    in_features=64,
    out_features=10  # 将LSTM层提取到的64个特征进一步转化为10个输出(0~9)
)
  • in_features:全连接层的输入大小,来自LSTM的输出,每个时间步的特征数是64(即 hidden_size )
  • out_features:全连接层的输出大小是10,通常表示有10个类别。
 1.3 Softmax层
self.softmax = nn.Softmax(dim=1)

        这一层主要是将全连接层的输出转化为概率分布。如果使用的是交叉熵代价函数(CrossEntropyLoss),可以不加这层。

2. 前向传播

  1. 在前面LSTM层中batch_first参数设置了输入数据的维度格式,即(批次,时间步、特征数)。所以首先要做的就是调整输入的维度格式。这里每个样本是 28 个时间步,每个时间步有 28 个特征(像是一个28x28的图像)

    x = x.view(-1, 28, 28)
  2. 让输入数据通过LSTM层,并最终输出三个信息,分别是 output,h_n 和 c_n。output 包含了每个时间步的输出信息(理解为LSTM分析每个时间步得到的结果)。h_n 是最后一个时间步的隐藏状态,c_n 是记忆状态。我们重点关注 h_n,因为它代表了 LSTM 在处理完所有时间步后的总结。

    output, (h_n, c_n) = self.lstm(x)
    
  3. 接下来从隐藏状态中拿到最后一个时间步 h_n 的输出 output_in_last_timestep。可以理解为,LSTM看完了所有时间步之后,得到了它对整个序列的理解。

    output_in_last_timestep = h_n[-1, :, :]
    
  4. 最后LSTM的输出被送到全连接层,转化成10个数字,这些数字代表模型对每个类别的预测分数。并通过Softmax转化为概率。

    x = self.out(output_in_last_timestep)
    x = self.softmax(x)
    

        构造好的LSTM神经网络模型代码如下所示: 

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(
            input_size=28,  # 每次输入特征数量
            hidden_size=64,  
            num_layers=1,  # LSTM隐藏层的层数
            batch_first=True  
        )
        self.out = nn.Linear(
            in_features=64,
            out_features=10  # 将LSTM层提取到的64个特征进一步转化为10个输出(0~9)
        )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 28, 28)  # 将输入调整成一个 (批次大小, 时间步数, 特征数) 的形式
        output, (h_n, c_n) = self.lstm(x)
        output_in_last_timestep = h_n[-1, :, :]  # 从隐藏状态中拿到最后一个时间步的输出
        x = self.out(output_in_last_timestep)  # LSTM的输出被送到全连接层,转化成10个数字
        x = self.softmax(x)  
        return x

三、测试 LSTM 神经网络模型

        用MNIST数据集测试代码如下:

# 训练集
train_dataset = datasets.MNIST(root='./',
                               train=True,
                               transform=transforms.ToTensor(),  # 数据转换为张量格式
                               download=True)
# 测试集
test_dataset = datasets.MNIST(root='./',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# 批次大小
batch_size = 100
# 装载训练集
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,  # 每次加载多少条数据
                          shuffle=True)  # 生成数据前打乱数据

# 装载测试集
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True)

LR = 0.001  # 学习率
model = LSTM()  # 模型
crossEntropy_loss = nn.CrossEntropyLoss()  # 交叉熵代价函数
optimizer = optim.Adam(model.parameters(), LR)


def train():
    model.train()
    for i, data in enumerate(train_loader):
        inputs, labels = data  # 获得一个批次的数据和标签
        out = model(inputs)  # 获得模型预测输出(64张图像,10个数字的概率)
        loss = crossEntropy_loss(out, labels)  # 使用交叉熵损失函数时,可以直接使用整型标签,无须独热编码
        optimizer.zero_grad()  # 梯度清0
        loss.backward()  # 计算梯度
        optimizer.step()  # 修改权值


def test():
    model.eval()
    correct = 0
    for i, data in enumerate(test_loader):
        inputs, labels = data  # 获得一个批次的数据和标签
        out = model(inputs)  # 获得模型预测结构(64,10)
        _, predicted = torch.max(out, 1)  # 获得最大值,以及最大值所在位置
        correct += (predicted == labels).sum()  # 判断64个值有多少是正确的
    print("测试集正确率:{}\n".format(correct.item() / len(test_loader)))


# 训练20个周期
for epoch in range(20):
    print("Epoch:{}".format(epoch))
    train()
    test()

        测试结果: 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2192574.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

贝尔曼公式

为什么return 非常重要 在选择哪个策略更好的时候,此时需要使用到return,比如下面三个策略的返回值。 策略1: 策略2:策略3:涉及到两个policys path How to calculate return 定义 上图定义了不同的起点下的return value 递推…

优化销售漏斗建立高效潜在客户生成策略的技巧

如何建立有效的潜在客户生成策略?建立有效潜在客户生成策略需要准确定义目标受众,利用内容营销、SEO、社交媒体、邮件营销和定向广告吸引客户,参加行业会议并跟踪分析数据。借助Zoho CRM系统,企业能够更加高效地管理客户信息&…

Windows上 minGW64 编译 libssh2库

下载libssh2库:https://libssh2.org/download/libssh2-1.11.0.zip 继续下载OpenSSL库: https://codeload.github.com/openssl/openssl/zip/refs/heads/OpenSSL_1_0_2-stable

算法讲解—最小生成树(Kruskal 算法)

算法讲解—最小生成树(Kruskal 算法) 简介 根据度娘的解释我们可以知道,最小生成树(Minimum Spanning Tree, MST)就是:一个有 n n n 个结点的连通图的生成树是原图的极小连通子图,且包含原图中的所有 n n n 个结点…

【Diffusion分割】CTS:基于一致性的医学图像分割模型

CTS: A Consistency-Based Medical Image Segmentation Model 摘要: 在医学图像分割任务中,扩散模型已显示出巨大的潜力。然而,主流的扩散模型存在采样次数多、预测结果慢等缺点。最近,作为独立生成网络的一致性模型解决了这一问…

【Python】数据可视化之聚类图

目录 clustermap 主要参数 参考实现 clustermap sns.clustermap是Seaborn库中用于创建聚类热图的函数,该函数能够将数据集中的样本按照相似性进行聚类,并将聚类结果以矩阵的形式展示出来。 sns.clustermap主要用于绘制聚类热图,该热图通…

云计算第四阶段 CLOUD2周目 01-03

国庆假期前,给小伙伴们更行完了云计算CLOUD第一周目的内容,现在为大家更行云计算CLOUD二周目内容,内容涉及K8S组件的添加与使用,K8S集群的搭建。最重要的主体还是资源文件的编写。 (*^▽^*) 环境准备: 主机清单 主机…

CUDNN下载配置

目录 简介 下载 配置 简介 cuDNN(CUDA Deep Neural Network library)是NVIDIA开发的一个深度学习GPU加速库,旨在提供高效、标准化的原语(基本操作)来加速深度学习框架(如TensorFlow、PyTorch等&#xf…

Rust 快速入门(一)

Rust安装信息解释 cargo:Rust的编译管理器、包管理器、通用工具。可以用Cargo启动新的项目,构建和运行程序,并管理代码所依赖的所有外部库。 Rustc:Rust的编译器。通常Cargo会替我们调用此编译器。 Rustdoc:是Rust的…

Java 面向对象设计一口气讲完![]~( ̄▽ ̄)~*(上)

目录 Java 类实例 Java面向对象设计 - Java类实例 null引用类型 访问类的字段的点表示法 字段的默认初始化 Java 访问级别 Java面向对象设计 - Java访问级别 Java 导入 Java面向对象设计 - Java导入 单类型导入声明 按需导入声明 静态导入声明 例子 Java 方法 J…

decltype推导规则

decltype推导规则 当用decltype(e)来获取类型时,编译器将依序判断以下四规则: 1.如果e是一个没有带括号的标记符表达式(id-expression)或者类成员访问表达式,那么decltype(e)就是e所命名的实体的类型。此外,如果e是一个被重载的函…

k8s 之安装metrics-server

作者:程序那点事儿 日期:2024/01/29 18:25 metrics-server可帮助我们查看pod的cpu和内存占用情况 kubectl top po nginx-deploy-56696fbb5-mzsgg # 报错,需要Metrics API 下载 Metrics 解决 wget https://github.com/kubernetes-sigs/metri…

基于auth2的单点登录原理理解

创作背景:基于auth2实现企业门户与业务系统的单点登录跳转。 架构组成:4A统一认证中心,门户系统,业务系统,用户; 实现目标:用户登录门户系统后,可通过点击业务系统菜单&#xff0c…

字符串数学专题

粗心的小可 题目描述 小可非常粗心,打字的时候将手放到了比正确位置偏右的一个位置,因此,Q打成了W,E打成了R,H打成了J等等。键盘如下所示 现在给你若干行小可打字的结果,请你还原成正确的文本。 输入描述…

嵌入式面试八股文(五)·一文带你详细了解程序内存分区中的堆与栈的区别

目录 1. 栈的工作原理 1.1 内存分配 1.2 地址生长方向 1.3 生命周期 2. 堆的工作原理 2.1 动态内存分配 2.1.1 malloc函数 2.1.2 calloc函数 2.1.3 realloc函数 2.1.4 free函数 2.2 生命周期管理 2.3 地址生长方向 3. 堆与栈区别 3.1 管理方式不同…

海南聚广众达电子商务咨询有限公司助力商家业绩飙升

在这个短视频与直播风靡的时代,抖音电商无疑成为了众多商家竞相追逐的新风口。作为电商服务领域的佼佼者,海南聚广众达电子商务咨询有限公司凭借其专业的团队、创新的策略与丰富的实战经验,正引领着一批又一批商家在抖音平台上破浪前行&#…

顺序表及其代码实现

目录 前言1.顺序表1.1 顺序表介绍1.2 顺序表基本操作代码实现 总结 前言 顺序表一般不会用来单独存储数据,但自身的优势,很多时候不得不使用顺序表。 1.顺序表 1.1 顺序表介绍 顺序表是物理结构连续的线性表,支持随机存取(底层…

Leetcode—139. 单词拆分【中等】

2024每日刷题&#xff08;173&#xff09; Leetcode—139. 单词拆分 dp实现代码 class Solution { public:bool wordBreak(string s, vector<string>& wordDict) {int n s.size();unordered_set<string> ust(wordDict.begin(), wordDict.end());vector<b…

探索基于基于人工智能进行的漏洞评估的前景

根据2023年的一份报告 网络安全企业据估计&#xff0c;到 10.5 年&#xff0c;网络犯罪每年将给世界造成 2025 万亿美元的损失。每年记录在案的网络犯罪数量都会创下新高。这要求对传统的安全测试流程进行重大改变。这就是漏洞评估发挥作用的地方。 漏洞评估对于识别系统中的弱…

双指针_有效三角形个数三数之和四数之和

有效三角形个数 思路&#xff1a; 我们可以通过暴力枚举&#xff0c;三重for循环来算但&#xff0c;时间复杂度过高。 有没有效率更高的算法呢&#xff1f; 我们知道如果两条较短的边小于最长的一条边&#xff0c;那么就可以构成三角形。 如果这个数组是升序的&#xff0c;两…