pytorch实现门控循环单元 (GRU)

news2025/2/3 7:18:06

 人工智能例子汇总:AI常见的算法和例子-CSDN博客  

特性GRULSTM
计算效率更快,参数更少相对较慢,参数更多
结构复杂度只有两个门(更新门和重置门)三个门(输入门、遗忘门、输出门)
处理长时依赖一般适用于中等长度依赖更适合处理超长时序依赖
训练速度训练更快,梯度更稳定训练较慢,占用更多内存

例子:

import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt

# 🏁 迷宫环境(5×5)
class MazeEnv:
    def __init__(self, size=5):
        self.size = size
        self.state = (0, 0)  # 起点
        self.goal = (size-1, size-1)  # 终点
        self.actions = [(0,1), (0,-1), (1,0), (-1,0)]  # 右、左、下、上
    
    def reset(self):
        self.state = (0, 0)  # 重置起点
        return self.state

    def step(self, action):
        dx, dy = self.actions[action]
        x, y = self.state
        nx, ny = max(0, min(self.size-1, x+dx)), max(0, min(self.size-1, y+dy))
        
        reward = 1 if (nx, ny) == self.goal else -0.1
        done = (nx, ny) == self.goal
        
        self.state = (nx, ny)
        return (nx, ny), reward, done

