【NLP 37、实践 ⑨ NER 命名实体识别任务 LSTM + CRF 实现】

news2025/3/20 3:55:42

难过的事情我要反复咀嚼,嚼到它再也不能困扰我半分

                                                                        —— 25.3.13

数据文件:

通过网盘分享的文件:Ner命名实体识别任务
链接: https://pan.baidu.com/s/1fUiin2um4PCS5i91V9dJFA?pwd=yc6u 提取码: yc6u 
--来自百度网盘超级会员v3的分享

一、配置文件 config.py

1.模型与数据路径

model_path:模型训练完成后保存的位置。例如:保存最终的模型权重文件。

schema_path:数据结构定义文件,通常用于描述数据的格式(如字段名、标签类型)。
在NER任务中,可能定义实体类别(如 {"PERSON": "人名", "ORG": "组织"})。

train_data_path:训练数据集路径,通常为标注好的文本文件(如 train.txt 或 JSON 格式)。

valid_data_path: 验证数据集路径,用于模型训练时的性能评估和超参数调优。

vocab_path:​字符词汇表文件,记录模型中使用的字符集(如中文字符、字母、数字等)。


2.模型架构

max_length:输入文本的最大序列长度。超过此长度的文本会被截断或填充(如用 [PAD])。

hidden_size:模型隐藏层神经元的数量,影响模型容量和计算复杂度。

num_layers:模型的堆叠层数(如LSTM、Transformer的编码器/解码器层数)。

class_num:任务类别总数。例如:NER任务中可能有9种实体类型。


3.训练配置

epoch:训练轮数。每轮遍历整个训练数据集一次。

batch_size:每次梯度更新所使用的样本数量。较小的批次可能更适合内存受限的环境。

optimizer:优化器类型,用于调整模型参数。Adam是常用优化器,结合动量梯度下降。

learning_rate:学习率,控制参数更新的步长。值过小可能导致训练缓慢,过大易过拟合。

use_crf:是否启用条件随机场(CRF)​层。在序列标注任务(如NER)中,CRF可捕捉标签间的依赖关系,提升准确性。


4.预训练模型

bert_path:预训练BERT模型的路径。BERT是一种强大的预训练语言模型,此处可能用于微调或特征提取。

# -*- coding: utf-8 -*-

"""
配置参数信息
"""

Config = {
    "model_path": "model_output",
    "schema_path": "ner_data/schema.json",
    "train_data_path": "ner_data/train",
    "valid_data_path": "ner_data/test",
    "vocab_path":"chars.txt",
    "max_length": 100,
    "hidden_size": 256,
    "num_layers": 2,
    "epoch": 20,
    "batch_size": 16,
    "optimizer": "adam",
    "learning_rate": 1e-3,
    "use_crf": True,
    "class_num": 9,
    "bert_path": r"F:\人工智能NLP\\NLP资料\week6 语言模型\bert-base-chinese"
}


二、数据加载 loader.py

1.代码运行流程

输入文本 → 分词/分字 → 序列编码 → 标签对齐 → 数据填充 → 批量加载:
          │
          ├── 文本解析 → 实体标签映射 → 序列截断 → 生成张量
          │
          └── DataLoader封装 → 批次迭代

2.初始化数据加载类

data_path:数据文件存储路径

config:包含训练 / 数据配置的字典

self.config:保存包含训练 / 数据配置的字典

self.path:保存数据文件存储路径

self.vocab:加载字表 / 词表文件存储路径

self.sentences:初始化句子列表

self.schema:加载实体标签与索引的映射关系表

self.load:调用 load() 方法从 data_path 加载原始数据,进行分词、编码、填充/截断等预处理。

    def __init__(self, data_path, config):
        self.config = config
        self.path = data_path
        self.vocab = load_vocab(config["vocab_path"])
        self.config["vocab_size"] = len(self.vocab)
        self.sentences = []
        self.schema = self.load_schema(config["schema_path"])
        self.load()

3.加载数据并预处理

① 初始化数据容器:初始化一个空列表 self.data,用于存储处理后的数据样本

文件读取与分段:按段落分割原始数据。

逐行解析:提取字符和标签。

编码转换:将字符转换为词汇表索引序列。

序列标准化:调整序列长度至模型要求。​

⑥ 数据存储:保存为张量列表,供训练使用。

self.path:数据文件的存储路径(如 train.txt),由类初始化时传入的 data_path 参数赋值。

f:文件对象,用于读取 self.path 指向的原始数据文件。

segments:是按双换行符分隔的段落列表,每个段落对应一个样本(如一个句子及其标注序列)。

segment:遍历 segments 时的单个样本段落,进一步按行分割处理为字符和标签

labels:存储当前样本的标签序列,[8]可能表示 [CLS] 标记的 ID,用于序列起始符,之后将每个字符的标签转换为ID。

char:当前行的字符(如 "中"),属于句子中的一个基本单元。

lable:当前行的原始标签字符串(如 "B-LOC"),​尚未映射为 ID

input_ids:将字符序列编码为模型输入所需的 ID 序列(如 BERT 分词后的 Token ID)

self.data:列表,存储预处理后的数据样本,每个样本由输入张量和标签张量组成

sentence:由字符列表拼接而成的完整句子(如 "中国科技大学"),存入 self.sentences 供后续可视化或调试。

open():打开文件并返回文件对象,支持读/写/追加等模式。

参数名类型说明
file字符串文件路径(绝对/相对路径)
mode字符串打开模式(如 r-只读、w-写入、a-追加)
encoding字符串文件编码(如 utf-8,文本模式需指定)
errors字符串编码错误处理方式(如 ignorereplace

文件对象.read():读取文件内容,返回字符串或字节流

参数名类型说明
size整数可选,指定读取的字节数(默认读取全部内容)

split():按分隔符分割字符串,返回子字符串列表

参数名类型说明
delimiter字符串分隔符(默认空格)
maxsplit整数可选,最大分割次数(默认-1表示全部)

strip():去除字符串首尾指定字符(默认空白字符)

参数名类型说明
chars字符串可选,指定需去除的字符集合

join():用分隔符连接可迭代对象的元素,返回新字符串

参数名类型说明
iterable可迭代对象需连接的元素集合(如列表、元组)
sep字符串分隔符(默认空字符串)

列表.append():在列表末尾添加元素

参数名类型说明
obj任意类型要添加的元素
    def load(self):
        self.data = []
        with open(self.path, encoding="utf8") as f:
            segments = f.read().split("\n\n")
            for segment in segments:
                sentenece = []
                labels = []
                for line in segment.split("\n"):
                    if line.strip() == "":
                        continue
                    char, label = line.split()
                    sentenece.append(char)
                    labels.append(self.schema[label])
                self.sentences.append("".join(sentenece))
                input_ids = self.encode_sentence(sentenece)
                labels = self.padding(labels, -1)
                self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])
        return

4.加载字 / 词表 

        load_vocab 函数用于从指定路径加载词汇表文件,并将每个词汇项映射到一个从 1 开始的唯一整数索引​(索引 0 保留给 Padding 占位符)

token_dict:字典,存储词汇到索引的映射

vocab_path:字 / 词表的存储路径

open():打开文件并返回文件对象,用于读写文件内容

