RNN预测下一句文本简单示例

news2025/1/20 3:56:29

根据句子前半句的内容推理出后半部分的内容,这样的任务可以使用循环的方式来实现。

RNN(Recurrent Neural Network,循环神经网络)是一种用于处理序列数据的强大神经网络模型。与传统的前馈神经网络不同,RNN能够通过其循环结构捕获序列内部的时间依赖性或顺序信息。

在RNN中,每个时间步(timestep)的隐藏状态不仅取决于当前输入,还与上一时间步的隐藏状态有关。这种递归特性使得网络能记忆过去的信息,并将其与当前输入相结合以做出决策或生成输出。

由于存在“梯度消失”和“梯度爆炸”的问题,在长序列建模时原始RNN可能效果不佳。因此,发展出了更复杂的变体,如LSTM(Long Short-Term Memory)和GRU(Gated Recurrent Units),它们通过门控机制更好地保留长期依赖信息。这些改进后的循环神经网络广泛应用于语音识别、自然语言处理(NLP)、机器翻译、视频分析等多种领域。

training_file = 'wordstest.txt' 在里面随便写入一些文章,当做数据,

具体代码如下,写了注释

import torch
import torch.nn.functional as F
import time
import random
import numpy as np
from collections import Counter

# 确保每次结果可复现
RANDOM_SEED = 123
torch.manual_seed(RANDOM_SEED)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def elapsed(sec):
    if sec<60:
        return str(sec) + " sec"
    elif sec<(60*60):
        return str(sec/60) + " min"
    else:
        return str(sec/(60*60)) + " hr"


#中文多文件
def readalltxt(txt_files):
    labels = []
    for txt_file in txt_files:

        target = get_ch_lable(txt_file)
        labels.append(target)
    return labels


def get_ch_lable(txt_file):
    """
    读取数据
    :param txt_file:
    :return:
    """
    labels = ""
    with open(txt_file, 'rb') as f:
        for label in f:
            labels += label.decode('utf-8')
            #labels += label.decode('gb2312')

    return labels


def get_ch_lable_v(txt_file, word_num_map, txt_label=None):
    """
    字符转向量
    :param txt_file:
    :param word_num_map:
    :param txt_label:
    :return:
    """
    words_size = len(word_num_map)
    to_num = lambda word: word_num_map.get(word, words_size)
    if txt_file != None:
        txt_label = get_ch_lable(txt_file)

    labels_vector = list(map(to_num, txt_label))
    return labels_vector


# 文本预处理,生成词向量
training_file = 'wordstest.txt'
training_data = get_ch_lable(training_file)
print("Loaded training data...")
print('样本长度:', len(training_data))
counter = Counter(training_data)
words = sorted(counter)
words_size= len(words)
word_num_map = dict(zip(words, range(words_size)))  # 给每个字构建索引,通过索引来处理计算预测每一个字
print('字表大小:', words_size)
wordlabel = get_ch_lable_v(training_file, word_num_map)


