【NLP】多标签分类【上】

news2024/10/5 14:01:57

简介

《【NLP】多标签分类》主要介绍利用三种机器学习方法和一种序列生成方法来解决多标签分类问题(包含实验与对应代码)。共分为上下两篇,上篇聚焦三种机器学习方法,分别是:Binary Relevance (BR)、Classifier Chains (CC)、Label Powerset (LP),下篇聚焦利用序列生成解决多标签分类方法,将使用Transformer完成该任务。

本文共分为5节,第一节介绍实验数据来源、任务说明;第二节介绍BR、CC、LP各自原理以及优缺点;第三节介绍本文使用的多标签分类评估标准;第四节介绍实验环境、实验步骤、实验评估以及相关代码;第五节为全文总结。

相关链接

本文相关代码和数据集已上传github: issey_Kaggle/MultiLabelClassification at main · iceissey/issey_Kaggle (github.com)

本文代码(Notebook)已公布至kaggle: XLNET embedding and machine learning(BR、CC、LP) | Kaggle

博主个人博客链接:issey的博客 - 愿无岁月可回首

1 实验数据与任务说明

数据来源:Multi-Label Classification Dataset (kaggle.com)

任务说明:

  • 背景:NLP——多标签分类数据集。
  • 内容:该数据集包含6个不同的标签(计算机科学、物理学、数学、统计学、定量生物学、定量金融),用于根据摘要和标题对研究论文进行分类。 标签列中的值1表示该标签属于该论文,每篇论文可以有多个标签为1。

2 多标签分类任务与相关算法

2.1 多标签分类任务简介

多标签分类(Multi-label Classification) 是一种机器学习任务,其中每个输入样本可以分配给多个类别标签,而不是只能分配给一个单一的类别标签。与传统的单标签分类不同,多标签分类允许一个样本同时属于多个类别,这更符合现实世界中许多复杂问题的性质。

2.2 相关算法

多标签分类方法主要分为两大类,分别是问题转换方法算法适应方法,本篇主要集中于问题转换方法中的前三种。

问题转换方法:这些方法通过转换问题使其适用于标准的单标签分类算法。主要包括以下几种:

  • 二元相关性(Binary Relevance, BR):这种方法将多标签问题分解成多个独立的二分类问题,每个标签都被视为一个独立的二分类问题。
    • 优点:
      1. 简单易实现: BR方法的实现相对简单直接,因为它将复杂的多标签问题分解为多个标准的二分类问题。
      2. 灵活性: 由于BR方法在每个标签上独立训练分类器,因此可以针对不同的标签选择最适合的分类算法。
      3. 可扩展性: 在新标签加入时,只需增加相应的二分类器,而无需修改或重新训练其他分类器。
      4. 高效: 由于每个标签都独立处理,可以并行训练和预测,提高了处理速度。
    • 缺点:
      1. 忽略标签依赖性:BR方法的主要缺点是它忽略了标签之间的相关性。在实际应用中,标签往往不是完全独立的,它们之间的关联可能对分类结果有重要影响。
      2. 预测性能问题:由于不考虑标签间的依赖关系,BR方法在某些复杂的多标签问题上的预测性能可能不如那些能够考虑标签依赖性的方法。
  • 标签幂集(Label Powerset, LP):在这种方法中,每一种标签组合都被视为一个独立的类别,从而将多标签问题转换为单标签多类别问题。
    • 优点
      1. 考虑标签之间的依赖性:LP方法能够捕捉和利用标签之间的相关性。这在标签彼此之间存在强烈依赖性的情况下特别有用。
      2. 简化模型训练:与需要为每个标签单独训练一个分类器的二元相关方法相比,LP只需训练一个模型,这可以简化训练过程。
      3. 直接预测标签集合:LP方法直接预测整个标签集合,避免了将标签预测作为独立事件处理时可能出现的问题。
    • 缺点:
      1. 组合爆炸:当标签数量增多时,可能的标签组合数会指数级增长,导致计算和存储需求急剧增加。由于组合爆炸的问题,标签幂集无法处理标签种类较多的问题。
      2. 数据稀疏问题:对于一些罕见的标签组合,可能没有足够的训练数据,这会导致模型性能下降。
      3. 效率问题:尽管只需训练一个模型,但模型可能变得非常复杂,特别是当存在大量的标签组合时。
  • 分类器链(Classifier Chains, CC):这种方法通过构建一个分类器链来解决标签之间的依赖问题。每个分类器在链中负责一个标签,并将前面分类器的预测结果作为额外的输入。
    • 优点:
      1. 考虑标签间的依赖性:分类器链通过序列化的方式考虑标签间的依赖关系,这在标签相关性显著的情况下特别有用。
      2. 可扩展性:相比于标签幂集方法,分类器链在处理大量标签时更为高效,因为它避免了组合爆炸问题。
      3. 较好的泛化能力:相对于二元相关方法,分类器链通常能够提供更好的泛化能力,尤其是在标签之间存在依赖关系时。
    • 缺点:
      1. 链的顺序敏感性:分类器链的性能可能受到链中分类器顺序的影响。不同的标签顺序可能导致不同的性能表现。
      2. 错误传播:链中早期分类器的错误可能会传播到链的后面部分,影响整体性能。
  • 随机k标签子集(Random k-Labelsets, RAkEL):这种方法是通过随机选择标签子集并对每个子集应用LP方法,然后综合这些模型的预测结果。由于本文涉及的实验总共标签总类也才6种,所以没有使用这种方法而直接选择了LP
    • 优点:
      1. 缓解组合爆炸问题:通过在较小的标签子集上应用LP方法,RAkEL减少了可能的标签组合数量,从而缓解了标签幂集法中的组合爆炸问题。
      2. 考虑标签间的依赖性:与二元相关方法相比,RAkEL能够捕捉标签子集内部的依赖关系,提高了模型的准确性。
      3. 更好的泛化能力:由于模型在多个随机选择的标签子集上训练,这可以增加模型的泛化能力。
    • 缺点:
      1. 随机性:标签子集的随机选择可能导致模型性能的不稳定性。
      2. 可能忽略某些标签关系:如果某些相关标签从不在同一个子集中出现,那么它们之间的关系可能不会被模型捕捉到。
      3. 计算复杂度:虽然RAkEL缓解了组合爆炸问题,但仍需要训练多个LP模型,这可能比单一的分类器链或二元相关方法更耗时。
      4. 预测一致性问题:不同的标签子集模型可能对相同的标签做出不同的预测,需要有效的机制来整合这些预测。
      5. 参数选择:选择合适的子集大小(k值)和子集数量是RAkEL方法的关键,这可能需要根据具体的数据集进行调整。

