基于BILSTM及其他RNN序列模型的人名分类器

news2024/11/24 16:11:57

数据集Kaggle链接

NameNationalLanguage | Kaggle

数据集分布:

第一列为人名,第二列为国家标签

代码开源地址

Kaggle代码链接

https://www.kaggle.com/code/houjijin/name-nationality-classification

Gitee码云链接

人名国籍分类 Name Nation classification: using BILSTM to predict individual's nationality by their name

github链接

GitHub - Foxbabe1q/Name-Nation-classification: Use BILSTM to do the classification of individuals by their names

RNN序列模型类编写

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F



device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_layers = num_layers
        self.output_size = 18
        self.rnn = nn.RNN(input_size, hidden_size, num_layers = num_layers, batch_first=True)
        self.fc = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, x, hidden):
        output, hidden = self.rnn(x, hidden)
        output = output[:, -1, :]
        output = self.fc(output)
        return output, hidden

    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        return hidden

class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(SimpleLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_layers = num_layers
        self.output_size = 18
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, x, hidden, c):
        output, (hidden, c) = self.rnn(x, (hidden, c))
        output = output[:, -1, :]
        output = self.fc(output)
        return output, hidden, c

    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        return hidden, c0


class SimpleBILSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(SimpleBILSTM, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_layers = num_layers
        self.output_size = 18
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(self.hidden_size*2, self.output_size)

    def forward(self, x, hidden, c):
        output, (hidden, c) = self.rnn(x, (hidden, c))
        output = output[:, -1, :]
        output = self.fc(output)
        return output, hidden, c

    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers*2, batch_size, self.hidden_size, device=device)
        c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size, device=device)
        return hidden, c0



class SimpleGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(SimpleGRU, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_layers = num_layers
        self.output_size = 18
        self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, x, hidden):
        output, hidden = self.rnn(x, hidden)
        output = output[:, -1, :]
        output = self.fc(output)
        return output, hidden

    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        return hidden

注意这里BILSTM类中,由于双向lstm会使用两个lstm模型分别处理前向序列和反向序列,所以在初始化隐藏层和记忆细胞层的时候要设置num_layers为2.

导包

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from RNN_Series1 import SimpleRNN, SimpleLSTM, SimpleGRU, SimpleBILSTM
from torch.utils.data import Dataset, DataLoader
import string
from sklearn.preprocessing import LabelEncoder
import time

字符序列及device定义

letters = string.ascii_letters + " .,;'"
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')

数据读取及标签列编码

def load_data():
    data = pd.read_csv('name_classfication.txt', sep='\t', names = ['name', 'country'])
    X = data[['name']]
    lb = LabelEncoder()
    y = data['country']
    y = lb.fit_transform(y)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
    return X_train, X_test, y_train, y_test

数据集定义

class create_dataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.length = len(self.X)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        data = torch.zeros(10, len(letters), dtype = torch.float, device=device)
        for i, letter in enumerate(self.X.iloc[idx,0]):
            if i==10:
                break
            data[i,letters.index(letter)] = 1
        label = torch.tensor(self.y[idx], dtype = torch.long, device=device)
        return data, label

这里使用字符序列进行独热编码,并且由于名字长度不一,所以经过序列长度分布,选取了10作为截断长度.

使用RNN训练

