勾八头歌之RNN

news2024/11/17 17:28:47

一、RNN快速入门

1.学习单步的RNN:RNNCell

# -*- coding: utf-8 -*-
import tensorflow as tf

# 参数 a 是 BasicRNNCell所含的神经元数, 参数 b 是 batch_size, 参数 c 是单个 input 的维数,shape = [ b , c ]
def creatRNNCell(a,b,c):
    # 请在此添加代码 完成本关任务
    # ********** Begin *********#
    x1=tf.placeholder(tf.float32,[b,c])

    cell=tf.nn.rnn_cell.BasicRNNCell(num_units=a)
    h0=cell.zero_state(batch_size=b,dtype=tf.float32)

    output,h1=cell.__call__(x1,h0)

    print(cell.state_size)
    print(h1)
    # ********** End **********#

2.探幽入微LSTM

# -*- coding: utf-8 -*-
import tensorflow as tf

# 参数 a 是 BasicLSTMCell所含的神经元数, 参数 b 是 batch_size, 参数 c 是单个 input 的维数,shape = [ b , c ]
def creatLSTMCell(a,b,c):
    # 请在此添加代码 完成本关任务
    # ********** Begin *********#

    x1=tf.placeholder(tf.float32,[b,c])

    cell=tf.nn.rnn_cell.BasicLSTMCell(num_units=a)
    h0=cell.zero_state(batch_size=b,dtype=tf.float32)

    output,h1=cell.__call__(x1,h0)
    
    print(h1.h)
    print(h1.c)
    # ********** End **********#

3.进阶RNN:学习一次执行多步以及堆叠RNN

# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np

# 参数 a 是RNN的层数, 参数 b 是每个BasicRNNCell包含的神经元数即state_size
# 参数 c 是输入序列的批量大小即batch_size,参数 d 是时间序列的步长即time_steps,参数 e 是单个输入input的维数即input_size
def MultiRNNCell_dynamic_call(a,b,c,d,e):
    # 用tf.nn.rnn_cell MultiRNNCell创建a层RNN,并调用tf.nn.dynamic_rnn
    # 请在此添加代码 完成本关任务
    # ********** Begin *********#
    cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicRNNCell(num_units=b) for _ in range(a)]) # a层RNN
    inputs = tf.placeholder(np.float32, shape=(c, d, e)) # a 是 batch_size,d 是time_steps, e 是input_size
    h0=cell.zero_state(batch_size=c,dtype=tf.float32)
    output, h1 = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0)
    print(output)
    # ********** End **********#

二、RNN循环神经网络

1.Attention注意力机制(A  ABC  B  C  A)

2.Seq2Seq

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

dtype = torch.FloatTensor
char_list = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
char_dic = {n: i for i, n in enumerate(char_list)}
seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]
seq_len = 8
n_hidden = 128
n_class = len(char_list)
batch_size = len(seq_data)

##########Begin##########
#对数据进行编码部分
##########End##########
def make_batch(seq_data):
    batch_size = len(seq_data)
    input_batch, output_batch, target_batch = [], [], []
    for seq in seq_data:
        for i in range(2):
            seq[i] += 'P' * (seq_len - len(seq[i]))
        input = [char_dic[n] for n in seq[0]]
        output = [char_dic[n] for n in ('S' + seq[1])]
        target = [char_dic[n] for n in (seq[1] + 'E')]
        input_batch.append(np.eye(n_class)[input])
        output_batch.append(np.eye(n_class)[output])
        target_batch.append(target)
    return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch))


##########Begin##########
#模型类定义
input_batch, output_batch, target_batch = make_batch(seq_data)
class Seq2Seq(nn.Module):
    def __init__(self):
        super(Seq2Seq, self).__init__()
        self.encoder = nn.RNN(input_size=n_class, hidden_size=n_hidden)
        self.decoder = nn.RNN(input_size=n_class, hidden_size=n_hidden)
        self.fc = nn.Linear(n_hidden, n_class)
    def forward(self, enc_input, enc_hidden, dec_input):
        enc_input = enc_input.transpose(0, 1)
        dec_input = dec_input.transpose(0, 1)
        _, h_states = self.encoder(enc_input, enc_hidden)
        outputs, _ = self.decoder(dec_input, h_states)
        outputs = self.fc(outputs)
        return outputs