算法适应方法:这些方法通过修改现有的学习算法使其能够直接处理多标签数据。主要包括以下几种:适应决策树(Adapted Decision Trees)、适应神经网络(Adapted Neural Networks)、适应支持向量机(Adapted Support Vector Machines)、k最近邻修改版(k-Nearest Neighbors Adaptation)。

除问题转换方法和算法适应方法外,深度学习方法也在多标签分类中表现出色。在本文的下篇中,会介绍将多标签分类转换为多标签序列生成任务的方法。

3 多标签分类评估方法

3.1 准确率(Accuracy)

  • 定义: 准确率是正确预测的样本数与总样本数的比例。在多标签分类中,如果所有的标签都被准确预测,则一个样本的预测被认为是正确的。
  • 实现: 使用sklearn.metricsaccuracy_score方法实现。
  • 备注: 由于只有当某样本所有标签全预测正确,才能算该样本预测正确,导致这种方式计算出的Acc结果普遍偏低。在下篇中,会介绍另一种计算Acc的方式,即先计算每一个label的Acc,然后在取平均值。

3.2 精确度(Precision)- 微观平均(Micro-average)

  • 定义: 精确度是模型正确预测为正的实例(真正例)占模型预测为正的所有实例(真正例和假正例)的比例。
  • 计算方法: 微观平均精确度是通过汇总所有类别的真正例和假正例的数量,然后计算总体精确度得到的。在多标签设置中,这意味着考虑所有标签的预测结果,而不是单独考虑每个标签。
  • 实现: 使用sklearn.metricsprecision_score方法实现。

3.3 召回率(Recall)- 微观平均(Micro-average)

  • 定义: 召回率是模型正确预测为正的实例占实际为正的所有实例(真正例和假负例)的比例。
  • 计算方法: 微观平均召回率是通过汇总所有类别的真正例和假负例的数量,然后计算总体召回率得到的。它反映了模型在所有标签上的总体能力,来正确地识别正类实例。
  • 实现: 使用sklearn.metricsrecall_score方法实现。

3.4 F1 分数(F1 Score)- 微观平均(Micro-average)

  • 定义: F1 分数是精确度和召回率的调和平均值,用于平衡这两个指标。
  • 计算方法: 微观平均 F1 分数是基于微观平均精确度和召回率计算得到的。它是这两个指标的调和平均值,因此在精确度和召回率都重要时,提供了一个综合性能度量。
  • 实现: 使用sklearn.metricsf1_score方法实现。

