【可解释性机器学习】解释基于Scikit-learn进行文本分类的pipeline及结果

news2024/11/16 4:20:04

使用Scikit-learn进行文本分类pipeline

  • 1. 基线模型
  • 2. 基线模型,改进的数据
  • 3. Pipeline改进
  • 4. 基于字符的pipeline
  • 5. 调试HashingVectorizer
  • 参考资料

scikit-learn文档提供了一个很好的文本分类教程。确保先阅读它。 本文中,我们将做类似的事情,同时更详细地研究分类器权重和预测结果。

1. 基线模型

首先,让加载 20 个新闻组数据,只保留 4 个类别:

from sklearn.datasets import fetch_20newsgroups

categories = ['alt.atheism', 'soc.religion.christian',
              'comp.graphics', 'sci.med']
twenty_train = fetch_20newsgroups(
    subset='train',
    categories=categories,
    shuffle=True,
    random_state=42
)
twenty_test = fetch_20newsgroups(
    subset='test',
    categories=categories,
    shuffle=True,
    random_state=42
)

一个基本的文本处理pipeline——bag of words特征和Logistic Regression作为分类器:

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegressionCV
from sklearn.pipeline import make_pipeline

vec = CountVectorizer()
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target);

这里使用 LogisticRegressionCV 来自动调整正则化参数 C它允许比较不同的向量化器——最佳 C 值对于不同的输入特征可能不同(例如,对于双字母组或字符级输入)。 另一种方法是使用 GridSearchCV 或 RandomizedSearchCV

让我们检查一下这条管道的质量:

from sklearn import metrics

def print_report(pipe):
    y_test = twenty_test.target
    y_pred = pipe.predict(twenty_test.data)
    report = metrics.classification_report(y_test, y_pred,
        target_names=twenty_test.target_names)
    print(report)
    print("accuracy: {:0.3f}".format(metrics.accuracy_score(y_test, y_pred)))

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.91      0.81      0.86       319
         comp.graphics       0.86      0.94      0.90       389
               sci.med       0.92      0.81      0.86       396
soc.religion.christian       0.88      0.98      0.93       398

              accuracy                           0.89      1502
             macro avg       0.89      0.89      0.89      1502
          weighted avg       0.89      0.89      0.89      1502

accuracy: 0.889

不错。 可以尝试其他分类器和预处理方法,但先检查模型使用eli5.show_weights() 函数学到了什么

import eli5
eli5.show_weights(clf, top=10)

输出结果
上表没有任何意义; 问题是eli5无法单独从分类器对象中获取特征和类名。 可以明确提供功能和目标名称:

eli5.show_weights(clf,feature_names=vec.get_feature_names(),target_names=twenty_test.target_names)

输出结果
上面的代码有效,但更好的方法是提供vectorizer,让eli5自动计算细节:

eli5.show_weights(clf, vec=vec, top=10, target_names=twenty_test.target_names)

输出结果
这开始变得更有意义了。 Columns 是目标类。 在每一列中都有特征及其权重。 Intercept (偏差)特征在同一张表中显示为 <BIAS>。 我们可以检查特征和权重,因为我们使用的是词袋向量化器和线性分类器(因此单个词和分类器系数之间存在直接映射)。 对于其他分类器,特征可能更难检查。

有些功能看起来不错,但有些则不然。 不过,模型似乎学习了一些特定于数据集的名称(电子邮件部分等),而不是学习特定主题的单词。 让我们检查一个例子的预测结果:

eli5.show_prediction(clf, twenty_test.data[0], vec=vec, target_names=twenty_test.target_names)

输出结果
可以在文本中突出显示内容。 还有一个单独的表格用于无法在文本中突出显示的功能 - 在本例中为<BIAS>。 如果将鼠标悬停在突出显示的词上,它会显示该词在标题中的权重。 单词根据其权重着色。

2. 基线模型,改进的数据

从上面的突出显示可以看出分类器确实学到了一些不感兴趣的东西,例如 它记住了部分电子邮件地址。 我们可能应该首先清理数据以使其更有趣; 改进模型(尝试不同的分类器等)在这一点上没有意义——它可能只是学会更好地利用这些电子邮件地址。

实际上,我们必须自己清理数据; 在此示例中,20 个新闻组数据集提供了一个选项,用于从消息中删除页脚和标题。让我们清理数据并重新训练分类器。

twenty_train = fetch_20newsgroups(
    subset='train',
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=['headers', 'footers'],
)
twenty_test = fetch_20newsgroups(
    subset='test',
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=['headers', 'footers'],
)