参数名类型默认值说明
file_namestr文件路径(需包含扩展名)
modestr'r'文件打开模式:
'r': 只读
'w': 只写(覆盖原文件)
'a': 追加写入
'b': 二进制模式
'x': 创建新文件(若存在则报错)
bufferingintNone缓冲区大小(仅二进制模式有效)
encodingstrNone文件编码(仅文本模式有效,如 'utf-8'
newlinestr'\n'行结束符(仅文本模式有效)
closefdboolTrue是否在文件关闭时自动关闭文件描述符
dir_fdint-1文件描述符(高级用法,通常忽略)
flagsint0Linux 系统下的额外标志位
modestr(重复参数,实际使用中只需指定 mode

enumerate():遍历可迭代对象时,同时返回元素的索引

参数名类型默认值说明
iterable可迭代对象需要遍历的对象(如列表、元组、字符串等)
startint0索引的起始值(可自定义,如从 1 开始)

strip():移除字符串开头和结尾的空白字符或指定字符

参数名类型默认值说明
charsstrNone需要移除的字符集合(默认为空格、换行、制表符 \t、换页符 \f、回车 \r
#加载字表或词表
def load_vocab(vocab_path):
    token_dict = {}
    with open(vocab_path, encoding="utf8") as f:
        for index, line in enumerate(f):
            token = line.strip()
            token_dict[token] = index + 1  #0留给padding位置,所以从1开始
    return token_dict

5.加载映射关系表

        加载位于指定路径的 JSON 格式的模式文件,并将其内容解析为 Python 对象以便在数据生成过程中使用。

path:映射关系表schema的存储路径

open():打开文件并返回文件对象,用于读写文件内容。

参数名类型默认值说明
file_namestr文件路径(需包含扩展名)
modestr'r'文件打开模式:
'r': 只读
'w': 只写(覆盖原文件)
'a': 追加写入
'b': 二进制模式
'x': 创建新文件(若存在则报错)
bufferingintNone缓冲区大小(仅二进制模式有效)
encodingstrNone文件编码(仅文本模式有效,如 'utf-8'
newlinestr'\n'行结束符(仅文本模式有效)
closefdboolTrue是否在文件关闭时自动关闭文件描述符
dir_fdint-1文件描述符(高级用法,通常忽略)
flagsint0Linux 系统下的额外标志位
modestr(重复参数,实际使用中只需指定 mode

json.load():从已打开的 JSON 文件对象中加载数据,并将其转换为 Python 对象(如字典、列表)。

参数名类型默认值说明
fpio.TextIO已打开的文件对象(需处于读取模式)
indentint/strNone缩进空格数(美化输出,如 4 或 " "
sort_keysboolFalse是否对 JSON 键进行排序
load_hookcallableNone自定义对象加载回调函数
object_hookcallableNone自定义对象解析回调函数
    def load_schema(self, path):
        with open(path, encoding="utf8") as f:
            return json.load(f)

6.封装数据

data_path:数据文件的路径(如 train.txt),用于初始化 DataGenerator,指向原始数据文件。

config:配置参数字典,通常包含 batch_sizebert_pathschema_path 等参数,用于控制数据加载逻辑。

dg:自定义数据集对象,继承 torch.utils.data.Dataset,负责数据加载、预处理和样本生成。

dl:封装 DataGenerator 的迭代器,实现批量加载、多进程加速等功能,直接用于模型训练。

DataLoader():PyTorch 模型训练的标配工具,通过合理的参数配置(如 batch_sizenum_workersshuffle),可以显著提升数据加载效率,尤其适用于大规模数据集和复杂预处理任务。其与 Dataset 类的配合使用,是构建高效训练管道的核心。

参数名类型默认值说明
datasetDatasetNone必须参数,自定义数据集对象(需继承 torch.utils.data.Dataset)。
batch_sizeint1每个批次的样本数量。
shuffleboolFalse是否在每个 epoch 开始时打乱数据顺序(训练时推荐设为 True)。
num_workersint0使用多线程加载数据的工人数量(需大于 0 时生效)。
pin_memoryboolFalse是否将数据存储在 pinned memory 中(加速 GPU 数据传输)。
drop_lastboolFalse如果数据集长度无法被 batch_size 整除,是否丢弃最后一个不完整的批次。
persistent_workersboolFalse是否保持工作线程在 epoch 之间持续运行(减少多线程初始化开销)。
worker_init_fncallableNone自定义工作线程初始化函数。
#用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
    dg = DataGenerator(data_path, config)
    dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
    return dl

7.对于输入文本做截断 / 填充

Ⅰ、截断过长序列​(超过预设最大长度)

Ⅱ、填充过短序列​(用 pad_token 补齐到预设最大长度)

    #补齐或截断输入的序列,使其可以在一个batch内运算
    def padding(self, input_id, pad_token=0):
        input_id = input_id[:self.config["max_length"]]
        input_id += [pad_token] * (self.config["max_length"] - len(input_id))
        return input_id

8.类内魔术方法

__len__():用于定义对象的“长度”,通过内置函数 len() 调用时返回该值。它通常用于容器类(如列表、字典、自定义数据结构),表示容器中元素的个数

__getitem__():允许对象通过索引或键值访问元素,支持 obj[index] 或 obj[key] 语法。它使对象表现得像序列(如列表)或映射(如字典)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

9.对于输入的文本编码

输入:原始文本 text(字符串),padding 标志(布尔值,决定是否填充)​

初始化:创建空列表 input_id 存储编码后的索引序列

分词/字符处理分支:

  • 词级别处理​("words.txt"):
    使用结巴分词(jieba.cut)将文本切分为词语,遍历每个词语,查询词汇表 self.vocab
    • 若词语存在 → 添加其索引
    • 若不存在 → 使用 [UNK](未知词)的索引
  • 字符级别处理​(其他情况):
    直接遍历文本的每个字符,查询词汇表 self.vocab
    • 若字符存在 → 添加其索引
    • 若不存在 → 使用 [UNK] 的索引

条件执行:若 padding=True,调用 self.padding 方法对 input_id 进行填充

返回:整数列表 input_id,表示文本的编码序列

input_id:初始化列表,存储词 / 字符的索引

jieba.cut():将中文句子分割成词语,支持三种分词模式(精确模式、全模式、搜索引擎模式)

参数名类型说明
sentence字符串需要分词的中文句子
cut_all布尔值是否采用全模式(True为全模式,False为精确模式,默认False)
HMM布尔值是否使用隐马尔可夫模型(True为使用,默认True)

列表.append():在列表末尾添加一个元素,修改原列表

参数名类型说明
obj任意类型要添加的元素(支持字符串、数字、列表等)

字典.get():安全获取字典中指定键的值,键不存在时返回默认值(默认为None

参数名类型说明
key不可变类型要查询的键
default任意类型可选,键不存在时返回的默认值(若未指定则返回None)
    def encode_sentence(self, text, padding=True):
        input_id = []
        if self.config["vocab_path"] == "words.txt":
            for word in jieba.cut(text):
                input_id.append(self.vocab.get(word, self.vocab["[UNK]"]))
        else:
            for char in text:
                input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))
        if padding:
            input_id = self.padding(input_id)
        return input_id

完整代码 

# -*- coding: utf-8 -*-

import json
import re
import os
import torch
import random
import jieba
import numpy as np
from torch.utils.data import Dataset, DataLoader

"""
数据加载
"""


class DataGenerator:
    def __init__(self, data_path, config):
        self.config = config
        self.path = data_path
        self.vocab = load_vocab(config["vocab_path"])
        self.config["vocab_size"] = len(self.vocab)
        self.sentences = []
        self.schema = self.load_schema(config["schema_path"])
        self.load()

    def load(self):
        self.data = []
        with open(self.path, encoding="utf8") as f:
            segments = f.read().split("\n\n")
            for segment in segments:
                sentenece = []
                labels = []
                for line in segment.split("\n"):
                    if line.strip() == "":
                        continue
                    char, label = line.split()
                    sentenece.append(char)
                    labels.append(self.schema[label])
                self.sentences.append("".join(sentenece))
                input_ids = self.encode_sentence(sentenece)
                labels = self.padding(labels, -1)
                self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])
        return

    def encode_sentence(self, text, padding=True):
        input_id = []
        if self.config["vocab_path"] == "words.txt":
            for word in jieba.cut(text):
                input_id.append(self.vocab.get(word, self.vocab["[UNK]"]))
        else:
            for char in text:
                input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))
        if padding:
            input_id = self.padding(input_id)
        return input_id

    #补齐或截断输入的序列,使其可以在一个batch内运算
    def padding(self, input_id, pad_token=0):
        input_id = input_id[:self.config["max_length"]]
        input_id += [pad_token] * (self.config["max_length"] - len(input_id))
        return input_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    def load_schema(self, path):
        with open(path, encoding="utf8") as f:
            return json.load(f)

#加载字表或词表
def load_vocab(vocab_path):
    token_dict = {}
    with open(vocab_path, encoding="utf8") as f:
        for index, line in enumerate(f):
            token = line.strip()
            token_dict[token] = index + 1  #0留给padding位置,所以从1开始
    return token_dict

#用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
    dg = DataGenerator(data_path, config)
    dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
    return dl



if __name__ == "__main__":
    from config import Config
    dg = DataGenerator("../ner_data/train.txt", Config)


三、模型建立 model.py

1.代码运行流程

输入序列 → 嵌入层 → 双向LSTM → 分类层 → 分支判断:
          │
          ├── 有标签 → CRF? → 是:计算CRF损失(序列联合概率优化)
          │                 │
          │                 └→ 否:计算交叉熵(逐位置分类损失)
          │
          └── 无标签 → CRF? → 是:维特比解码最优路径(考虑标签转移约束)
                            │
                            └→ 否:输出原始logits(分类层未归一化得分)

2.模型初始化

代码运行流程

输入配置 → 模型组件初始化 → 网络结构构建 → 损失函数选择:
          │
          ├── 嵌入层 → 双向LSTM → 分类层 → CRF条件判断
          │
          └── 损失函数配置 → 交叉熵/CRF损失切换

hidden_size:定义LSTM隐藏层的维度(即每个时间步输出的特征数量)

vocab_size:词表大小,即嵌入层(Embedding)可处理的词汇总数

max_length:输入序列的最大长度,用于数据预处理(如截断或填充)

class_num:分类任务的类别数量,决定线性层(nn.Linear)的输出维度

num_layers:堆叠的LSTM层数,用于增加模型复杂度

nn.Embedding():将离散的索引映射为稠密向量(如词嵌入)

参数名类型默认值说明
num_embeddings整数词表大小(如 vocab_size + 1
embedding_dim整数嵌入向量维度(如 hidden_size
padding_idx整数None指定填充符索引(如 0),该位置的梯度不更新

nn.LSTM():长短期记忆网络(LSTM),用于序列建模。

参数名类型默认值说明
input_size整数输入特征维度(如嵌入层输出维度 hidden_size
hidden_size整数隐藏状态维度(决定模型容量)
num_layers整数1LSTM 堆叠层数(多层时上一层的输出作为下一层的输入)
batch_first布尔值False输入张量是否为 (batch_size, seq_len, input_size) 格式
bidirectional布尔值False是否启用双向 LSTM(输出维度变为 hidden_size * 2

nn.Linear():实现全连接层的线性变换(y = xW^T + b

参数名类型默认值说明
in_features整数输入特征维度(如词向量维度 hidden_size
out_features整数输出特征维度(如分类类别数 class_num
bias布尔值True是否启用偏置项

CRF():条件随机场层,用于序列标注任务中约束标签转移逻辑。

参数名类型默认值说明
num_tags整数标签类别数(如 class_num
batch_first布尔值False输入张量是否为 (batch_size, seq_len) 格式

torch.nn.CrossEntropyLoss():计算交叉熵损失,常用于分类任务。

参数名类型默认值说明
ignore_index整数-1忽略指定索引的标签(如填充符 -1
reduction字符串mean损失聚合方式(可选 nonesummean
    def __init__(self, config):
        super(TorchModel, self).__init__()
        hidden_size = config["hidden_size"]
        vocab_size = config["vocab_size"] + 1
        max_length = config["max_length"]
        class_num = config["class_num"]
        num_layers = config["num_layers"]
        self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
        self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)
        self.classify = nn.Linear(hidden_size * 2, class_num)
        self.crf_layer = CRF(class_num, batch_first=True)
        self.use_crf = config["use_crf"]
        self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)  #loss采用交叉熵损失

3.前向计算

代码运行流程

输入 x → 嵌入层 → 序列层 → 分类层 → 预测值 → 分支判断:
          │
          ├── 存在 target → CRF? → 是:计算 CRF 损失(带掩码)
          │                 │
          │                 └→ 否:计算交叉熵损失(展平处理)
          │
          └── 无 target → CRF? → 是:解码最优标签序列
                            │
                            └→ 否:直接返回预测 logits

x:输入序列的 Token ID 矩阵,代表一个批次的文本数据(如 [[101, 234, ...], [103, 456, ...]])。

target:真实标签序列(如实体标注),若不为 None 表示训练阶段,需计算损失;否则为预测阶段。

predict:分类层输出的每个位置标签的未归一化分数(logits),用于后续的 CRF 或交叉熵损失计算。

mask:标记序列中有效 Token 的位置(非填充部分),target.gt(-1) 表示标签值大于 -1 的位置有效。

gt():张量的逐元素比较函数,返回布尔型张量,标记输入张量中大于指定值的元素位置。常用于生成掩码(如忽略填充符)

参数名类型默认值说明
otherTensor/标量比较的阈值或张量。若为标量,则张量中每个元素与该值比较;若为张量,需与输入张量形状相同。
outTensorNone可选输出张量,用于存储结果。

shape():返回张量的维度信息,描述各轴的大小。

view():调整张量的形状,支持自动推断维度(通过-1占位符)。常用于数据展平或维度转换。

参数名类型默认值说明
*shape可变参数目标形状的维度序列,如view(2, 3)view(-1, 28)-1表示自动计算。
    #当输入真实标签,返回loss值;无真实标签,返回预测值
    def forward(self, x, target=None):
        x = self.embedding(x)  #input shape:(batch_size, sen_len)
        x, _ = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)
        predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)

        if target is not None:
            if self.use_crf:
                mask = target.gt(-1) 
                return - self.crf_layer(predict, target, mask, reduction="mean")
            else:
                #(number, class_num), (number)
                return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))
        else:
            if self.use_crf:
                return self.crf_layer.decode(predict)
            else:
                return predict

4.选择优化器

 代码运行流程

输入 config → 提取参数 → 分支判断:
          │
          ├── optimizer == "adam" → 返回 Adam 优化器实例
          │
          └── optimizer == "sgd" → 返回 SGD 优化器实例

Adam():自适应矩估计优化器(Adaptive Moment Estimation),结合动量和 RMSProp 的优点。

参数名类型默认值说明
lrfloat1e-3学习率。
betastuple(0.9, 0.999)动量系数(β₁, β₂)。
epsfloat1e-8防止除零误差。
weight_decayfloat0权重衰减率。
amsgradboolFalse是否启用 AMSGrad 优化。
foreachboolFalse是否为每个参数单独计算梯度。

SGD():随机梯度下降优化器(Stochastic Gradient Descent)

参数名类型默认值说明
lrfloat1e-3学习率。
momentumfloat0动量系数(如 momentum=0.9)。
weight_decayfloat0权重衰减率。
dampeningfloat0动力衰减系数(用于 SGD with Momentum)。
nesterovboolFalse是否启用 Nesterov 动量。
foreachboolFalse是否为每个参数单独计算梯度。

parameters():返回模型所有可训练参数的迭代器,常用于参数初始化或梯度清零。

参数名类型默认值说明
filtercallableNone过滤条件函数(如 lambda p: p.requires_grad)。默认返回所有参数。
def choose_optimizer(config, model):
    optimizer = config["optimizer"]
    learning_rate = config["learning_rate"]
    if optimizer == "adam":
        return Adam(model.parameters(), lr=learning_rate)
    elif optimizer == "sgd":
        return SGD(model.parameters(), lr=learning_rate)

5.模型建立 

# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torchcrf import CRF
"""
建立网络模型结构
"""

class TorchModel(nn.Module):
    def __init__(self, config):
        super(TorchModel, self).__init__()
        hidden_size = config["hidden_size"]
        vocab_size = config["vocab_size"] + 1
        max_length = config["max_length"]
        class_num = config["class_num"]
        num_layers = config["num_layers"]
        self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
        self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)
        self.classify = nn.Linear(hidden_size * 2, class_num)
        self.crf_layer = CRF(class_num, batch_first=True)
        self.use_crf = config["use_crf"]
        self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)  #loss采用交叉熵损失

    #当输入真实标签,返回loss值;无真实标签,返回预测值
    def forward(self, x, target=None):
        x = self.embedding(x)  #input shape:(batch_size, sen_len)
        x, _ = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)
        predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)

        if target is not None:
            if self.use_crf:
                mask = target.gt(-1) 
                return - self.crf_layer(predict, target, mask, reduction="mean")
            else:
                #(number, class_num), (number)
                return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))
        else:
            if self.use_crf:
                return self.crf_layer.decode(predict)
            else:
                return predict


