人名分类器(nlp)

news2024/11/27 15:48:47
# coding: utf-8
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# 导入torch工具
import json

import torch
# 导入nn准备构建模型
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# 导入torch的数据源 数据迭代器工具包
from torch.utils.data import Dataset, DataLoader
# 用于获得常见字母及字符规范化
import string
# 导入时间工具包
import time
# 引入制图工具包
import matplotlib.pyplot as plt
# 从io中导入文件打开方法
from io import open

# 1 获取常用的字符 标点,把每个char字符作为一个token,用onehot编码表示token
# 因此我们的词表就是 char表 (字符表) 57个char
all_letters = string.ascii_letters + " ,.;'"
print(all_letters)

n_letter = len(all_letters)  # 词表的大小
print('字符表的长度:', n_letter)

# 2 获取国家的类别种数
# 国家名 种类数
categorys = ['Italian', 'English', 'Arabic', 'Spanish', 'Scottish', 'Irish', 'Chinese', 'Vietnamese', 'Japanese',
             'French', 'Greek', 'Dutch', 'Korean', 'Polish', 'Portuguese', 'Russian', 'Czech', 'German']
# 国家名 个数,就是模型的 (linear输出维度) 分类数
categorynum = len(categorys)
print('categorys--->', categorys)


# 3 读取数据
def read_data(filename):
    # 3.1 初始化空列表两个
    my_list_x, my_list_y = [], []

    # 3.2 读取文件内容
    with open(filename, 'r', encoding='utf-8') as fr:
        for line in fr.readlines():
            # 异常点判断:改行长度<=5,说明这是异常样本,直接跳到下一行
            if len(line) <= 5:
                continue
            x, y = line.strip().split('\t')
            my_list_x.append(x)
            my_list_y.append(y)

    # 3.3 返回两个列表
    return my_list_x, my_list_y


# 4 构建数据集
class NameClsDataset(Dataset):
    def __init__(self, mylist_x, mylist_y):
        self.mylist_x = mylist_x
        self.mylist_y = mylist_y

    def __len__(self):
        return len(self.mylist_x)

    def __getitem__(self, item):
        # 01 item 异常值出处理
        index = min(max(item, 0), len(self.mylist_x) - 1)

        # 02 根据idx拿到人名 国家名
        x = self.mylist_x[index]
        y = self.mylist_y[index]

        # 03 完成onehot
        tensor_x = torch.zeros(len(x), n_letter)
        for idx, letter in enumerate(x):
            tensor_x[idx][all_letters.find(letter)] = 1

        # 04 获得标签
        tensor_y = torch.tensor(categorys.index(y), dtype=torch.long)

        return tensor_x, tensor_y


# 5 构建dataloader
def get_dataloader():
    filename = './data/name_classfication.txt'
    my_list_x, my_list_y = read_data(filename)
    mydataset = NameClsDataset(my_list_x, my_list_y)
    my_dataloader = DataLoader(
        mydataset,
        batch_size=1,
        shuffle=True,  # 打乱顺序
        # drop_last=True,  # 是否丢弃最后那个不足一个batch_size的数据组
        # collate_fn=collate_fn,  # 处理一个batch的数据为整齐的维度
    )
    x, y = next(iter(my_dataloader))
    # print(x)
    # print(x.shape)
    # print(y)
    return my_dataloader


# 6 创建rnn模型
class MyRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers

        self.rnn = nn.RNN(self.input_size, self.hidden_size,
                          self.num_layers, batch_first=True)
        # self.linear = nn.Linear(self.hidden_size, self.hidden_size)
        self.linear = nn.Linear(self.hidden_size, self.output_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, input):
        # input.shape = (1, 9, 57)
        # hidden.shape = (1, 1, 128)
        # rnn_output.shape = (1, 9, 128)
        # rnn_hn.shape = (1, 1, 128)
        # rnn_output, _ = self.rnn(input)
        rnn_output, rnn_hn = self.rnn(input)

        # temp.shape = (1, 128)
        # temp = rnn_output[0][-1].unsqueeze(0)
        temp = rnn_hn[0]

        # output.shape=(1,18)
        # self.softmax(output) (2, 18)
        output = self.linear(temp)  # 可以接受三维数据
        return self.softmax(output), rnn_hn