vec = CountVectorizer()
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target);

我们只是让分类器的任务变得更难、更现实。

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.81      0.76      0.79       319
         comp.graphics       0.82      0.93      0.87       389
               sci.med       0.87      0.78      0.82       396
soc.religion.christian       0.86      0.88      0.87       398

              accuracy                           0.84      1502
             macro avg       0.84      0.84      0.84      1502
          weighted avg       0.84      0.84      0.84      1502

accuracy: 0.840

一个很好的结果——我们只是让质量变差了! 这是否意味着管道现在更糟? 不,可能它对看不见的消息有更好的质量。 现在是比较公正的评价。 检查分类器使用的特征让我们注意到数据的问题并做出了很好的改变,尽管数字告诉我们不要这样做。

我们可以直接改进评估设置,而不是删除页眉和页脚,例如使用 来自 scikit-learn 的 GroupKFold。 那么旧模型的质量就会下降,我们可以删除页眉/页脚并提高准确性,所以数字会告诉我们删除页眉和页脚。 但是,如何拆分数据以及使用 GroupKFold 的哪些组并不明显。

那么,更新后的分类器学到了什么? (输出不那么冗长,因为只显示了类的一个子集 - 请参阅“targets”参数):

eli5.show_prediction(clf, twenty_test.data[0], vec=vec, target_names=twenty_test.target_names, targets=['sci.med'])

输出结果
它不再使用电子邮件地址,但它看起来仍然不太好:分类器将高权重分配给看似无关的词,如“do”或“my”。 这些词出现在许多文本中,因此分类器可能将它们用作偏差的代理。 或者也许其中一些在某些类别中更常见。

3. Pipeline改进

为了帮助分类器,我们可以过滤掉停用词:

vec = CountVectorizer(stop_words='english')
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.86      0.76      0.81       319
         comp.graphics       0.85      0.94      0.89       389
               sci.med       0.92      0.85      0.88       396
soc.religion.christian       0.86      0.89      0.87       398

              accuracy                           0.87      1502
             macro avg       0.87      0.86      0.86      1502
          weighted avg       0.87      0.87      0.87      1502

accuracy: 0.868
eli5.show_prediction(clf, twenty_test.data[0], vec=vec, target_names=twenty_test.target_names, targets=['sci.med'])

输出结果
从结果上看起来更好了。

或者,可以使用 TF*IDF 方案; 它应该会产生类似的效果。

请注意,在这里交叉验证LogisticRegression正则化参数,就像在其他示例中一样(LogisticRegressionCV,而不是LogisticRegression)。 TF*IDF值不同于单词计数值,因此最佳C值可能不同。 如果使用具有固定正则化强度的分类器,我们可能会得出错误的结论——所选择的C值可能对一种数据更有效。

from sklearn.feature_extraction.text import TfidfVectorizer

vec = TfidfVectorizer()
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.89      0.81      0.85       319
         comp.graphics       0.87      0.95      0.91       389
               sci.med       0.94      0.88      0.91       396
soc.religion.christian       0.89      0.92      0.90       398

              accuracy                           0.89      1502
             macro avg       0.90      0.89      0.89      1502
          weighted avg       0.90      0.89      0.89      1502

accuracy: 0.895
eli5.show_prediction(clf, twenty_test.data[0], vec=vec, target_names=twenty_test.target_names, targets=['sci.med'])

输出结果
它有所帮助,但没有完全相同的效果。 为什么不两者都做?

vec = TfidfVectorizer(stop_words='english')
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.92      0.80      0.86       319
         comp.graphics       0.90      0.96      0.93       389
               sci.med       0.95      0.92      0.93       396
soc.religion.christian       0.89      0.94      0.91       398

              accuracy                           0.91      1502
             macro avg       0.91      0.91      0.91      1502
          weighted avg       0.91      0.91      0.91      1502

accuracy: 0.911
eli5.show_prediction(clf, twenty_test.data[0], vec=vec, target_names=twenty_test.target_names, targets=['sci.med'])

输出结果
这开始看起来不错!

4. 基于字符的pipeline

也许可以通过选择不同的分类器来获得更好的质量,但我们暂时跳过它。 让我们试试其他分析器——使用 char n-grams 代替单词:

vec = TfidfVectorizer(stop_words='english', analyzer='char',
                      ngram_range=(3,5))
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.91      0.80      0.85       319
         comp.graphics       0.86      0.96      0.90       389
               sci.med       0.93      0.88      0.91       396
