pytorch实现循环神经网络

news2025/2/2 5:15:02

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

PyTorch 提供三种主要的 RNN 变体:

  • nn.RNN:最基本的循环神经网络,适用于短时依赖任务。
  • nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决梯度消失问题。
  • nn.GRU:门控循环单元,比 LSTM 计算更高效,适用于大部分任务。
网络类型优势适用场景
RNN计算简单,适用于短时序列语音、文本处理(短序列)
LSTM适用于长序列,能记忆长期信息机器翻译、语音识别、股票预测
GRU比 LSTM 计算更高效,效果相似语音处理、文本生成

例子:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


# 1. 生成正弦波数据(仅使用 PyTorch)
def generate_sine_wave(seq_length=10, num_samples=1000):
    x = torch.linspace(0, 100, num_samples)  # 生成 1000 个等间距数据点
    y = torch.sin(x)  # 计算正弦值

    X_data, Y_data = [], []
    for i in range(len(y) - seq_length):
        X_data.append(y[i:i + seq_length].unsqueeze(-1))  # 过去 seq_length 作为输入
        Y_data.append(y[i + seq_length])  # 预测下一个点

    return torch.stack(X_data), torch.tensor(Y_data).unsqueeze(-1)


# 生成数据
seq_length = 10  # 序列长度
X, Y = generate_sine_wave(seq_length)

# 划分训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
Y_train, Y_test = Y[:train_size], Y[train_size:]


# 2. 定义 RNN 模型
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)  # 初始化隐藏状态
        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])  # 取最后一个时间步的输出
        return out


# 3. 训练模型
# 超参数
input_size = 1
hidden_size = 32
output_size = 1
num_layers = 1
num_epochs = 100
learning_rate = 0.001

