pytorch实现长短期记忆网络 (LSTM)

news2025/2/3 7:42:56

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

LSTM 通过 记忆单元(cell)三个门控机制(遗忘门、输入门、输出门)来控制信息流:

 记忆单元(Cell State)

  • 负责存储长期信息,并通过门控机制决定保留或丢弃信息。

 遗忘门(Forget Gate, ftf_tft​)

 输入门(Input Gate, iti_tit​)

 输出门(Output Gate, oto_tot​)

特性

传统 RNNLSTM
记忆能力短期记忆长短期记忆
计算复杂度
解决梯度消失
适用场景短序列数据长序列数据

LSTM 应用场景

  • 自然语言处理(NLP):文本生成、情感分析、机器翻译
  • 时间序列预测:股票预测、天气预报、传感器数据分析
  • 语音识别:自动字幕生成、语音转文字(ASR)
  • 机器人与控制系统:智能体决策、自动驾驶

例子:

下面例子实现了一个 基于 LSTM 的强化学习智能体,在 1D 网格环境 里移动,并找到最优路径。
最终,我们 绘制 5 条测试路径,并高亮显示最佳路径(红色)

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


# ========== 1. 定义 LSTM 策略网络 ==========
class LSTMPolicy(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(LSTMPolicy, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, hidden_state):
        batch_size = x.size(0)

        # 确保 hidden_state 维度正确
        if hidden_state[0].dim() == 2:
            hidden_state = (hidden_state[0].unsqueeze(1).repeat(1, batch_size, 1),
                            hidden_state[1].unsqueeze(1).repeat(1, batch_size, 1))

        out, hidden_state = self.lstm(x, hidden_state)
        out = self.fc(out[:, -1, :])  # 取最后时间步的输出
        action_prob = self.softmax(out)  # 归一化输出,作为策略
        return action_prob, hidden_state

    def init_hidden(self, batch_size=1):
        return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
                torch.zeros(self.num_layers, batch_size, self.hidden_size))


# ========== 2. 创建网格环境 ==========
class GridWorld:
    def __init__(self, grid_size=10, goal_position=9):
        self.grid_size = grid_size
        self.goal_position = goal_position
        self.reset()

    def reset(self):
        self.position = 0
        return self.position

    def step(self, action):
        if action == 0:
            self.position = max(0, self.position - 1)
        elif action == 1:
            self.position = min(self.grid_size - 1, self.position + 1)

        reward = 1 if self.position == self.goal_position else -0.1
        done = self.position == self.goal_position
        return self.position, reward, done


# ========== 3. 训练智能体 ==========
def train(num_episodes=500, max_steps=50):
    env = GridWorld()
    input_size = 1
    hidden_size = 64
    output_size = 2
    num_layers = 1

    policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)
    optimizer = optim.Adam(policy.parameters(), lr=0.01)
    gamma = 0.99

    for episode in range(num_episodes):
        state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)
        hidden_state = policy.init_hidden(batch_size=1)

        log_probs = []
        rewards = []

        for step in range(max_steps):
            action_probs, hidden_state = policy(state, hidden_state)
            action = torch.multinomial(action_probs, 1).item()
            log_prob = torch.log(action_probs.squeeze(0)[action])
            log_probs.append(log_prob)

            next_state, reward, done = env.step(action)
            rewards.append(reward)

            if done:
                break

            state = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)

        # 计算回报并更新策略
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)

        returns = torch.tensor(returns, dtype=torch.float32)
        returns = (returns - returns.mean()) / (returns.std() + 1e-9)

        loss = sum([-log_prob * R for log_prob, R in zip(log_probs, returns)])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (episode + 1) % 50 == 0:
            print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {sum(rewards)}")

    torch.save(policy.state_dict(), "policy.pth")


# 训练智能体
train(500)