soc.religion.christian       0.89      0.92      0.91       398

              accuracy                           0.89      1502
             macro avg       0.90      0.89      0.89      1502
          weighted avg       0.90      0.89      0.89      1502

accuracy: 0.895
eli5.show_prediction(clf, twenty_test.data[0], vec=vec, target_names=twenty_test.target_names)

输出结果
它有效,但质量有点差。 此外,训练需要很长时间。
看起来stop_words现在没有效果——事实上,这在scikit-learn文档中有记录,所以我们的stop_words=‘english’ 没有用。 但至少现在更明显的是文本对于基于char ngram的分类器的样子。 稍等片刻,看看char_wb长什么样:

vec = TfidfVectorizer(analyzer='char_wb', ngram_range=(3,5))
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.92      0.79      0.85       319
         comp.graphics       0.89      0.96      0.92       389
               sci.med       0.92      0.89      0.91       396
soc.religion.christian       0.88      0.93      0.91       398

              accuracy                           0.90      1502
             macro avg       0.90      0.90      0.90      1502
          weighted avg       0.90      0.90      0.90      1502

accuracy: 0.900
eli5.show_prediction(clf, twenty_test.data[0], vec=vec, target_names=twenty_test.target_names)

输出结果
结果是相似的,有一些小的变化。 质量更好,原因不明; 也许交叉词依赖并不那么重要。

5. 调试HashingVectorizer

为了检查我们是否可以尝试拟合word n-gram 而不是 char n-gram。 但让我们首先处理效率问题。 为了处理大词汇表,我们可以使用scikit-learn中的HashingVectorizer; 为了加快训练速度,我们可以使用SGDCLassifier:

from sklearn.feature_extraction.text import HashingVectorizer
from sklearn.linear_model import SGDClassifier

vec = HashingVectorizer(stop_words='english', ngram_range=(1,2))
clf = SGDClassifier(max_iter=10, random_state=42)
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.90      0.80      0.85       319
         comp.graphics       0.87      0.96      0.91       389
               sci.med       0.93      0.89      0.91       396
soc.religion.christian       0.89      0.92      0.90       398

              accuracy                           0.90      1502
             macro avg       0.90      0.89      0.89      1502
          weighted avg       0.90      0.90      0.90      1502

accuracy: 0.897

超级快! 不过,我们没有使用交叉验证来选择正则化参数。 让我们检查一下模型学到了什么:

eli5.show_prediction(clf, twenty_test.data[0], vec=vec, target_names=twenty_test.target_names, targets=['sci.med'])

输出结果
结果看起来类似于CountVectorizer。 但是使用HashingVectorizer我们甚至没有词汇! 为什么它有效?

eli5.show_weights(clf, vec=vec, top=10, target_names=twenty_test.target_names)

输出结果
好的,我们没有词汇表,所以我们没有特征名称。 我们运气不好吗? 不,eli5 有一个答案:InvertableHashingVectorizer。 它可用于获取 HahshingVectorizer 的特征名称,而无需使用大量词汇。 它仍然需要一些数据来学习单词 -> 哈希映射; 我们可以使用随机数据子集来拟合它。

from eli5.sklearn import InvertableHashingVectorizer
import numpy as np

ivec = InvertableHashingVectorizer(vec)
sample_size = len(twenty_train.data) // 10
X_sample = np.random.choice(twenty_train.data, size=sample_size)
ivec.fit(X_sample);

eli5.show_weights(clf, vec=ivec, top=20, target_names=twenty_test.target_names)

输出结果
存在collisions (将鼠标悬停在带有“…”的特征上),并且存在随机样本中未看到的重要特征(FEATURE[…]),但总体而言它看起来不错。
“rutgers edu”的二元组特征很可疑,它看起来像是 URL 的一部分。

rutgers_example = [x for x in twenty_train.data if 'rutgers' in x.lower()][0]
print(rutgers_example)
In article <Apr.8.00.57.41.1993.28246@athos.rutgers.edu> REXLEX@fnal.gov writes:
>In article <Apr.7.01.56.56.1993.22824@athos.rutgers.edu> shrum@hpfcso.fc.hp.com
>Matt. 22:9-14 'Go therefore to the main highways, and as many as you find
>there, invite to the wedding feast.'...

>hmmmmmm.  Sounds like your theology and Christ's are at odds. Which one am I 
>to believe?