# 🤖 GRU 策略网络
class GRUPolicy(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUPolicy, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.gru(x, hidden)
        out = self.fc(out[:, -1, :])  # 只取最后时间步
        return out, hidden

# 🎯 训练参数
env = MazeEnv(size=5)
policy = GRUPolicy(input_size=2, hidden_size=16, output_size=4)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# 🎓 训练
num_episodes = 500
epsilon = 1.0  # 初始的ε值,控制探索的概率
epsilon_min = 0.01  # 最小ε值
epsilon_decay = 0.995  # ε衰减率
best_path = []  # 用于存储最佳路径

for episode in range(num_episodes):
    state = env.reset()
    hidden = torch.zeros(1, 1, 16)  # GRU 初始状态
    states, actions, rewards = [], [], []
    logits_list = []  

    for _ in range(20):  # 最多 20 步
        state_tensor = torch.tensor([[state[0], state[1]]], dtype=torch.float32).unsqueeze(0)
        logits, hidden = policy(state_tensor, hidden)
        logits_list.append(logits)

        # ε-greedy 策略
        if random.random() < epsilon:
            action = random.choice(range(4))  # 随机选择动作
        else:
            action = torch.argmax(logits, dim=1).item()  # 选择最大值对应的动作

        next_state, reward, done = env.step(action)

        states.append(state)
        actions.append(action)
        rewards.append(reward)

        if done:
            print(f"Episode {episode} - Reached Goal!")
            # 找到最优路径
            best_path = states + [next_state]  # 当前 episode 的路径
            break
        state = next_state

    # 计算损失
    logits = torch.cat(logits_list, dim=0)  # (T, 4)
    action_tensor = torch.tensor(actions, dtype=torch.long)  # (T,)
    loss = loss_fn(logits, action_tensor)  

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 衰减 ε
    epsilon = max(epsilon_min, epsilon * epsilon_decay)

    if episode % 100 == 0:
        print(f"Episode {episode}, Loss: {loss.item():.4f}, Epsilon: {epsilon:.4f}")

# 🧐 确保 best_path 已经记录
if len(best_path) == 0:
    print("No path found during training.")
else:
    print(f"Best path: {best_path}")

# 🚀 测试路径(只绘制最佳路径)
fig, ax = plt.subplots(figsize=(6,6))

# 初始化迷宫图
maze = [[0 for _ in range(5)] for _ in range(5)]  # 5×5 迷宫
ax.imshow(maze, cmap="coolwarm", origin="upper")

# 画网格
ax.set_xticks(range(5))
ax.set_yticks(range(5))
ax.grid(True, color="black", linewidth=0.5)

# 画出最佳路径(红色)
for (x, y) in best_path:
    ax.add_patch(plt.Rectangle((y, x), 1, 1, color="red", alpha=0.8))

# 画起点和终点
ax.text(0, 0, "S", ha="center", va="center", fontsize=14, color="white", fontweight="bold")
ax.text(4, 4, "G", ha="center", va="center", fontsize=14, color="white", fontweight="bold")

plt.title("GRU RL Agent - Best Path")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2291145.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

有没有个性化的UML图例

绿萝小绿萝 (53****338) 2012-05-10 11:55:45 各位大虾&#xff0c;有没有个性化的UML图例 绿萝小绿萝 (53****338) 2012-05-10 11:56:03 例如部署图或时序图的图例 潘加宇 (35***47) 2012-05-10 12:24:31 "个性化"指的是&#xff1f; 你的意思使用你自己的图标&…

Vue3.0实战:大数据平台可视化

文章目录 创建vue3.0项目项目初始化项目分辨率响应式设置项目顶部信息条创建页面主体创建全局引入echarts和axios后台接口创建express销售总量图实现完整项目下载项目任何问题都可在评论区,或者直接私信即可。 创建vue3.0项目 创建项目: vue create vueecharts选择第三项:…

洛谷 P1130 红牌 C语言

题目描述 某地临时居民想获得长期居住权就必须申请拿到红牌。获得红牌的过程是相当复杂&#xff0c;一共包括 N 个步骤。每一步骤都由政府的某个工作人员负责检查你所提交的材料是否符合条件。为了加快进程&#xff0c;每一步政府都派了 M 个工作人员来检查材料。不幸的是&…

语音识别播报人工智能分类垃圾桶(论文+源码)

2.1 需求分析 本次语音识别播报人工智能分类垃圾桶&#xff0c;设计功能要求如下∶ 1、具有四种垃圾桶&#xff0c;分别为用来回收厨余垃圾&#xff0c;有害垃圾&#xff0c;可回收垃圾&#xff0c;其他垃圾。 2、当用户语音说出“旧报纸”&#xff0c;“剩菜”等特定词语时…

MVC、MVP和MVVM模式

MVC模式中&#xff0c;视图和模型之间直接交互&#xff0c;而MVP模式下&#xff0c;视图与模型通过Presenter进行通信&#xff0c;MVVM则采用双向绑定&#xff0c;减少手动同步视图和模型的工作。每种模式都有其优缺点&#xff0c;适合不同规模和类型的项目。 ### MVVM 与 MVP…

shiro学习五:使用springboot整合shiro。在前面学习四的基础上,增加shiro的缓存机制,源码讲解:认证缓存、授权缓存。

文章目录 前言1. 直接上代码最后在讲解1.1 新增的pom依赖1.2 RedisCache.java1.3 RedisCacheManager.java1.4 jwt的三个类1.5 ShiroConfig.java新增Bean 2. 源码讲解。2.1 shiro 缓存的代码流程。2.2 缓存流程2.2.1 认证和授权简述2.2.2 AuthenticatingRealm.getAuthentication…

属性编程与权限编程

问题 如何获取文件的大小&#xff0c;时间戳以及类型等信息&#xff1f; 再论 inode 文件的物理载体是硬盘&#xff0c;硬盘的最小存储单元是扇区 (每个扇区 512 字节) 文件系统以 块 为单位(每个块 8 个扇区) 管理文件数据 文件元信息 (创建者、创建日期、文件大小&#x…

用 HTML、CSS 和 JavaScript 实现抽奖转盘效果

顺序抽奖 前言 这段代码实现了一个简单的抽奖转盘效果。页面上有一个九宫格布局的抽奖区域&#xff0c;周围八个格子分别放置了不同的奖品名称&#xff0c;中间是一个 “开始抽奖” 的按钮。点击按钮后&#xff0c;抽奖区域的格子会快速滚动&#xff0c;颜色不断变化&#xf…

R语言绘制有向无环图(DAG)

有向无环图&#xff08;Directed Acyclic Graph&#xff0c;简称DAG&#xff09;是一种特殊的有向图&#xff0c;它由一系列顶点和有方向的边组成&#xff0c;其中不存在任何环路。这意味着从任一顶点出发&#xff0c;沿着箭头方向移动&#xff0c;你永远无法回到起始点。 从流…

Spring Web MVC基础第一篇

目录 1.什么是Spring Web MVC&#xff1f; 2.创建Spring Web MVC项目 3.注解使用 3.1RequestMapping&#xff08;路由映射&#xff09; 3.2一般参数传递 3.3RequestParam&#xff08;参数重命名&#xff09; 3.4RequestBody&#xff08;传递JSON数据&#xff09; 3.5Pa…

129.求根节点到叶节点数字之和(遍历思想)

Problem: 129.求根节点到叶节点数字之和 文章目录 题目描述思路复杂度Code 题目描述 思路 遍历思想(利用二叉树的先序遍历) 直接利用二叉树的先序遍历&#xff0c;将遍历过程中的节点值先利用字符串拼接起来遇到根节点时再转为数字并累加起来&#xff0c;在归的过程中&#xf…

unity中的动画混合树

为什么需要动画混合树&#xff0c;动画混合树有什么作用&#xff1f; 在Unity中&#xff0c;动画混合树&#xff08;Animation Blend Tree&#xff09;是一种用于管理和混合多个动画状态的工具&#xff0c;包括1D和2D两种类型&#xff0c;以下是其作用及使用必要性的介绍&…

MySQL存储过程和存储函数_mysql 存储过 call proc_stat_data(3,null)

2&#xff09;很难调试存储过程。只有少数数据库管理系统允许调试存储过程。不幸的是&#xff0c;MySQL不提供调试存储过程的功能。 1.2 数据准备 创建数据库&#xff1a; DEFAULT CHARACTER SET utf8; use test;这里记得设置编码&#xff01; 创建测试表&#xff1a; DROP…

Flink2支持提交StreamGraph到Flink集群

最近研究Flink源码的时候&#xff0c;发现Flink已经支持提交StreamGraph到集群了&#xff0c;替换掉了原来的提交JobGraph。 新增ExecutionPlan接口&#xff0c;将JobGraph和StreamGraph作为实现。 Flink集群Dispatcher也进行了修改&#xff0c;从JobGraph改成了接口Executio…

Vue 入门到实战 七

第7章 渲染函数 目录 7.1 DOM树 7.2 什么是渲染函数 7.3 h()函数 7.3.1 基本参数 7.3.2 约束 7.3.3 使用JavaScript代替模板功能 7.1 DOM树 7.2 什么是渲染函数 在多数情况下&#xff0c;Vue推荐使用模板template来创建HTML。然而在一些应用场景中&#xff0c;需要使用J…

系统学习算法: 专题八 二叉树中的深搜

深搜其实就是深度优先遍历&#xff08;dfs&#xff09;&#xff0c;与此相对的还有宽度优先遍历&#xff08;bfs&#xff09; 如果学完数据结构有点忘记&#xff0c;如下图&#xff0c;左边是dfs&#xff0c;右边是bfs 而二叉树的前序&#xff0c;中序&#xff0c;后序遍历都可…

进程、线程、内存和IO模型的概念详解

进程、线程、内存和IO模型的概念详解 1 进程与线程1.1 进程1.1.1 进程分类1.1.2 进程的状态和转换1.1.3 僵尸进程和孤儿进程的区别1.1.4 进程之间的通信1.1.5 用户态和内核态1.1.6 用户空间和内核空间 1.2 线程1.2.1 线程的状态和转换1.2.2 进程与线程的区别 1.3 多进程和多线程…

Labelme转Voc、Coco

Q&#xff1a;在github找的cv代码基本都是根据现有且流行的公共数据集格式组织的训练数据集&#xff0c;这导致我使用labelme标注好之后需要我们重新组织数据集 labelme2coco #!/usr/bin/env pythonimport argparse import collections import datetime import glob import j…

JVM方法区

一、栈、堆、方法区的交互关系 二、方法区的理解: 尽管所有的方法区在逻辑上属于堆的一部分&#xff0c;但是一些简单的实现可能不会去进行垃圾收集或者进行压缩&#xff0c;方法区可以看作是一块独立于Java堆的内存空间。 方法区(Method Area)与Java堆一样&#xff0c;是各个…

【Python】第七弹---Python基础进阶:深入字典操作与文件处理技巧

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】【MySQL】【Python】 目录 1、字典 1.1、字典是什么 1.2、创建字典 1.3、查找 key 1.4、新增/修改元素 1.5、删除元素 1.6、遍历…