def choose_optimizer(config, model):
    optimizer = config["optimizer"]
    learning_rate = config["learning_rate"]
    if optimizer == "adam":
        return Adam(model.parameters(), lr=learning_rate)
    elif optimizer == "sgd":
        return SGD(model.parameters(), lr=learning_rate)


if __name__ == "__main__":
    from config import Config
    model = TorchModel(Config)

四、模型效果评估 evaluate.py

1.代码运行流程

输入验证数据 → 模型预测 → 实体解码 → 指标统计 → 性能评估:
          │
          ├── 数据加载 → 批量预测 → 标签解码 → 实体匹配 → 计算准确率/召回率
          │
          └── 结果汇总 → 宏/微平均F1输出 → 模型效果可视化

2.初始化

Ⅰ、加载配置文件、模型及日志模块 ——>

Ⅱ、读取验证集数据(固定顺序,避免随机性干扰评估)——>

Ⅲ、初始化统计字典 stats_dict,按实体类别记录正确识别数、样本实体数等

config:存储运行时配置,例如数据路径、超参数(如批次大小 batch_size)、是否使用CRF层等。通过 config["valid_data_path"] 动态获取验证集路径。

model:待评估的模型实例,用于调用预测方法(如 model(input_id)),需提前完成训练和加载。

logger:记录运行日志,例如输出评估指标(准确率、F1值)到文件或控制台,便于调试和监控。

