RNN经典案例——构建人名分类器

news2024/11/25 22:59:38

RNN经典案例——人名分类器

  • 一、数据处理
    • 1.1 去掉语言中的重音标记
    • 1.2 读取数据
    • 1.3 构建人名类别与人名对应关系字典
    • 1.4 将人名转换为对应的onehot张量
  • 二、构建RNN模型
    • 2.1 构建传统RNN模型
    • 2.2 构建LSTM模型
    • 2.3 构建GRU模型
  • 三、构建训练函数并进行训练
    • 3.1 从输出结果中获得指定类别函数
    • 3.2 随机生成训练数据
    • 3.3 构建传统的RNN训练函数
    • 3.4 构建LSTM训练函数
    • 3.5 构建GRU训练函数
    • 3.6 构建时间计算函数
    • 3.7 构建训练过程的日志打印函数
    • 3.8 调用train函数, 进行模型的训练
  • 四、构建评估模型并预测
    • 4.1 构建传统RNN评估函数
    • 4.2 构建LSTM评估函数
    • 4.3 构建GRU评估函数
    • 4.4 构建预测函数

一、数据处理

from io import open
import glob
import os
import string 
import unicodedata
import random
import time
import math
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

1.1 去掉语言中的重音标记

# 获取常用字符数量和常用标点
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)
print("all_letters:",all_letters)
print("n_letters:",n_letters)

在这里插入图片描述

# 去掉一些语言中的重音标记
# 如: Ślusàrski ---> Slusarski
def unicodeToAscii(s):
    ascii = ''.join(
        # NFD会将每个字符分解为其基本字符和组合标记,Ś会拆分为音掉和S
        #'Mn'这类字符通常用于表示重音符号、音调符号等
        c for c in unicodedata.normalize('NFD',s) 
        if unicodedata.category(c) != 'Mn' and c in all_letters
    )
    return ascii

1.2 读取数据

# 读取数据
data_path = "./data/names/"

def readLines(filename):
    lines = open(filename,encoding='utf-8').read().strip().split('\n')
    return [unicodeToAscii(line) for line in lines]
# 调试
filename = data_path + "Chinese.txt"
lines = readLines(filename)
print(lines)

在这里插入图片描述

1.3 构建人名类别与人名对应关系字典

# 类别名字列表
category_lines = {}
# 类别名称
all_category = []

for filename in glob.glob(data_path + '*.txt'):
    category = os.path.splitext(os.path.basename(filename))[0]
    all_category.append(category)
    lines = readLines(filename)
    category_lines[category] = lines

# 查看类别总数
n_categories = len(all_category)
print("n_categories:",n_categories)

在这里插入图片描述

1.4 将人名转换为对应的onehot张量

def lineToTensor(line):
    tensor = torch.zeros(len(line),1,n_letters)
    for i,letter in enumerate(line):
        tensor[i][0][all_letters.find(letter)] = 1
    return tensor  
# 调试
line = 'cui'
line_tensor = lineToTensor(line)
line_tensor

在这里插入图片描述

二、构建RNN模型

2.1 构建传统RNN模型