4 实验

4.1 实验环境

本实验是在以下配置的环境中进行的:

  • 编程语言和版本
    • Python 3.9:一个广泛使用的高级编程语言,适用于数据科学和机器学习项目。
  • 主要库和框架
    • NumPy 1.23.3:用于高性能科学计算和数据分析的基础包。
    • Pandas 1.4.4:提供高效的数据结构和数据分析工具。
    • Matplotlib 3.5.3:用于数据可视化的绘图库。
    • PyTorch 1.13.0:一个灵活的深度学习框架,适用于研究和生产。
    • PyTorch CUDA 11.6:用于在NVIDIA GPU上加速PyTorch运算的CUDA支持库。
  • 机器学习和深度学习库
    • Transformers 4.18.0:由Hugging Face提供的,用于自然语言处理的预训练模型和转换器。
    • scikit-learn 1.2.2:提供简单有效的数据挖掘和数据分析工具。
    • scikit-multilearn 0.2.0:用于多标签分类的机器学习库。

4.2 实验步骤

本篇的实验步骤主要包括:1)数据观察与预处理阶段。2)词嵌入阶段。3)模型训练与测试阶段。4)进一步探索。

4.2.1 数据观察与预处理

数据观察
  • 单词数量统计: 在本实验中,我们专注于观察数据集中每个文本项的单词数量。通过统计信息,我们可以了解数据集中文本的长度分布。
数据预处理
  • 最小化预处理:由于本实验在后续词嵌入时使用XL-NET模型,且与使用传统文本分类方法相比,使用XL-NET等先进的预训练模型时,常规的文本预处理步骤(如去除特殊符号、停用词移除、词形还原)并不是必要的。这些模型的分词器能够有效处理原始文本中的复杂词汇结构,同时保留对上下文理解至关重要的词汇和语法特征。
  • 实验步骤完整性:虽然在本实验中不需要传统的预处理步骤,但为了保持实验步骤的完整性和系统性,我们仍然包含了这一部分。这有助于清晰地展示实验流程,并为可能需要适当预处理的后续研究提供参考。

代码部分
  • 准备工作

导入相关库

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import h5py
from tqdm import tqdm
from transformers import XLNetTokenizer, XLNetModel
import os

检查GPU是否可用。在上篇的实验中,如果GPU不可用问题也不大,直接用CPU跑即可,因为上篇使用GPU的地方只有embedding。不过在下篇时GPU是必要的,如果本地环境不支持,建议放到云服务器(如kaggle)上跑。

# Set the device to GPU (if available).
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
Using device: cuda
  • 准备数据集

由于题目要求使用TITLE和ABSTRACT共同参与预测,所以简单做一下拼接。

"""Prepare the data"""
input_csv = "/kaggle/input/multilabel-classification-dataset/train.csv"
data = pd.read_csv(input_csv)  
# data = data[:20]  # Test
print(len(data))
data['combined_text'] = data['TITLE'] + " " + data['ABSTRACT'] 
print(data['combined_text'].head())
20972
0    Reconstructing Subject-Specific Effect Maps   ...
1    Rotation Invariance Neural Network   Rotation ...
2    Spherical polyharmonics and Poisson kernels fo...
3    A finite element approximation for the stochas...
4    Comparative study of Discrete Wavelet Transfor...
Name: combined_text, dtype: object
  • 统计combined_text单词分布

检查combined_text最长、最小、平均单词长度。

"""View the distribution of word counts"""
# Split the text using spaces and calculate the number of words
data['word_count'] = data['combined_text'].apply(lambda x: len(str(x).split()))

# Print statistical information about the number of words
print("Word count statistics:")
print("Maximum word count:", data['word_count'].max())
print("Minimum word count:", data['word_count'].min())
print("Average word count:", data['word_count'].mean())
Word count statistics:
Maximum word count: 462
Minimum word count: 5
Average word count: 157.9198455082968

绘制单词分布柱状图。

plt.figure(figsize=(10, 6))
plt.hist(data['word_count'], bins=50, alpha=0.75, color='b', edgecolor='k')
plt.xlabel('Word count')
plt.ylabel('Frequency')
plt.title('Word count distribution')
plt.show()

4.2.2 词嵌入阶段

XL-NET嵌入

