用deepseek学大模型08-长短时记忆网络 (LSTM)

news2025/2/21 5:16:49

deepseek.com 从入门到精通长短时记忆网络(LSTM),着重介绍的目标函数,损失函数,梯度下降 标量和矩阵形式的数学推导,pytorch真实能跑的代码案例以及模型,数据, 模型应用场景和优缺点,及如何改进解决及改进方法数据推导。

从入门到精通长短时记忆网络 (LSTM)

参考:长短时记忆网络(LSTM)在序列数据处理中的优缺点分析
LSTM


1. LSTM 核心机制

LSTM 通过门控机制(遗忘门、输入门、输出门)和细胞状态(Cell State)解决 RNN 的梯度消失问题。

核心公式(时间步 t t t):

  1. 遗忘门(Forget Gate):
    f t = σ ( W f [ h t − 1 , x t ] + b f ) \mathbf{f}_t = \sigma\left( \mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f \right) ft=σ(Wf[ht1,xt]+bf)
  2. 输入门(Input Gate):
    i t = σ ( W i [ h t − 1 , x t ] + b i ) \mathbf{i}_t = \sigma\left( \mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i \right) it=σ(Wi[ht1,xt]+bi)
    C ~ t = tanh ⁡ ( W C [ h t − 1 , x t ] + b C ) \tilde{\mathbf{C}}_t = \tanh\left( \mathbf{W}_C [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_C \right) C~t=tanh(WC[ht1,xt]+bC)
  3. 细胞状态更新
    C t = f t ⊙ C t − 1 + i t ⊙ C ~ t \mathbf{C}_t = \mathbf{f}_t \odot \mathbf{C}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{C}}_t Ct=ftCt1+itC~t
  4. 输出门(Output Gate):
    o t = σ ( W o [ h t − 1 , x t ] + b o ) \mathbf{o}_t = \sigma\left( \mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o \right) ot=σ(Wo[ht1,xt]+bo)
    h t = o t ⊙ tanh ⁡ ( C t ) \mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{C}_t) ht=ottanh(Ct)

2. 目标函数与损失函数
  • 目标函数:最小化预测与真实值的差异(监督学习)。
  • 损失函数(以分类任务交叉熵为例):
    L = − 1 T ∑ t = 1 T ∑ c = 1 C y ^ t , c log ⁡ ( y t , c ) L = -\frac{1}{T} \sum_{t=1}^T \sum_{c=1}^C \mathbf{\hat{y}}_{t,c} \log(\mathbf{y}_{t,c}) L=T1t=1Tc=1Cy^t,clog(yt,c)
    其中 C C C为类别数, y ^ \mathbf{\hat{y}} y^为真实标签的 one-hot 编码。

3. 梯度下降与数学推导

LSTM 的梯度反向传播通过细胞状态 C t \mathbf{C}_t Ct和门控机制稳定梯度流动。

标量形式推导(以遗忘门 f t \mathbf{f}_t ft为例):
∂ L ∂ f t = ∂ L ∂ h t ⋅ ∂ h t ∂ C t ⋅ ∂ C t ∂ f t \frac{\partial L}{\partial \mathbf{f}_t} = \frac{\partial L}{\partial \mathbf{h}_t} \cdot \frac{\partial \mathbf{h}_t}{\partial \mathbf{C}_t} \cdot \frac{\partial \mathbf{C}_t}{\partial \mathbf{f}_t} ftL=htLCthtftCt
其中:
∂ C t ∂ f t = C t − 1 ⊙ f t ⊙ ( 1 − f t ) \frac{\partial \mathbf{C}_t}{\partial \mathbf{f}_t} = \mathbf{C}_{t-1} \odot \mathbf{f}_t \odot (1 - \mathbf{f}_t) ftCt=Ct1ft(1ft)

矩阵形式推导(链式法则):
∂ L ∂ W f = ∑ t = 1 T ( δ f , t ⋅ [ h t − 1 , x t ] T ) \frac{\partial L}{\partial \mathbf{W}_f} = \sum_{t=1}^T \left( \delta_{f,t} \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t]^T \right) WfL=t=1T(δf,t[ht1,xt]T)
其中 δ f , t \delta_{f,t} δf,t为遗忘门的梯度误差:
δ f , t = ∂ L ∂ f t ⊙ σ ′ ( ⋅ ) \delta_{f,t} = \frac{\partial L}{\partial \mathbf{f}_t} \odot \sigma'(\cdot) δf,t=ftLσ()