class RNN(nn.Module):
    def __init__(self,input_size,hidden_size,output_size,num_layers=1):
        super(RNN,self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # 实例化RNN
        self.rnn = nn.RNN(input_size,hidden_size,num_layers)
        # RNN 层的输出转换为最终的输出特征
        self.linear = nn.Linear(hidden_size,output_size)
        # 将全连接层的输出特征转换为概率分布
        self.softmax = nn.LogSoftmax(dim=-1)
    def forward(self,input,hidden):
        # input 形状为1*n_letters需要变换为三维张量
        input = input.unsqueeze(0)
        rr,hn = self.rnn(input,hidden)
        return self.softmax(self.linear(rr)),hn
	# 定义初始化隐藏状态
    def initHidden(self):
        return torch.zeros(self.num_layers,1,self.hidden_size)

2.2 构建LSTM模型

class LSTM(nn.Module):
    def __init__(self,input_size,hidden_size,output_size,num_layers=1):
        super(LSTM,self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size,hidden_size,num_layers)
        self.linear = nn.Linear(hidden_size,output_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self,input,hidden,c):
        input = input.unsqueeze(0)
        rr,(hn,c) = self.lstm(input,(hidden,c))
        return self.softmax(self.linear(rr)),hn,c
        
    def initHidden(self):
        hidden = c = torch.zeros(self.num_layers,1,self.hidden_size)
        return hidden,c

2.3 构建GRU模型

class GRU(nn.Module):
    def __init__(self,input_size,hidden_size,output_size,num_layers=1):
        super(GRU,self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size,hidden_size,num_layers)
        self.linear = nn.Linear(hidden_size,output_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self,input,hidden):
        input = input.unsqueeze(0)
        rr,hn = self.gru(input,hidden)
        return self.softmax(self.linear(rr)),hn
    def initHidden(self):
        return torch.zeros(self.num_layers,1,self.hidden_size)
# 调用
# 实例化参数
input_size = n_letters
n_hidden = 128
output_size = n_categories

input = lineToTensor('B').squeeze(0)
hidden = c = torch.zeros(1,1,n_hidden)

rnn = RNN(n_letters,n_hidden,n_categories)
lstm = LSTM(n_letters,n_hidden,n_categories)
gru = GRU(n_letters,n_hidden,n_categories)

rnn_output, next_hidden = rnn(input, hidden)
print("rnn:", rnn_output)
lstm_output, next_hidden, c = lstm(input, hidden, c)
print("lstm:", lstm_output)
gru_output, next_hidden = gru(input, hidden)
print("gru:", gru_output) 

在这里插入图片描述

三、构建训练函数并进行训练

3.1 从输出结果中获得指定类别函数

def categoryFromOutput(output):
    # 从输出张量中返回最大的值和索引
    top_n,top_i = output.topk(1)
    category_i = top_i[0].item()
    # 获取对应语言类别, 返回语⾔类别和索引值
    return all_category[category_i],category_i
# 调试
category, category_i = categoryFromOutput(gru_output)
print("category:", category)
print("category_i:", category_i)

在这里插入图片描述

3.2 随机生成训练数据

def randomTrainingExample():
    # 随机获取一个类别
    category = random.choice(all_category)
    # 随机获取该类别中的名字
    line = random.choice(category_lines[category])
    # 将类别索引转换为tensor张量
    category_tensor = torch.tensor([all_category.index(category)],dtype=torch.long)
    # 对名字进行onehot编码
    line_tensor = lineToTensor(line)
    return category,line,category_tensor,line_tensor
# 调试
for i in range(10):
    category,line,category_tensor,line_tensor = randomTrainingExample()
    print('category =',category,'/ line =',line,'/ category_tensor =',category_tensor,'/ line_tensor =',line_tensor)

在这里插入图片描述

3.3 构建传统的RNN训练函数

# 定义损失函数
criterion = nn.NLLLoss()
# 设置学习率为0.005
learning_rate = 0.005
import torch.optim as optim
def trainRNN(category_tensor,line_tensor):
    # 实例化对象rnn初始化隐层张量
    hidden = rnn.initHidden()
    # 梯度清零
    optimizer = optim.SGD(rnn.parameters(),lr=0.01,momentum=0.9)
    optimizer.zero_grad()
    # 前向传播
    for i in range(line_tensor.size()[0]):
        # output 是 RNN 在每个时间步的输出。每个时间步的输出是一个隐藏状态,这些隐藏状态可以用于后续的处理,例如分类、回归等任务。
        # hidden是 RNN 在最后一个时间步的隐藏状态。这些隐藏状态可以用于捕获整个序列的信息,通常用于后续的处理,例如作为下一个层的输入。
        output,hidden = rnn(line_tensor[i],hidden)
    # 计算损失
    loss = criterion(output.squeeze(0),category_tensor)
    # 反向传播
    loss.backward()
    optimizer.step()
    # 更新模型中的参数
    #for p in rnn.parameters():
        #p.data.add_(-learning_rate,p.grad.data)

    return output,loss.item()

3.4 构建LSTM训练函数

def trainLSTM(category_tensor,line_tensor):
    hidden,c = lstm.initHidden()
    lstm.zero_grad()
    for i in range(line_tensor.size()[0]):
        output,hidden,c = lstm(line_tensor[i],hidden,c)
    loss = criterion(output.squeeze(0),category_tensor)
    loss.backward()
    for p in lstm.parameters():
        p.data.add_(-learning_rate,p.grad.data)
    return output,loss.item()

3.5 构建GRU训练函数

def trainGRU(category_tensor,line_tensor):
    hidden = gru.initHidden()
    gru.zero_grad()
    for i in range(line_tensor.size()[0]):
        output,hidden = gru(line_tensor[i],hidden)
    loss = criterion(output.squeeze(0),category_tensor)
    loss.backward()
    for p in gru.parameters():
        p.data.add_(-learning_rate,p.grad.data)
    return output,loss.item()

3.6 构建时间计算函数

# 获取每次打印的训练耗时
def timeSince(since):
    # 获得当前时间
    now = time.time()
    # 获取时间差
    s = now - since
    # 将秒转换为分
    m = s // 60
    # 计算不够1分钟的秒数
    s -= m * 60
    return '%dm %ds' % (m,s)

3.7 构建训练过程的日志打印函数

# 设置训练迭代次数
n_iters= 1000
# 设置结果的打印间隔
print_every = 50
# 设置绘制损失曲线上的打印间隔
plot_every = 10
def train(train_typr_fn):
    # 保存每个间隔的损失函数
    all_losses = []
    # 获得训练开始的时间戳
    start = time.time()
    # 设置当前间隔损失为0
    current_loss = 0
    # 循环训练
    for iter in range(1,n_iters+1):
        category,line,category_tensor,line_tensor = randomTrainingExample()
        output,loss = train_typr_fn(category_tensor,line_tensor)
        # 计算打印间隔的总损失
        current_loss += loss
        
        if iter % print_every == 0:
            # 获得预测的类别和索引
            guess,guess_i = categoryFromOutput(output)
            if guess == category:
                correct = '✓'
            else:
                correct = '✗(%s)' % category
            print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters *100, timeSince(start), loss, line, guess, correct))
        
        if iter % plot_every == 0:
            all_losses.append(current_loss / plot_every)
            current_loss = 0 
            
    return all_losses,int(time.time()-start)       