在本实验中,我们使用了预训练的XL-NET模型来生成文本嵌入,这是一个关键步骤,旨在将文本转换为能被机器学习模型有效处理的数值形式。

  • 模型和分词器加载:我们首先加载了XLNet的基础模型(xlnet-base-cased)和对应的分词器。这个分词器将负责将原始文本转换成模型可以理解的令牌序列。
  • 设定批处理大小:考虑到计算效率和内存限制,我们设定了一个合适的批处理大小(batch_size = 32)。这意味着每次向模型输入32个文本样本进行处理。
嵌入生成过程
  • 文本准备和处理:我们将数据集中的文本转换为字符串列表,并按批次处理。每个批次的文本被分词器编码,其中包括截断和填充操作以确保文本长度一致。
  • 嵌入计算:对于每个批次,我们将编码后的文本输入XL-NET模型。通过模型,我们获取每个文本的嵌入表示,这些表示捕捉了文本中的语义信息。
  • 处理和存储嵌入:得到的嵌入被转换为NumPy数组,并被收集在一起。最终,所有的嵌入被存储在HDF5文件格式中,方便后续的机器学习任务使用。
不进行微调的决定
  • 一次性嵌入过程:本实验选择不对XL-NET模型进行微调,而是直接使用预训练模型一次性生成所有文本的嵌入。这种方法简化了实验流程,同时允许我们充分利用XL-NET预训练模型的强大语义捕捉能力。
  • 效率和实用性:将所有文本的嵌入预先计算并存储起来,提高了后续实验步骤的效率。

代码部分
  • 加载分词器和预训练模型
"""Load the XLNet tokenizer and model"""
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
model = XLNetModel.from_pretrained('xlnet-base-cased')
model.to(device)
batch_size = 32  # Determine the batch size
all_embeddings = []
# token = tokenizer.convert_ids_to_tokens(5)
# print(token)
""""Choose not to fine-tune the embedding layer, so embed all texts at once into vectors"""
texts = data['combined_text'].astype(str).tolist()
  • embedding,并将嵌入好的向量一次性存储下来

tqdm是一个可视化进度条的库,可以方便的查看处理进度。

这里解释一下如何从XL-NET模型的输出中提取嵌入(embedding)。

  • 模型输出理解:当我们将输入文本通过XL-NET模型处理时,outputs 对象包含了多个不同的输出组件。其中,last_hidden_state 是一个多维张量,其维度通常是 [批处理大小, 序列长度, 隐藏单元数]。这个张量包含了模型对每个输入令牌的最后一层隐藏状态的表示。
  • 选择特定令牌的嵌入:在XL-NET和类似的变压器模型中,每个输入令牌都有一个对应的输出向量。在这里,outputs.last_hidden_state[:, 0, :] 表示我们选择了每个序列的第一个令牌(通常是特殊的分类令牌,如BERT中的[CLS])的输出向量。这个向量被认为是整个输入序列的聚合表示,并经常用于分类任务。

还记得我在今年早些的时候做的那个Bert+Bilstm的任务【NLP实战】基于Bert和双向LSTM的情感分类【中篇】-CSDN博客,当时我在embeding后直接取的last_hidden_state,也就是个三维向量,接着用Bilstm得到最终的二维隐藏层(只保留了最后的隐藏状态),现在想来当时对Bert的理解还是不到位。然而这两种方法都是有效的,不过一个是词维度的嵌入,一个是句维度的嵌入,本文上篇使用的embedding就是句维度的嵌入。

# Specify the directory path
directory_path = '/kaggle/working/multilabel-classification-dataset/'

# Create the directory if it doesn't exist
if not os.path.exists(directory_path):
    os.makedirs(directory_path)
    
for start_index in tqdm(range(0, len(texts), batch_size)):
    # Encode the text
    batch_texts = texts[start_index:start_index + batch_size]
    encoded_inputs = tokenizer(batch_texts, return_tensors='pt', max_length=512, truncation=True, padding='max_length')
    # get embeddings
    input_ids = encoded_inputs['input_ids'].to(device)
    attention_mask = encoded_inputs['attention_mask'].to(device)
    #  calculate embeddings
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    #  move the results back to CPU and convert to numpy arrays
    embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
    # print(embeddings.shape)
    
    all_embeddings.extend(embeddings)
# Convert all embeddings to numpy arrays
all_embeddings = np.array(all_embeddings)
print(all_embeddings.shape)
# Store embedding vectors to an HDF5 file
hdf5_filename = '/kaggle/working/multilabel-classification-dataset/embeddings.h5'
with h5py.File(hdf5_filename, 'w') as hdf5_file:
    hdf5_file.create_dataset('embeddings', data=all_embeddings)

