w04_nlp大模型训练·中文分词

news2024/12/26 2:46:12

一、基于pytorch的网络编写一个分词模型

#coding:utf8

import torch
import torch.nn as nn
import jieba
import numpy as np
import random
import json
from torch.utils.data import DataLoader

"""
基于pytorch的网络编写一个分词模型
我们使用jieba分词的结果作为训练数据
看看是否可以得到一个效果接近的神经网络模型
"""

# TorchModel 类定义了一个包含嵌入层、RNN 层和线性分类层的神经网络
# 使用 nn.Embedding 将字符映射到高维空间,通过 nn.RNN 处理序列信息,使用 nn.Linear 进行分类。
class TorchModel(nn.Module):
    def __init__(self, input_dim, hidden_size, num_rnn_layers, vocab):
        super(TorchModel, self).__init__()
        self.embedding = nn.Embedding(len(vocab)+1, input_dim, padding_idx=0)
        self.rnn_layer = nn.RNN(input_size=input_dim,
                          hidden_size=hidden_size,
                          batch_first=True,
                          num_layers=num_rnn_layers
                          )
        self.classify = nn.Linear(hidden_size, 2)
        self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, x, y=None):
        x = self.embedding(x)  #input shape: (batch_size, sen_len), output shape:(batch_size, sen_len, input_dim)
        x, _ = self.rnn_layer(x)  #output shape:(batch_size, sen_len, hidden_size)
        y_pred = self.classify(x)   #output shape:(batch_size, sen_len, 2) -> y_pred.view(-1, 2) (batch_size*sen_len, 2)
        if y is not None:
            return self.loss_func(y_pred.view(-1, 2), y.view(-1))
        else:
            return y_pred


# Dataset 类用于处理语料库数据,从文件中读取句子
# 使用 sentence_to_sequence 将句子转换为数字序列,使用 sequence_to_label 生成标记序列,使用 padding 方法将序列和标签填充到固定长度。
class Dataset:
    def __init__(self, corpus_path, vocab, max_length):
        self.vocab = vocab
        self.corpus_path = corpus_path
        self.max_length = max_length
        self.load()

    def load(self):
        self.data = []
        with open(self.corpus_path, encoding="utf8") as f:
            for line in f:
                sequence = sentence_to_sequence(line, self.vocab)
                label = sequence_to_label(line)
                sequence, label = self.padding(sequence, label)
                sequence = torch.LongTensor(sequence)
                label = torch.LongTensor(label)
                self.data.append([sequence, label])
                #使用部分数据做展示,使用全部数据训练时间会相应变长
                if len(self.data) > 10000:
                    break

    #将文本截断或补齐到固定长度
    def padding(self, sequence, label):
        sequence = sequence[:self.max_length]
        sequence += [0] * (self.max_length - len(sequence))
        label = label[:self.max_length]
        label += [-100] * (self.max_length - len(label))
        return sequence, label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]

#文本转化为数字序列,为embedding做准备
def sentence_to_sequence(sentence, vocab):
    sequence = [vocab.get(char, vocab['unk']) for char in sentence]
    return sequence

#基于结巴生成分级结果的标注
def sequence_to_label(sentence):
    words = jieba.lcut(sentence)
    label = [0] * len(sentence)
    pointer = 0
    for word in words:
        pointer += len(word)
        label[pointer - 1] = 1
    return label

#加载字表
def build_vocab(vocab_path):
    vocab = {}
    with open(vocab_path, "r", encoding="utf8") as f:
        for index, line in enumerate(f):
            char = line.strip()
            vocab[char] = index + 1   #每个字对应一个序号
    vocab['unk'] = len(vocab) + 1
    return vocab

#建立数据集
def build_dataset(corpus_path, vocab, max_length, batch_size):
    dataset = Dataset(corpus_path, vocab, max_length) #diy __len__ __getitem__
    data_loader = DataLoader(dataset, shuffle=True, batch_size=batch_size) #torch
    return data_loader