4. PyTorch 代码案例
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# 数据生成:正弦波 + 噪声
time = torch.arange(0, 100, 0.1)
data = torch.sin(time) + 0.1 * torch.randn(len(time))

# 转换为序列数据(窗口长度=20)
def create_sequences(data, seq_length=20):
    X, y = [], []
    for i in range(len(data)-seq_length):
        X.append(data[i:i+seq_length])
        y.append(data[i+seq_length])
    return torch.stack(X).unsqueeze(-1), torch.stack(y).unsqueeze(-1)

X, y = create_sequences(data)
X_train, y_train = X[:800], y[:800]  # 划分训练集和测试集
X_test, y_test = X[800:], y[800:]

# 定义 LSTM 模型
class LSTMModel(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, output_size=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        out, (h_n, c_n) = self.lstm(x)  # out: (batch, seq_len, hidden_size)
        out = self.fc(out[:, -1, :])    # 取最后一个时间步
        return out

model = LSTMModel()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练
epochs = 100
train_loss = []
for epoch in range(epochs):
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 0.5)  # 梯度裁剪
    optimizer.step()
    train_loss.append(loss.item())

# 可视化训练损失
plt.plot(train_loss)
plt.title("Training Loss")
plt.show()

# 预测
model.eval()
with torch.no_grad():
    train_pred = model(X_train)
    test_pred = model(X_test)

# 绘制结果
plt.figure(figsize=(12, 5))
plt.plot(data.numpy(), label="True Data")
plt.plot(range(20, 820), train_pred.numpy(), label="Train Predictions")
plt.plot(range(820, len(data)), test_pred.numpy(), label="Test Predictions")
plt.legend()
plt.show()

5. 应用场景与优缺点
  • 应用场景
    • 时间序列预测(股票价格、天气)
    • 自然语言处理(文本生成、机器翻译)
    • 语音识别
  • 优点
    • 解决长程依赖问题
    • 通过门控机制稳定梯度流动
    • 可处理变长序列
  • 缺点
    • 计算复杂度高(参数多)
    • 对短序列可能过拟合
    • 训练时间较长

6. 改进方法及数学推导
  1. GRU(门控循环单元)
    简化 LSTM,合并遗忘门和输入门:
    z t = σ ( W z [ h t − 1 , x t ] ) \mathbf{z}_t = \sigma(\mathbf{W}_z [\mathbf{h}_{t-1}, \mathbf{x}_t]) zt=σ(Wz[ht1,xt])
    r t = σ ( W r [ h t − 1 , x t ] ) \mathbf{r}_t = \sigma(\mathbf{W}_r [\mathbf{h}_{t-1}, \mathbf{x}_t]) rt=σ(Wr[ht1,xt])
    h ~ t = tanh ⁡ ( W [ r t ⊙ h t − 1 , x t ] ) \tilde{\mathbf{h}}_t = \tanh(\mathbf{W} [\mathbf{r}_t \odot \mathbf{h}_{t-1}, \mathbf{x}_t]) h~t=tanh(W[rtht1,xt])
    h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t \mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t ht=(1zt)ht1+zth~t

  2. 双向 LSTM(Bi-LSTM)
    同时捕捉前向和后向依赖:
    h t → = LSTM ( x t , h t − 1 → ) \overrightarrow{\mathbf{h}_t} = \text{LSTM}(\mathbf{x}_t, \overrightarrow{\mathbf{h}_{t-1}}) ht =LSTM(xt,ht1 )
    h t ← = LSTM ( x t , h t + 1 ← ) \overleftarrow{\mathbf{h}_t} = \text{LSTM}(\mathbf{x}_t, \overleftarrow{\mathbf{h}_{t+1}}) ht =LSTM(xt,ht+1 )
    h t = [ h t → , h t ← ] \mathbf{h}_t = [\overrightarrow{\mathbf{h}_t}, \overleftarrow{\mathbf{h}_t}] ht=[ht ,ht ]

  3. 注意力机制
    增强对关键时间步的关注:
    α t = softmax ( v T tanh ⁡ ( W h h t + W s s ) ) \alpha_t = \text{softmax}(\mathbf{v}^T \tanh(\mathbf{W}_h \mathbf{h}_t + \mathbf{W}_s \mathbf{s})) αt=softmax(vTtanh(Whht+Wss))
    c = ∑ t = 1 T α t h t \mathbf{c} = \sum_{t=1}^T \alpha_t \mathbf{h}_t c=t=1Tαtht