# 7 测试RNN
def ceshiRNN():
    # 1 拿到数据
    my_dataloader = get_dataloader()

    # 2 实例化模型
    input_size = n_letter  # 字符表的大小 (词表的大小)
    hidden_size = 128  # 超参数 768,rnn输出维度
    output_size = len(categorys)  # 18,分类总数
    my_rnn = MyRNN(input_size, hidden_size, output_size)

    # 3 将数据送入到模型
    x, y = next(iter(my_dataloader))
    output, hn = my_rnn(x)  # output.shape = (1, 18)
    print(output.shape)
    print(hn.shape)


# 8 训练RNN
def train_my_rnn():
    epochs = 1
    my_lr = 1e-3
    # 1 读取数据
    my_list_x, my_list_y = read_data('./data/name_classfication.txt')
    # 2 定义dataset
    myDataset = NameClsDataset(my_list_x, my_list_y)
    # 3 实例化dataloader
    my_dataloader = DataLoader(myDataset, batch_size=1, shuffle=True)
    # 4 实例化RNN模型
    input_size = 57
    hidden_size = 128
    output_size = 18
    my_rnn = MyRNN(input_size, hidden_size, output_size)

    # 5 损失函数
    my_crossentropy = nn.NLLLoss()
    # 6 优化器
    my_optimizer = optim.Adam(my_rnn.parameters(), lr=my_lr)

    # 7 日志
    start_time = time.time()
    total_iter_num = 0  # 已经训练好的样本数
    total_loss = 0  # 总的loss
    total_loss_list = []  # 每隔多少步存储loss-avg
    total_acc_num = 0
    total_acc_list = []  # 存储间隔准确率acc-avg

    # 8 开始训练
    # 8.1 外部循环
    for epoch_idx in range(epochs):
        # 8.2 batch循环
        for i, (x, y) in enumerate(my_dataloader):
            # 8.3 将x送入到模型 一轮模型训练
            output, hn = my_rnn(x)
            my_loss = my_crossentropy(output, y)
            my_optimizer.zero_grad()
            my_loss.backward()
            my_optimizer.step()

            total_iter_num += 1
            total_loss += my_loss.item()
            item1 = 1 if torch.argmax(output, dim=-1).item() == y.item() else 0
            total_acc_num += item1

            # 每隔 100 步存储avg-loss acc-avg
            if total_iter_num % 100 == 0:
                # 保存一下平均损失
                loss_avg = total_loss / total_iter_num
                total_loss_list.append(loss_avg)

                # acc-avg
                acc_avg = total_acc_num / total_iter_num
                total_acc_list.append(acc_avg)

            if total_iter_num % 1000 == 0:
                loss_avg = total_loss / total_iter_num
                acc_avg = total_acc_num / total_iter_num
                end_time = time.time()
                use_time = end_time - start_time
                print(
                    '当前的训练批次:%d, 平均损失:%.5f, 训练时间:%.3f, 准确率:%.2f' % (
                        epoch_idx + 1,
                        loss_avg,
                        use_time,
                        acc_avg
                    )
                )

        # 9 保存模型
        torch.save(my_rnn.state_dict(), './model/my_rnn.bin')

    # 10 结束
    all_time = time.time() - start_time
    print('总耗时:', all_time)
    return total_loss_list, total_acc_list, all_time


# 9 将模型结果进行保存,方便进行读取
def save_rnn_res():
    # 1 训练模型,得到需要的结果
    total_loss_list, total_acc_list, all_time = train_my_rnn()
    # 2 定义一个字典
    dict1 = {
        'loss': total_loss_list,
        'time': all_time,
        'acc': total_acc_list
    }
    # 3 保存成json
    with open('./data/rnn_result.json', 'w') as fw:
        fw.write(json.dumps(dict1))