'''
GRU 构建 RNN 模型
1、将输入的文字索引转为词嵌入
2、将词嵌入结果输入用 GRU 所形成的网络层
3、对步骤 2 的输出结果做全连接处理,得到维度为【字表长度】的预测结果,这个结果代表的是每个文字的频率
'''
class GRURNN(torch.nn.Module):
    def __init__(self, word_size, embed_dim,
                 hidden_dim, output_size, num_layers):
        super(GRURNN, self).__init__()

        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        self.embed = torch.nn.Embedding(word_size, embed_dim)
        self.gru = torch.nn.GRU(input_size=embed_dim,
                                hidden_size=hidden_dim,
                                num_layers=num_layers, bidirectional=True)
        # bidirectional=True 代表网络是双向的,从前往后,从后往前
        # hidden_dim*2 代表包含了两个维度的层数
        # 全连接层(线性层),它将接收前面双向GRU输出的隐藏状态作为输入
        self.fc = torch.nn.Linear(hidden_dim*2, output_size)

    def forward(self, features, hidden):
        embedded = self.embed(features.view(1, -1))
        output, hidden = self.gru(embedded.view(1, 1, -1), hidden)
        output = self.fc(output.view(1, -1))
        return output, hidden

    def init_zero_state(self):
        """
        一个初始化隐藏状态的方法,主要用于循环神经网络(RNN)类的实例。这个方法的作用是为RNN创建一组全零初始隐藏状态。
        self.num_layers * 2: 表示双向RNN时的层数(如果模型是双向的,即参数bidirectional=True),
            因为每个方向都会有一个隐藏层,所以总共有num_layers * 2个隐藏层。
        1: 表示批量大小(batch size),在这里初始化的是单个样本的隐藏状态,因此设置为1。若需要处理批量数据,则应根据实际批量大小调整。
        self.hidden_dim: 表示隐藏层的维度(hidden dimension),也就是每个隐藏单元的特征数量。
        :return:
        """
        init_hidden = torch.zeros(self.num_layers * 2, 1, self.hidden_dim).to(DEVICE)
        return init_hidden


EMBEDDING_DIM = 10  # 向量的维度或者说长度
HIDDEN_DIM = 20  # 每一个隐藏层的神经元数量
NUM_LAYERS = 1  # 隐藏层数量

model = GRURNN(words_size, EMBEDDING_DIM, HIDDEN_DIM, words_size, NUM_LAYERS)
model = model.to(DEVICE)  # 将模型移动到指定设备上进行计算
# model.parameters():获取模型中所有需要优化的参数。
# Adam:是优化算法的一种,它基于梯度下降法,并结合了动量项(Momentum)和自适应学习率调整策略(RMSProp)。Adam通常在很多深度学习任务中表现良好,因为它能够自动调整学习率并减少对初始化学习率的敏感性。
# lr=0.005:表示设置学习率为0.005,这是Adam算法中的一个重要超参数,决定了每次更新参数时步伐的大小。
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)


def evaluate(model, prime_str, predict_len, temperature=0.8):
    """
    评估函数
    :param model:
    :param prime_str: 一个表示起始序列的整数列表,每个整数代表词汇表中的索引
    :param predict_len: 指定要预测的字符或单词数量
    :param temperature: 控制生成文本时随机性的一个超参数,较小的值会让模型更倾向于生成概率最高的结果,较大的值则会增加多样性
    :return:
    """
    hidden = model.init_zero_state().to(DEVICE)
    predicted = ''

    # 处理输入语义
    # 将生成的字符添加到预测结果字符串 predicted 中
    for p in range(len(prime_str) - 1):
        _, hidden = model(prime_str[p], hidden)
        predicted += words[prime_str[p]]
    # 用最后一个输入字符开始进行预测
    inp = prime_str[-1]
    predicted += words[inp]

    for p in range(predict_len):
        output, hidden = model(inp, hidden)

        #从多项式分布中采样
        # 将模型输出转换为分布形式,通过除以温度 temperature 并求指数得到softmax分布
        output_dist = output.data.view(-1).div(temperature).exp()
        # 根据调整后的分布采样下一个字符的索引
        inp = torch.multinomial(output_dist, 1)[0]

        predicted += words[inp]

    return predicted


#定义参数训练模型
training_iters = 5000
display_step = 1000
n_input = 4
step = 0
offset = random.randint(0, n_input+1)  # 每次迭代结束时,将偏移值向后移动 n_input+1 个距离,保证输入样本的相对均匀
end_offset = n_input + 1

