【人工智能基础】RNN实验

news2024/9/26 5:14:28

一、RNN特性

权重共享

wordi · weight + bais

持久记忆单元

wordi · weightword + baisword + hi · weighth + baish

二、公式化表达

ht</sub = f(ht - 1, xt)
ht = tanh(Whhht - 1 + Wxhxt)
yt = Whyht

三、RNN网络正弦波波形预测

环境准备

import numpy as np
import torch
from torch import nn,optim
from matplotlib import pyplot as plt

# 时间轴采样数
num_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
lr = 0.01

RNN类

class Net(nn.Module):
    def __init__(self,):
        super(Net, self).__init__()
        self.rnn = nn.RNN(
            input_size = input_size, 
            hidden_size = hidden_size, 
            num_layers = 1,
            # 格式为[batch, seq, feature]
            batch_first = True
        )
        for p in self.rnn.parameters():
            nn.init.normal_(p,mean=0.0, std=0.001)
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden_prev):
        out, hidden_prev = self.rnn(x, hidden_prev)
        # [1, seq, h] => [seq, h]
        out = out.view(-1,hidden_size)
        # [seq, h] => [seq, 1]
        out = self.linear(out)
        # [seq, 1] => [1, seq, 1], 需要和y做均方差
        out = out.unsqueeze(dim=0)
        return out, hidden_prev.clone()

正弦数据构建函数

def create_image():
    start = np.random.randint(3, size=1)[0]
    time_steps = np.linspace(start, start + 10, num_time_steps)
    data = np.sin(time_steps)
    data = data.reshape(num_time_steps, 1)
    x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
    y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
    return time_steps,x, y

训练模型


model = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr)

hidden_prev = torch.zeros(1,1, hidden_size)
for iter in range(6000):
    time_steps,x, y = create_image()
    output, hidden_prev = model(x, hidden_prev)
    hidden_prev = hidden_prev.detach()

    loss = criterion(output,y)
    model.zero_grad()
    loss.backward()
    for p in model.parameters():
        torch.nn.utils.clip_grad_norm_(p,10)
    optimizer.step()

    if iter % 1000 == 0:
        plt.plot(time_steps[:-1], x.ravel(), c = 'b')
        plt.plot(time_steps[:-1], y.ravel(), c= 'r')
        plt.plot(time_steps[:-1], output.detach().numpy().ravel(), c= 'g')
        plt.show()
        print('Iteration:{} loss {}'.format(iter, loss.item()))

可以看到第二次绘制图像的时候,输出曲线基本拟合了目标曲线

未训练图像

训练后图像

图像预测

time_steps,x, y = create_image()

predictions = []
# input = x[:, 0, :]
for i in range(x.shape[1]):
    input = x[:, i, :].view(1, 1, 1)
    (pred, hiden_prev) = model(input, hidden_prev)
    input = pred
    predictions.append(pred.detach().numpy().ravel()[0])

x = x.data.numpy().ravel()

y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())

plt.scatter(time_steps[1:],predictions)
plt.show()
    

输出的预测曲线基本与目标曲线相同

预测结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1644342.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何快速找出文件夹里的全部带有中文纯中文的文件

首先&#xff0c;需要用到的这个工具YTool&#xff1a; 度娘网盘 提取码&#xff1a;qwu2 蓝奏云 提取码&#xff1a;2r1z 步骤 1、打开工具&#xff0c;切换到批量复制文件 2、鼠标移到右侧&#xff0c;点击搜索添加 3、设定查找范围、指定为文件、勾选 包含全部子文件夹&…

macOS DOSBox 汇编环境搭建

正文 一、安装DOSBox 首先前往DOSBox的官网下载并安装最新版本的DOSBox。 二、下载必备的工具包 在用户目录下新建一个文件夹&#xff0c;比如 dosbox: mkdir dosbox然后下载一些常用的工具。下载好了后&#xff0c;将这些工具解压&#xff0c;重新放在 dosbox 这个文件夹…

微服务---feign调用服务

目录 Feign简介 Feign的作用 Feign的使用步骤 引入依赖 具体业务逻辑 配置日志 在其它服务中使用接口 接着上一篇博客&#xff0c;我们讲过了nacos的基础使用&#xff0c;知道它是注册服务用的&#xff0c;接下来我们我们思考如果一个服务需要调用另一个服务的接口信息&…

ICode国际青少年编程竞赛- Python-1级训练场-识别循环规律1