# 10 读取模型结果json
def read_json(json_path):
    with open(json_path, 'r') as fr:
        # '{a:1, b:2,,,}'  --> json.loads()
        # json.load() 加载json文件
        res = json.load(fr)
    return res


# 11 绘图
def plt_RNN():
    # 1 拿到数据
    rnn_results = read_json('./data/rnn_result-epoch3.json')
    total_loss_list_rnn, all_time_rnn, total_acc_list_rnn = rnn_results['loss'], rnn_results['time'], rnn_results['acc']
    lstm_results = read_json('./data/lstm_result-epoch3.json')
    total_loss_list_lstm, all_time_lstm, total_acc_list_lstm = lstm_results['loss'], lstm_results['time'], lstm_results[
        'acc']
    gru_results = read_json('./data/gru_result-epoch3.json')
    total_loss_list_gru, all_time_gru, total_acc_list_gru = gru_results['loss'], gru_results['time'], gru_results['acc']

    # 2 绘制loss对比曲线图
    plt.figure(0)
    plt.plot(total_loss_list_rnn, label='RNN')
    plt.plot(total_loss_list_lstm, label='LSTM', color='red')
    plt.plot(total_loss_list_gru, label='GRU', color='orange')
    plt.legend(loc='upper right')
    plt.savefig('./picture/loss.png')
    plt.show()

    # 3 绘制耗时柱状图
    plt.figure(1)
    x_data = ['RNN', 'LSTM', 'GRU']
    y_data = [all_time_rnn, all_time_lstm, all_time_gru]
    plt.bar(range(len(x_data)), y_data, tick_label=x_data)
    plt.savefig('./picture/use_time.png')
    plt.show()

    # 4 绘制acc曲线图
    plt.figure(2)
    plt.plot(total_acc_list_rnn, label='RNN')
    plt.plot(total_acc_list_lstm, label='LSTM', color='red')
    plt.plot(total_acc_list_gru, label='GRU', color='orange')
    plt.legend(loc='upper right')
    plt.savefig('./picture/acc.png')
    plt.show()


# 12 定义预测输入的x --》 tensor_x
def line2tensor(x):
    tensor_x = torch.zeros(len(x), n_letter)
    for li, letter in enumerate(x):
        tensor_x[li][all_letters.find(letter)] = 1
    return tensor_x


# 13 预测主函数
def rnn_predict(x):
    # 1 x --》 tensor_x
    tensor_x = line2tensor(x)
    # 2 实力化模型
    my_rnn = MyRNN(input_size=57, hidden_size=128, output_size=18)
    my_rnn.load_state_dict(torch.load('./model/my_rnn.bin'))

    # 3 预测
    with torch.no_grad():  # 预测时不去计算梯度
        input0 = tensor_x.unsqueeze(0)  # input0 是三维的,rnn需要
        output, hn = my_rnn(input0)
        topv, topi = output.topk(3, 1, True)
        print('人名是', x)

        # 4 打印topk个
        for i in range(3):
            value = topv[0][i]
            index = topi[0][i]
            cate = categorys[index]
            print('国家名是:', cate)


if __name__ == '__main__':
    # filename = './data/name_classfication.txt'
    # x, y = read_data(filename)
    # print(x)
    # print(y)
    # get_dataloader()
    # ceshiRNN()
    # train_my_rnn()
    # plt_RNN()

    rnn_predict('zhang')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2248500.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UEFI 中的 Protocol

Protocol 在 UEFI 内核中的表示 typedef VOID *EFI_HANDLE;EFI_HANDLE是指向某种对象的指针&#xff0c;UEFI 用它来表示某个对象。 UEFI 扫描总线后&#xff0c;会为每个设备建立一个 Controller 对象&#xff0c;用于控制设备&#xff0c;所有该设备的驱动以 Protocol 的形式…