while step < training_iters:
    start_time = time.time()

    # 随机取一个位置偏移
    if offset > (len(training_data) - end_offset):
        offset = random.randint(0, n_input+1)

    # 取出偏移量为 4 的数据长度,因为文本时序列数据
    inwords = wordlabel[offset:offset + n_input]
    # [n_input, -1, 1] 表示重塑后的三维形状:
    # 第一维是序列长度(即每个序列有 n_input 个元素),
    # 第二维 -1 表示自动计算以适应原始数据大小,
    # 第三维为通道数(这里设为1,通常用于表示一维特征)
    inwords = np.reshape(np.array(inwords), [n_input, -1,  1])
    # 编码
    out_onehot = wordlabel[offset+1:offset+n_input+1]
    # 初始化隐藏层
    hidden = model.init_zero_state()
    '''
    模型完成一次前向传播计算并得到损失(loss)后,在反向传播(backpropagation)之前,需要调用这个函数来清零所有可训练参数的梯度。
    在开始新一轮的前向传播和反向传播之前,使用 optimizer.zero_grad() 来清零所有参数的梯度是至关重要的,确保每次优化步骤只基于当前批次数据计算出的梯度来进行参数更新
    '''
    optimizer.zero_grad()

    '''
    模型训练
    '''
    loss = 0.
    # 将输入数据 inwords 和目标数据 out_onehot 转换为PyTorch张量
    inputs, targets = torch.LongTensor(inwords).to(DEVICE), torch.LongTensor(out_onehot).to(DEVICE)
    for c in range(n_input):
        # 当前时间步的输入和前一时间步的隐藏状态运行模型,得到输出 (outputs) 和新的隐藏状态 (hidden)。
        outputs, hidden = model(inputs[c], hidden)
        # 计算当前时间步的交叉熵损失(Cross Entropy Loss),将模型预测的输出与实际的目标标签比较
        loss += F.cross_entropy(outputs, targets[c].view(1))
    # 所有时间步完成后,平均损失值
    loss /= n_input
    # 反向传播计算梯度:调用 .backward() 函数来计算关于损失函数关于模型参数的梯度
    loss.backward()
    # 使用优化器(在这里是 optimizer)根据计算出的梯度更新模型参数
    optimizer.step()

    #输出日志
    # with torch.set_grad_enabled(False): 这一上下文管理器用于在计算过程中暂时禁用梯度计算。这样,在打印损失、评估模型性能等操作时,
    # 不会占用额外的内存来存储中间计算的梯度,同时避免不必要的反向传播计算。
    with torch.set_grad_enabled(False):
        if (step+1) % display_step == 0:
            print(f'Time elapsed: {(time.time() - start_time)/60:.4f} min')
            print(f'step {step+1} | Loss {loss.item():.2f}\n\n')
            # torch.no_grad() 上下文管理器再次禁用梯度计算,以便于高效地进行模型评估,并且不影响之前或之后的梯度计算状态。
            with torch.no_grad():
                print(evaluate(model, inputs, 32), '\n')
            print(50*'=')
    step += 1
    offset += (n_input+1)#中间隔了一个,作为预测

print("Finished!")


# 使用模型
while True:
    prompt = "请输入几个字,最好是%s个: " % n_input
    sentence = input(prompt)
    inputword = sentence.strip()

    try:
        inputword = get_ch_lable_v(None, word_num_map, inputword)
        keys = np.reshape(np.array(inputword), [len(inputword), -1, 1])
        '''
        调用 model.eval() 方法将模型设置为评估模式。在评估模式下,模型中的批量归一化层(如果有)会使用经过训练时平均的移动统计量,
        并且不会更新模型参数(梯度计算被禁用)。
        接下来,通过 with torch.no_grad(): 语句创建了一个临时上下文,在此上下文中执行所有操作时都不会累积梯度。
        这对于生成任务非常关键,因为在这种情况下我们并不关心反向传播以更新模型权重,而是要利用当前模型状态来生成文本。
        '''
        model.eval()
        with torch.no_grad():
            sentence = evaluate(model, torch.LongTensor(keys).to(DEVICE), 32)

        print(sentence)
    except:
        print("该字我还没学会")

运行结果类似:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1416126.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度推荐模型之DeepFM

一、FM 背景&#xff1a;主要解决大规模稀疏数据下的特征组合遇到的问题&#xff1a;1. 二阶特征参数数据呈指数增长 怎么做的&#xff1a;对每个特征引入大小为k的隐向量&#xff0c;两两特征的权重值通过计算对应特征的隐向量内积 而特征之间计算点积的复杂度原本为 实际应…