print(f"Embeddings have been stored in the {hdf5_filename} file.")
100%|██████████| 656/656 [15:11<00:00,  1.39s/it]
(20972, 768)
Embeddings have been stored in the /kaggle/working/multilabel-classification-dataset/embeddings.h5 file.

可以看到,现在我们的数据集中的text(也就是'combined_text'),被编译为了一个768维度的向量。一共有20972行text,所以嵌入矩阵为(20972, 768)

4.2.3 模型训练与测试阶段

数据准备
  • 数据集加载与分割:我们从CSV文件中加载了数据集,并提取了标签列。接着,使用XL-NET生成的嵌入向量作为特征,将数据集分割为训练集和测试集,保证了模型训练和评估的有效性和公正性。
多标签分类方法

我们采用了三种不同的多标签分类方法:二元相关(Binary Relevance, BR)、分类器链(Classifier Chains, CC)和标签幂集(Label Powerset, LP)。每种方法都使用了随机森林分类器作为基学习器。

  • 二元相关(Binary Relevance):这种方法将多标签问题分解为多个独立的二分类问题。我们首先训练了BR模型,并记录了训练时间。接着,我们在测试集上进行预测,并计算了准确度、精确度、召回率和F1分数(微观平均)。
  • 分类器链(Classifier Chains):这种方法通过构建一个分类器链,使每个分类器在预测时考虑到之前分类器的输出。同样,我们训练了CC模型,记录了训练时间,并在测试集上进行了评估。
  • 标签幂集(Label Powerset):LP方法将多标签问题转换为单标签多类别问题。我们训练了LP模型,并对其进行了测试集上的性能评估。
性能评估
  • 评估指标:为了全面评估每种方法的性能,我们计算了准确度、精确度、召回率和F1分数(均采用微观平均),评估指标详细说明如第三节所示。这些指标帮助我们理解不同方法在处理多标签分类任务时的效果和局限。

  • 训练时间和性能:每种方法的训练时间都被记录下来,以评估其在实际应用中的可行性。


代码部分
  • 导入相关库
from skmultilearn.problem_transform import BinaryRelevance, ClassifierChain, LabelPowerset
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import pandas as pd
import h5py
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score, recall_score, f1_score
import time
  • 准备数据

提取标签。

data_path = "/kaggle/input/multilabel-classification-dataset/train.csv"
data = pd.read_csv(data_path)

label_columns = data.columns[-6:]  # Extract the 'labels' column
y = data[label_columns].values
print(y.shape)
(20972, 6)

加载经过XL-NET嵌入后的隐向量。

# Load embedding vectors
with h5py.File('/kaggle/input/xlnet-embedding-for-multilabel-classification/embeddings.h5', 'r') as f:
    embeddings = np.array(f['embeddings'])
print(embeddings.shape)
# 确保标签和嵌入向量的行数相同
assert embeddings.shape[0] == y.shape[0]
(20972, 768)

用于后续测试,如果要让模型快速运行就把注释打开。

# TEST
# embeddings = embeddings[:1000]
# y = y[:1000]

分割数据集。

# Split the dataset into a training set and a test set.
X_train, X_test, y_train, y_test = train_test_split(embeddings, y, test_size=0.2, random_state=10)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
(16777, 768) (4195, 768) (16777, 6) (4195, 6)
  • 模型训练与测试
# Binary Relevance
start_time = time.time()
br_classifier = BinaryRelevance(RandomForestClassifier())
br_classifier.fit(X_train, y_train)
br_training_time = time.time() - start_time
br_predictions = br_classifier.predict(X_test)
br_precision = precision_score(y_test, br_predictions, average='micro')
br_recall = recall_score(y_test, br_predictions, average='micro')
br_f1 = f1_score(y_test, br_predictions, average='micro')
print("===================================")
print("BR Training Time:", br_training_time)
print("BR Accuracy =", accuracy_score(y_test, br_predictions))
print("BR Precision (micro-average) =", br_precision)
print("BR Recall (micro-average) =", br_recall)
print("BR F1 Score (micro-average) =", br_f1)


# Classifier Chains
start_time = time.time()
cc_classifier = ClassifierChain(RandomForestClassifier())
cc_classifier.fit(X_train, y_train)
cc_training_time = time.time() - start_time
cc_predictions = cc_classifier.predict(X_test)
cc_precision = precision_score(y_test, cc_predictions, average='micro')
cc_recall = recall_score(y_test, cc_predictions, average='micro')
cc_f1 = f1_score(y_test, cc_predictions, average='micro')
print("===================================")
print("CC Training Time:", cc_training_time)
print("CC Accuracy =", accuracy_score(y_test, cc_predictions))
print("CC Precision (micro-average) =", cc_precision)
print("CC Recall (micro-average) =", cc_recall)
print("CC F1 Score (micro-average) =", cc_f1)