valid_data:验证数据集,用于模型训练时的性能评估和超参数调优。

load_data():数据加载类中,用torch自带的DataLoader类封装数据的函数

    def __init__(self, config, model, logger):
        self.config = config
        self.model = model
        self.logger = logger
        self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)

3.统计模型效果

代码运行流程

输入 labels/predictions/sentences → 数据校验 → 预测结果处理 → 样本遍历 → 实体解码 → 统计指标:
          │
          ├── 非CRF模式 → argmax提取预测标签 → 数据迁移至CPU → 标签序列转换
          │
          ├── 实体匹配 → 按类别统计:
          │             │
          │             ├── 正确识别数 → 交集实体计数
          │             │
          │             ├── 样本实体数 → 真实实体总数
          │             │
          │             └── 识别出实体数 → 预测实体总数
          │
          └── 结果累加 → 更新统计字典

labels:真实标签序列(如实体标注的整数 ID 列表),用于与预测结果对比计算评估指标

pred_results:模型预测结果,若使用 CRF,为标签序列,否则为每个位置的 logits(未归一化概率)。

sentences:原始文本句子列表(如 ["中国北京", "今天天气"]),用于解码标签序列到具体实体。

use_crf:控制是否使用 CRF 层

pred_label:单个样本的预测标签序列,若未使用 CRF,需从 logits 中提取(argmax)并转换为列表。

true_label:单个样本的真实标签序列(如 [0, 4, 4, 8]),已从 GPU 张量转换为 CPU 列表。

true_entities:解码后的真实实体字典,如 {"LOCATION": ["北京"], "PERSON": []}

pred_entities:解码后的预测实体字典,用于与真实实体对比统计正确识别数。

key:字符串,实体类别名称(如 "PERSON"),遍历四类实体以分别统计指标。

assert:Python 的 ​调试断言工具,主要用于在开发阶段验证程序内部的逻辑条件是否成立

        assert expression [, message]  

参数类型是否必填作用
expression布尔表达式需要验证的条件。若结果为 False,则触发断言失败;若为 True,程序继续执行。
message字符串(可选)断言失败时输出的自定义错误信息,用于辅助调试。若省略,则输出默认错误提示。

len():返回对象的元素数量(字符串、列表、元组、字典等)

参数名类型说明
object任意可迭代对象如字符串、列表、字典等

torch.argmax():返回张量中最大值所在的索引

参数名类型说明
inputTensor输入张量
dimint沿指定维度查找最大值
keepdimbool是否保持输出维度一致

cpu():将张量从GPU移动到CPU内存

zip():将多个可迭代对象打包成元组列表

参数名类型说明
iterables多个可迭代对象如列表、元组、字符串

.detach():从计算图中分离张量,阻止梯度传播

.tolist():将张量或数组转换为Python列表

    def write_stats(self, labels, pred_results, sentences):
        assert len(labels) == len(pred_results) == len(sentences)
        if not self.config["use_crf"]:
            pred_results = torch.argmax(pred_results, dim=-1)
        for true_label, pred_label, sentence in zip(labels, pred_results, sentences):
            if not self.config["use_crf"]:
                pred_label = pred_label.cpu().detach().tolist()
            true_label = true_label.cpu().detach().tolist()
            true_entities = self.decode(sentence, true_label)
            pred_entities = self.decode(sentence, pred_label)
            # print("=+++++++++")
            # print(true_entities)
            # print(pred_entities)
            # print('=+++++++++')
            # 正确率 = 识别出的正确实体数 / 识别出的实体数
            # 召回率 = 识别出的正确实体数 / 样本的实体数
            for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
                self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])
                self.stats_dict[key]["样本实体数"] += len(true_entities[key])
                self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])
        return

4.可视化统计模型效果

代码运行流程

统计字典 → 按类别计算指标 → 宏平均计算 → 全局统计 → 微平均计算 → 结果输出:
          │
          ├── 遍历实体类别 → 计算precision/recall/F1 → 记录F1分数
          │
          └── 汇总全局统计量 → 计算micro-precision/recall/F1 → 输出评估报告