量子安全与经典密码学:一些现实方面的讨论

量子安全与经典密码学 背景&#xff1a;量子安全与经典密码学量子计算对传统密码学的威胁 安全性分析经典密码学的数学复杂性假设**量子密码学的物理不可克隆性假设** **性能与实现难度**后量子算法在经典计算机上的运行效率**量子通信设备的技术要求与成本** **可扩展性与适用…

【大模型】LLaMA-Factory的环境配置、微调模型与测试

前言 【一些闲扯】 时常和朋友闲聊&#xff0c;时代发展这么快&#xff0c;在时代的洪流下&#xff0c;我们个人能抓住些什么呢。我问了大模型&#xff0c;文心一言是这样回答的&#xff1a; 在快速发展的时代背景下&#xff0c;个人确实面临着诸多挑战&#xff0c;但同时也充满…

PostgreSQL的学习心得和知识总结(一百五十八)|在线调优工具pgtune的实现原理和源码解析

目录结构 注&#xff1a;提前言明 本文借鉴了以下博主、书籍或网站的内容&#xff0c;其列表如下&#xff1a; 1、参考书籍&#xff1a;《PostgreSQL数据库内核分析》 2、参考书籍&#xff1a;《数据库事务处理的艺术&#xff1a;事务管理与并发控制》 3、PostgreSQL数据库仓库…

汽车渲染领域:Blender 和 UE5 哪款更适用?两者区别?

在汽车渲染领域&#xff0c;选择合适的工具对于实现高质量的视觉效果至关重要。Blender和UE5&#xff08;Unreal Engine 5&#xff09;作为两大主流3D软件&#xff0c;各自在渲染动画方面有着显著的差异。本文将从核心定位与用途、工作流程、渲染技术和灵活性、后期处理与合成四…

机器学习—迁移学习:使用其他任务中的数据

对于一个没有那么多数据的应用程序&#xff0c;迁移学习是一种奇妙的技术&#xff0c;它允许你使用来自不同任务的数据来帮助你的应用程序&#xff0c;迁移学习是如何工作的&#xff1f; 以下是迁移学习的工作原理&#xff0c;假设你想识别手写的数字0到9&#xff0c;但是你没…

LeetCode 3206.交替组 I:遍历

【LetMeFly】3206.交替组 I&#xff1a;遍历 力扣题目链接&#xff1a;https://leetcode.cn/problems/alternating-groups-i/ 给你一个整数数组 colors &#xff0c;它表示一个由红色和蓝色瓷砖组成的环&#xff0c;第 i 块瓷砖的颜色为 colors[i] &#xff1a; colors[i] …

如何通过高效的缓存策略无缝加速湖仓查询

引言 本文将探讨如何利用开源项目 StarRocks 的缓存策略来加速湖仓查询&#xff0c;为企业提供更快速、更灵活的数据分析能力。作为 StarRocks 社区的主要贡献者和商业化公司&#xff0c;镜舟科技深度参与 StarRocks 项目开发&#xff0c;也为企业着手构建湖仓架构提供更多参考…

25A物联网微型断路器 智慧空开1P 2P 3P 4P-安科瑞黄安南

微型断路器&#xff0c;作为现代电气系统中不可或缺的重要组件&#xff0c;在保障电路安全与稳定运行方面发挥着关键作用。从其工作原理来看&#xff0c;微型断路器通过感知电流的异常变化来迅速作出响应。当电路中的电流超过预设的安全阈值时&#xff0c;其内部的电磁感应装置…

目标检测,图像分割,超分辨率重建

目标检测和图像分割 目标检测和图像分割是计算机视觉中的两个不同任务&#xff0c;它们的输出形式也有所不同。下面我将分别介绍这两个任务的输出。图像分割又可以分为&#xff1a;语义分割、实例分割、全景分割。 语义分割&#xff08;Semantic Segmentation&#xff09;&…