橘子学ES实战操作01之集群模式如何实现快照备份

我们知道ES中通过副本在一定意义上实现了数据的备份和高可用。但是我们说万一副本数据丢失了&#xff0c;不小心被rm -f了&#xff0c;你就说逆天不逆天吧&#xff0c;此时要实现数据真正意义上的备份就要使用到快照机制&#xff0c;来把数据持久化备份起来&#xff0c;万一数据…

CAD-autolisp(三)——文件、对话框

目录 一、文件操作1.1 写文件1.2 读文件 二、对话框DCL2.1 初识对话框2.2 常用对话框界面2.2.1 复选框、列表框2.2.2 下拉框2.2.3 文字输入框、单选点框 2.3 Lisp对dcl的驱动2.4 对话框按钮实现拾取2.5 对话框加载图片2.5.1 幻灯片图片制作2.5.1 代码部分 一、文件操作 1.1 写…

TCP 三次握手 四次挥手以及滑动窗口

TCP 三次握手 简介&#xff1a; TCP 是一种面向连接的单播协议&#xff0c;在发送数据前&#xff0c;通信双方必须在彼此间建立一条连接。所谓的 “ 连接” &#xff0c;其实是客户端和服务器的内存里保存的一份关于对方的信息&#xff0c;如 IP 地址、端口号等。 TCP 可以…

CDSP认证:引领数据安全领域的权威之巅!

随着数据安全法和个人信息保护的施行&#xff0c;数据安全领域越来越受到重视。市场上涌现出众多数据安全相关的证书&#xff0c;而数据安全相关职位也成为了2023年最热门的职业之一。 &#x1f4a1;对于想要入门数据安全领域的小伙伴&#xff0c;我强烈推荐CDSP认证&#xff0…

【笔试常见编程题02】字符串中找出连续最长的数字串、数组中出现次数超过一半的数字、计算糖果、进制转换

1. 字符串中找出连续最长的数字串 读入一个字符串str&#xff0c;输出字符串str中的连续最长的数字串 输入描述 个测试输入包含1个测试用例&#xff0c;一个字符串str&#xff0c;长度不超过255。 输出描述 在一行内输出str中里连续最长的数字串。 示例 1 输入 abcd12345ed125s…

【前端】尚硅谷Node.js零基础视频教程笔记

文章目录 1.基础1.1. 基础命令1.2. 注意事项 2. Buffer&#xff08;缓冲器&#xff09;介绍与创建2.1 概念2.2 特点2.3 使用 3. fs模块(file system)3.1 文件写入3.2 文件读取 【前端目录贴】 参考视频: 尚硅谷Node.js零基础视频教程&#xff0c;nodejs新手到高手 1.基础 1.1.…

【MATLAB第92期】基于MATLAB的集成聚合多输入单输出回归预测方法(LSBoost、Bag)含自动优化超参数和特征敏感性分析功能

【MATLAB第92期】基于MATLAB的集成聚合多输入单输出回归预测方法&#xff08;LSBoost、Bag&#xff09;含自动优化超参数和特征敏感性分析功能 本文展示多种非常用多输入单输出回归预测模型效果。 注&#xff1a;每次运行数据训练集测试集为随机&#xff0c;故对比不严谨&…

搜维尔科技:【简报】元宇宙数字人赛道,《救食有道》!

在这个快速发展的数位时代里&#xff0c;本组相信透过制作融合虚拟人物 与 AI 智慧的创新宣传影片&#xff0c;定能为食物银行提高曝光率并让更多人 投身参与并落实减少食物浪费的行动&#xff0c;并与本组共同在生活中宣传食 物银行的理念 学校&#xff1a; 桃园市立中场商业高…

BTC的数据结构Merkle Tree和Hash pointer