精确率 (Precision):正确预测实体数 / 总预测实体数

召回率 (Recall):正确预测实体数 / 总真实实体数​

F1值:精确率与召回率的调和平均 

F1:F1分数:准确率与召回率的调和平均数,综合衡量模型的精确性与覆盖能力。

F1_scores:存储四个实体类别的 F1 分数,用于计算宏观平均。

precision:准确率:模型预测为某类实体的结果中,正确的比例。反映模型预测的精确度。

recall:召回率:真实存在的某类实体中,被模型正确识别的比例。反映模型对实体的覆盖能力。

key:当前处理的实体类别(如 "PERSON""LOCATION")。

correct_pred:总正确识别数:所有类别中被正确识别的实体总数。

total_pred:总识别实体数:模型预测出的所有实体数量(含错误识别)。

true_enti:总样本实体数:验证数据中真实存在的所有实体数量。

micro_precision:微观准确率:全局视角下的准确率,所有实体类别的正确识别数与总识别数的比例。

micro_recall:微观召回率:全局视角下的召回率,所有实体类别的正确识别数与总样本实体数的比例。

micro_f1:微观F1分数:微观准确率与微观召回率的调和平均数。

列表.append():在列表末尾添加元素

参数名类型说明
element任意要添加的元素

logger.info():记录日志信息(需配置日志模块)

参数名类型说明
formatstr格式化字符串
*args可变参数格式化参数

sum():计算可迭代对象的元素总和

参数名类型说明
iterable可迭代对象如列表、元组
start数值(可选)初始累加值

列表推导式:通过简洁语法生成新列表,语法:[表达式 for item in iterable if 条件]

    def show_stats(self):
        F1_scores = []
        for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
            # 正确率 = 识别出的正确实体数 / 识别出的实体数
            # 召回率 = 识别出的正确实体数 / 样本的实体数
            precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])
            recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])
            F1 = (2 * precision * recall) / (precision + recall + 1e-5)
            F1_scores.append(F1)
            self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))
        self.logger.info("Macro-F1: %f" % np.mean(F1_scores))
        correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        micro_precision = correct_pred / (total_pred + 1e-5)
        micro_recall = correct_pred / (true_enti + 1e-5)
        micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)
        self.logger.info("Micro-F1 %f" % micro_f1)
        self.logger.info("--------------------")
        return

 5.评估模型效果

 代码运行流程

输入验证数据 → 模型模式切换 → 批量处理 → 预测与统计 → 性能汇总:
          │
          ├── 初始化统计字典 → 设置评估模式 → 数据迁移至GPU → 无梯度预测 → 实体匹配统计
          │
          └── 指标计算 → 宏/微平均输出 → 日志记录结果

epoch:当前训练轮次,用于日志。

logger:记录日志的工具。

stats_dict:统计字典,记录各实体类别的指标。

valid_data:验证数据集,通常由 load_data 加载(如 config["valid_data_path"] 指定路径)

index:循环中的批次索引

batch_data:循环中的数据。

sentences:当前批次的原始句子

pred_results:模型预测结果

write_stats():写入统计信息

show_stats():显示统计结果

logger.info():记录日志信息(需配置日志模块)

参数名类型说明
formatstr格式化字符串
*args可变参数格式化参数

defaultdict():创建带有默认值工厂的字典

参数名类型说明
default_factory可调用对象如int、list、自定义函数

model.eval():将模型设置为评估模式(关闭Dropout等训练层)

enumerate():返回索引和元素组成的枚举对象

参数名类型说明
iterable可迭代对象如列表、字符串
startint(可选)起始索引,默认为0

torch.cuda.is_available():检查当前环境是否支持CUDA(GPU加速)

cuda():将张量或模型移动到GPU

参数名类型说明
deviceint/str指定GPU设备号,如"cuda:0"

torch.no_grad():禁用梯度计算,节省内存并加速推理

    def eval(self, epoch):
        self.logger.info("开始测试第%d轮模型效果:" % epoch)
        self.stats_dict = {"LOCATION": defaultdict(int),
                           "TIME": defaultdict(int),
                           "PERSON": defaultdict(int),
                           "ORGANIZATION": defaultdict(int)}
        self.model.eval()
        for index, batch_data in enumerate(self.valid_data):
            sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]
            if torch.cuda.is_available():
                batch_data = [d.cuda() for d in batch_data]
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            with torch.no_grad():
                pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测
            self.write_stats(labels, pred_results, sentences)
        self.show_stats()
        return

6.根据标签将句子解码为实体