##########End##########

model = Seq2Seq()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

##########Begin##########
#模型训练过程
for epoch in range(5001):
    hidden = Variable(torch.zeros(1, batch_size, n_hidden))
    optimizer.zero_grad()
    outputs = model(input_batch, hidden, output_batch)
    outputs = outputs.transpose(0, 1)
    loss = 0
    for i in range(batch_size):
        loss += criterion(outputs[i], target_batch[i])
    loss.backward()
    optimizer.step()
##########End##########


##########Begin##########
#模型验证过程函数
def translated(word):
    input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]])
    hidden = Variable(torch.zeros(1, 1, n_hidden))
    outputs = model(input_batch, hidden, output_batch)
    predict = outputs.data.max(2, keepdim=True)[1]
    decode = [char_list[i] for i in predict]
    end = decode.index('P')
    translated = ''.join(decode[:end])
    print(translated)
##########End##########

translated('highh')
translated('kingh')

三、RNN和LSTM

1.循环神经网络简介

import torch

def rnn(input,state,params):
    """
    循环神经网络的前向传播
    :param input: 输入,形状为 [ batch_size,num_inputs ]
    :param state: 上一时刻循环神经网络的状态,形状为 [ batch_size,num_hiddens ]
    :param params: 循环神经网络的所使用的权重以及偏置
    :return: 输出结果和此时刻网络的状态
    """
    W_xh,W_hh,b_h,W_hq,b_q = params
    """
    W_xh : 输入层到隐藏层的权重
    W_hh : 上一时刻状态隐藏层到当前时刻的权重
    b_h : 隐藏层偏置
    W_hq : 隐藏层到输出层的权重
    b_q : 输出层偏置
    """
    H = state
    # 输入层到隐藏层
    H = torch.matmul(input, W_xh) + torch.matmul(H, W_hh) + b_h
    H = torch.tanh(H)
    # 隐藏层到输出层
    Y = torch.matmul(H, W_hq) + b_q
    return Y,H

def init_rnn_state(num_inputs,num_hiddens):
    """
    循环神经网络的初始状态的初始化
    :param num_inputs: 输入层中神经元的个数
    :param num_hiddens: 隐藏层中神经元的个数
    :return: 循环神经网络初始状态
    """
    init_state = torch.zeros((num_inputs,num_hiddens),dtype=torch.float32)
    return init_state

2.长短时记忆网络

import torch

def lstm(X,state,params):
    """
    LSTM
    :param X: 输入
    :param state: 上一时刻的单元状态和输出
    :param params: LSTM 中所有的权值矩阵以及偏置
    :return: 当前时刻的单元状态和输出
    """
    W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
    """
    W_xi,W_hi,b_i : 输入门中计算i的权值矩阵和偏置
    W_xf,W_hf,b_f : 遗忘门的权值矩阵和偏置
    W_xo,W_ho,b_o : 输出门的权值矩阵和偏置
    W_xc,W_hc,b_c : 输入门中计算c_tilde的权值矩阵和偏置
    W_hq,b_q : 输出层的权值矩阵和偏置
    """
    #上一时刻的输出 H 和 单元状态 C。
    (H,C) = state
    # 遗忘门
    F = torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f
    F = torch.sigmoid(F)
    # 输入门
    I = torch.sigmoid(torch.matmul(X,W_xi)+torch.matmul(H,W_hi) + b_i)
    C_tilde = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)
    C = F * C + I * C_tilde
    # 输出门
    O = torch.sigmoid(torch.matmul(X,W_xo)+torch.matmul(H,W_ho) + b_o)
    H = O * C.tanh()
    # 输出层
    Y = torch.matmul(H,W_hq) + b_q
    return Y,(H,C)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1619754.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

typedef 定义函数指针