# Label Powerset
start_time = time.time()
lp_classifier = LabelPowerset(RandomForestClassifier())
lp_classifier.fit(X_train, y_train)
lp_training_time = time.time() - start_time
lp_predictions = lp_classifier.predict(X_test)
lp_precision = precision_score(y_test, lp_predictions, average='micro')
lp_recall = recall_score(y_test, lp_predictions, average='micro')
lp_f1 = f1_score(y_test, lp_predictions, average='micro')
print("===================================")
print("LP Training Time:", lp_training_time)
print("LP Accuracy =", accuracy_score(y_test, lp_predictions))
print("LP Precision (micro-average) =", lp_precision)
print("LP Recall (micro-average) =", lp_recall)
print("LP F1 Score (micro-average) =", lp_f1)
===================================
BR Training Time: 445.8087875843048
BR Accuracy = 0.4476758045292014
BR Precision (micro-average) = 0.8038496791934006
BR Recall (micro-average) = 0.4978240302743614
BR F1 Score (micro-average) = 0.6148632858144426
===================================
CC Training Time: 410.08831691741943
CC Accuracy = 0.4786650774731824
CC Precision (micro-average) = 0.8012065498419995
CC Recall (micro-average) = 0.5277199621570482
CC F1 Score (micro-average) = 0.6363221537759525
===================================
LP Training Time: 74.27938294410706
LP Accuracy = 0.5349225268176401
LP Precision (micro-average) = 0.7178777393310265
LP Recall (micro-average) = 0.5888363292336802
LP F1 Score (micro-average) = 0.646985446985447
结果分析

可以看到,LP不仅训练时间最短,而且Acc和F1都要更好。因此,我们可以继续探究使用支持向量机(SVM)作为基分类器的效果。

4.2.4 进一步探索–使用SVM的标签幂集方法

实验设计
  • 基分类器更换:鉴于LP方法的成功,我们决定用SVM替换原先的随机森林分类器,以进一步探索不同基分类器对多标签分类任务性能的影响。
  • SVM配置:我们选择了线性核的SVM,并将其包装在OneVsRestClassifier中,以适应多类别问题。线性核是因其在处理高维数据时的有效性和计算效率而被选用。
训练和评估
  • 模型训练:使用LP方法结合SVM分类器训练模型,并记录了训练时间。

  • 性能评估:在测试集上评估了模型的准确度、精确度、召回率和F1分数(均采用微观平均)。这些指标有助于我们全面了解SVM在多标签分类任务中的表现。

  • 训练时间对比:与之前使用随机森林的LP方法相比,我们特别关注SVM版本的训练时间,以评估其在实际应用中的效率。


代码部分
from sklearn.svm import SVC
from sklearn.multiclass import OneVsRestClassifier

# Use SVM as the base classifier
svm_classifier = OneVsRestClassifier(SVC(kernel='linear'))  # The kernel function uses a linear function.

# Label Powerset with SVM
start_time = time.time()
lp_svm_classifier = LabelPowerset(svm_classifier)
lp_svm_classifier.fit(X_train, y_train)
lp_svm_training_time = time.time() - start_time
lp_svm_predictions = lp_svm_classifier.predict(X_test)
print("===================================")
print("LP-SVM Training Time:", lp_svm_training_time)
print("LP-SVM Accuracy =", accuracy_score(y_test, lp_svm_predictions))
print("LP-SVM Precision (micro-average) =", precision_score(y_test, lp_svm_predictions, average='micro'))
print("LP-SVM Recall (micro-average) =", recall_score(y_test, lp_svm_predictions, average='micro'))
print("LP-SVM F1 Score (micro-average) =", f1_score(y_test, lp_svm_predictions, average='micro'))
===================================
LP-SVM Training Time: 13640.821268558502
LP-SVM Accuracy = 0.5914183551847437
LP-SVM Precision (micro-average) = 0.7367712141620165
LP-SVM Recall (micro-average) = 0.7245033112582782
LP-SVM F1 Score (micro-average) = 0.7305857660751764
结果分析

可以看到,LP-SVM的训练时间比使用随机森林的LP长了184倍,但所有评估标准都比使用随机森林的LP好。显然,它是我们本篇中最好的模型。

5 总结