标签序列预处理:将数值标签拼接为字符串(如 [0,4,4] → "044"

正则匹配实体

   04+:B-LOCATION(0)后接多个I-LOCATION(4)

   15+:B-ORGANIZATION(1)后接I-ORGANIZATION(5)

           其他实体类别同理

索引对齐:根据匹配位置截取原始句子中的实体文本

Ⅰ、输入预处理

在原句首添加 $ 符号,通常用于对齐标签与字符位置(例如避免索引越界)

        sentence = "$" + sentence

Ⅱ、标签序列转换

将整数标签序列转换为字符串,并截取长度与 sentence 对齐

str.join():将可迭代对象中的字符串元素按指定分隔符连接成一个新字符串

参数名类型说明
iterable可迭代对象元素必须为字符串类型

str():将对象转换为字符串表示形式,支持自定义类的 __str__ 方法

参数名类型说明
object任意要转换的对象

len():返回对象的长度或元素个数(适用于字符串、列表、字典等)

参数名类型说明
object可迭代对象如字符串、列表等

列表推导式:通过简洁语法生成新列表,支持条件过滤和多层循环

        [expression for item in iterable if condition]

部分类型说明
expression表达式对 item 处理后的结果
item变量迭代变量
iterable可迭代对象如列表、range() 生成的序列
condition条件表达式 (可选)过滤不符合条件的元素
        labels = "".join([str(x) for x in labels[:len(sentence)+1]])

Ⅲ、 初始化结果容器

        创建默认值为列表的字典,存储四类实体(LOCATION、ORGANIZATION、PERSON、TIME)的识别结果

defaultdict():创建默认值字典,当键不存在时自动生成默认值(基于工厂函数)

参数名类型说明
default_factory可调用对象如 intlist 或自定义函数
        results = defaultdict(list)

Ⅳ、 正则表达式匹配

    (04+): 匹配以 0(B-LOCATION)开头,后接多个 4(I-LOCATION)的连续标签

    (15+)(26+)(37+)分别对应 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的标签模式。

re.finditer():在字符串中全局搜索正则表达式匹配项,返回一个迭代器,每个元素为 Match 对象

参数名类型说明
patternstr 或正则表达式对象要匹配的正则表达式模式
stringstr要搜索的字符串
flagsint (可选)正则匹配标志(如 re.IGNORECASE

.span():返回正则匹配的起始和结束索引(左闭右开区间)

列表.append():向列表末尾添加单个元素,直接修改原列表

参数名类型说明
element任意要添加的元素
        for location in re.finditer("(04+)", labels):
            s, e = location.span()
            results["LOCATION"].append(sentence[s:e])

Ⅴ、完整代码 

    '''
    {
      "B-LOCATION": 0,
      "B-ORGANIZATION": 1,
      "B-PERSON": 2,
      "B-TIME": 3,
      "I-LOCATION": 4,
      "I-ORGANIZATION": 5,
      "I-PERSON": 6,
      "I-TIME": 7,
      "O": 8
    }
    '''
    def decode(self, sentence, labels):
        labels = "".join([str(x) for x in labels[:len(sentence)]])
        results = defaultdict(list)
        for location in re.finditer("(04+)", labels):
            s, e = location.span()
            results["LOCATION"].append(sentence[s:e])
        for location in re.finditer("(15+)", labels):
            s, e = location.span()
            results["ORGANIZATION"].append(sentence[s:e])
        for location in re.finditer("(26+)", labels):
            s, e = location.span()
            results["PERSON"].append(sentence[s:e])
        for location in re.finditer("(37+)", labels):
            s, e = location.span()
            results["TIME"].append(sentence[s:e])
        return results

7.完整代码 

# -*- coding: utf-8 -*-
import torch
import re
import numpy as np
from collections import defaultdict
from loader import load_data

"""
模型效果测试
"""

class Evaluator:
    def __init__(self, config, model, logger):
        self.config = config
        self.model = model
        self.logger = logger
        self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)


    def eval(self, epoch):
        self.logger.info("开始测试第%d轮模型效果:" % epoch)
        self.stats_dict = {"LOCATION": defaultdict(int),
                           "TIME": defaultdict(int),
                           "PERSON": defaultdict(int),
                           "ORGANIZATION": defaultdict(int)}
        self.model.eval()
        for index, batch_data in enumerate(self.valid_data):
            sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]
            if torch.cuda.is_available():
                batch_data = [d.cuda() for d in batch_data]
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            with torch.no_grad():
                pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测
            self.write_stats(labels, pred_results, sentences)
        self.show_stats()
        return

    def write_stats(self, labels, pred_results, sentences):
        assert len(labels) == len(pred_results) == len(sentences)
        if not self.config["use_crf"]:
            pred_results = torch.argmax(pred_results, dim=-1)
        for true_label, pred_label, sentence in zip(labels, pred_results, sentences):
            if not self.config["use_crf"]:
                pred_label = pred_label.cpu().detach().tolist()
            true_label = true_label.cpu().detach().tolist()
            true_entities = self.decode(sentence, true_label)
            pred_entities = self.decode(sentence, pred_label)
            # print("=+++++++++")
            # print(true_entities)
            # print(pred_entities)
            # print('=+++++++++')
            # 正确率 = 识别出的正确实体数 / 识别出的实体数
            # 召回率 = 识别出的正确实体数 / 样本的实体数
            for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
                self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])
                self.stats_dict[key]["样本实体数"] += len(true_entities[key])
                self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])
        return

    def show_stats(self):
        F1_scores = []
        for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
            # 正确率 = 识别出的正确实体数 / 识别出的实体数
            # 召回率 = 识别出的正确实体数 / 样本的实体数
            precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])
            recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])
            F1 = (2 * precision * recall) / (precision + recall + 1e-5)
            F1_scores.append(F1)
            self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))
        self.logger.info("Macro-F1: %f" % np.mean(F1_scores))
        correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        micro_precision = correct_pred / (total_pred + 1e-5)
        micro_recall = correct_pred / (true_enti + 1e-5)
        micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)
        self.logger.info("Micro-F1 %f" % micro_f1)
        self.logger.info("--------------------")
        return

    '''
    {
      "B-LOCATION": 0,
      "B-ORGANIZATION": 1,
      "B-PERSON": 2,
      "B-TIME": 3,
      "I-LOCATION": 4,
      "I-ORGANIZATION": 5,
      "I-PERSON": 6,
      "I-TIME": 7,
      "O": 8
    }
    '''
    def decode(self, sentence, labels):
        labels = "".join([str(x) for x in labels[:len(sentence)]])
        results = defaultdict(list)
        for location in re.finditer("(04+)", labels):
            s, e = location.span()
            results["LOCATION"].append(sentence[s:e])
        for location in re.finditer("(15+)", labels):
            s, e = location.span()
            results["ORGANIZATION"].append(sentence[s:e])
        for location in re.finditer("(26+)", labels):
            s, e = location.span()
            results["PERSON"].append(sentence[s:e])
        for location in re.finditer("(37+)", labels):
            s, e = location.span()
            results["TIME"].append(sentence[s:e])
        return results



五、主函数文件 main.py

1.代码运行流程

配置参数 → 创建模型目录 → 加载训练数据 → 初始化模型 → 设备检测:
          │
          ├── GPU可用 → 迁移模型至GPU
          │
          └── GPU不可用 → 保持CPU模式

→ 选择优化器 → 初始化评估器 → 进入训练循环:
          │
          ├── 当前epoch → 训练模式 → 遍历数据批次:
          │                 │
          │                 ├── 清空梯度 → 数据迁移至GPU → 前向计算 → 分支判断:
          │                 │             │
          │                 │             ├── 启用CRF → 计算CRF损失 → 反向传播 → 参数更新
          │                 │             │
          │                 │             └── 禁用CRF → 计算交叉熵损失 → 反向传播 → 参数更新
          │                 │
          │                 └── 记录批次损失 → 周期中点打印日志
          │
          └── 计算epoch平均损失 → 验证集评估 → 保存当前模型权重

2.导入文件

# -*- coding: utf-8 -*-

import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data

3.日志配置

logging.basicConfig():配置日志系统的基础参数(一次性设置,应在首次日志调用前调用)