def main():
    epoch_num = 5        #训练轮数
    batch_size = 20       #每次训练样本个数
    char_dim = 50         #每个字的维度
    hidden_size = 100     #隐含层维度
    num_rnn_layers = 1    #rnn层数
    max_length = 20       #样本最大长度
    learning_rate = 1e-3  #学习率
    vocab_path = "chars.txt"  #字表文件路径
    corpus_path = "../corpus.txt"  #语料文件路径
    vocab = build_vocab(vocab_path)       #建立字表
    data_loader = build_dataset(corpus_path, vocab, max_length, batch_size)  #建立数据集
    model = TorchModel(char_dim, hidden_size, num_rnn_layers, vocab)   #建立模型
    optim = torch.optim.Adam(model.parameters(), lr=learning_rate)     #建立优化器
    #训练开始
    for epoch in range(epoch_num):
        model.train()
        watch_loss = []
        for x, y in data_loader:
            optim.zero_grad()    #梯度归零
            loss = model.forward(x, y)   #计算loss
            loss.backward()      #计算梯度
            optim.step()         #更新权重
            watch_loss.append(loss.item())
        print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))
    #保存模型
    torch.save(model.state_dict(), "model.pth")
    return

#最终预测
def predict(model_path, vocab_path, input_strings):
    #配置保持和训练时一致
    char_dim = 50  # 每个字的维度
    hidden_size = 100  # 隐含层维度
    num_rnn_layers = 1  # rnn层数
    vocab = build_vocab(vocab_path)       #建立字表
    model = TorchModel(char_dim, hidden_size, num_rnn_layers, vocab)   #建立模型
    model.load_state_dict(torch.load(model_path))   #加载训练好的模型权重
    model.eval()
    for input_string in input_strings:
        #逐条预测
        x = sentence_to_sequence(input_string, vocab)
        with torch.no_grad():
            result = model.forward(torch.LongTensor([x]))[0]
            result = torch.argmax(result, dim=-1)  #预测出的01序列
            #在预测为1的地方切分,将切分后文本打印出来
            for index, p in enumerate(result):
                if p == 1:
                    print(input_string[index], end=" ")
                else:
                    print(input_string[index], end="")
            print()



if __name__ == "__main__":
    # main()
    input_strings = ["同时国内有望出台新汽车刺激方案",
                     "沪胶后市有望延续强势",
                     "经过两个交易日的强势调整后",
                     "昨日上海天然橡胶期货价格再度大幅上扬"]
    
    predict("model.pth", "chars.txt", input_strings)

模型分析

  • 模型定义
    • TorchModel 类定义了一个包含嵌入层、RNN 层和线性分类层的神经网络,使用 nn.Embedding 将字符映射到高维空间,通过 nn.RNN 处理序列信息,使用 nn.Linear 进行分类。
    • 数据集处理
      • Dataset 类用于处理语料库数据,从文件中读取句子,使用 sentence_to_sequence 将句子转换为数字序列,使用 sequence_to_label 生成标记序列,使用 padding 方法将序列和标签填充到固定长度。
      • build_vocab 函数从文件中构建词汇表,build_dataset 函数使用 Dataset 类和 DataLoader 进行批处理。
    • 训练部分
      • main 函数中,设置超参数,构建模型和优化器,进行多轮训练,计算损失、反向传播和更新参数,保存训练好的模型。
    • 预测部分
      • predict 函数加载训练好的模型和词汇表,对输入句子进行分词预测,将预测为 1 的位置进行分词,将结果打印输出。
  • 数据预处理
    • sentence_to_sequence 函数将输入的句子中的字符根据词汇表转换为数字序列,未在词汇表中的字符使用 unk 的索引。
    • sequence_to_label 函数利用结巴分词的结果,将分词结束位置标记为 1,其余为 0,生成标记序列。
    • Dataset 类的 padding 方法确保所有序列和标签具有相同的长度,便于批处理。
  • 模型架构
    • TorchModel 类的 embedding 层将输入的数字序列映射到高维空间,rnn_layer 处理序列信息,classify 层将 RNN 的输出映射到 2 个类别(分词或不分词)。
    • 训练时使用 CrossEntropyLoss 计算损失,预测时使用 torch.argmax 找到最可能的类别。
  • 训练和预测流程
    • main 函数设置训练的超参数,创建数据集和模型,使用 Adam 优化器进行优化,保存训练好的模型。
    • predict 函数加载训练好的模型,对输入句子进行分词预测并输出结果。