typdef int(*FUNC_TYPE)(int,int) FUNC_TYPE p NULL; 定义了一个函数指针 函数指针作为函数的参数的用法demon

批量提取SemEval 2014 Task 4-aspect_term的xml文件为csv

批量提取SemEval 2014 Task 4-aspect_term的xml文件为csv 数据data 格式 <sentence id"892:1"> <text>Boot time is super fast, around anywhere from 35 seconds to 1 minute.</text> <aspectTerms> <aspectTerm term"Boot time&…

《HCIP-openEuler实验指导手册》1.1Apache安装与测试

一、安装httpd 查看软件仓库中apache版本列表 dnf provides http 安装apache dnf install -y httpd 二、启动http并测试 查看apache版本号 httpd -v 检查配置文件是否正确 httpd -t 将如下97行取消注释消除报错 重新测试配置文件 httpd -t 启动并设置为开机启动 syste…

Unity系统学习笔记

文章目录 1.基础组件的认识1.0.组件继承关系图1.1.项目工程文件结构&#xff0c;各个文件夹都是做什么的&#xff1f;1.2.物体变化组件1.2.3.三维向量表示方向1.2.4.移动物体位置附录&#xff1a;使用变换组件实现物体WASD移动 1.3.游戏物体和组件的显示和禁用1.3.1.界面上的操…

HarmonyOS开发案例:【图片编辑】

介绍 本篇Codelab是基于ArkTS的声明式开发范式的样例&#xff0c;主要介绍了图片编辑实现过程。样例主要包含以下功能&#xff1a; 图片的解码。使用PixelMap进行图片编辑&#xff0c;如裁剪、旋转、亮度、透明度、饱和度等。图片的编码。 相关概念 [图片解码]&#xff1a;读…

Anon Network:基于 Ator Protocol 的 DePIN 匿名互联网

Anon Network正在以Ator Protocol为基础构建世界上最大的Web3隐私互联网生态&#xff0c;其旨在基于DePIN网络&#xff08;Ator protocol&#xff09;&#xff0c;通过激励体系构建一个自下而上、自我维持且可持续、不依赖于任何三方实体且完全匿名的完备互联体系。在该体系中&…

SOLIDWORKS Electrical 3D--精准的三维布线

相信很多工程师在实际生产的时候都会遇到线材长度不准确的问题&#xff0c;从而导致线材浪费甚至整根线材报废的问题&#xff0c;这基本都是由于人工测量长度所导致的&#xff0c;因此本次和大家简单介绍一下SOLIDWORKS Electrical 3D布线的功能&#xff0c;Electrical 3D布线能…

ArrayList 和LinkedList

目录 ArrayListadd自动扩容ArrayList的remove()方法查找 indexof LinkedListLinkedList的add方法LinkedList的remove方法查找 indexof arraylist和linkedlist的区别 ArrayList ArrayList 的底层是数组队列&#xff0c;相当于动态数组。与 Java 中的数组相比&#xff0c;它的容…

恒峰智慧科技-太阳能语音警示杆:节能环保新时代的标配!

随着社会的发展和人们对环境保护意识的不断提高&#xff0c;节能环保已经成为了新时代的标配。在这个大背景下&#xff0c;太阳能语音警示杆应运而生&#xff0c;为森林防火工作提供了有力的支持。太阳能语音警示杆是一种集太阳能发电、语音播报、红蓝警示灯于一体的多功能设备…

【解决NodeJS项目无法在IDEA中调试的问题】使用JetBrains IDEA 2023 调试nodejs项目

项目采用Ant Design Pro React&#xff0c;使用前后端分离开发方式&#xff0c;后端可以很容易的打断点调试&#xff0c;但是前端通过网页进行调试&#xff0c;在IDEA中加了调试断点&#xff0c;但是没有什么用处。 解决方案如下&#xff1a; 点击新建运行配置 新建JavaScrip…

比亚迪唐EV和唐DM-p荣耀版上市,成为新能源汽车市场中的佼佼者!