ICode国际青少年编程竞赛- Python-1级训练场-识别循环规律1 1、 for i in range(4):Dev.step(6)Dev.turnLeft()2、 for i in range(3):Dev.turnLeft()Dev.step(2)Dev.turnRight()Dev.step(2)3、 for i in range(3):Spaceship.step(5)Spaceship.turnLeft()Spaceship.step(…

MySQL: Buffer Pool概念整理

一. 简介 MySQL中的Buffer Pool是InnoDB存储引擎用来缓存表数据和索引的内存区域。这是InnoDB性能优化中最关键的部分之一。通过在内存中缓存这些数据&#xff0c;InnoDB可以极大减少对磁盘I/O的需求&#xff0c;因为从内存中读取数据远比从磁盘读取要快得多。因此&#xff0c…

如何修复连接失败出现的错误651?这里提供修复方法

错误651消息在Windows 7到Windows 11上很常见&#xff0c;通常会出现在一个小的弹出窗口中。实际文本略有不同&#xff0c;具体取决于连接问题的原因&#xff0c;但始终包括文本“错误651”。 虽然很烦人&#xff0c;但错误651是一个相对较小的问题&#xff0c;不应该导致计算…

01-JDK安装(Window环境和Linux环境)

1. Windows环境安装JDK 1.1 Oracle官网下载需要版本的JDK 官网传送门https://www.oracle.com/java/technologies/downloads/#java8-windows下载完成之后 以管理员身份&#xff08;管理员&#xff01;管理员&#xff01;&#xff09;运行下载的exe文件 期间修改需要安装的路径…

【多模态】29、OCRBench | 为大型多模态模型提供一个 OCR 任务测评基准

文章目录 一、背景二、实验2.1 测评标准和结果2.1.1 文本识别 Text Recognition2.1.2 场景文本中心的视觉问答 Scene Text-Centric VQA2.1.3 文档导向的视觉问答 Document-Oriented VQA2.1.4 关键信息提取 Key Information Extraction2.1.5 手写数学公式识别 Handwritten Mathe…

【学一点儿前端】Bad value with message: unexpected token `.`. 问题及解决方法

问题 今天从vue3的项目copy一段代码到vue2项目&#xff0c;编译后访问页面报错了 Bad value with message: unexpected token ..注意到错误字符‘.’&#xff0c;这个错误通常发生在处理 JavaScript 或者 HTML 中的动态表达式中&#xff0c;日常使用二分法不断缩小报错代码范…

AI新突破:多标签预测技术助力语言模型提速3倍

DeepVisionary 每日深度学习前沿科技推送&顶会论文分享&#xff0c;与你一起了解前沿深度学习信息&#xff01; 引言&#xff1a;多标签预测的新视角 在人工智能领域&#xff0c;尤其是在自然语言处理&#xff08;NLP&#xff09;中&#xff0c;预测模型的训练方法一直在…

每天五分钟深度学习:数学中常见函数中的导数

本文重点 导数是微积分学中的一个核心概念,它描述了函数在某一点附近的变化率。在物理学、工程学、经济学等众多领域中,导数都发挥着极其重要的作用。本文旨在详细介绍数学中常见函数的导数,以期为读者提供一个全面而深入的理解。 数学中常见的导数 常数函数的导数 对于常数…

利用策略模式+模板方法实现项目中运维功能

前段时间项目中有个需求&#xff1a;实现某业务的运维功能&#xff0c;主要是对10张数据库表的增删改查&#xff0c;没有复杂的业务逻辑&#xff0c;只是满足运维人员的基本需要&#xff0c;方便他们快速分析定位问题。这里简单记录分享下实现方案&#xff0c;仅供参考。 一、…

对C语言符号的一些冷门知识运用的剖析和总结

符号 目录* 符号 注释 - 奇怪的注释 - C风格的注释无法嵌套 - 一些特殊的注释 - 注释的规则建议 反斜杠’’ - 反斜杠有续行的作用&#xff0c;但要注意续行后不能添加空格 * 回车也能起到换行的作用&#xff0c;那续行符的意义在哪&#xff1f; - 反斜杠的转义功能 单引号…

HCIP的学习(11)

OSPF的LSA详解 LSA头部信息 ​ [r2]display ospf lsdb router 1.1.1.1----查看OSPF某一条LSA的详细信息&#xff0c;类型以及LS ID参数。 链路状态老化时间 指一条LSA的老化时间&#xff0c;即存在了多长时间。当一条LSA被始发路由器产生时&#xff0c;该参数值被设定为0之后…

电机控制系列模块解析(16)—— 电流环

一、FOC为什么使用串联控制器 在此说明&#xff0c;串联形式&#xff08;内外环形式&#xff0c;速度环和电流环控制器串联&#xff09;并不是必须的&#xff0c;但是对于线性控制系统来说&#xff0c;电机属于非线性控制对象&#xff0c;早期工程师们为了处理电机的非线性&am…

备考2024年小学生古诗文大会:吃透10道历年真题和知识点(持续)

对上海小学生的小升初和各种评优争章来说&#xff0c;语文、数学、英语的含金量较高的证书还是很有价值和帮助的。对于语文类的竞赛&#xff0c;小学生古诗文大会和汉字小达人通常是必不可少的&#xff0c;因为这两个针对性强&#xff0c;而且具有很强的上海本地特色。 今天我…

Python | Leetcode Python题解之第69题x的平方根

题目&#xff1a; 题解&#xff1a; class Solution:def mySqrt(self, x: int) -> int:if x 0:return 0C, x0 float(x), float(x)while True:xi 0.5 * (x0 C / x0)if abs(x0 - xi) < 1e-7:breakx0 xireturn int(x0)

kotlin语法快速入门--(完整版)

Kotlin语法入门 文章目录 Kotlin语法入门一、变量声明1、整型2、字符型3、集合3.1、创建array数组3.2、创建list集合3.3、不可变类型数组3.4、Set集合--不重复添加元素3.5、键值对集合Map 4、kotlin特有的数据类型和集合4.1、Any、Nothing4.2、二元组--Pair4.3、三元组--Triple…

ue引擎游戏开发笔记(31)——对角色移动进行优化:角色滑步处理

1.需求分析&#xff1a; 角色的移动与动画不匹配&#xff0c;角色移动起来像是在滑行。。。适当进行优化。 2.操作实现&#xff1a; 这个问题本质是角色的运动速度并没有匹配世界动画的运行速度&#xff0c;不论世界动画快慢于角色移动速度&#xff0c;都会感到有滑步感。所以…

VMware worksation 17 简易安装Centos8.2、Redhat8.2、Ubuntu16.04

系列文章目录 文章目录 系列文章目录前言一、VMware worksation 17 安装二、安装Centos8.2三、安装RHEL8.2四、安装Ubuntu16.04总结 前言 傻瓜式按照Linux系统&#xff0c;如果觉得简单&#xff0c;可以自定义设置&#xff0c;特别是配置一下磁盘空间大小&#xff0c;对以后排…