【PyTorch】循环神经网络

news2024/11/18 12:47:10

循环神经网络是什么

Recurrent Neural Networks
RNN:循环神经网络

  • 处理不定长输入的模型
  • 常用于NLP及时间序列任务(输入数据具有前后关系

RNN网络结构

参考资料
Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs
Understanding LSTM Networks
在这里插入图片描述

RNN实现人名分类

问题定义:输入任意长度姓名(字符串),输出姓名来自哪一个国家(18类分类任务)
数据: https://download.pytorch.org/tutorial/data.zip
Jackie Chan —— 成龙
Jay Chou —— 周杰伦
Tingsong Yue —— 余霆嵩

RNN如何处理不定长输入

思考:计算机如何实现不定长字符串分类向量的映射?
Chou(字符串)→ RNN →Chinese(分类类别)

  1. 单词字符 → 数字
  2. 数字 → model
  3. 下一个字符 → 数字 → model
  4. 最后一个字符 → 数字 → model → 分类向量
# 伪代码
# Chou(字符串)→ RNN →Chinese(分类类别)
for string in [C, h, o, u]:
	1. one-hot:string → [0,0, ...., 1, ..., 0]	# 首先把每个字母转换成编码
	2. y, h = model([0,0, ...., 1, ..., 0], h)		# h就是隐藏层的状态信息

xt:时刻t的输入,shape = (1, 57)
st:时刻t的状态值,shape=(1, 128)
ot:时刻t的输出值,shape=(1, 18)
U:linear层的权重参数, shape = (128, 57)
W:linear层的权重参数, shape = (128, 128)
V:linear层的权重参数, shape = (18, 128)

代码如下:

# -*- coding: utf-8 -*-
"""
# @file name  : rnn_demo.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2019-12-09
# @brief      : rnn人名分类
"""
from io import open
import glob
import unicodedata
import string
import math
import os
import time
import torch.nn as nn
import torch
import random
import matplotlib.pyplot as plt
import torch.utils.data
import sys
# 获取路径
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)

from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# 选择运行设备
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")


# Read a file and split into lines
def readLines(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [unicodeToAscii(line) for line in lines]


def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters)


# Find letter index from all_letters, e.g. "a" = 0
def letterToIndex(letter):
    return all_letters.find(letter)


# Just for demonstration, turn a letter into a <1 x n_letters> Tensor
def letterToTensor(letter):
    tensor = torch.zeros(1, n_letters)
    tensor[0][letterToIndex(letter)] = 1
    return tensor


# Turn a line into a <line_length x 1 x n_letters>,
# or an array of one-hot letter vectors
def lineToTensor(line):
    tensor = torch.zeros(len(line), 1, n_letters)
    for li, letter in enumerate(line):
        tensor[li][0][letterToIndex(letter)] = 1
    return tensor


def categoryFromOutput(output):
    top_n, top_i = output.topk(1)
    category_i = top_i[0].item()
    return all_categories[category_i], category_i


def randomChoice(l):
    return l[random.randint(0, len(l) - 1)]


def randomTrainingExample():
    category = randomChoice(all_categories)                 # 选类别
    line = randomChoice(category_lines[category])           # 选一个样本
    category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)
    line_tensor = lineToTensor(line)    # str to one-hot
    return category, line, category_tensor, line_tensor


def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


# Just return an output given a line
def evaluate(line_tensor):
    hidden = rnn.initHidden()

    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)

    return output


def predict(input_line, n_predictions=3):
    print('\n> %s' % input_line)
    with torch.no_grad():
        output = evaluate(lineToTensor(input_line))

        # Get top N categories
        topv, topi = output.topk(n_predictions, 1, True)

        for i in range(n_predictions):
            value = topv[0][i].item()
            category_index = topi[0][i].item()
            print('(%.2f) %s' % (value, all_categories[category_index]))


def get_lr(iter, learning_rate):
    lr_iter = learning_rate if iter < n_iters else learning_rate*0.1
    return lr_iter