# 初始化模型
model = SimpleRNN(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    outputs = model(X_train)
    loss = criterion(outputs, Y_train)

    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 4. 评估与绘图
model.eval()
with torch.no_grad():
    predictions = model(X_test)

# 画图
plt.figure(figsize=(10, 5))
plt.plot(Y_test.numpy(), label="Real Data")
plt.plot(predictions.numpy(), label="Predicted Data")
plt.legend()
plt.title("RNN Sine Wave Prediction")
plt.show()

代码解析

数据生成

  • torch.linspace(0, 100, num_samples) 生成 1000 个均匀分布的数据点。
  • torch.sin(x) 计算正弦值,形成时间序列数据。
  • X过去 10 个时间步的数据,Y下一个时间步的预测目标

构建 RNN

  • nn.RNN(input_size, hidden_size, num_layers, batch_first=True) 定义循环神经网络
    • input_size=1:每个时间步只有一个输入值(正弦波)。
    • hidden_size=32:隐藏层神经元数目。
    • num_layers=1:单层 RNN。
  • self.fc = nn.Linear(hidden_size, output_size) 负责最终输出。

训练

  • 使用 MSELoss(均方误差损失) 计算预测值与真实值的误差。
  • 使用 Adam 优化器 更新模型参数。
  • 每 10 个 epoch 输出一次损失 loss

测试 & 绘图

  • 关闭梯度计算 (torch.no_grad()),执行前向传播预测测试数据。
  • Matplotlib 绘制预测曲线与真实曲线。

运行效果

如果训练成功,预测曲线(橙色)应该与真实曲线(蓝色)非常接近

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2289629.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python从零构建macOS状态栏应用(仿ollama)并集成AI同款流式聊天 API 服务(含打包为独立应用)

在本教程中,我们将一步步构建一个 macOS 状态栏应用程序,并集成一个 Flask 服务器,提供流式响应的 API 服务。 如果你手中正好持有一台 MacBook Pro,又怀揣着搭建 AI 聊天服务的想法,却不知从何处迈出第一步,那么这篇文章绝对是你的及时雨。 最终,我们将实现以下功能: …

leetcode 2080. 区间内查询数字的频率

题目如下 数据范围 示例 这题十分有意思一开始我想对每个子数组排序二分结果超时了。 转换思路:我们可以提前把每个数字出现的位置先记录下来形成集合, 然后拿着left和right利用二分查找看看left和right是不是在集合里然后做一个相减就出答案了。通过…

深入了解 SSRF 漏洞:原理、条件、危害

目录 前言 SSRF 原理 漏洞产生原因 产生条件 使用协议 使用函数 漏洞影响 防御措施 结语 前言 本文将深入剖析 SSRF(服务端请求伪造)漏洞,从原理、产生原因、条件、影响,到防御措施,为你全面梳理相关知识&am…

11.QT控件:输入类控件

1. Line Edit(单行输入框) QLineEdit表示单行输入框,用来输入一段文本,但是不能换行。 核心属性: 核心信号: 2. Text Edit(多行输入框) QTextEdit表示多行输入框,也是一个富文本 & markdown编辑器。并且能在内容超…

Cesium+Vue3教程(011):打造数字城市

文章目录 Cesium打造数字城市创建项目加载地球设置底图设置摄像头查看具体位置和方向添加纽约建筑模型并设置样式添加纽约建筑模型设置样式划分城市区域并着色地图标记显示与实现实现飞机巡城完整项目下载Cesium打造数字城市 创建项目 使用vite创建vue3项目: pnpm create v…

Windows系统本地部署deepseek 更改目录

本地部署deepseek 无论是mac还是windows系统本地部署deepseek或者其他模型的命令和步骤是一样的。 可以看: 本地部署deepsek 无论是ollama还是部署LLM时候都默认是系统磁盘,对于Windows系统,我们一般不把应用放到系统盘(C:)而是…

基于Python的药物相互作用预测模型AI构建与优化(下.代码部分)

四、特征工程 4.1 分子描述符计算 分子描述符作为量化分子性质的关键数值,能够从多维度反映药物分子的结构和化学特征,在药物相互作用预测中起着举足轻重的作用。RDKit 库凭借其强大的功能,为我们提供了丰富的分子描述符计算方法,涵盖了多个重要方面的分子性质。 分子量…

[Python学习日记-79] socket 开发中的粘包现象(解决模拟 SSH 远程执行命令代码中的粘包问题)

[Python学习日记-79] socket 开发中的粘包现象(解决模拟 SSH 远程执行命令代码中的粘包问题) 简介 粘包问题底层原理分析 粘包问题的解决 简介 在Python学习日记-78我们留下了两个问题,一个是服务器端 send() 中使用加号的问题&#xff0c…

origin如何在已经画好的图上修改数据且不改变原图像的画风和格式

例如我现在的.opju文件长这样 现在我换了数据集,我想修改这两个图表里对应的算法里的数据,但是我还想保留这图像现在的形式,可以尝试像下面这样做: 右击第一个图,出现下面,选择Book[sheet1] 选择工作簿 出…

5.3.2 软件设计原则

文章目录 抽象模块化信息隐蔽与独立性衡量 软件设计原则:抽象、模块化、信息隐蔽。 抽象 抽象是抽出事物本质的共同特性。过程抽象是指将一个明确定义功能的操作当作单个实体看待。数据抽象是对数据的类型、操作、取值范围进行定义,然后通过这些操作对数…

【ArcGIS遇上Python】批量提取多波段影像至单个波段

本案例基于ArcGIS python,将landsat影像的7个波段影像数据,批量提取至单个波段。 相关阅读:【ArcGIS微课1000例】0141:提取多波段影像中的单个波段 文章目录 一、数据准备二、效果比对二、python批处理1. 编写python代码2. 运行代码一、数据准备 实验数据及完整的python位…

Spring Security(maven项目) 3.0.2.9版本 --- 改

前言: 通过实践而发现真理,又通过实践而证实真理和发展真理。从感性认识而能动地发展到理性认识,又从理性认识而能动地指导革命实践,改造主观世界和客观世界。实践、认识、再实践、再认识,这种形式,循环往…

仿真设计|基于51单片机的温度与烟雾报警系统

目录 具体实现功能 设计介绍 51单片机简介 资料内容 仿真实现(protues8.7) 程序(Keil5) 全部内容 资料获取 具体实现功能 (1)LCD1602实时监测及显示温度值和烟雾浓度值; (2…

深入剖析 CSRF 漏洞:原理、危害案例与防护

目录 前言 漏洞介绍 漏洞原理 产生条件 产生的危害 靶场练习 post 请求csrf案例 防御措施 验证请求来源 设置 SameSite 属性 双重提交 Cookie 结语 前言 在网络安全领域,各类漏洞层出不穷,时刻威胁着用户的隐私与数据安全。跨站请求伪造&…

buuuctf_秘密文件

题目: 应该是分析流量包了,用wireshark打开 我追踪http流未果,分析下ftp流 追踪流看看 用户 “ctf” 使用密码 “ctf” 登录。 PORT命令用于为后续操作设置数据连接。 LIST命令用于列出 FTP 服务器上目录的内容,但在此日志中未…

课程设计|结构力学

课 程 设 计 第一部分 (结构力学) 2、两种结构在静力等效荷载作用下,内力有哪些不同?(分析比较) 1/2 1 1 1 1 1 1/2 1/4 11(1/2) 1/4 图1求解过程及结果: 轴力图: 内力计算 单位&…

跟李沐学AI:视频生成类论文精读(Movie Gen、HunyuanVideo)

Movie Gen:A Cast of Media Foundation Models 简介 Movie Gen是Meta公司提出的一系列内容生成模型,包含了 3.2.1 预训练数据 Movie Gen采用大约 100M 的视频-文本对和 1B 的图片-文本对进行预训练。 图片-文本对的预训练流程与Meta提出的 Emu: Enh…

keil5如何添加.h 和.c文件,以及如何添加文件夹

1.简介 在hal库的编程中我们一般会生成如下的几个文件夹,在这几个文件夹内存储着各种外设所需要的函数接口.h文件,和实现函数具体功能的.c文件,但是有时我们想要创建自己的文件夹并在这些文件夹下面创造.h .c文件来实现某些功能,…

2025-1-28-sklearn学习(47) (48) 万家灯火亮年至,一声烟花开新来。

文章目录 sklearn学习(47) & (48)sklearn学习(47) 把它们放在一起47.1 模型管道化47.2 用特征面进行人脸识别47.3 开放性问题: 股票市场结构 sklearn学习(48) 寻求帮助48.1 项目邮件列表48.2 机器学习从业者的 Q&A 社区 sklearn学习(47) & (48) 文章参考网站&…

DeepSeek 云端部署,释放无限 AI 潜力!

1.简介 目前,OpenAI、Anthropic、Google 等公司的大型语言模型(LLM)已广泛应用于商业和私人领域。自 ChatGPT 推出以来,与 AI 的对话变得司空见惯,对我而言没有 LLM 几乎无法工作。 国产模型「DeepSeek-R1」的性能与…