3.8 调用train函数, 进行模型的训练

# 调⽤train函数, 分别进⾏RNN, LSTM, GRU模型的训练
all_losses1, period1 = train(trainRNN)
all_losses2, period2 = train(trainLSTM)
all_losses3, period3 = train(trainGRU)

在这里插入图片描述

# 创建画布0
plt.figure(0)
# 绘制损失对⽐曲线
plt.plot(all_losses1, label="RNN")
plt.plot(all_losses2, color="red", label="LSTM")
plt.plot(all_losses3, color="orange", label="GRU")
plt.legend(loc='upper left')

# 创建画布1
plt.figure(1)
x_data=["RNN", "LSTM", "GRU"]
y_data = [period1, period2, period3]
# 绘制训练耗时对⽐柱状图
plt.bar(range(len(x_data)), y_data, tick_label=x_data)

在这里插入图片描述
在这里插入图片描述

四、构建评估模型并预测

4.1 构建传统RNN评估函数

# 构建传统RNN评估函数
def evaluateRNN(line_tensor):
    hidden = rnn.initHidden()
    for i in range(line_tensor.size()[0]):
        output,hidden = rnn(line_tensor[i],hidden)
    return output.squeeze(0)

4.2 构建LSTM评估函数

# 构建LSTM评估函数
def evaluateLSTM(line_tensor):
    hidden,c = lstm.initHidden()
    for i in range(line_tensor.size()[0]):
        output,hidden,c = lstm(line_tensor[i],hidden,c)
    return output.squeeze(0)

4.3 构建GRU评估函数

# 构建GRU评估函数
def evaluateGRU(line_tensor):
    hidden = gru.initHidden()
    for i in range(line_tensor.size()[0]):
        output,hidden = gru(line_tensor[i],hidden)
    return output.squeeze(0)
# 调试
line = "Bai"
line_tensor = lineToTensor(line)

rnn_output = evaluateRNN(line_tensor)
lstm_output = evaluateLSTM(line_tensor)
gru_output = evaluateGRU(line_tensor)
print("rnn_output:", rnn_output)
print("gru_output:", lstm_output)
print("gru_output:", gru_output)

在这里插入图片描述

4.4 构建预测函数

def predict(input_line,evaluate,n_predictions=3):
    print('\n> %s' % input_line)

    with torch.no_grad():
        output = evaluate(lineToTensor(input_line))
        topv,topi = output.topk(n_predictions,1,True)
        predictions = []
        for i in range(n_predictions):
            # 从topv中取出的output值
            value = topv[0][i].item()
            # 取出索引并找到对应的类别
            category_index = topi[0][i].item()
            # 打印ouput的值, 和对应的类别
            print('(%.2f) %s' % (value, all_category[category_index]))
            # 将结果装进predictions中
            predictions.append([value, all_category[category_index]])