# 定义网络结构
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.u = nn.Linear(input_size, hidden_size)
        self.w = nn.Linear(hidden_size, hidden_size)
        self.v = nn.Linear(hidden_size, output_size)

        self.tanh = nn.Tanh()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, hidden):

        u_x = self.u(inputs)

        hidden = self.w(hidden)
        hidden = self.tanh(hidden + u_x)

        output = self.softmax(self.v(hidden))

        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)


def train(category_tensor, line_tensor):
    hidden = rnn.initHidden()

    rnn.zero_grad()

    line_tensor = line_tensor.to(device)
    hidden = hidden.to(device)
    category_tensor = category_tensor.to(device)

    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)

    loss = criterion(output, category_tensor)
    loss.backward()

    # Add parameters' gradients to their values, multiplied by learning rate
    for p in rnn.parameters():
        # p.data.add_(-learning_rate, p.grad.data) # 该方法已经被弃用
        p.data.add_(p.grad.data, alpha=-learning_rate)

    return output, loss.item()


if __name__ == "__main__":
    print(device)

    # config
    data_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rnn_data", "names"))
    if not os.path.exists(data_dir):
        raise Exception("\n{} 不存在,请下载 08-05-数据-20200724.zip  放到\n{}  下,并解压即可".format(
            data_dir, os.path.dirname(data_dir)))

    path_txt = os.path.join(data_dir, "*.txt")
    all_letters = string.ascii_letters + " .,;'"
    n_letters = len(all_letters)    # 52 + 5 字符总数
    print_every = 5000
    plot_every = 5000
    learning_rate = 0.005
    n_iters = 200000

    # step 1 data
    # Build the category_lines dictionary, a list of names per language
    category_lines = {}
    all_categories = []
    for filename in glob.glob(path_txt):
        category = os.path.splitext(os.path.basename(filename))[0]
        all_categories.append(category)
        lines = readLines(filename)
        category_lines[category] = lines

    n_categories = len(all_categories)

    # step 2 model
    n_hidden = 128
    # rnn = RNN(n_letters, n_hidden, n_categories)
    rnn = RNN(n_letters, n_hidden, n_categories)

    rnn.to(device)

    # step 3 loss
    criterion = nn.NLLLoss()

    # step 4 optimize by hand

    # step 5 iteration
    current_loss = 0
    all_losses = []
    start = time.time()
    for iter in range(1, n_iters + 1):
        # sample
        category, line, category_tensor, line_tensor = randomTrainingExample()

        # training
        output, loss = train(category_tensor, line_tensor)

        current_loss += loss

        # Print iter number, loss, name and guess
        if iter % print_every == 0:
            guess, guess_i = categoryFromOutput(output)
            correct = '✓' if guess == category else '✗ (%s)' % category
            print('Iter: {:<7} time: {:>8s} loss: {:.4f} name: {:>10s}  pred: {:>8s} label: {:>8s}'.format(
                iter, timeSince(start), loss, line, guess, correct))

        # Add current loss avg to list of losses
        if iter % plot_every == 0:
            all_losses.append(current_loss / plot_every)
            current_loss = 0

path_model = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rnn_state_dict.pkl"))
if not os.path.exists(path_model):
    raise Exception("\n{} 不存在,请下载 08-05-数据-20200724.zip  放到\n{}  下,并解压即可".format(
        path_model, os.path.dirname(path_model)))
torch.save(rnn.state_dict(), path_model)
plt.plot(all_losses)
plt.show()

predict('Yue Tingsong')
predict('Yue tingsong')
predict('yutingsong')

predict('test your name')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2177880.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

古代帝王与啤酒的不解之缘

在历史的长河中&#xff0c;古代帝王们的生活总是充满神秘与传奇。他们掌握着无上的权力&#xff0c;享受着世间的荣华富贵。而在这些帝王的日常生活中&#xff0c;有一种饮品始终伴随着他们&#xff0c;那便是精酿啤酒。今天&#xff0c;我们就来探寻古代帝王与啤酒之间的不解…