def train_rnn():
    X_train, X_test, y_train, y_test = load_data()
    criterion = nn.CrossEntropyLoss(reduction='sum')
    loss_list = []
    acc_list = []
    val_acc_list = []
    val_loss_list = []
    epochs = 10
    my_dataset = create_dataset(X_train, y_train)
    val_dataset = create_dataset(X_test, y_test)
    my_dataloader = DataLoader(my_dataset, batch_size=64, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=len(y_test), shuffle=True)
    my_rnn = SimpleRNN(len(letters), 128,2)
    my_rnn.to(device)
    optimizer = torch.optim.Adam(my_rnn.parameters(), lr=0.001)
    start_time = time.time()

    for epoch in range(epochs):
        my_rnn.train()
        total_loss = 0
        total_acc = 0
        total_sample = 0
        for i, (X,y) in enumerate(my_dataloader):
            output, hidden = my_rnn(X, my_rnn.init_hidden(batch_size=len(y)))
            total_sample += len(y)
            loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            prediction = output.argmax(dim=1)
            acc_num = torch.sum(prediction == y).item()
            total_acc += acc_num
        loss_list.append(total_loss/total_sample)
        acc_list.append(total_acc/total_sample)

        my_rnn.eval()
        with torch.no_grad():
            for i, (X_val, y_val) in enumerate(val_dataloader):
                output, hidden = my_rnn(X_val, my_rnn.init_hidden(batch_size=len(y_test)))
                loss = criterion(output, y_val)
                prediction = output.argmax(dim=1)
                acc_num = torch.sum(prediction == y_val).item()
                val_acc_list.append(acc_num/len(y_val))
                val_loss_list.append(loss.item()/len(y_val))
                print(f'epoch: {epoch+1}, train_loss: {total_loss/total_sample:.2f}, train_acc: {total_acc/total_sample:.2f}, val_loss: {loss.item()/len(y_val):.2f}, val_acc: {acc_num/len(y_val):.2f}, time: {time.time() - start_time : .2f}')
    torch.save(my_rnn.state_dict(), 'rnn.pt')
    plt.plot(np.arange(1,11),loss_list,label = 'Training Loss')
    plt.plot(np.arange(1,11),val_loss_list,label = 'Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.xticks(np.arange(1,11))
    plt.title('Loss')
    plt.legend()
    plt.savefig('logg.png')
    plt.show()
    plt.plot(np.arange(1,11),acc_list,label = 'Training Accuracy')
    plt.plot(np.arange(1,11),val_acc_list,label = 'Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.xticks(np.arange(1,11))
    plt.title('Accuracy')
    plt.legend()
    plt.savefig('accuracy.png')
    plt.show()

使用BILSTM训练

def train_bilstm():
    X_train, X_test, y_train, y_test = load_data()
    criterion = nn.CrossEntropyLoss(reduction='sum')
    loss_list = []
    acc_list = []
    val_acc_list = []
    val_loss_list = []
    epochs = 10
    my_dataset = create_dataset(X_train, y_train)
    val_dataset = create_dataset(X_test, y_test)
    my_dataloader = DataLoader(my_dataset, batch_size=64, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=len(y_test), shuffle=True)
    my_rnn = SimpleBILSTM(len(letters), 128,2)
    my_rnn.to(device)
    optimizer = torch.optim.Adam(my_rnn.parameters(), lr=0.001)
    start_time = time.time()

    for epoch in range(epochs):
        my_rnn.train()
        total_loss = 0
        total_acc = 0
        total_sample = 0
        for i, (X,y) in enumerate(my_dataloader):
            hidden,c0 = my_rnn.init_hidden(batch_size=len(y))
            output, hidden,c = my_rnn(X, hidden,c0)
            total_sample += len(y)
            loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            prediction = output.argmax(dim=1)
            acc_num = torch.sum(prediction == y).item()
            total_acc += acc_num
        loss_list.append(total_loss/total_sample)
        acc_list.append(total_acc/total_sample)


        my_rnn.eval()
        with torch.no_grad():
            for i, (X_val, y_val) in enumerate(val_dataloader):
                hidden, c0 = my_rnn.init_hidden(batch_size=len(y_val))
                output, hidden ,c= my_rnn(X_val, hidden,c0)
                loss = criterion(output, y_val)
                prediction = output.argmax(dim=1)
                acc_num = torch.sum(prediction == y_val).item()
                val_acc_list.append(acc_num/len(y_val))
                val_loss_list.append(loss.item()/len(y_val))
                print(f'epoch: {epoch+1}, train_loss: {total_loss/total_sample:.2f}, train_acc: {total_acc/total_sample:.2f}, val_loss: {loss.item()/len(y_val):.2f}, val_acc: {acc_num/len(y_val):.2f}, time: {time.time() - start_time : .2f}')

    torch.save(my_rnn.state_dict(), 'bilstm.pt')
    plt.plot(np.arange(1,11),loss_list,label = 'Training Loss')
    plt.plot(np.arange(1,11),val_loss_list,label = 'Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.xticks(np.arange(1,11))
    plt.title('Loss')
    plt.legend()
    plt.savefig('loss.png')
    plt.show()
    plt.plot(np.arange(1,11),acc_list,label = 'Training Accuracy')
    plt.plot(np.arange(1,11),val_acc_list,label = 'Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.xticks(np.arange(1,11))
    plt.title('Accuracy')
    plt.legend()
    plt.savefig('accuracy.png')
    plt.show()

这里超参数设置为epochs:10,lr:1e-3,Adam优化器

epoch: 1, train_loss: 1.70, train_acc: 0.51, val_loss: 1.50, val_acc: 0.56, time:  11.83
epoch: 2, train_loss: 1.36, train_acc: 0.60, val_loss: 1.25, val_acc: 0.64, time:  22.84
epoch: 3, train_loss: 1.19, train_acc: 0.65, val_loss: 1.10, val_acc: 0.69, time:  33.76
epoch: 4, train_loss: 1.05, train_acc: 0.69, val_loss: 0.97, val_acc: 0.72, time:  44.63
epoch: 5, train_loss: 0.93, train_acc: 0.73, val_loss: 0.91, val_acc: 0.74, time:  55.49
epoch: 6, train_loss: 0.85, train_acc: 0.75, val_loss: 0.85, val_acc: 0.75, time:  66.38
epoch: 7, train_loss: 0.78, train_acc: 0.77, val_loss: 0.78, val_acc: 0.77, time:  77.38
epoch: 8, train_loss: 0.73, train_acc: 0.78, val_loss: 0.75, val_acc: 0.77, time:  88.27
epoch: 9, train_loss: 0.68, train_acc: 0.79, val_loss: 0.71, val_acc: 0.78, time:  99.44
epoch: 10, train_loss: 0.64, train_acc: 0.80, val_loss: 0.72, val_acc: 0.78, time:  110.43

完整代码的开源链接可以查询kaggle,gitee,github链接,其中gitee和github仓库中有训练好的模型权重,有需要可以在模型实例化后直接使用.

如需使用其他rnn序列模型如lstm和gru也可以直接实例化这里对应的模型类进行训练即可

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2236963.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

打包18款AI营销神器,批量运营项目收藏必备!

淘金的不如卖铲子的,AI工具的应用越来越普及,这也让很多原本淘金的人都来卖铲子。如果自己能有很好的铲子,自己也会淘金,就可以既能卖铲子赚钱,也能掏金赚钱。 还有两天就是双十一了,各种AI工具&#xff0…

Leetcode - 周赛422

目录 一,3340. 检查平衡字符串 二,3341. 到达最后一个房间的最少时间 I 三,3342. 到达最后一个房间的最少时间 II 四,3343. 统计平衡排列的数目 一,3340. 检查平衡字符串 本题直接暴力,定义一个变量 s&…

聚观早报 | 比亚迪腾势D9登陆泰国;苹果 iOS 18.2 将发布

聚观早报每日整理最值得关注的行业重点事件,帮助大家及时了解最新行业动态,每日读报,就读聚观365资讯简报。 整理丨Cutie 11月5日消息 比亚迪腾势D9登陆泰国 苹果 iOS 18.2 将发布 真我GT7 Pro防尘防水细节 小米15 Ultra最快明年登场 …

Pytest-Bdd-Playwright 系列教程(7):使用测试代码生成辅助工具

Pytest-Bdd-Playwright 系列教程(7):测试代码生成辅助工具的使用 前言一、代码生成辅助工具的设计思路1.1 功能概览1.2 适用人群 二、如何使用 pytest-bdd 代码生成器三、代码生成器的实际应用场景3.1 初学者的学习和实践3.2 大规模功能测试3…

动态规划 —— dp 问题-买卖股票的最佳时机含冷冻期

1. 买卖股票的最佳时机含冷冻期 题目链接: 309. 买卖股票的最佳时机含冷冻期 - 力扣(LeetCode)https://leetcode.cn/problems/best-time-to-buy-and-sell-stock-with-cooldown/description/ 2. 题目解析 3. 算法原理 状态表示:以…

大数据性能测试怎么做?看完这篇终于懂了

大数据性能测试的目的 1.大数据组件的性能回归,在版本升级的时候,进行新旧版本的性能比对。 2.在新版本/新的生产环境发布之后获取性能基线,建立可度量的参考标准,为其他测试场景或者调优过程提供对比参考。 3.在众多的发行版本…

鸿蒙开发:ArkTS如何读取图片资源

ArkTS在TS的基础上主要扩展了声明式UI能力,简化了构建和更新UI的过程。开发者可以以声明式的方式来描述UI的结构,如使用build方法中的代码块。同时,ArkTS提供了自定义组件、系统组件、属性方法、事件方法等,以构建应用UI界面。今天…

【Unity基础】Unity中如何导入字体?

在Unity中,不能像其他软件一样直接使用字体文件,需要通过FontAssetCreator将其转换成Texture的Asset文件,然后才能使用。 本文介绍了使用FontAssetCreator导入字体的过程,并对其参数设置进行了说明。 Font Asset Creator 是 Uni…

三、直流有刷电机H桥驱动原理

1、H桥简介 H桥驱动电路:是一种用于控制直流电机正反转及速度的电路,其名称来源于其电路结构类似于字母“H”。H桥驱动电路由四个开关元件(如晶体管、MOSFET等)组成,这些开关元件被配置成可以控制电机两端电流的方向&…

Unity性能优化 -- 性能分析工具

Stats窗口Profiler窗口Memory Profiler其他性能分析工具(Physica Debugger 窗口,Import Activity 窗口,Code Coverage 窗口,Profile Analyzer 窗口,IMGUI Debugger 窗口) Stats 统级数据窗口 game窗口 可…

html的week控件 获取周(星期)的第一天(周一)和最后一天(周日)

html的week控件 获取周(星期)的第一天(周一)和最后一天(周日) <input type"week" id"week" class"my-css" value"ViewBag.DefaultWeek" /><script> function PageList() { var dateStrin…

[C++11] 可变参数模板

文章目录 基本语法及原理可变参数模板的基本语法参数包的两种类型可变参数模板的定义 sizeof... 运算符可变参数模板的实例化原理可变参数模板的意义 包扩展包扩展的基本概念包扩展的实现原理编译器如何展开参数包包扩展的高级应用 emplace 系列接口emplace_back 和 emplace 的…

Axure设计之左右滚动组件教程(动态面板)

很多项目产品设计经常会遇到左右滚动的导航、图片展示、内容区域等&#xff0c;接下来我们用Axure来实现一下左右滚动的菜单导航。通过案例我们可以举一反三进行其他方式的滚动组件设计&#xff0c;如常见的上下滚动、翻页滚动等等。 一、效果展示&#xff1a; 1、点击“向左箭…

qt QListWidget详解

1、概述 QListWidget 是 Qt 框架中的一个类&#xff0c;它提供了一个基于模型的视图&#xff0c;用于显示项目的列表。QListWidget 继承自 QAbstractItemView 并为项目列表提供了一个直观的接口。与 QTreeView 和 QTableView 不同&#xff0c;QListWidget 是专门为单行或多行项…

vue--vueCLI

何为CLI ■ CLI是Command-Line Interface,俗称脚手架. ■ 使用Vue.js开发大型应用时&#xff0c;我们需要考虑代码目录结构、项目结构和部署、热加载、代码单元测试等事情。&#xff08;vue 脚手架的作用&#xff09;&#xff0c; 而通过vue-cli即可&#xff1a;vue-cli 可以…

思维,CF 1735D - Meta-set

目录 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 二、解题报告 1、思路分析 2、复杂度 3、代码详解 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 1735D - Meta-set 二、解题报告 1、思路分析 考虑一个五元组<a, b, c…

C#的6种常用集合类

一.先来说说数组的不足&#xff08;也可以说集合与数组的区别&#xff09;&#xff1a; 1.数组是固定大小的&#xff0c;不能伸缩。虽然System.Array.Resize这个泛型方法可以重置数组大小&#xff0c;但是该方法是重新创建新设置大小的数组&#xff0c;用的是旧数组的元素初始…

深度学习-神经网络基础-激活函数与参数初始化(weight, bias)

一. 神经网络介绍 神经网络概念 神经元构建 神经网络 人工神经网络是一种模仿生物神经网络结构和功能的计算模型, 由神经元构成 将神经元串联起来 -> 神经网络 输入层: 数据 输出层: 目标(加权和) 隐藏层: 加权和 激活 全连接 第N层的每个神经元和第N-1层的所有神经元…

SpringBoot框架在资产管理中的应用

3系统分析 3.1可行性分析 通过对本企业资产管理系统实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本企业资产管理系统采用Spring Boot框架&#xff0c;JAVA作…

【C#】选课程序增加、删除统计学时

文章目录 【例6-2】编写选课程序。利用利用列表框和组合框增加和删除相关课程&#xff0c;并统计学时数1. 表6-2 属性设置2. 设计窗体及页面3. 代码实现4. 运行效果 【例6-2】编写选课程序。利用利用列表框和组合框增加和删除相关课程&#xff0c;并统计学时数 分析&#xff1…