比特币是一种基于区块链技术的加密数字货币&#xff0c;其底层数据结构被设计为分布式&#xff0c;去中心化的。它的核心数据结构是一个链式的区块&#xff0c;每个区块都包含了多笔交易记录和一个散列值。 比特币的底层数据结构使用了两个关键概念&#xff1a;hash pointer和…

【计算机网络】IP协议及动态路由算法

对应代码包传送门 IP协议及动态路由算法代码包及思科模拟器资料说明 相关文章 【计算机网络】中小型校园网构建与配置 【计算机网络】Socket通信编程与传输协议分析 【计算机网络】网络应用通信基本原理 目的&#xff1a; 1、掌握IP协议&#xff0c;IP分片&#xff0c;DH…

JCEF学习

JCEF重要概念 CEF CEF&#xff0c;全称Chromium Embedded Framework &#xff0c;它是基于Google Chromium的开源项目&#xff0c;它的目标是能够向第三方程序添加WEB浏览器功能&#xff0c;以及可以使用HTML、CSS和JS渲染界面。 CEF框架是由Marshall Greenblatt 在 2008 年创…

vue项目如何实现运行完项目就跳转到浏览器

在package.json中的启动命令中添加--open参数可以实现在Vue项目编译后自动打开浏览器的功能。 通过这样的设置&#xff0c;在运行npm run dev时&#xff0c;Vue项目编译完成后会自动打开默认浏览器并加载应用程序。

【问题解决】java-word转pdf踩坑

问题情境&#xff1a; 项目中采用word转pdf&#xff0c;最开始使用的pdf相关的apache的pdfbox和itextpdf&#xff0c;后面发现对于有图片背景的word转pdf的情景&#xff0c;word中的背景图会直接占用位置&#xff0c;导致正文不会正确落在背景图上。 解决方案&#xff1a; 采…

计算机网络——虚拟局域网+交换机基本配置实验

1.实验题目 虚拟局域网交换机基本配置实验 2.实验目的 1.了解交换机的作用 2.熟悉交换机的基本配置方法 3.熟悉Packet Tracer 7.0交换机模拟软件的使用 4.掌握在交换机上划分局域网&#xff0c;并且使用局域网与端口连接&#xff0c;检测信号传输 3.实验任务 1.了解交换…

C 变量

目录 1. C变量 2. C变量定义 2.1 变量初始化 2.2 C中的变量声明 3. C中的左值&#xff08;Lvalues&#xff09;和右值&#xff08;Rvalues&#xff09; 1. C变量 在C语言中&#xff0c;变量可以根据其类型分为以下几种基本类型&#xff1a; 整型变量&#xff1a;用…

蓝桥小白赛4 乘飞机 抽屉原理 枚举

&#x1f468;‍&#x1f3eb; 乘飞机 &#x1f437; 抽屉原理 import java.util.Scanner;public class Main {static int N 100010;static int[] a new int[N];public static void main(String[] args){Scanner sc new Scanner(System.in);int n sc.nextInt();int q s…

(南京观海微电子)——OLED驱动与调试

一、OLED DDIC分类 OLED DDIC的技术方向可以分为3类&#xff1a;带Ram【内存】的IC、Ram-less IC和TDDI【显示&触控集成的IC】 1、带Ram的OLED DDIC OLED DDIC有两个Ram&#xff0c;分别是Demura Ram和Display Ram。 1、带Ram的OLED DDIC 1-1&#xff09;Demura Ram&a…

取消Vscode在输入符号时自动补全

取消Vscode在输入符号时自动补全 取消Vscode在输入符号时自动补全问题演示解决方法 取消Vscode在输入符号时自动补全 问题演示 在此状态下输入/会直接自动补全, 如下图 笔者想要达到的效果为可以正常输入/而不进行补全, 如下图 解决方法 在设置->文本编辑器->建议, 取消…

C语言第十一弹---函数(下)

​ ✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】 函数 1、嵌套调用和链式访问 1.1、嵌套调用 1.2、链式访问 2、函数的声明和定义 2.1、单个文件 2.2、多个文件 2.3、static 和 extern 2.3.1、static…