7. 关键改进的数学验证(以 GRU 为例)
  • 梯度稳定性
    GRU 的更新门 z t \mathbf{z}_t zt控制历史信息的保留比例,梯度可沿两条路径传播:
    ∂ h t ∂ h t − 1 = ( 1 − z t ) + z t ⊙ ∂ h ~ t ∂ h t − 1 \frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_{t-1}} = (1 - \mathbf{z}_t) + \mathbf{z}_t \odot \frac{\partial \tilde{\mathbf{h}}_t}{\partial \mathbf{h}_{t-1}} ht1ht=(1zt)+ztht1h~t
    避免传统 RNN 的连乘梯度。

通过上述内容,您可全面掌握 LSTM 的理论基础、实际实现及优化方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2301094.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(蓝桥杯——10. 小郑做志愿者)洛斯里克城志愿者问题详解

题目背景 小郑是一名大学生,她决定通过做志愿者来增加自己的综合分。她的任务是帮助游客解决交通困难的问题。洛斯里克城是一个六朝古都,拥有 N 个区域和古老的地铁系统。地铁线路覆盖了树形结构上的某些路径,游客会询问两个区域是否可以通过某条地铁线路直达,以及有多少条…

小胡说技书博客分类(部分目录):服务治理、数据治理与安全治理对比表格

文章目录 一、对比表格二、目录2.1 服务2.2 数据2.3 安全 一、对比表格 下表从多个维度对服务治理、数据治理和安全治理进行详细对比,为读者提供一个直观而全面的参考框架。 维度服务治理数据治理安全治理定义对软件开发全流程、应用交付及API和接口管理进行规范化…

开源模型应用落地-DeepSeek-R1-Distill-Qwen-7B-LoRA微调-LLaMA-Factory-单机单卡-V100(一)

一、前言 如今,大语言模型领域热闹非凡,各种模型不断涌现。DeepSeek-R1-Distill-Qwen-7B 模型凭借其出色的效果和性能,吸引了众多开发者的目光。而 LLaMa-Factory 作为强大的微调工具,能让模型更好地满足个性化需求。 在本篇中&am…

uni-app发起网络请求的三种方式

uni.request(OBJECT) 发起网络请求 具体参数可查看官方文档uni-app data:请求的参数; header:设置请求的 header,header 中不能设置 Referer; method:请求方法; timeout:超时时间,单位 ms&a…

EasyRTC:智能硬件适配,实现多端音视频互动新突破

一、智能硬件全面支持,轻松跨越平台障碍 EasyRTC 采用前沿的智能硬件适配技术,无缝对接 Windows、macOS、Linux、Android、iOS 等主流操作系统,并全面拥抱 WebRTC 标准。这一特性确保了“一次开发,多端运行”的便捷性&#xff0c…

LeetCode1287