# ========== 4. 测试智能体并绘制最佳路径 ==========
def test(num_episodes=5):
    env = GridWorld()
    input_size = 1
    hidden_size = 64
    output_size = 2
    num_layers = 1

    policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)
    policy.load_state_dict(torch.load("policy.pth"))

    plt.figure(figsize=(10, 5))
    best_path = None
    best_steps = float('inf')

    for episode in range(num_episodes):
        state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)
        hidden_state = policy.init_hidden(batch_size=1)
        positions = [env.position]  # 记录位置变化

        while True:
            action_probs, hidden_state = policy(state, hidden_state)
            action = torch.argmax(action_probs, dim=-1).item()
            next_state, reward, done = env.step(action)
            positions.append(next_state)

            if done:
                break

            state = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)

        # 记录最佳路径(最短步数)
        if len(positions) < best_steps:
            best_steps = len(positions)
            best_path = positions

        # 绘制普通路径(蓝色)
        plt.plot(range(len(positions)), positions, marker='o', linestyle='-', color='blue', alpha=0.6,
                 label=f'Episode {episode + 1}' if episode == 0 else "")

    # 绘制最佳路径(红色)
    if best_path:
        plt.plot(range(len(best_path)), best_path, marker='o', linestyle='-', color='red', linewidth=2,
                 label="Best Path")

    # 打印最佳路径
    print(f"Best Path (steps={best_steps}): {best_path}")

    plt.xlabel("Time Steps")
    plt.ylabel("Agent Position")
    plt.title("Agent's Movement Path (Best Path in Red)")
    plt.legend()
    plt.grid(True)
    plt.show()