for evaluate_fn in [evaluateRNN, evaluateLSTM, evaluateGRU]:
    print("-"*18)
    predict('Dovesky', evaluate_fn)
    predict('Jackson', evaluate_fn)
    predict('Satoshi', evaluate_fn)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2190470.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

字符和ACSII编码

1.字符和ASCII编码 C语言中char类型,专门用来创建字符变量,字符放在单引号中 char ch a ASCII码表 c官网,最全de c官网链接 数字字符0~9对应ASCII码十进制48~57 字符 大写字母A~Z对应ASCII码十进制65~90 字符 小写字母a~z对应ASCII码…

EtherCAT 转 EtherNet/IP, EtherCAT/Ethernet/IP/Profinet/ModbusTCP协议互转工业串口网关

EtherCAT/Ethernet/IP/Profinet/ModbusTCP协议互转工业串口网关https://item.taobao.com/item.htm?ftt&id822721028899协议转换通信网关 EtherCAT 转 EtherNet/IP GW系列型号 MS-GW12 概述 MS-GW12 是 EtherCAT 和 EtherNet/IP 协议转换网关,为用户提供两…

突发!Meta重磅发布Movie Gen入局视频生成赛道!

引言 Meta于2024年10月4日首次推出 Meta Movie Gen,号称是迄今为止最先进的媒体基础模型。Movie Gen 由 Meta 的 AI 研究团队开发,在一系列功能上获取最先进的效果,包括:文生视频、创建个性化视频、精准的视频编辑和音频创作。 …

递归--C语言

1 递归定义 函数自己调用自己的过程,称为递归。 2 递归的必要条件 1.必须要有终止条件。达到条件就停止递归,退出函数。2.每次调用自己都要越来越接近这个终止条件。 因此写函数的时候,也分两部分 第一部分:写终止条件&#x…

点击按钮提示气泡信息(Toast)

演示效果&#xff1a; 目录结构&#xff1a; activity_main.xml(布局文件)代码&#xff1a; <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"xmlns:app"http:…

【第三版 系统集成项目管理工程师】第15章 组织保障

持续更新。。。。。。。。。。。。。。。 【第三版】第十五章 组织保障 15.1信息和文档管理15.1.1 信息和文档1.信息系统信息-P5462.信息系统文档-P546 15.1.2 信息(文档)管理规则和方法1.信息(文档)编制规范-P5472.信息(文档)定级保护-P5483.信息(文档)配置管理-P549练习 15.…