二、DAG(有向无环图)法做分词

import jieba

#词典,每个词后方存储的是其词频,仅为示例,也可自行添加
Dict = {"经常":0.1,
        "经":0.05,
        "有":0.1,
        "常":0.001,
        "有意见":0.1,
        "歧":0.001,
        "意见":0.2,
        "分歧":0.2,
        "见":0.05,
        "意":0.05,
        "见分歧":0.05,
        "分":0.1}

#根据上方词典,对于输入文本,构造一个存储有所有切分方式的信息字典
#学术叫法为有向无环图,DAG(Directed Acyclic Graph),不理解也不用纠结,只当是个专属名词就好
#这段代码直接来自于jieba分词
# jieba.cut
def calc_dag(sentence):
    DAG = {}
    n = len(sentence)
    for k in range(n):
        i = k
        tmplist = []
        while i < n:
            frag = sentence[k: i+1]
            if frag in Dict:
                tmplist.append(i)
            i += 1
        if not tmplist:
            tmplist = [k]
        DAG[k] = tmplist
    return DAG

sentence = "经常有意见分歧"
print(calc_dag(sentence))
#结果应该为{0: [0, 1], 1: [1], 2: [2, 4], 3: [3, 4], 4: [4, 6], 5: [5, 6], 6: [6]}
#0:[0,1]代表句子中的第0个字,可以单独成词,或与第1个字一起成词
#2:[2,4]代表句子中的第2个字,可以单独成词,或第2-4个字一起成词
#依次类推
#这个字典中实际上就存储了所有可能的切分方式的信息


#将DAG中的信息解码(还原)出来,用文本展示出所有切分方式
class DAGDecode:
    #通过两个队列来实现
    def __init__(self, sentence):
        self.sentence = sentence
        self.DAG = calc_dag(sentence)  #使用了上方的函数
        self.length = len(sentence)
        self.unfinish_path = [[]]   #保存待解码序列的队列
        self.finish_path = []  #保存解码完成的序列的队列

    #对于每一个序列,检查是否需要继续解码
    #不需要继续解码的,放入解码完成队列
    #需要继续解码的,将生成的新队列,放入待解码队列
    #path形如:["经常", "有", "意见"]
    def decode_next(self, path):
        path_length = len("".join(path))
        if path_length == self.length:  #已完成解码
            self.finish_path.append(path)
            return
        candidates = self.DAG[path_length]
        new_paths = []
        for candidate in candidates:
            new_paths.append(path + [self.sentence[path_length:candidate+1]])
        self.unfinish_path += new_paths  #放入待解码对列
        return

    #递归调用序列解码过程
    def decode(self):
        while self.unfinish_path != []:
            path = self.unfinish_path.pop(0) #从待解码队列中取出一个序列
            self.decode_next(path)     #使用该序列进行解码


sentence = "经常有意见分歧"
dd = DAGDecode(sentence)
dd.decode()
print(dd.finish_path)

代码分析