参数名类型是否必需默认值说明
filename字符串None日志输出文件名(若指定,日志写入文件而非控制台)
filemode字符串'a'文件打开模式(如'w'覆盖,'a'追加)
format字符串基础格式日志格式模板(如'%(asctime)s - %(levelname)s - %(message)s'
datefmt字符串时间格式(如'%Y-%m-%d %H:%M:%S'
level整数WARNING日志级别(如logging.INFOlogging.DEBUG
stream对象None指定日志输出流(如sys.stderr,与filename互斥)

logging.getLogger():获取或创建指定名称的日志记录器(Logger)。若nameNone,返回根日志记录器

参数名类型是否必需默认值说明
name字符串None日志记录器名称(分层结构,如'module.sub'
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

4.主函数 main  

Ⅰ、创建模型保存目录

os.path.isdir():检查指定路径是否为目录(文件夹)

参数名类型是否必需默认值说明
path字符串要检查的路径(绝对或相对)

os.mkdir():创建单个目录(若父目录不存在会抛出异常)

参数名类型是否必需默认值说明
path字符串要创建的目录路径
mode整数0o777目录权限(八进制格式,某些系统可能忽略此参数)
    #创建保存模型的目录
    if not os.path.isdir(config["model_path"]):
        os.mkdir(config["model_path"])

Ⅱ、加载训练数据

    #加载训练数据
    train_data = load_data(config["train_data_path"], config)

Ⅲ、加载模型

    #加载模型
    model = TorchModel(config)

Ⅳ、检查GPU并迁移模型

torch.cuda.is_available():检查系统是否满足 CUDA 环境要求

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)

cuda():将张量或模型移动到GPU显存,加速计算

参数类型必须说明示例
deviceint/str指定GPU设备(如0"cuda:0"tensor.cuda(device=0)
non_blockingbool是否异步传输数据(默认False)tensor.cuda(non_blocking=True)
    # 标识是否使用gpu
    cuda_flag = torch.cuda.is_available()
    if cuda_flag:
        logger.info("gpu可以使用,迁移模型至gpu")
        model = model.cuda()

Ⅴ、加载优化器

    #加载优化器
    optimizer = choose_optimizer(config, model)

Ⅵ、加载评估器

    #加载效果测试类
    evaluator = Evaluator(config, model, logger)

Ⅶ、模型训练主流程 ⭐

① Epoch循环控制

epoch:指将整个训练数据集完整地通过神经网络进行一次前向传播和反向传播的过程。一个 epoch 确保模型已经使用所有训练数据更新了一次权重

range():Python 内置函数,用于生成一个不可变的整数序列,​核心功能是为循环控制提供高效的数值迭代支持

参数名类型默认值说明
start整数0序列起始值(包含)。若省略,则默认从 0 开始。例如 range(3) 等价于 range(0,3)
stop整数必填序列结束值(不包含)。例如 range(2, 5) 生成 2,3,4
step整数1步长(正/负):
- ​正步长需满足 start < stop,否则无输出(如 range(5, 2) 无效)。
- ​负步长需满足 start > stop,例如 range(5, 0, -1) 生成 5,4,3,2,1
​**不能为 0**​(否则触发 ValueError
for epoch in range(config["epoch"]):
    epoch += 1
② 模型设置训练模式 

train_loss:计算当前批次的损失值,通常结合损失函数(如交叉熵、均方误差)使用

model.train():设置模型为训练模式,启用Dropout、BatchNorm等层的训练行为

参数类型默认值说明示例
modeboolTrue是否启用训练模式(True)或评估模式(False)model.train(True)

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)
        model.train()
        logger.info("epoch %d begin" % epoch)
        train_loss = []

③ Batch数据遍历

enumerate():遍历可迭代对象时返回索引和元素,支持自定义起始索引

参数类型必须说明示例
iterableIterable可迭代对象(如列表、生成器)enumerate(["a", "b"])
startint索引起始值(默认0)enumerate(data, start=1)
        for index, batch_data in enumerate(train_data):

④ 梯度清零与设备切换

cuda_flag:

batch_data:

optimizer.zero_grad():清空模型参数的梯度,防止梯度累积

参数类型必须说明示例
set_to_nonebool是否将梯度置为None(高效但危险)optimizer.zero_grad(True)

cuda():将张量或模型移动到GPU显存,加速计算

参数类型必须说明示例
deviceint/str指定GPU设备(如0"cuda:0"tensor.cuda(device=0)
non_blockingbool是否异步传输数据(默认False)tensor.cuda(non_blocking=True)
            optimizer.zero_grad()
            if cuda_flag:
                batch_data = [d.cuda() for d in batch_data]

⑤ 前向传播与损失计算
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            loss = model(input_id, labels)

⑥ 反向传播与参数更新

loss.backward():反向传播计算梯度,基于损失值更新模型参数的.grad属性

参数类型必须说明示例
retain_graphbool是否保留计算图(用于多次反向传播)loss.backward(retain_graph=True)

optimizer.step():根据梯度更新模型参数,执行优化算法(如SGD、Adam)

参数类型必须说明示例
closureCallable重新计算损失的闭包函数(如LBFGS)optimizer.step(closure)
            loss.backward()
            optimizer.step()

⑦ 损失记录与日志输出

列表.append():在列表末尾添加元素,直接修改原列表

参数类型必须说明示例
objectAny要添加到列表末尾的元素train_loss.append(loss.item())

int():将字符串或浮点数转换为整数,支持进制转换

参数类型必须说明示例
xstr/float待转换的值(如字符串或浮点数)int("10", base=2)(输出2进制10=2)
baseint进制(默认10)

len():返回对象(如列表、字符串)的长度或元素个数

参数类型必须说明示例
objSequence/Collection可计算长度的对象(如列表、字符串)len([1, 2, 3])(返回3)

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)
            train_loss.append(loss.item())
            if index % int(len(train_data) / 2) == 0:
                logger.info("batch loss %f" % loss)

⑧ Epoch评估与日志

item():从张量中提取标量值(仅当张量包含单个元素时可用)

列表.append():Python 列表(list)的内置方法,用于向列表的 ​末尾 添加一个元素。

参数名类型默认值说明
element任意类型要添加到列表末尾的元素。可以是单个值(如 42)、对象(如 [1, 2, 3])等。

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)
            train_loss.append(loss.item())
            if index % int(len(train_data) / 2) == 0:
                logger.info("batch loss %f" % loss)

⑨ 完整训练代码
    #训练
    for epoch in range(config["epoch"]):
        epoch += 1
        model.train()
        logger.info("epoch %d begin" % epoch)
        train_loss = []
        for index, batch_data in enumerate(train_data):
            optimizer.zero_grad()
            if cuda_flag:
                batch_data = [d.cuda() for d in batch_data]
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            loss = model(input_id, labels)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            if index % int(len(train_data) / 2) == 0:
                logger.info("batch loss %f" % loss)
        logger.info("epoch average loss: %f" % np.mean(train_loss))
        evaluator.eval(epoch)

Ⅷ、模型保存

os.path.join():用于跨平台路径拼接的核心函数,其核心功能是智能处理不同操作系统的路径分隔符,确保代码的可移植性和健壮性

参数名类型说明
path字符串必填参数,起始路径组件。
*paths可变参数可接受多个路径组件,按顺序拼接。

torch.save():PyTorch 中用于序列化保存模型、张量或字典等对象的核心函数,支持将数据持久化存储为 .pth 或 .pt 文件,便于后续加载和复用

参数名类型默认值说明
obj任意 PyTorch 对象必填待保存的对象,如模型、张量或字典。
fstr 或文件对象必填保存路径(如 'model.pth')或已打开的文件对象(需二进制写入模式 'wb'
pickle_protocolint2指定 pickle 协议版本(通常无需修改,高版本可能提升效率但需兼容性验证)
_use_new_zipfile_serializationboolTrue启用新版序列化格式(压缩率更高,推荐保持默认)

    model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
    # torch.save(model.state_dict(), model_path)
    return model, train_data

5.调用模型预测

# -*- coding: utf-8 -*-

import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data

logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

"""
模型训练主程序
"""

def main(config):
    #创建保存模型的目录
    if not os.path.isdir(config["model_path"]):
        os.mkdir(config["model_path"])
    #加载训练数据
    train_data = load_data(config["train_data_path"], config)
    #加载模型
    model = TorchModel(config)
    # 标识是否使用gpu
    cuda_flag = torch.cuda.is_available()
    if cuda_flag:
        logger.info("gpu可以使用,迁移模型至gpu")
        model = model.cuda()
    #加载优化器
    optimizer = choose_optimizer(config, model)
    #加载效果测试类
    evaluator = Evaluator(config, model, logger)
    #训练
    for epoch in range(config["epoch"]):
        epoch += 1
        model.train()
        logger.info("epoch %d begin" % epoch)
        train_loss = []
        for index, batch_data in enumerate(train_data):
            optimizer.zero_grad()
            if cuda_flag:
                batch_data = [d.cuda() for d in batch_data]
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            loss = model(input_id, labels)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            if index % int(len(train_data) / 2) == 0:
                logger.info("batch loss %f" % loss)
        logger.info("epoch average loss: %f" % np.mean(train_loss))
        evaluator.eval(epoch)
    model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
    # torch.save(model.state_dict(), model_path)
    return model, train_data

if __name__ == "__main__":
    model, train_data = main(Config)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2318120.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

再学:函数可见性、特殊函数、修饰符

目录 1.可见性 2.合约特殊函数 constructor && getter 3. receive && fallback 4.view && pure 5.payable 6.自定义函数修饰符 modifier 1.可见性 public&#xff1a;内外部 private&#xff1a;内部 external&#xff1a;外部访问 internal&…

基于Spring Boot的项目申报系统的设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导&#xff0c;欢迎高校老师/同行前辈交流合作✌。 技术范围&#xff1a;SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容&#xff1a;…

Web元件库 ElementUI元件库+后台模板页面(支持Axure9、10、11)

Axure是一款非常强大的原型设计工具&#xff0c;它允许设计师和开发者快速创建高保真原型&#xff0c;以展示应用或网站的设计和功能。通过引入各种元件库&#xff0c;如ElementUI元件库&#xff0c;可以极大地丰富Axure的原型设计能力&#xff0c;使其更加贴近实际开发中的UI组…

孜然SEO静态页面生成系统V1.0

孜然SEO静态页面生成系统&#xff0c;1秒生成上万个不同的静态单页系统&#xff0c;支持URL裂变采集&#xff0c;采集的内容不会重复&#xff0c;因为程序系统自带AI重写算法&#xff0c;AI扩写算法&#xff0c;可视化的蜘蛛池系统让您更清楚的获取到信息&#xff01; 可插入二…

Blender-MCP服务源码3-插件开发

Blender-MCP服务源码3-插件开发 Blender-MCP服务源码解读-如何进行Blender插件开发 1-核心知识点 1&#xff09;使用Blender开发框架学习如何进行Blender调试2&#xff09;学习目标1-移除所有的Blender业务-了解如何MCP到底做了什么&#xff1f;3&#xff09;学习目标2-模拟MC…

C语言和C++到底有什么关系?

C 读作“C 加加”&#xff0c;是“C Plus Plus”的简称。 顾名思义&#xff0c;C 就是在 C 语言的基础上增加了新特性&#xff0c;玩出了新花样&#xff0c;所以才说“Plus”&#xff0c;就像 Win11 和 Win10、iPhone 15 和 iPhone 15 Pro 的关系。 C 语言是 1972 年由美国贝…

【华三】路由器交换机忘记登入密码或super密码的重启操作

【华三】路由器交换机忘记登入密码或super密码的重启操作 背景步骤跳过认证设备&#xff1a;路由器重启设备翻译说明具体操作 跳过当前系统配置重启设备具体操作 背景 当console口的密码忘记&#xff0c;或者说本地用户的密码忘记&#xff0c;其实这时候是登入不了路由器的&am…

DeepSeek-prompt指令-当DeepSeek答非所问,应该如何准确的表达我们的诉求?

当DeepSeek答非所问&#xff0c;应该如何准确的表达我们的诉求&#xff1f;不同使用场景如何向DeepSeek发问&#xff1f;是否有指令公式&#xff1f; 目录 1、 扮演专家型指令2、 知识蒸馏型指令3、 颗粒度调节型指令4、 时间轴推演型指令5、 极端测试型6、 逆向思维型指令7、…

HOVER:人形机器人的多功能神经网络全身控制器

编辑&#xff1a;陈萍萍的公主一点人工一点智能 HOVER&#xff1a;人形机器人的多功能神经网络全身控制器HOVER通过策略蒸馏和统一命令空间设计&#xff0c;为人形机器人提供了通用、高效的全身控制框架。https://mp.weixin.qq.com/s/R1cw47I4BOi2UfF_m-KzWg 01 介绍 1.1 摘…

HTML中滚动加载的实现

设置div的overflow属性&#xff0c;可以使得该div具有滚动效果&#xff0c;下面以div中包含的是table来举例。 当table的元素较多&#xff0c;以至于超出div的显示范围的话&#xff0c;观察下该div元素的以下3个属性&#xff1a; clientHeight是div的显示高度&#xff0c;scrol…

Python----计算机视觉处理(Opencv:形态学变换)

一、形态学变化 形态学变换&#xff08;Morphological Transformations&#xff09;是一种基于形状的图像处理技术&#xff0c;主要处理的对象为二值化图像。 形态学变换有两个输入和一个输出&#xff1a;输入为原始图像和核&#xff08;即结构化元素&#xff09;&#xff0c;输…

opencv中stitch图像融合

openv版本: opencv249 vs &#xff1a;2010 qt : 4.85 #include "quanjing.h"#include <iostream> #include <opencv2/core/core.hpp> #include <opencv2/highgui/highgui.hpp> #include <opencv2/imgproc/imgproc.hpp> #include <open…

matlab R2024b下载教程及安装教程(附安装包)

文章目录 前言一、matlab R2024b安装包下载二、matlab R2024b安装教程 前言 为帮助大家顺利安装该版本软件&#xff0c;特准备matlab R2024b下载教程及安装教程&#xff0c;它将以简洁明了的步骤&#xff0c;指导你轻松完成安装&#xff0c;开启 MATLAB R2024 的强大功能之旅。…

游戏引擎学习第167天

回顾和今天的计划 我们不使用引擎&#xff0c;也不依赖库&#xff0c;只有我们自己和我们的小手指在敲击代码。 今天我们会继续进行一些工作。首先&#xff0c;我们会清理昨天留下的一些问题&#xff0c;这些问题我们当时没有深入探讨。除了这些&#xff0c;我觉得我们在资产…

JS逆向案例-HIKVISION-视频监控的前端密码加密分析

免责声明 本文仅为技术研究与渗透测试思路分享,旨在帮助安全从业人员更好地理解相关技术原理和防御措施。任何个人或组织不得利用本文内容从事非法活动或攻击他人系统。 如果任何人因违反法律法规或不当使用本文内容而导致任何法律后果,本文作者概不负责。 请务必遵守法律…

STM32---FreeRTOS内存管理实验

一、简介 1、FreeRTOS内存管理简介 2、FreeRTOS提供的内存管理算法 1、heap_1内存管理算法 2、heap_2内存管理算法 4、heap_4内存管理算法 5、heap_5内存管理算法 二、FreeRTOS内存管理相关API函数介绍 三、 FreeRTOS内存管理实验 1、代码 main.c #include "st…

STC89C52单片机学习——第25节: [11-1]蜂鸣器

写这个文章是用来学习的,记录一下我的学习过程。希望我能一直坚持下去,我只是一个小白,只是想好好学习,我知道这会很难&#xff0c;但我还是想去做&#xff01; 本文写于&#xff1a;2025.03.18 51单片机学习——第25节: [11-1]蜂鸣器 前言开发板说明引用解答和科普一、蜂鸣器…

音视频入门基础:RTP专题(19)——FFmpeg源码中,获取RTP的音频信息的实现(下)

本文接着《音视频入门基础&#xff1a;RTP专题&#xff08;18&#xff09;——FFmpeg源码中&#xff0c;获取RTP的音频信息的实现&#xff08;上&#xff09;》&#xff0c;继续讲解FFmpeg获取SDP描述的RTP流的音频信息到底是从哪个地方获取的。本文的一级标题从“四”开始。 四…

卷积神经网络 - 卷积的变种、数学性质

本文我们来学习卷积的变种和相关的数学性质&#xff0c;为后面学习卷积神经网络做准备&#xff0c;有些概念可能不好理解&#xff0c;可以先了解其概念&#xff0c;然后慢慢理解、逐步深入。 在卷积的标准定义基础上&#xff0c;还可以引入卷积核的滑动步长和零填充来增加卷积…

BLIP论文阅读

目录 现存的视觉语言预训练存在两个不足&#xff1a; 任务领域 数据集领域 相关研究 知识蒸馏 Method 单模态编码器&#xff1a; 基于图像的文本编码器&#xff1a; 基于图像的文本解码器&#xff1a; 三重目标优化 图像文本对比损失&#xff1a;让匹配的图像文本更加…