是的,看起来模型学习了这个地址,而不是学习了一些有用的东西。

eli5.show_prediction(clf, rutgers_example, vec=vec, target_names=twenty_test.target_names, targets=['soc.religion.christian'])

输出结果
引用的文本使模型很容易对某些消息进行分类; 这不会推广到新消息。 因此,为了改进模型,下一步可能是进一步处理数据,例如 删除引用的文本或用特殊标记替换电子邮件地址。

查看特征有助于理解分类器的工作原理。 也许更重要的是,它有助于注意到预处理错误、数据泄漏、任务规范问题——所有这些你在现实世界中遇到的讨厌问题。

参考资料

[1] [Debugging scikit-learn text classification pipeline](Debugging scikit-learn text classification pipeline)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/184851.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学习笔记-----通道

加粗样式# system v共享内存 进程通信的前提条件是&#xff1a;让不同进程看见同一份资源。 共享内存&#xff1a;其实就是进程获取共享区里面的地址&#xff0c;该地址为物理内存中某块我所需要资源的地址(该内存是创建的共享内存处在共享区里)&#xff0c;地址通过页表映射到…

项目工时管理遇难题?看看这套工时管理系统解决方案

随着社会化大生产的发展以及市场竞争的日趋激烈&#xff0c;现代企业的规模在不断扩大。对于项目企业来说&#xff0c;人力资源的成本就是项目的主要成本&#xff0c;而工时是项目中人工成本的重要依据&#xff0c;因此&#xff0c;管理好员工工时是项目管理过程中最重要的任务…

计算机图形学 第3章 圆的扫描转换-第三章结束

书用的是 书名:计算机图形学基础教程&#xff08;VisualC版&#xff09;&#xff08;第二版&#xff09; 定价&#xff1a;44.5元 作者:孔令德 出版社&#xff1a;清华大学出版社 出版日期&#xff1a;2013-03-01 ISBN&#xff1a;9787302297529 目录习题3&#xff08;续&#…

【数据结构】8.3 交换排序

文章目录1. 冒泡排序冒泡排序算法冒泡排序算法分析2. 快速排序快速排序算法快速排序算法分析基本思想 每两个元素之间互相比较&#xff0c;如果发现大小关系相反&#xff0c;则将他们交换过来&#xff0c;直到所有记录都排好序为止。假设希望是从小到大来排序&#xff0c;结果…

Nginx-反向代理配置学习总结

Nginx-反向代理配置学习总结 正向代理&#xff1a;指的是通过代理服务器 代理浏览器/客户端去重定向请求访问到目标服务器 的一种代理服务&#xff0c;正向代理服务的特点是代理服务器 代理的对象是浏览器/客户端&#xff0c;也就是对于目标服务器 来说浏览器/客户端是隐藏的。…

文件的IO

一、文件的定义狭隘的文件:指你的硬盘上的文件和目录.广义的文件:泛指计算机中的硬件资源,操作系统中,把很多硬件设备和软件资源都抽象成了文件,按照文件的形式统一管理.比如网卡,操作系统也是把网卡抽象成了文件资源,所以说操作网卡其实和操作文件的方式是基本一样的.而我们本…

初识流计算框架Spark

Spark简介 Spark最初由美国加州伯克利大学&#xff08;UCBerkeley&#xff09;的AMP&#xff08;Algorithms, Machines and People&#xff09;实验室于2009年开发&#xff0c;是基于内存计算的大数据并行计算框架&#xff0c;可用于构建大型的、低延迟的数据分析应用程序。Sp…

一刷代码随想录——链表

1.理论基础链表节点的定义&#xff1a;struct ListNode {int val;ListNode* next;ListNode() : val(0), next(nullptr) {}ListNode(int x) : val(x), next(nullptr) {}ListNode(int x, ListNode* next) : val(x), next(next) {} };根据卡哥提示&#xff0c;由于力扣中已经给出如…

C++中拷贝构造函数、拷贝赋值运算符、析构函数、移动构造函数、移动赋值运算符(三/五法则)

1、介绍 三五法则是针对C中类的成员和类对象的操作函数。 三法则是指&#xff1a;拷贝构造函数、拷贝赋值运算符、析构函数。 五法则是在三法则的基础上增加了&#xff1a;移动构造函数、移动赋值运算符。 2、拷贝构造函数 定义&#xff1a;如果构造函数的第一个参数是自身…

Postman前置脚本