本篇为《【NLP】多标签分类》的上篇,本文详细细探讨了多标签分类问题,聚焦于三种机器学习方法(Binary Relevance, Classifier Chains, Label Powerset),展示了每种方法的原理、优缺点,以及具体的实验评估和代码实现。本文还探讨了如何使用XL-NET做嵌入。实验结果表明,标签幂集方法配合随机森林分类器在训练时间和性能(准确度和F1分数)上表现良好。进一步探索使用SVM作为基分类器后,虽然训练时间增长,但所有评估标准均有所提升,显示出更好的性能。文章通过详细的实验步骤和评估方法,为选择适合特定多标签分类任务的方法提供了实证依据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1372110.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实用Unity3D Log打印工具XDebug

特点 显示时间&#xff0c;精确到毫秒显示当前帧数&#xff08;在主线程中的打印才有意义&#xff0c;非主线程显示为-1&#xff09;有三种条件编译符(如下图) 注&#xff1a;要能显示线程中的当前帧数&#xff0c;要在app启动时&#xff0c;初始化mainThreadID字段条件编译符…

在App Store Connect上编辑多个用户的访问权限

作为一名编程新手&#xff0c;在App Store Connect中管理用户权限可能初听起来有些复杂&#xff0c;但实际上它是一个相对直接的过程。这里是一个步骤清晰的指南来帮助您在App Store Connect上编辑多个用户的访问权限。 App Store Connect 简介 在开始之前&#xff0c;让我们先…

爬虫网易易盾滑块及轨迹算法案例:某乎

声明&#xff1a; 该文章为学习使用&#xff0c;严禁用于商业用途和非法用途&#xff0c;违者后果自负&#xff0c;由此产生的一切后果均与作者无关 一、滑块初步分析 js运行 atob(‘aHR0cHM6Ly93d3cuemhpaHUuY29tL3NpZ25pbg’) 拿到网址&#xff0c;浏览器打开网站&#xff0…

【低照度图像增强系列(3)】EnlightenGAN算法详解与代码实现

前言 ☀️ 在低照度场景下进行目标检测任务&#xff0c;常存在图像RGB特征信息少、提取特征困难、目标识别和定位精度低等问题&#xff0c;给检测带来一定的难度。 &#x1f33b;使用图像增强模块对原始图像进行画质提升&#xff0c;恢复各类图像信息&#xff0c;再使用目标检…

如何降低成本,制作个性化电子产品宣传册呢

​随着科技的飞速发展&#xff0c;电子产品已经深入到我们生活的每一个角落。然而&#xff0c;如何让你的产品在众多竞争者中脱颖而出呢&#xff1f;制作一份个性化的宣传册&#xff0c;不仅可以吸引潜在客户&#xff0c;还能有效降低成本&#xff0c;提升销售效果。 一、明确目…

基于JAVA+SpringBoot的高校学术报告系统

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取项目下载方式&#x1f345; 一、项目背景介绍&#xff1a; 智慧高校学术报告系统…

职场日常英语口语,成人英语培训学校,柯桥学英语推荐哪里

“玩手机”用英语怎么说&#xff1f;你的第一反应是不是&#xff1a;play the phone&#xff1f; 在英语中&#xff0c;play这个动词通常表示“玩耍、娱乐、操纵”等意思&#xff0c;而手机是一种工具&#xff0c;不是玩耍的对象。 换句话说&#xff0c;我们“玩手机”&#xf…

CUDA编程:执行模型

SM 在SM中&#xff0c;共享内存和寄存器是非常重要的资源。共享内存被分配在SM上的常驻线程 块中&#xff0c;寄存器在线程中被分配。线程块中的线程通过这些资源可以进行相互的合作和通 信。 WARP CUDA采用单指令多线程&#xff08;SIMT&#xff09;架构来管理和执行线程&am…

机器学习中的隐马尔可夫模型及Python实现示例

隐马尔可夫模型&#xff08;HMM&#xff09;是一种统计模型&#xff0c;用于描述观测序列和隐藏状态序列之间的概率关系。它通常用于生成观测值的底层系统或过程未知或隐藏的情况&#xff0c;因此它被称为“隐马尔可夫模型”。 它用于根据生成数据的潜在隐藏过程来预测未来的观…

跟我学java|Stream流式编程——并行流

什么是并行流 并行流是 Java 8 Stream API 中的一个特性。它可以将一个流的操作在多个线程上并行执行&#xff0c;以提高处理大量数据时的性能。 在传统的顺序流中&#xff0c;所有的操作都是在单个线程上按照顺序执行的。而并行流则会将流的元素分成多个小块&#xff0c;并在多…