# 测试并绘制智能体移动路径
test(5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2291157.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Games104——引擎工具链高级概念与应用

世界编辑器 其实是一个平台&#xff08;hub&#xff09;&#xff0c;集合了所有能够制作地形世界的逻辑 editor viewport&#xff1a;可以说是游戏引擎的特殊视角&#xff0c;会有部分editor only的代码&#xff08;不小心开放就会变成外挂入口&#xff09;Editable Object&…

消息队列应用示例MessageQueues-STM32CubeMX-FreeRTOS《嵌入式系统设计》P343-P347

消息队列 使用信号量、事件标志组和线标志进行任务同步时&#xff0c;只能提供同步的时刻信息&#xff0c;无法在任务之间进行数据传输。要实现任务间的数据传输&#xff0c;一般使用两种方式&#xff1a; 1. 全局变量 在 RTOS 中使用全局变量时&#xff0c;必须保证每个任务…

网络攻防实战指北专栏讲解大纲与网络安全法

专栏 本专栏为网络攻防实战指北&#xff0c;大纲如下所示 进度&#xff1a;目前已更完准备篇、HTML基础 计划&#xff1a;所谓基础不牢&#xff0c;地动山摇。所以下一步将持续更新基础篇内容 讲解信息安全时&#xff0c;结合《中华人民共和国网络安全法》&#xff08;以下简…

Spark的基本概念

个人博客地址&#xff1a;Spark的基本概念 | 一张假钞的真实世界 编程接口 RDD&#xff1a;弹性分布式数据集&#xff08;Resilient Distributed Dataset &#xff09;。Spark2.0之前的编程接口。Spark2.0之后以不再推荐使用&#xff0c;而是被Dataset替代。Dataset&#xff…

效用曲线的三个实例

效用曲线的三个实例 文章目录 效用曲线的三个实例什么是效用曲线风险与回报&#xff1a;投资决策消费选择&#xff1a;价格与质量的平衡程序员绩效评估&#xff1a;准时与程序正确性 分析- 风险与回报&#xff1a;投资决策分析- 消费选择&#xff1a;价格与质量的平衡- 程序员绩…

neo4j-community-5.26.0 create new database

1.edit neo4j.conf 把 # The name of the default database initial.dbms.default_databasehonglouneo4j # 写上自己的数据库名称 和 # Name of the service #5.0 server.windows_service_nameneo4j #4.0 dbms.default_databaseneo4j #dbms.default_databaseneo4jwind serve…

pytorch实现门控循环单元 (GRU)

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 特性GRULSTM计算效率更快&#xff0c;参数更少相对较慢&#xff0c;参数更多结构复杂度只有两个门&#xff08;更新门和重置门&#xff09;三个门&#xff08;输入门、遗忘门、输出门&#xff09;处理长时依赖一般适…

有没有个性化的UML图例

绿萝小绿萝 (53****338) 2012-05-10 11:55:45 各位大虾&#xff0c;有没有个性化的UML图例 绿萝小绿萝 (53****338) 2012-05-10 11:56:03 例如部署图或时序图的图例 潘加宇 (35***47) 2012-05-10 12:24:31 "个性化"指的是&#xff1f; 你的意思使用你自己的图标&…

Vue3.0实战:大数据平台可视化

文章目录 创建vue3.0项目项目初始化项目分辨率响应式设置项目顶部信息条创建页面主体创建全局引入echarts和axios后台接口创建express销售总量图实现完整项目下载项目任何问题都可在评论区,或者直接私信即可。 创建vue3.0项目 创建项目: vue create vueecharts选择第三项:…

洛谷 P1130 红牌 C语言

题目描述 某地临时居民想获得长期居住权就必须申请拿到红牌。获得红牌的过程是相当复杂&#xff0c;一共包括 N 个步骤。每一步骤都由政府的某个工作人员负责检查你所提交的材料是否符合条件。为了加快进程&#xff0c;每一步政府都派了 M 个工作人员来检查材料。不幸的是&…

语音识别播报人工智能分类垃圾桶(论文+源码)

2.1 需求分析 本次语音识别播报人工智能分类垃圾桶&#xff0c;设计功能要求如下∶ 1、具有四种垃圾桶&#xff0c;分别为用来回收厨余垃圾&#xff0c;有害垃圾&#xff0c;可回收垃圾&#xff0c;其他垃圾。 2、当用户语音说出“旧报纸”&#xff0c;“剩菜”等特定词语时…

MVC、MVP和MVVM模式

MVC模式中&#xff0c;视图和模型之间直接交互&#xff0c;而MVP模式下&#xff0c;视图与模型通过Presenter进行通信&#xff0c;MVVM则采用双向绑定&#xff0c;减少手动同步视图和模型的工作。每种模式都有其优缺点&#xff0c;适合不同规模和类型的项目。 ### MVVM 与 MVP…

shiro学习五:使用springboot整合shiro。在前面学习四的基础上,增加shiro的缓存机制,源码讲解:认证缓存、授权缓存。

文章目录 前言1. 直接上代码最后在讲解1.1 新增的pom依赖1.2 RedisCache.java1.3 RedisCacheManager.java1.4 jwt的三个类1.5 ShiroConfig.java新增Bean 2. 源码讲解。2.1 shiro 缓存的代码流程。2.2 缓存流程2.2.1 认证和授权简述2.2.2 AuthenticatingRealm.getAuthentication…

属性编程与权限编程

问题 如何获取文件的大小&#xff0c;时间戳以及类型等信息&#xff1f; 再论 inode 文件的物理载体是硬盘&#xff0c;硬盘的最小存储单元是扇区 (每个扇区 512 字节) 文件系统以 块 为单位(每个块 8 个扇区) 管理文件数据 文件元信息 (创建者、创建日期、文件大小&#x…

用 HTML、CSS 和 JavaScript 实现抽奖转盘效果

顺序抽奖 前言 这段代码实现了一个简单的抽奖转盘效果。页面上有一个九宫格布局的抽奖区域&#xff0c;周围八个格子分别放置了不同的奖品名称&#xff0c;中间是一个 “开始抽奖” 的按钮。点击按钮后&#xff0c;抽奖区域的格子会快速滚动&#xff0c;颜色不断变化&#xf…

R语言绘制有向无环图(DAG)

有向无环图&#xff08;Directed Acyclic Graph&#xff0c;简称DAG&#xff09;是一种特殊的有向图&#xff0c;它由一系列顶点和有方向的边组成&#xff0c;其中不存在任何环路。这意味着从任一顶点出发&#xff0c;沿着箭头方向移动&#xff0c;你永远无法回到起始点。 从流…

Spring Web MVC基础第一篇

目录 1.什么是Spring Web MVC&#xff1f; 2.创建Spring Web MVC项目 3.注解使用 3.1RequestMapping&#xff08;路由映射&#xff09; 3.2一般参数传递 3.3RequestParam&#xff08;参数重命名&#xff09; 3.4RequestBody&#xff08;传递JSON数据&#xff09; 3.5Pa…

129.求根节点到叶节点数字之和(遍历思想)

Problem: 129.求根节点到叶节点数字之和 文章目录 题目描述思路复杂度Code 题目描述 思路 遍历思想(利用二叉树的先序遍历) 直接利用二叉树的先序遍历&#xff0c;将遍历过程中的节点值先利用字符串拼接起来遇到根节点时再转为数字并累加起来&#xff0c;在归的过程中&#xf…

unity中的动画混合树

为什么需要动画混合树&#xff0c;动画混合树有什么作用&#xff1f; 在Unity中&#xff0c;动画混合树&#xff08;Animation Blend Tree&#xff09;是一种用于管理和混合多个动画状态的工具&#xff0c;包括1D和2D两种类型&#xff0c;以下是其作用及使用必要性的介绍&…

MySQL存储过程和存储函数_mysql 存储过 call proc_stat_data(3,null)

2&#xff09;很难调试存储过程。只有少数数据库管理系统允许调试存储过程。不幸的是&#xff0c;MySQL不提供调试存储过程的功能。 1.2 数据准备 创建数据库&#xff1a; DEFAULT CHARACTER SET utf8; use test;这里记得设置编码&#xff01; 创建测试表&#xff1a; DROP…