一、函数和类的功能分析:

  • calc_dag(sentence)函数:

    • 功能:
      • 该函数的主要目的是根据输入的句子和预定义的词典 Dict 构建一个有向无环图(DAG),用于存储句子中所有可能的词切分信息。
    • 实现步骤:
      1. 首先,初始化一个空字典 DAG 用于存储结果。
      2. 获取输入句子的长度 n
      3. 遍历句子中的每个字符,从当前字符开始,通过不断增加子串长度,检查子串是否在 Dict 中。
      4. 若子串在 Dict 中,将该子串结束字符的索引添加到 tmplist 中。
      5. 若 tmplist 为空,说明当前字符没有可切分的词,将当前字符索引添加到 tmplist
      6. 最后将 k 作为键,tmplist 作为值存储在 DAG 中。
  • DAGDecode类:

    • __init__(self, sentence)方法:
      • 功能:
        • 对输入的句子进行初始化操作,为后续的解码操作准备所需的数据结构。
      • 实现步骤:
        1. 存储输入的句子。
        2. 调用 calc_dag(sentence) 函数生成有向无环图,并存储在 self.DAG 中。
        3. 存储句子的长度。
        4. 初始化两个队列:self.unfinish_path 存储待解码的序列,初始化为只包含一个空列表的列表;self.finish_path 存储已完成解码的序列,初始化为空列表。
    • decode_next(self, path)方法:
      • 功能:
        • 对于给定的部分解码路径,判断是否完成解码,若未完成则根据 self.DAG 生成新的待解码路径并添加到 self.unfinish_path 中,若完成则添加到 self.finish_path 中。
      • 实现步骤:
        1. 计算当前 path 所代表的字符串的长度。
        2. 若长度等于句子长度,说明解码完成,将 path 加入 self.finish_path
        3. 若未完成,根据 self.DAG 中存储的信息,找出可能的下一个词的结束位置,生成新的解码路径并添加到 self.unfinish_path 中。
    • decode(self)方法:
      • 功能:
        • 循环从 self.unfinish_path 中取出路径,调用 decode_next 方法进行解码,直到 self.unfinish_path 为空。
      • 实现步骤:
        1. 只要 self.unfinish_path 不为空,就取出其中的一个元素。
        2. 调用 decode_next 方法对该元素进行解码。

二、代码逻辑总结:

  • 首先,使用 calc_dag(sentence) 函数对输入的句子构建一个有向无环图,该图以字典的形式存储了从每个字符开始的所有可能的词切分信息。例如对于输入 "经常有意见分歧",会得到 {0: [0, 1], 1: [1], 2: [2, 4], 3: [3, 4], 4: [4, 6], 5: [5, 6], 6: [6]}
  • 然后,DAGDecode 类利用这个有向无环图进行解码操作:
    • 在 __init__ 阶段,存储句子、有向无环图、句子长度,并初始化待解码和已完成解码的队列。
    • decode_next 方法会根据当前的部分解码结果判断是否继续解码,若继续解码,会根据 DAG 生成新的可能路径添加到待解码队列中,若完成则添加到已完成队列中。
    • decode 方法通过不断调用 decode_next 方法处理待解码队列中的元素,最终将所有可能的句子切分方式存储在 finish_path 中。

三、代码解释示例:

  • 以输入句子 "经常有意见分歧" 为例:
    • 在 calc_dag 函数中:
      • 从 k = 0 开始,"经" 在 Dict 中,"经常" 也在 Dict 中,所以 DAG[0] = [0, 1]
      • 对于 k = 1,只有 "常" 在 Dict 中,所以 DAG[1] = [1]
      • 对于 k = 2"有" 在 Dict 中,"有意见" 也在 Dict 中,所以 DAG[2] = [2, 4]
      • 以此类推,最终得到完整的 DAG
    • 在 DAGDecode 类中:
      • 初始化时,unfinish_path = [[]]finish_path = []
      • 第一次调用 decode_next 对 [] 进行处理,会根据 DAG[0] 生成 ["经"] 和 ["经常"] 等新路径添加到 unfinish_path
      • 不断循环调用 decode_next,直到 unfinish_path 为空,最终得到所有可能的句子切分方式存储在 finish_path 中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2265588.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iOS 苹果开发者账号: 查看和添加设备UUID 及设备数量

参考链接&#xff1a;苹果开发者账号下添加新设备UUID - 简书 如果要添加新设备到 Profiles 证书里&#xff1a; 1.登录开发者中心 Sign In - Apple 2.找到证书设置&#xff1a; Certificate&#xff0c;Identifiers&Profiles > Profiles > 选择对应证书 edit &g…