渗透测试实战—教育攻防演练信息收集

免责声明&#xff1a;文章来源于真实渗透测试&#xff0c;已获得授权&#xff0c;且关键信息已经打码处理&#xff0c;请勿利用文章内的相关技术从事非法测试&#xff0c;由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本…

【Python】数据可视化之热力图

热力图&#xff08;Heatmap&#xff09;是一种通过颜色深浅来展示数据分布、密度和强度等信息的可视化图表。它通过对色块着色来反映数据特征&#xff0c;使用户能够直观地理解数据模式&#xff0c;发现规律&#xff0c;并作出决策。 目录 基本原理 sns.heatmap 代码实现 基…

「OC」探索 KVC 的基础与应用

「OC」KVC的初步学习 文章目录 「OC」KVC的初步学习前言介绍KVC的相关方法key和keyPath的区别KVC的工作原理KVO的setValue:forKey原理KVO的ValueforKey原理 在集合之中KVC的用法1. mutableArrayValueForKey: 和 mutableArrayValueForKeyPath:2. mutableSetValueForKey: 和 muta…

【源码+文档+调试讲解】无人超市系统python

摘 要 随着科技的不断进步&#xff0c;无人超市成为了零售行业的新兴趋势。无人超市管理系统是支撑这一新型商业模式的关键软件基础设施。该系统采用python技术和MySQL数据库技术以及Django框架进行开发。通过高度自动化和智能化的方式&#xff0c;允许消费者在没有收银员的情…

WordPress LearnPress插件 SQL注入复现(CVE-2024-8522)

0x01 产品描述&#xff1a; LearnPress 是一款功能强大的 WordPress LMS&#xff08;学习管理系统&#xff09;插件&#xff0c;适用于创建和销售在线课程。凭借其直观的界面和丰富的功能&#xff0c;无论您是否具备编程背景&#xff0c;都能轻松搭建起在线教育网站。学会如何使…

【若依RuoYi-Vue | 项目实战】帝可得后台管理系统(三)

文章目录 一、商品管理1、需求说明2、生成基础代码&#xff08;1&#xff09;创建目录菜单&#xff08;2&#xff09;配置代码生成信息&#xff08;3&#xff09;下载代码并导入项目 3、商品类型改造&#xff08;1&#xff09;基础页面 4、商品管理改造&#xff08;1&#xff0…

【YOLO目标检测车牌数据集】共10000张、已标注txt格式、有训练好的yolov5的模型

目录 说明图片示例 说明 数据集格式&#xff1a;YOLO格式 图片数量&#xff1a;10000&#xff08;2000张绿牌、8000张蓝牌&#xff09; 标注数量(txt文件个数)&#xff1a;10000 标注类别数&#xff1a;1 标注类别名称&#xff1a;licence 数据集下载&#xff1a;车牌数据…

docker 部署 Seatunnel 和 Seatunnel Web

docker 部署 Seatunnel 和 Seatunnel Web 说明&#xff1a; 部署方式前置条件&#xff0c;已经在宿主机上运行成功运行文件采用挂载宿主机目录的方式部署SeaTunnel Engine 采用的是混合模式集群 编写Dockerfile并打包镜像 Seatunnel FROM openjdk:8 WORKDIR /opt/seatunne…

在github上,如何只下载选中的文件?

GitHub官方不直接支持下载子目录&#xff0c;但可以使用特定的第三方工具或脚本来实现这一需求。 总而言之一句话&#xff1a;需要下载插件&#xff01;&#xff01;&#xff01;具体实操步骤如下&#xff1a; 1.打开谷歌浏览器右上角的管理扩展程序&#xff1a; 2.搜索GitZi…

NLP任务之预测最后一个词