随着环保理念的深入人心&#xff0c;新能源汽车市场正迎来前所未有的发展机遇。在这个变革的浪潮中&#xff0c;唐EV和唐DM-p荣耀版的上市无疑为市场注入了新的活力。它们凭借先进的技术、卓越的性能以及豪华配置&#xff0c;成为了新能源汽车市场中的佼佼者。然而&#xff0c;…

银行买的黄金怎么卖出去?了解黄金交易的步骤和注意事项

黄金一直以来都是备受投资者关注的贵金属之一。银行提供了购买黄金的机会&#xff0c;但投资者也需要了解如何卖出银行买的黄金。 选择适合的购买方式 投资者可以通过多种途径购买黄金&#xff0c;其中包括银行提供的黄金交易服务。银行买黄金的方式可以是通过黄金交易账户、黄…

从C向C++14——STL初识及函数对象

一.STL初识 1.STL的诞生 长久以来&#xff0c;软件界一直希望建立一种可重复利用的东西C的面向对象和泛型编程思想&#xff0c;目的就是复用性的提升多情况下&#xff0c;数据结构和算法都未能有一套标准,导致被迫从事大量重复工作为了建立数据结构和算法的一套标准,诞生了ST…

MPC的横向控制与算法仿真实现

文章目录 1. 引言2. 模型预测控制&#xff08;MPC&#xff09;2.1 基础知识2.2 MPC的整体流程2.3 MPC的设计求解 3. 车辆运动学MPC设计4. 算法和仿真实现 1. 引言 随着智能交通系统和自动驾驶技术的发展&#xff0c;车辆的横向控制成为了研究的热点。横向控制指的是对车辆在行…

《html自用使用指南》--基于w3School实践

1.基础标签 文本输入时&#xff0c;在编辑器中的换行&#xff0c;多个空格&#xff0c;都被编辑器看作一个空格 <p> 这个段落 在源代码 中 包含 许多行 但是 浏览器 忽略了 它们。 </p>结果&#xff1a;这个段落 在源代码 中 包含 许多行 但是 浏览器…

机器学习中常见的数据分析,处理方式(以泰坦尼克号为例)

数据分析 读取数据查看数据各个参数信息查看有无空值如何填充空值一些特殊字段如何处理读取数据查看数据中的参数信息实操具体问题具体分析年龄问题 重新划分数据集如何删除含有空白值的行根据条件删除一些行查看特征和标签的相关性 读取数据 查看数据各个参数信息 查看有无空…

Python Tiler库:创建可视化网格布局的利器

更多Python学习内容&#xff1a;ipengtao.com Tiler是一个Python库&#xff0c;用于创建各种类型的网格布局&#xff0c;包括等宽/等高布局、自定义大小布局、响应式布局等。本文将深入介绍Tiler库的功能、用法以及示例代码&#xff0c;帮助读者全面了解并灵活应用该库。 安装和…

主打国产算力 广州市通用人工智能公共算力中心项目签约

4月9日&#xff0c;第十届广州国际投资年会期间&#xff0c;企商在线&#xff08;北京&#xff09;数据技术股份有限公司与广州市增城区政府就“广州市通用人工智能公共算力中心”项目进行签约。 该项目由广州市增城区人民政府发起&#xff0c;企商在线承建。项目拟建成中国最…

【Linux】实现一个进度条

我们之前也学了gcc/vim/make和makefile&#xff0c;那么我们就用它们实现一个进度条。 在实现这个进度条之前&#xff0c;我们要先简单了解一下缓冲区和回车和换行的区别 缓冲区其实就是一块内存空间&#xff0c;我们先看这样一段代码 它的现象是先立马打印&#xff0c;三秒后程…

centos7.6上安装mysql7.6 完整过程

安装过程&#xff1a; 参考&#xff1a;https://blog.csdn.net/qq_45103475/article/details/123151050 查找mysql [rootbogon ~]# whereis mysql mysql: /usr/lib64/mysql /usr/share/mysql 删除目录 [rootbogon ~]# rm -rf /usr/lib64/mysql [rootbogon ~]# whereis mysql m…