汽车IVI中控开发入门及进阶(47):CarPlay开发

概述: 车载信息娱乐(IVI)系统已经从仅仅播放音乐的设备发展成为现代车辆的核心部件。除了播放音乐,IVI系统还为驾驶员提供导航、通信、空调、电源配置、油耗性能、剩余行驶里程、节能建议和许多其他功能。 ​ 驾驶座逐渐变成了你家和工作场所之外的额外生活空间。2014年,…

Oracle、ACCSEE与TDMS的区别

Oracle、ACCSEE和TDMS都是不同类型的数据管理和存储工具&#xff0c;它们各自有独特的用途、结构和复杂性。Oracle是一个功能强大的关系型数据库管理系统&#xff0c;适用于大规模企业级应用&#xff0c;支持复杂查询和事务管理。ACCSEE主要应用于实时数据采集和过程监控&#…

商场消防电气控制系统设计(论文+源码)

1系统的功能及方案设计 如图2.1所示为本次设计的整体框图&#xff0c;其中单片机部分采用ST89C52来负责协调各个模块&#xff1b;液晶选择LCD1602液晶屏来显示信息;温度传感器选择PT1000进行温度的检测&#xff1b;烟雾传检测选择MQ2烟雾传感器&#xff1b;CO2检测选择CCS811模…

7. petalinux 根文件系统配置(package group)

根文件系统配置&#xff08;Petalinux package group&#xff09; 当使能某个软件包组的时候&#xff0c;依赖的包也会相应被使能&#xff0c;解决依赖问题&#xff0c;在配置页面的help选项可以查看需要安装的包 每个软件包组的功能: packagegroup-petalinux-audio包含与音…

2024年12月一区SCI-加权平均优化算法Weighted average algorithm-附Matlab免费代码

引言 本期介绍了一种基于加权平均位置概念的元启发式优化算法&#xff0c;称为加权平均优化算法Weighted average algorithm&#xff0c;WAA。该成果于2024年12月最新发表在中JCR1区、 中科院1区 SCI期刊 Knowledge-Based Systems。 在WAA算法中&#xff0c;加权平均位置代表当…

操作系统(23)外存的存储空间的管理

一、外存的基本概念与特点 定义&#xff1a;外存&#xff0c;也称为辅助存储器&#xff0c;是计算机系统中用于长期存储数据的设备&#xff0c;如硬盘、光盘、U盘等。与内存相比&#xff0c;外存的存储容量大、成本低&#xff0c;但访问速度相对较慢。特点&#xff1a;外存能够…

【202】仓库管理系统

-- 基于springboot仓库管理系统设计与实现 开发技术栈: 开发语言 : Java 开发软件 : Eclipse/MyEclipse/IDEA JDK版本 : JDK8 后端技术 : SpringBoot 前端技术 : Vue、Element、HTML、JS、CsS、JQuery 服务器 : Tomcat8/9 管理包 : Maven 数据库 : MySQL5.x/8 数据库工具 : …

iDP3复现代码数据预处理全流程(二)——vis_dataset.py

vis_dataset.py 主要作用在于点云数据的可视化&#xff0c;并可以做一些简单的预处理 关键参数基本都在 vis_dataset.sh 中定义了&#xff0c;需要改动的仅以下两点&#xff1a; 1. 点云图像保存位置&#xff0c;因为 dataset_path 被设置为了绝对路径&#xff0c;因此需要相…

重温设计模式--1、组合模式

文章目录 1 、组合模式&#xff08;Composite Pattern&#xff09;概述2. 组合模式的结构3. C 代码示例4. C示例代码25 .应用场景 1 、组合模式&#xff08;Composite Pattern&#xff09;概述 定义&#xff1a;组合模式是一种结构型设计模式&#xff0c;它允许你将对象组合成…

精通Redis

目录 1.NoSQL 非关系型数据库 2.Redis 3.Redis的java客户端 4.Jedis 4.1Jedis快速入门 4.2Jedis连接池及使用 5.SpringDataRedis和RedisTemplate 6.SpringDataRedis快速入门 7.RedisSerializer 1.NoSQL 非关系型数据库 基础篇-02.初始Redis-认识NoSQL_哔哩哔哩_bilib…