位置&#xff1a;作用&#xff1a;调用脚本之前需要执行的代码片段一、产生随机数字生成0-1之间的随机数&#xff0c;包括0&#xff0c;不包括1&#xff1b;var random Math.random();console.log("随机数",random);获取最小值到最大值之前的整数随机数function Get…

2019-ICML-Graph U-Nets

2019-ICML-Graph U-Nets Paper: https://arxiv.org/abs/1905.05178 Code: https://github.com/HongyangGao/Graph-U-Nets 图U-Nets 作者将CNN上的U-Net运用到了图分类上&#xff0c;因为我们主题是图分类&#xff0c;就不对U-Net进行论述了&#xff0c;只对其中的gPool&#…

eureka 读写锁的一点思考

读写锁 读写锁一般实现 读读不互斥 读写互斥 写写互斥 读写锁的好处是&#xff0c;面对读多写多的场景会拥有比较好的表现 一般我们会在读操作加上读锁&#xff0c;写操作加上写锁。但是最近我发现eureka 在使用读写锁的时候是相反的&#xff0c; 也就是说在读操作加上了读锁&…

2023最值得入手的运动耳机是哪款、口碑最好的运动蓝牙耳机推荐

不知道有没有和我一样的小伙伴&#xff0c;在运动时特别喜欢听音乐&#xff0c;每次听到一首合适的音乐&#xff0c;感觉运动起来都更有激情和活力了。所以这时候就需要挑选一款舒适的耳机了。别看市面上各种各样的运动耳机很多&#xff0c;但实际上能真正适合运动的少之又少&a…

oss服务端签名后直传分析与代码实现

文章目录1.简介1.1 普通上传方式1.2 服务端签名后直传3.服务端签名后直传文档3.1 用户向应用服务器请求上传Policy和回调。3.2 应用服务器返回上传Policy和签名给用户。3.3 用户使用Post方法向OSS发送文件上传请求。4.实战开发-后端4.1 pom.xml核心配置4.2 application.yml核心…

Java两大工具库:Commons和Guava(2)

您好&#xff0c;我是湘王&#xff0c;这是我的CSDN博客。值此新春佳节&#xff0c;我给您拜年啦&#xff5e;祝您在新的一年中所求皆所愿&#xff0c;所行皆坦途&#xff0c;展宏“兔”&#xff0c;有钱“兔”&#xff0c;多喜乐&#xff0c;常安宁&#xff01;开发中有一类应…

如何在es中查询null值

文章目录1、背景2、需求3、准备数据3.1 创建mapping3.2 插入数据4、查询 name字段为null的数据5、查询address不存在或值直接为null的数据6、参考链接1、背景 在我们向es中写入数据时&#xff0c;有些时候数据写入到es中的是null&#xff0c;或者没有写入这个字段&#xff0c;…

离散数学与组合数学-08谓词逻辑

文章目录离散数学与组合数学-08谓词逻辑8.1 谓词的引入8.1.1 引入谓词逻辑8.1.2 个体词与谓词8.2 量词的引入8.2.1 量词引入8.2.2 个体域符号化8.2.3 量词真值确定8.3 谓词符号化举例8.3.1 示例一8.3.2 示例二8.3.3 示例三8.3.4 示例四8.4 谓词合式公式8.4.1 四类符号8.4.2 项8…

MySQL运维(一)MySQL中的日志、Mysql主从复制

MySQL运维(一)MySQL中的日志、Mysql主从复制 1、MySQL日志 1.1 错误日志 错误日志是 MySQL 中最重要的日志之一&#xff0c;它记录了当 mysqld 启动和停止时&#xff0c;以及服务器在运行过程中发生任何严重错误时的相关信息。当数据库出现任何故障导致无法正常使用时&#…

Elasticsearch 需要了解的都在这

ES选主过程&#xff1f;其实ES的选主过程其实没有很高深的算法加持&#xff0c;启动过程中对接点的ID进行排序&#xff0c;取ID最大节点作为Master节点&#xff0c;那么如果选出来的主节点中存储的元信息不是最新的怎么办&#xff1f;其实他是分了2个步骤做这件事&#xff0c;先…

react 项目 中使用 Dllplugin 打包优化技巧

目录 0.React和DLLPlugin 前言 使用步骤 结果截图 主要说明 0.React和DLLPlugin React 是一个用于构建用户界面的 JavaScript 库。它由 Facebook 开发&#xff0c;现在由 Facebook 和一个由个人开发者和公司组成的社区维护。React 允许开发人员构建可重用的 UI 组件并有…