LeetCode1287 目录 题目描述示例思路分析代码段代码逐行讲解复杂度分析总结的知识点整合总结 题目描述 给定一个非递减的整数数组 arr,其中有一个元素恰好出现超过数组长度的 25%。请你找到并返回这个元素。 示例 示例 1 输入: arr [1, 2, 2, 6, 6, 6, 6, 7,…

深度学习笔记之自然语言处理(NLP)

深度学习笔记之自然语言处理(NLP) 在行将开学之时,我将开始我的深度学习笔记的自然语言处理部分,这部分内容是在前面基础上开展学习的,且目前我的学习更加倾向于通识。自然语言处理部分将包含《动手学深度学习》这本书的第十四章&#xff0c…

自动化测试框架搭建-单次接口执行-三部曲

目的 判断接口返回值和提前设置的预期是否一致,从而判断本次测试是否通过 代码步骤设计 第一步:前端调用后端已经写好的POST接口,并传递参数 第二步:后端接收到参数,组装并请求指定接口,保存返回 第三…

DeepSeek R1生成图片总结2(虽然本身是不能直接生成图片,但是可以想办法利用别的工具一起实现)

DeepSeek官网 目前阶段,DeepSeek R1是不能直接生成图片的,但可以通过优化文本后转换为SVG或HTML代码,再保存为图片。另外,Janus-Pro是DeepSeek的多模态模型,支持文生图,但需要本地部署或者使用第三方工具。…

ESP32 ESP-IDF TFT-LCD(ST7735 128x160) LVGL基本配置和使用

ESP32 ESP-IDF TFT-LCD(ST7735 128x160) LVGL基本配置和使用 📍项目地址:https://github.com/lvgl/lv_port_esp32参考文章:https://blog.csdn.net/chentuo2000/article/details/126668088https://blog.csdn.net/p1279030826/article/details/…

【笔记】LLM|Ubuntu22服务器极简本地部署DeepSeek+联网使用方式

2025/02/18说明:2月18日~2月20日是2024年度博客之星投票时间,走过路过可以帮忙点点投票吗?我想要前一百的实体证书,经过我严密的计算只要再拿到60票就稳了。一人可能会有多票,Thanks♪(・ω・)&am…

Linux的基础指令和环境部署,项目部署实战(下)

目录 上一篇:Linxu的基础指令和环境部署,项目部署实战(上)-CSDN博客 1. 搭建Java部署环境 1.1 apt apt常用命令 列出所有的软件包 更新软件包数据库 安装软件包 移除软件包 1.2 JDK 1.2.1. 更新 1.2.2. 安装openjdk&am…

数值积分:通过复合梯形法计算

在物理学和工程学中,很多问题都可以通过数值积分来求解,特别是当我们无法得到解析解时。数值积分是通过计算积分区间内离散点的函数值来近似积分的结果。在这篇博客中,我将讨论如何使用 复合梯形法 来进行数值积分,并以一个简单的…

【Java计算机毕业设计】基于SSM+VUE保险公司管理系统数据库源代码+LW文档+开题报告+答辩稿+部署教程+代码讲解

源代码数据库LW文档(1万字以上)开题报告答辩稿 部署教程代码讲解代码时间修改教程 一、开发工具、运行环境、开发技术 开发工具 1、操作系统:Window操作系统 2、开发工具:IntelliJ IDEA或者Eclipse 3、数据库存储&#xff1a…

C#之上位机开发---------C#通信库及WPF的简单实践

〇、上位机,分层架构 界面层 要实现的功能: 展示数据 获取数据 发送数据 数据层 要实现的功能: 转换数据 打包数据 存取数据 通信层 要实现的功能: 打开连接 关闭连接 读取数据 写入数据 实体类 作用: 封装数据…

仿 Sora 之形,借物理模拟之技绘视频之彩

来自麻省理工学院、斯坦福大学、哥伦比亚大学以及康奈尔大学的研究人员携手开源了一款创新的3D交互视频模型——PhysDreamer(以下简称“PD”)。PD与OpenAI旗下的Sora相似,能够借助物理模拟技术来生成视频,这意味着PD所生成的视频蕴…

RedisTemplate存储含有特殊字符解决

ERROR信息: 案发时间: 2025-02-18 01:01 案发现场: UserServiceImpl.java 嫌疑人: stringRedisTemplate.opsForValue().set(SystemConstants.LOGIN_CODE_PREFIX phone, code, Duration.ofMinutes(3L)); // 3分钟过期作案动机: stringRedisTemplate继承了Redistemplate 使用的…

Django REST Framework (DRF) 中用于构建 API 视图类解析

Django REST Framework (DRF) 提供了丰富的视图类,用于构建 API 视图。这些视图类可以分为以下几类: 1. 基础视图类 这些是 DRF 中最基础的视图类,通常用于实现自定义逻辑。 常用类 APIView: 最基本的视图类,所有其…

Zotero PDF Translate插件配置百度翻译api

Zotero PDF Translate插件可以使用几种翻译api,虽然谷歌最好用,但是由于众所周知的原因,不稳定。而cnki有字数限制,有道有时也不行。其他的翻译需要申请密钥。本文以百度为例,进行申请 官方有申请教程: Zot…

Redis离线安装

Linux系统Centos安装部署Redis缓存插件 参考:Redis中文网: https://www.redis.net.cn/ 参考:RPM软件包下载地址: https://rpmfind.net/linux/RPM/index.html http://rpm.pbone.net/ https://mirrors.aliyun.com/centos/7/os…