目录 1.加载预训练模型 2 从本地加载数据集 3.数据集处理 4.下游任务模型 5.测试代码 6.训练代码 7.保存训练好的模型 8. 加载 保存的模型 1.加载预训练模型 #加载预训练模型 from transformers import AutoTokenizer#预训练模型&#xff1a;distilgpt2 #use_fast…

《无机杀手》制作团队选择Blender的原因分析

《无机杀手》&#xff08;Murder Drones&#xff09;是一部备受欢迎的动画短片&#xff0c;其制作团队选择使用Blender软件进行制作&#xff0c;这一选择背后有着多方面的原因。【成都渲染101--blender渲染农场邀请码6666提供文案参考】 开源且免费 Blender是一个开源且免费的…

什么是数字化转型?数字化转型对企业有哪些优势?

一、什么是数字化转型&#xff1f; 定义&#xff1a; 数字化转型是指企业或组织将传统业务转化为数字化业务&#xff0c;利用人工智能、大数据、云计算、区块链、5G等数字技术提升业务效率和质量的过程。通俗来说&#xff0c;就是将数字技术应用到企业的各个方面&#xff0c;…

贝锐蒲公英网盘首发,秒建私有云,高速远程访问

虽然公共网盘带来了不少便利&#xff0c;但是大家对隐私泄露和重要数据泄密的担忧也随之增加。如果想要确保数据安全&#xff0c;自建私有云似乎是一条出路&#xff0c;然而面对搭建私有云的复杂步骤&#xff0c;许多人感到力不从心&#xff0c;NAS设备的成本也往往让人望而却步…

【MySQL】数据库中的内置函数

W...Y的主页 &#x1f60a; 代码仓库分享 &#x1f495; 目录 函数 日期函数 字符串函数 数学函数 ​编辑 其它函数 MySQL数据库提供了大量的内置函数&#xff0c;这些函数可以帮助你执行各种操作&#xff0c;如字符串处理、数值计算、日期和时间处理等&#xff01; 函数…

云计算Openstack Keystone

OpenStack Keystone是OpenStack平台中的一个核心组件&#xff0c;主要负责身份认证和授权管理服务。以下是关于OpenStack Keystone的详细介绍&#xff1a; 一、作用 身份认证&#xff1a;Keystone为OpenStack平台提供统一的身份认证服务&#xff0c;管理所有用户&#xff08;…

ElasticSearch系列:【Win10环境(版本8.11.1) 】elasticsearch+kibana纪实

一、环境 安装环境&#xff1a;win10 JDK&#xff1a;1.8 elasticsearch&#xff1a;8.11.1 kibana&#xff1a;8.11.1 下载地址1&#xff08;elasticsearchkibana&#xff09;&#xff1a;Past Releases of Elastic Stack Software | Elastic i下载地址2&#xff08;k分…

RS HMP4040 直流电源

R&S HMP404 直流电源 苏州新利通仪器仪表 产品综述 单台仪器中最多四个通道 R&SHMP4000 直流电源具有三个或四个输出通道&#xff0c;每个通道的输出电流高达 10 A&#xff0c;主要设计用于工业应用&#xff0c;例如&#xff1a; -生产测试 -维护 -工程实验室 这些…

关于git分支冲突问题

什么是冲突 在Git中&#xff0c;冲突是指两个或多个开发者对同一文件统一部份进行了不同的修改&#xff0c;并且在合并这些修改时&#xff0c;Git无法自动确定应该采用哪种修改而产生的情况。 分支冲突 如何出现并解决 在一个版本时&#xff0c;有一个master分支&#xff0c…

JAVA甜蜜升级情侣专属扭蛋机游戏系统小程序源码

甜蜜升级&#xff01;情侣专属扭蛋机游戏系统&#xff0c;让爱更有趣&#x1f496; &#x1f389; 开篇&#xff1a;爱的游戏新玩法 在爱情的旅途中&#xff0c;我们总在寻找那些能让彼此心跳加速、笑容满面的瞬间。现在&#xff0c;“甜蜜升级情侣专属扭蛋机游戏系统”为你和…