【Java集合篇】 ConcurrentHashMap在哪些地方做了并发控制

ConcurrentHashMap在哪些地方做了并发控制 ✅典型解析✅初始化桶阶段&#x1f7e2;桶满了会自动扩容吗&#x1f7e0;自动扩容的时间频率是多少 ✅put元素阶段✅扩容阶段&#x1f7e0; 拓展知识仓&#x1f7e2;ConcurrentSkipListMap和ConcurrentHashMap有什么区别☑️简单介绍一…

2024年第九届计算机与通信系统国际会议(ICCCS2024) ,邀您相约西安!

会议官网: ICCCS2024 | Xian China 时间: 2024年4月19-22日 地点: 中国西安 会议简介&#xff1a; 近年来&#xff0c;信息通信在不断发展&#xff0c;为计算机网络的进步与发展提供了先进可靠的技术支持。随着计算机网络与通信技术的深入发展&#xff0c;计算机通信技术、数…

报错解决:RuntimeError: Error building extension ‘bias_act_plugin‘

系统&#xff1a; Ubuntu22.04&#xff0c; nvcc -V&#xff1a;11.8 &#xff0c; torch&#xff1a;2.0.0cu118 一&#xff1a;BUG内容 运行stylegan项目的train.py时遇到报错&#x1f447; Setting up PyTorch plugin "bias_act_plugin"... Failed! /home/m…

docker+jmeter实现windows作为主控机,linux作为负载机的分布式压测环境搭建

dockerjmeter实现windows作为主控机&#xff0c;linux作为负载机的分布式压测环境搭建 1、搭建环境说明2、windows主控机安装Jmeter3、linux负载机安装Jmeter3.1、安装docker环境3.2、使用docker安装jmeter 4、windows主控机分发测试任务 1、搭建环境说明 准备一台windows主机…

京东(天猫淘宝)数据分析工具-鲸参谋系统全功能解析——行业大盘、红蓝海市场、品牌分析、店铺分析、商品分析、竞品监控(区分自营和POP)

作为第三方电商数据平台&#xff0c;鲸参谋电商大数据系统能够为品牌方和商家提供包括行业趋势、热门品牌、店铺分析、单品分析在内的多个层面数据分析&#xff0c;帮助商家做出更加准确的经营决策&#xff0c;提升经营效率&#xff0c;实现精准营销。 下面&#xff0c;我们针…

YOLOv8优化策略:轻量化改进 | 超越RepVGG!浙大阿里提出OREPA:在线卷积重参数化

🚀🚀🚀本文改进:在线卷积重参数化巧妙的和YOLOV8结合,并实现轻量化 🚀🚀🚀YOLOv8改进专栏:http://t.csdnimg.cn/hGhVK 学姐带你学习YOLOv8,从入门到创新,轻轻松松搞定科研; 1.OREPA介绍 论文:https://arxiv.org/pdf/2204.00826.pdf 摘要:结构重新参数化在…

软件测试|Python Selenium 库安装使用指南

简介 Selenium 是一个用于自动化浏览器操作的强大工具&#xff0c;它可以模拟用户在浏览器中的行为&#xff0c;例如点击、填写表单、导航等。在本指南中&#xff0c;我们将详细介绍如何安装和使用 Python 的 Selenium 库。 安装 Selenium 库 使用以下命令可以通过 pip 安装…

内网渗透之CobaltStrike(CS)

目录 一、Cobalt Strike简介 二、Cobalt Strike基本用法 1、启动服务端 2、客户端连接 3、设置监听器&#xff08;Listeners&#xff09; 4、脚本管理器&#xff08;Script Manager&#xff09; 5、攻击&#xff08;最常用的是生成后门&#xff09; 6、CS上线 7、Beaco…

图神经网络 7大高效创新思路分享,附17篇最新顶会论文和代码

2024年了&#xff0c;图神经网络方向还好发论文吗&#xff1f;答案当然是能。 图神经网络在处理非欧空间数据和复杂特征方面具有明显的优势&#xff0c;且已成为了深度学习领域的热点&#xff0c;在学术界和工业界都有着广泛的研究和应用。不仅如此&#xff0c;图神经网络与CV…

如何在集简云中调用GPTs(Assistant) API

我们在OpenAI中创建了GPTs(Assistant)后&#xff0c;希望放到其它软件中使用&#xff0c;比如 抖音私信&#xff0c;抖音评论&#xff0c;微信公众号&#xff0c;钉钉&#xff0c;飞书&#xff0c;企业微信...... 要如何实现这样的功能呢&#xff1f; 您可以使用集简云的 “数…