【2024】Merry Christmas!一起用Rust绘制一颗圣诞树吧

博主未授权任何人或组织机构转载博主任何原创文章&#xff0c;感谢各位对原创的支持&#xff01; 博主链接 博客内容主要围绕&#xff1a; 5G/6G协议讲解 高级C语言讲解 Rust语言讲解 文章目录 一起用Rust绘制一颗圣诞树吧一、 Rust Cargo.toml配置文件二…

查询 MySQL 默认的存储引擎(SELECT @@default_storage_engine;)

要查询 MySQL 默认的存储引擎&#xff0c;可以使用以下 SQL 查询语句&#xff1a; SELECT default_storage_engine;解释&#xff1a; SELECT: 表示你要执行一个查询。default_storage_engine: 这是一个 MySQL 系统变量&#xff0c;它存储着当前 MySQL 服务器的默认存储引擎。…

两道数组有关的OJ练习题

系列文章目录 &#x1f388; &#x1f388; 我的CSDN主页:OTWOL的主页&#xff0c;欢迎&#xff01;&#xff01;&#xff01;&#x1f44b;&#x1f3fc;&#x1f44b;&#x1f3fc; &#x1f389;&#x1f389;我的C语言初阶合集&#xff1a;C语言初阶合集&#xff0c;希望能…

clickhouse-题库

1、clickhouse介绍以及架构 clickhouse一个分布式列式存储数据库&#xff0c;主要用于在线分析查询 2、列式存储和行式存储有什么区别&#xff1f; 行式存储&#xff1a; 1&#xff09;、数据是按行存储的 2&#xff09;、没有建立索引的查询消耗很大的IO 3&#xff09;、建…

近实时”(NRT)搜索、倒排索引

近实时&#xff08;Near Real-Time, NRT&#xff09;搜索 近实时&#xff08;NRT&#xff09;搜索是 Elasticsearch 的核心特性之一&#xff0c;指的是数据在被写入到系统后&#xff0c;可以几乎立即被搜索和查询到。虽然它不像传统数据库那样完全实时&#xff0c;但它的延迟通…

springboot477基于vue技术的农业设备租赁系统(论文+源码)_kaic

摘 要 使用旧方法对农业设备租赁系统的信息进行系统化管理已经不再让人们信赖了&#xff0c;把现在的网络信息技术运用在农业设备租赁系统的管理上面可以解决许多信息管理上面的难题&#xff0c;比如处理数据时间很长&#xff0c;数据存在错误不能及时纠正等问题。这次开发的农…

vue2 升级为 vite 打包

VUE2 中使用 Webpack 打包、开发&#xff0c;每次打包时间太久&#xff0c;尤其是在开发的过程中&#xff0c;本文记录一下 VUE2 升级Vite 步骤。 安装 Vue2 Vite 依赖 dev 依赖 vitejs/plugin-vue2": "^2.3.3 vitejs/plugin-vue2-jsx": "^1.1.1 vite&…

【HarmonyOS 5.0】第十二篇-ArkUI公共属性(一)

一、公共样式类属性 ArkUI框架提供的基础组件直接或者间接的继承自 CommonMethod &#xff0c; CommonMethod 中定义的属性样式属于公共样式。下面就来学习这些样式 1.1.尺寸设置 宽高设置 设置组件的宽高&#xff0c;缺省时使用组件自身内容的宽高&#xff0c;比如充满父布…

数据库系统原理:数据库安全性与权限控制

2.1vue技术 Vue (读音 /vjuː/&#xff0c;类似于 view) 是一套用于构建用户界面的渐进式JavaScript框架。 [5] 与其它大型框架不同的是&#xff0c;Vue 被设计为可以自底向上逐层应用。Vue 的核心库只关注视图层&#xff0c;不仅易于上手&#xff0c;还便于与第三方库或既有项…