38 文件包含(标准库头文件、自定义头文件)、相对路径与绝对路径、条件编译(#if、#ifdef、#if define、#ifndef)

目录 1 文件包含 1.1 #include 指令 1.2 包含标准库头文件 1.3 包含自定义头文件 1.3.1 使用相对路径 1.3.2 使用绝对路径 2 条件编译 2.1 #if … #endif 2.1.1 语法格式 2.1.2 功能说明 2.1.3 流程分析 2.1.4 案例演示&#xff1a;#if 0 ... #endif 2.1.5 案例演…

关于懒惰学习与渴求学习的一份介绍

在这篇文章中&#xff0c;我将介绍些懒惰学习与渴求学习的算法例子&#xff0c;会介绍其概念、优缺点以及其python的运用。 一、渴求学习 1.1概念 渴求学习&#xff08;Eager Learning&#xff09;是指在训练阶段构建出复杂的模型&#xff0c;然后在预测阶段运用这个构建出的…

分布式锁--redission 最佳实践!

我们知道如果我们的项目服务不只是一个实例的时候&#xff0c;单体锁就不再适用&#xff0c;而我们自己去用redis实现分布式锁的话&#xff0c;会有比如锁误删、超时释放、锁的重入、失败重试、Redis主从一致性等等一系列的问题需要自己解决。 当然&#xff0c;上述问题并非无…

3dsMax合并FBX的时候相同的节点会被合并(重命名解决),3Ds MAX创建空物体(虚拟对象或者点)

3dsMax合并FBX的时候相同的节点会被合并 3dsamax的文档&#xff0c;但是并没有说FBX的合并如何处理 https://help.autodesk.com/view/3DSMAX/2024/CHS/?guidGUID-98146EB8-436F-4954-8682-C57D4E53262A模型节点信息&#xff0c;yase&#xff0c;Points&#xff0c;Mesh 都是点…

【优选算法】(第二十一篇)

目录 外观数列(medium) 题目解析 讲解算法原理 编写代码 数⻘蛙&#xff08;medium&#xff09; 题目解析 讲解算法原理 编写代码 外观数列(medium) 题目解析 1.题目链接&#xff1a;. - 力扣&#xff08;LeetCode&#xff09; 2.题目描述 给定⼀个正整数n&#xff0…

openpnp - 坐标文件中的元件0角度如果和编带规定的角度不一样,需要调整贴片任务中的元件旋转角度

文章目录 openpnp - 坐标文件中的元件0角度如果和编带规定的角度不一样&#xff0c;需要调整贴片任务中的元件旋转角度笔记查看自己图纸中的封装的0角度方法贴片任务的角度值范围编带规定的0角度根据编带规定的元件0角度来调整贴片的元件旋转角度如果是托盘飞达备注备注END ope…

电脑失声,一招搞定

早已习惯了Edge浏览器的“大声朗读”功能&#xff0c;今天值班&#xff0c;值班室用的两台电脑只配有耳机&#xff0c;没有音箱&#xff0c;顿时感觉不适。 先找了一个带功放的老音箱&#xff0c;发现少了电箱到功放的音频线。 一顿搜索&#xff0c;在找到音频线的同时&#…

2024年计算机视觉与艺术研讨会(CVA 2024)

目录 基本信息 大会简介 征稿主题 会议议程 参会方式 基本信息 大会官网&#xff1a;www.icadi.net&#xff08;点击了解参会投稿等信息&#xff09; 大会时间&#xff1a;2024年11月29-12月1日 大会地点&#xff1a;中国-天津 大会简介 2024年计算机视觉与艺术国际学术…

基于SpringBoot+Vue+MySQL的装修公司管理系统

系统展示 管理员后台界面 员工后台界面 系统背景 随着信息技术的快速发展&#xff0c;装修行业正面临数字化转型的关键时刻。传统的装修管理方式存在信息管理混乱、出错率高、信息安全性差等问题&#xff0c;已无法满足现代市场的需求。因此&#xff0c;开发一套高效、便捷的装…

仿《11773手游》源码/手机游戏软件下载门户网站模板/帝国CMS 7.5

帝国CMS 7.5仿《11773手游》源码&#xff0c;手机游戏软件下载门户网站模板。简洁漂亮的手游下载网站模板&#xff0c;采用帝国CMS7.5核心&#xff0c;同步刷新M端。 该模板带同步生成插件&#xff0c;整站干净大气界面漂亮&#xff0c;简单不失简约&#xff0c;模板中的典范&…

css 简单网页布局(一)

1. 三种布局方式 1.1 标准流 1.2 浮动的使用 1.3 简述浮动 1.3.1 浮动三大特性 <style>.out {border: 1px red solid;width: 1000px;height: 500px;}.one {background-color: aquamarine;width: 200px;height: 100px;}.two {background-color: blueviolet;width: 200px;h…

Chromium 中JavaScript Fetch API接口c++代码实现(二)

Chromium 中JavaScript Fetch API接口c代码实现&#xff08;一&#xff09;-CSDN博客 接着上一篇继续介绍调用&#xff0c;上函数堆栈。 1、打开http://192.168.8.1/chfs/shared/test/test02.html 此标签进程ID12484&#xff0c; 2、打开vs附加上此进程ID12484 3、点击页面测…

华为 HCIP-Datacom H12-821 题库 (31)

&#x1f423;博客最下方微信公众号回复题库,领取题库和教学资源 &#x1f424;诚挚欢迎IT交流有兴趣的公众号回复交流群 &#x1f998;公众号会持续更新网络小知识&#x1f63c; 1. 默认情况下&#xff0c;IS-IS Level-1-2 路由器会将 Level-2 区域的明细路由信息发布到Lev…

YOLOv8 基于NCNN的安卓部署

YOLOv8 NCNN安卓部署 前两节我们依次介绍了基于YOLOv8的剪枝和蒸馏 本节将上一节得到的蒸馏模型导出NCNN&#xff0c;并部署到安卓。 NCNN 导出 YOLOv8项目中提供了NCNN导出的接口&#xff0c;但是这个模型放到ncnn-android-yolov8项目中你会发现更换模型后app会闪退。原因…