16 —— Webpack多页面打包

需求&#xff1a;把 黑马头条登陆页面-内容页面 一起引入打包使用 步骤&#xff1a; 准备源码&#xff08;html、css、js&#xff09;放入相应位置&#xff0c;并改用模块化语法导出 原始content.html代码 <!DOCTYPE html> <html lang"en"><head&…

《PH47 快速开发教程》发布

PDF 教程下载位于CSDN资源栏目&#xff08;网页版本文上方&#xff09; 或Gitee&#xff1a;document ss15/PH47 - 码云 - 开源中国

腾讯云OCR车牌识别实践:从图片上传到车牌识别

在当今智能化和自动化的浪潮中&#xff0c;车牌识别&#xff08;LPR&#xff09;技术已经广泛应用于交通管理、智能停车、自动收费等多个场景。腾讯云OCR车牌识别服务凭借其高效、精准的识别能力&#xff0c;为开发者提供了强大的技术支持。本文将介绍如何利用腾讯云OCR车牌识别…

C++ 优先算法 —— 无重复字符的最长子串(滑动窗口)

目录 题目&#xff1a; 无重复字符的最长子串 1. 题目解析 2. 算法原理 Ⅰ. 暴力枚举 Ⅱ. 滑动窗口&#xff08;同向双指针&#xff09; 3. 代码实现 Ⅰ. 暴力枚举 Ⅱ. 滑动窗口 题目&#xff1a; 无重复字符的最长子串 1. 题目解析 题目截图&#xff1a; 此题所说的…

[网安靶场] [更新中] UPLOAD LABS —— 靶场笔记合集

GitHub - c0ny1/upload-labs: 一个想帮你总结所有类型的上传漏洞的靶场一个想帮你总结所有类型的上传漏洞的靶场. Contribute to c0ny1/upload-labs development by creating an account on GitHub.https://github.com/c0ny1/upload-labs 0x01&#xff1a;UPLOAD LABS 靶场初识…

安装python拓展库pyquery相关问题

我采用的是离线whl文件安装, 从官方库地址: https://pypi.org/, 下载whl文件, 然后在本地电脑上执行pip install whl路径文件名.whl 但是在运行时报错如下图 大体看了看, 先是说了说找到了合适的 lxml>2.1, 在我的python库路径中, 然后我去看了看我的lxml版本, 是4.8.0, 对…

春秋云境 CVE 复现

CVE-2022-4230 靶标介绍 WP Statistics WordPress 插件13.2.9之前的版本不会转义参数&#xff0c;这可能允许经过身份验证的用户执行 SQL 注入攻击。默认情况下&#xff0c;具有管理选项功能 (admin) 的用户可以使用受影响的功能&#xff0c;但是该插件有一个设置允许低权限用…

图论入门编程

卡码网刷题链接&#xff1a;98. 所有可达路径 一、题目简述 二、编程demo 方法①邻接矩阵 from collections import defaultdict #简历邻接矩阵 def build_graph(): n, m map(int,input().split()) graph [[0 for _ in range(n1)] for _ in range(n1)]for _ in range(m): …

Jackson库中JsonInclude的使用

简介 JsonInclude是 Jackson 库&#xff08;Java 中用于处理 JSON 数据的流行库&#xff09;中的一个注解。它用于控制在序列化 Java 对象为 JSON 时&#xff0c;哪些属性应该被包含在 JSON 输出中。这个注解提供了多种策略来决定属性的包含与否&#xff0c;帮助减少不必要的数…

鸿蒙学习自由流转与分布式运行环境-价值与架构定义(1)

文章目录 价值与架构定义1、价值2、架构定义 随着个人设备数量越来越多&#xff0c;跨多个设备间的交互将成为常态。基于传统 OS 开发跨设备交互的应用程序时&#xff0c;需要解决设备发现、设备认证、设备连接、数据同步等技术难题&#xff0c;不但开发成本高&#xff0c;还存…