AI大模型探索之路-训练篇13:大语言模型Transformer库-Evaluate组件实践

news2024/11/17 20:29:42

系列篇章💥

AI大模型探索之路-训练篇1:大语言模型微调基础认知
AI大模型探索之路-训练篇2:大语言模型预训练基础认知
AI大模型探索之路-训练篇3:大语言模型全景解读
AI大模型探索之路-训练篇4:大语言模型训练数据集概览
AI大模型探索之路-训练篇5:大语言模型预训练数据准备-词元化
AI大模型探索之路-训练篇6:大语言模型预训练数据准备-预处理
AI大模型探索之路-训练篇7:大语言模型Transformer库之HuggingFace介绍
AI大模型探索之路-训练篇8:大语言模型Transformer库-预训练流程编码体验
AI大模型探索之路-训练篇9:大语言模型Transformer库-Pipeline组件实践
AI大模型探索之路-训练篇10:大语言模型Transformer库-Tokenizer组件实践
AI大模型探索之路-训练篇11:大语言模型Transformer库-Model组件实践
AI大模型探索之路-训练篇12:语言模型Transformer库-Datasets组件实践


目录

  • 系列篇章💥
  • 前言
  • 一、Evaluate组件介绍
  • 二、Evaluate API用法
    • 1、安装依赖
    • 2、查看支持的评估函数
    • 3、加载评估函数
    • 4、查看函数说明
    • 5、评估指标计算——全局计算
    • 6、评估指标计算——迭代计算
    • 7、多个评估指标计算
    • 8、评估结果对比可视化
  • 三、改造预训练代码
  • 总结


前言

在自然语言处理(NLP)技术的迅猛发展过程中,基于深度学习的模型逐渐成为了研究和工业界解决语言问题的主流工具。特别是Transformer模型,以其独特的自注意力机制和对长距离依赖的有效捕捉能力,在多个NLP任务中取得了革命性的突破。然而,随着模型变得越来越复杂,如何准确评估模型的性能,理解模型的优势与局限,以及指导进一步的模型优化,成为了一个挑战。
在这里插入图片描述

官网的API:https://huggingface.co/docs/evaluate/index
为了简化这一过程并推广到更广泛的开发者群体,大语言模型Transformer库被设计出来,它不仅提供了训练和部署模型的工具,还包含了Evaluate组件。Evaluate组件的出现,正是为了填补这一空白,使得评估模型性能变得更加直接、高效和标准化。它为研究人员和工程师们提供了一套系统的评价方案,从而能够更加深入地理解和改进他们的模型。接下来,我们将深入探讨Evaluate组件的功能及其在大语言模型Transformer库中的实践应用。

一、Evaluate组件介绍

在自然语言处理的深度学习模型开发流程中,评估(Evaluation)阶段是至关重要的一环。它不仅帮助我们量化模型的性能,还为模型的进一步优化提供了依据。Evaluate组件正是为此而设计,旨在提供一套标准化的评估流程和工具,帮助开发者和研究人员更加系统地分析模型表现。

Evaluate组件在大语言模型Transformer库中扮演着核心角色。它与训练和推理(Inference)组件紧密集成,确保了从模型开发到部署的每一个环节都能被有效监控和评价。以下是Evaluate组件的关键特性:
1)标准化指标:Evaluate组件提供了一系列标准化的评价指标,如准确率(Accuracy)、召回率(Recall)、精确度(Precision)和F1分数(F1 Score),以及更复杂的度量如混淆矩阵(Confusion Matrix)、编辑距离(Edit Distance)等。这些指标为模型性能提供了全面的量化分析。
2)模型-数据接口:该组件设计了明确的接口用于连接模型和评估数据。开发者可以轻松指定待评估的模型以及用于评估的数据集,使得评估过程既灵活又高效。
3)评估报告:通过evaluate函数完成评估后,Evaluate组件能够生成详细的评估报告。报告中包括各项指标的得分,以及模型在不同类别上的表现,甚至可以展现模型预测的实例结果,帮助用户直观理解模型的优势和不足。
4)可视化工具:除了文本形式的报告外,Evaluate组件还提供可视化工具,如混淆矩阵的图表展示。这些图表帮助开发者快速识别模型在特定类型的输入上的表现情况,从而针对性地进行改进。
5)错误分析:对于分类任务,Evaluate组件还能够进行错误分析,指出模型在哪些样本上出错,以及出错的可能原因,这对于调试和改善模型具有重要意义。
6)易于集成和扩展:Evaluate组件的设计允许它轻松集成到现有的模型开发流程中,并且可以根据具体任务需求进行扩展,比如添加新的评价指标或定制报告格式。

二、Evaluate API用法

1、安装依赖

pip install evaluate

2、查看支持的评估函数

import evaluate
evaluate.list_evaluation_modules()

输出

['lvwerra/test',
 'jordyvl/ece',
 'angelina-wang/directional_bias_amplification',
 'cpllab/syntaxgym',
 'lvwerra/bary_score',
 'hack/test_metric',
 'yzha/ctc_eval',
 'codeparrot/apps_metric',
 'mfumanelli/geometric_mean',
 'daiyizheng/valid',
 'erntkn/dice_coefficient',
 'mgfrantz/roc_auc_macro',
 'Vlasta/pr_auc',
 'gorkaartola/metric_for_tp_fp_samples',
 'idsedykh/metric',
 'idsedykh/codebleu2',
 'idsedykh/codebleu',
 'idsedykh/megaglue',
 'cakiki/ndcg',
 'Vertaix/vendiscore',
 'GMFTBY/dailydialogevaluate',
 'GMFTBY/dailydialog_evaluate',
 'jzm-mailchimp/joshs_second_test_metric',
 'ola13/precision_at_k',
 'yulong-me/yl_metric',
 'abidlabs/mean_iou',
 'abidlabs/mean_iou2',
 'KevinSpaghetti/accuracyk',
 'NimaBoscarino/weat',
 'ronaldahmed/nwentfaithfulness',
 'Viona/infolm',
 'kyokote/my_metric2',
 'kashif/mape',
 'Ochiroo/rouge_mn',
 'giulio98/code_eval_outputs',
 'leslyarun/fbeta_score',
 'giulio98/codebleu',
 'anz2/iliauniiccocrevaluation',
 'zbeloki/m2',
 'xu1998hz/sescore',
 'dvitel/codebleu',
 'NCSOFT/harim_plus',
 'JP-SystemsX/nDCG',
 'sportlosos/sescore',
 'Drunper/metrica_tesi',
 'jpxkqx/peak_signal_to_noise_ratio',
 'jpxkqx/signal_to_reconstruction_error',
 'hpi-dhc/FairEval',
 'lvwerra/accuracy_score',
 'ybelkada/cocoevaluate',
 'harshhpareek/bertscore',
 'posicube/mean_reciprocal_rank',
 'bstrai/classification_report',
 'omidf/squad_precision_recall',
 'Josh98/nl2bash_m',
 'BucketHeadP65/confusion_matrix',
 'BucketHeadP65/roc_curve',
 'yonting/average_precision_score',
 'transZ/test_parascore',
 'transZ/sbert_cosine',
 'hynky/sklearn_proxy',
 'xu1998hz/sescore_english_mt',
 'xu1998hz/sescore_german_mt',
 'xu1998hz/sescore_english_coco',
 'xu1998hz/sescore_english_webnlg',
 'unnati/kendall_tau_distance',
 'Viona/fuzzy_reordering',
 'Viona/kendall_tau',
 'lhy/hamming_loss',
 'lhy/ranking_loss',
 'Muennighoff/code_eval_octopack',
 'yuyijiong/quad_match_score',
 'Splend1dchan/cosine_similarity',
 'AlhitawiMohammed22/CER_Hu-Evaluation-Metrics',
 'Yeshwant123/mcc',
 'transformersegmentation/segmentation_scores',
 'sma2023/wil',
 'chanelcolgate/average_precision',
 'ckb/unigram',
 'Felipehonorato/eer',
 'manueldeprada/beer',
 'tialaeMceryu/unigram',
 'shunzh/apps_metric',
 'He-Xingwei/sari_metric',
 'langdonholmes/cohen_weighted_kappa',
 'fschlatt/ner_eval',
 'hyperml/balanced_accuracy',
 'brian920128/doc_retrieve_metrics',
 'guydav/restrictedpython_code_eval',
 'k4black/codebleu',
 'Natooz/ece',
 'ingyu/klue_mrc',
 'Vipitis/shadermatch',
 'unitxt/metric',
 'gabeorlanski/bc_eval',
 'jjkim0807/code_eval',
 'vichyt/metric-codebleu',
 'repllabs/mean_reciprocal_rank',
 'repllabs/mean_average_precision',
 'mtc/fragments',
 'DarrenChensformer/eval_keyphrase',
 'kedudzic/charmatch',
 'Vallp/ter',
 'DarrenChensformer/relation_extraction',
 'Ikala-allen/relation_extraction',
 'danieldux/hierarchical_softmax_loss',
 'nlpln/tst',
 'bdsaglam/jer',
 'fnvls/bleu1234',
 'fnvls/bleu_1234',
 'nevikw39/specificity',
 'yqsong/execution_accuracy',
 'shalakasatheesh/squad_v2',
 'arthurvqin/pr_auc',
 'd-matrix/dmx_perplexity',
 'lvwerra/test',
 'jordyvl/ece',
 'angelina-wang/directional_bias_amplification',
 'cpllab/syntaxgym',
 'lvwerra/bary_score',
 'hack/test_metric',
 'yzha/ctc_eval',
 'codeparrot/apps_metric',
 'mfumanelli/geometric_mean',
 'daiyizheng/valid',
 'erntkn/dice_coefficient',
 'mgfrantz/roc_auc_macro',
 'Vlasta/pr_auc',
 'gorkaartola/metric_for_tp_fp_samples',
 'idsedykh/metric',
 'idsedykh/codebleu2',
 'idsedykh/codebleu',
 'idsedykh/megaglue',
 'cakiki/ndcg',
 'Vertaix/vendiscore',
 'GMFTBY/dailydialogevaluate',
 'GMFTBY/dailydialog_evaluate',
 'jzm-mailchimp/joshs_second_test_metric',
 'ola13/precision_at_k',
 'yulong-me/yl_metric',
 'abidlabs/mean_iou',
 'abidlabs/mean_iou2',
 'KevinSpaghetti/accuracyk',
 'NimaBoscarino/weat',
 'ronaldahmed/nwentfaithfulness',
 'Viona/infolm',
 'kyokote/my_metric2',
 'kashif/mape',
 'Ochiroo/rouge_mn',
 'giulio98/code_eval_outputs',
 'leslyarun/fbeta_score',
 'giulio98/codebleu',
 'anz2/iliauniiccocrevaluation',
 'zbeloki/m2',
 'xu1998hz/sescore',
 'dvitel/codebleu',
 'NCSOFT/harim_plus',
 'JP-SystemsX/nDCG',
 'sportlosos/sescore',
 'Drunper/metrica_tesi',
 'jpxkqx/peak_signal_to_noise_ratio',
 'jpxkqx/signal_to_reconstruction_error',
 'hpi-dhc/FairEval',
 'lvwerra/accuracy_score',
 'ybelkada/cocoevaluate',
 'harshhpareek/bertscore',
 'posicube/mean_reciprocal_rank',
 'bstrai/classification_report',
 'omidf/squad_precision_recall',
 'Josh98/nl2bash_m',
 'BucketHeadP65/confusion_matrix',
 'BucketHeadP65/roc_curve',
 'yonting/average_precision_score',
 'transZ/test_parascore',
 'transZ/sbert_cosine',
 'hynky/sklearn_proxy',
 'xu1998hz/sescore_english_mt',
 'xu1998hz/sescore_german_mt',
 'xu1998hz/sescore_english_coco',
 'xu1998hz/sescore_english_webnlg',
 'unnati/kendall_tau_distance',
 'Viona/fuzzy_reordering',
 'Viona/kendall_tau',
 'lhy/hamming_loss',
 'lhy/ranking_loss',
 'Muennighoff/code_eval_octopack',
 'yuyijiong/quad_match_score',
 'Splend1dchan/cosine_similarity',
 'AlhitawiMohammed22/CER_Hu-Evaluation-Metrics',
 'Yeshwant123/mcc',
 'transformersegmentation/segmentation_scores',
 'sma2023/wil',
 'chanelcolgate/average_precision',
 'ckb/unigram',
 'Felipehonorato/eer',
 'manueldeprada/beer',
 'tialaeMceryu/unigram',
 'shunzh/apps_metric',
 'He-Xingwei/sari_metric',
 'langdonholmes/cohen_weighted_kappa',
 'fschlatt/ner_eval',
 'hyperml/balanced_accuracy',
 'brian920128/doc_retrieve_metrics',
 'guydav/restrictedpython_code_eval',
 'k4black/codebleu',
 'Natooz/ece',
 'ingyu/klue_mrc',
 'Vipitis/shadermatch',
 'unitxt/metric',
 'gabeorlanski/bc_eval',
 'jjkim0807/code_eval',
 'vichyt/metric-codebleu',
 'repllabs/mean_reciprocal_rank',
 'repllabs/mean_average_precision',
 'mtc/fragments',
 'DarrenChensformer/eval_keyphrase',
 'kedudzic/charmatch',
 'Vallp/ter',
 'DarrenChensformer/relation_extraction',
 'Ikala-allen/relation_extraction',
 'danieldux/hierarchical_softmax_loss',
 'nlpln/tst',
 'bdsaglam/jer',
 'fnvls/bleu1234',
 'fnvls/bleu_1234',
 'nevikw39/specificity',
 'yqsong/execution_accuracy',
 'shalakasatheesh/squad_v2',
 'arthurvqin/pr_auc',
 'd-matrix/dmx_perplexity',
 'lvwerra/test',
 'jordyvl/ece',
 'angelina-wang/directional_bias_amplification',
 'cpllab/syntaxgym',
 'lvwerra/bary_score',
 'hack/test_metric',
 'yzha/ctc_eval',
 'codeparrot/apps_metric',
 'mfumanelli/geometric_mean',
 'daiyizheng/valid',
 'erntkn/dice_coefficient',
 'mgfrantz/roc_auc_macro',
 'Vlasta/pr_auc',
 'gorkaartola/metric_for_tp_fp_samples',
 'idsedykh/metric',
 'idsedykh/codebleu2',
 'idsedykh/codebleu',
 'idsedykh/megaglue',
 'cakiki/ndcg',
 'Vertaix/vendiscore',
 'GMFTBY/dailydialogevaluate',
 'GMFTBY/dailydialog_evaluate',
 'jzm-mailchimp/joshs_second_test_metric',
 'ola13/precision_at_k',
 'yulong-me/yl_metric',
 'abidlabs/mean_iou',
 'abidlabs/mean_iou2',
 'KevinSpaghetti/accuracyk',
 'NimaBoscarino/weat',
 'ronaldahmed/nwentfaithfulness',
 'Viona/infolm',
 'kyokote/my_metric2',
 'kashif/mape',
 'Ochiroo/rouge_mn',
 'giulio98/code_eval_outputs',
 'leslyarun/fbeta_score',
 'giulio98/codebleu',
 'anz2/iliauniiccocrevaluation',
 'zbeloki/m2',
 'xu1998hz/sescore',
 'dvitel/codebleu',
 'NCSOFT/harim_plus',
 'JP-SystemsX/nDCG',
 'sportlosos/sescore',
 'Drunper/metrica_tesi',
 'jpxkqx/peak_signal_to_noise_ratio',
 'jpxkqx/signal_to_reconstruction_error',
 'hpi-dhc/FairEval',
 'lvwerra/accuracy_score',
 'ybelkada/cocoevaluate',
 'harshhpareek/bertscore',
 'posicube/mean_reciprocal_rank',
 'bstrai/classification_report',
 'omidf/squad_precision_recall',
 'Josh98/nl2bash_m',
 'BucketHeadP65/confusion_matrix',
 'BucketHeadP65/roc_curve',
 'yonting/average_precision_score',
 'transZ/test_parascore',
 'transZ/sbert_cosine',
 'hynky/sklearn_proxy',
 'xu1998hz/sescore_english_mt',
 'xu1998hz/sescore_german_mt',
 'xu1998hz/sescore_english_coco',
 'xu1998hz/sescore_english_webnlg',
 'unnati/kendall_tau_distance',
 'Viona/fuzzy_reordering',
 'Viona/kendall_tau',
 'lhy/hamming_loss',
 'lhy/ranking_loss',
 'Muennighoff/code_eval_octopack',
 'yuyijiong/quad_match_score',
 'Splend1dchan/cosine_similarity',
 'AlhitawiMohammed22/CER_Hu-Evaluation-Metrics',
 'Yeshwant123/mcc',
 'transformersegmentation/segmentation_scores',
 'sma2023/wil',
 'chanelcolgate/average_precision',
 'ckb/unigram',
 'Felipehonorato/eer',
 'manueldeprada/beer',
 'tialaeMceryu/unigram',
 'shunzh/apps_metric',
 'He-Xingwei/sari_metric',
 'langdonholmes/cohen_weighted_kappa',
 'fschlatt/ner_eval',
 'hyperml/balanced_accuracy',
 'brian920128/doc_retrieve_metrics',
 'guydav/restrictedpython_code_eval',
 'k4black/codebleu',
 'Natooz/ece',
 'ingyu/klue_mrc',
 'Vipitis/shadermatch',
 'unitxt/metric',
 'gabeorlanski/bc_eval',
 'jjkim0807/code_eval',
 'vichyt/metric-codebleu',
 'repllabs/mean_reciprocal_rank',
 'repllabs/mean_average_precision',
 'mtc/fragments',
 'DarrenChensformer/eval_keyphrase',
 'kedudzic/charmatch',
 'Vallp/ter',
 'DarrenChensformer/relation_extraction',
 'Ikala-allen/relation_extraction',
 'danieldux/hierarchical_softmax_loss',
 'nlpln/tst',
 'bdsaglam/jer',
 'fnvls/bleu1234',
 'fnvls/bleu_1234',
 'nevikw39/specificity',
 'yqsong/execution_accuracy',
 'shalakasatheesh/squad_v2',
 'arthurvqin/pr_auc',
 'd-matrix/dmx_perplexity']

3、加载评估函数

"""
准确率(accuracy)是一种评估分类模型性能的指标。它是正确预测的数量与总预测数量的比值。
对于二分类问题,其计算公式为(TP+TN) / (TP+TN+FP+FN),其中:

TP(True Positive):真正,实际为正样本,预测也为正样本
TN(True Negative):真负,实际为负样本,预测也为负样本
FP(False Positive):假正,实际为负样本,预测却为正样本
FN(False Negative):假负,实际为正样本,预测却为负样本

"""

accuracy = evaluate.load("accuracy")

在这里插入图片描述

4、查看函数说明

1)打印出准确率(accuracy)的描述信息

print(accuracy.description)

输出

Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with:
Accuracy = (TP + TN) / (TP + TN + FP + FN)
 Where:
TP: True positive
TN: True negative
FP: False positive
FN: False negative

2)打印出准确率(accuracy)的输入描述信息

print(accuracy.inputs_description)

输出

Args:
    predictions (`list` of `int`): Predicted labels.
    references (`list` of `int`): Ground truth labels.
    normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True.
    sample_weight (`list` of `float`): Sample weights Defaults to None.

Returns:
    accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy.

Examples:

    Example 1-A simple example
        >>> accuracy_metric = evaluate.load("accuracy")
        >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
        >>> print(results)
        {'accuracy': 0.5}

    Example 2-The same as Example 1, except with `normalize` set to `False`.
        >>> accuracy_metric = evaluate.load("accuracy")
        >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], normalize=False)
        >>> print(results)
        {'accuracy': 3.0}

    Example 3-The same as Example 1, except with `sample_weight` set.
        >>> accuracy_metric = evaluate.load("accuracy")
        >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], sample_weight=[0.5, 2, 0.7, 0.5, 9, 0.4])
        >>> print(results)
        {'accuracy': 0.8778625954198473}

3)查看:accuracy

accuracy

输出

EvaluationModule(name: "accuracy", module_type: "metric", features: {'predictions': Value(dtype='int32', id=None), 'references': Value(dtype='int32', id=None)}, usage: """
Args:
    predictions (`list` of `int`): Predicted labels.
    references (`list` of `int`): Ground truth labels.
    normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True.
    sample_weight (`list` of `float`): Sample weights Defaults to None.

Returns:
    accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy.

Examples:

    Example 1-A simple example
        >>> accuracy_metric = evaluate.load("accuracy")
        >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
        >>> print(results)
        {'accuracy': 0.5}

    Example 2-The same as Example 1, except with `normalize` set to `False`.
        >>> accuracy_metric = evaluate.load("accuracy")
        >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], normalize=False)
        >>> print(results)
        {'accuracy': 3.0}

    Example 3-The same as Example 1, except with `sample_weight` set.
        >>> accuracy_metric = evaluate.load("accuracy")
        >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], sample_weight=[0.5, 2, 0.7, 0.5, 9, 0.4])
        >>> print(results)
        {'accuracy': 0.8778625954198473}
""", stored examples: 0)

5、评估指标计算——全局计算

一次性计算准确率(accuracy)。具体来说,它首先加载了一个名为"accuracy"的评估对象,然后使用compute方法计算准确率。在调用compute方法时,传入了两个列表作为参数:references和predictions。其中,references表示参考值,predictions表示预测值。最后,将计算结果存储在results变量中并返回。

accuracy = evaluate.load("accuracy")
results = accuracy.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
results

{‘accuracy’: 0.5}

6、评估指标计算——迭代计算

1)采用循环迭代计算准确率(accuracy)

accuracy = evaluate.load("accuracy")
for ref, pred in zip([0,1,0,1], [1,0,0,1]):
    accuracy.add(references=ref, predictions=pred)
accuracy.compute()

{‘accuracy’: 0.5}

2)使用add_batch方法将每个元组添加到评估对象中。最后,调用compute方法计算准确率并返回结果

accuracy = evaluate.load("accuracy")
for refs, preds in zip([[0,1],[0,1]], [[1,0],[0,1]]):
    accuracy.add_batch(references=refs, predictions=preds)
accuracy.compute()

{‘accuracy’: 0.5}

7、多个评估指标计算

"""
F1分数(F1 Score)是精确率(Precision)和召回率(Recall)的调和平均值。它综合考虑了模型的精确率和召回率,用于评估模型的整体性能。
其数学公式为:2 * (Precision * Recall) / (Precision + Recall)。
举个例子来解释一下:
假设我们有一个用来预测疾病的模型,我们测试了100个人,其中实际有病的人有50人,没有病的人有50人。
这个模型预测出有病的人有60人,其中实际有病的人有40人(即真正例TP),预测病人中实际健康的人有20人(即假正例FP);
它预测的健康人有40人,其中实际健康的人有30人(即真负例TN),预测健康中实际有病的人有10人(即假负例FN)。
此时,精确率Precision(预测为正的样本中实际为正的比例)为TP / (TP + FP) = 40 / (40 + 20) = 0.67,
召回率Recall(实际为正的样本中预测为正的比例)为TP / (TP + FN) = 40 / (40 + 10) = 0.8。
将精确率和召回率代入F1分数的公式,我们得到F1 = 2 * (0.67 * 0.8) / (0.67 + 0.8) = 0.73。这个值越接近1,表示模型的性能越好。

"""
clf_metrics = evaluate.combine(["accuracy", "f1", "recall", "precision"])
clf_metrics

在这里插入图片描述

#计算预测结果的准确率、F1分数、召回率和精确度
clf_metrics.compute(predictions=[0, 1, 0], references=[0, 1, 1])

输出

{'accuracy': 0.6666666666666666,
 'f1': 0.6666666666666666,
 'recall': 0.5,
 'precision': 1.0}

8、评估结果对比可视化

from evaluate.visualization import radar_plot

data = [
   {"accuracy": 0.99, "precision": 0.8, "f1": 0.95, "latency_in_seconds": 33.6},
   {"accuracy": 0.98, "precision": 0.87, "f1": 0.91, "latency_in_seconds": 11.2},
   {"accuracy": 0.98, "precision": 0.78, "f1": 0.88, "latency_in_seconds": 87.6}, 
   {"accuracy": 0.88, "precision": 0.78, "f1": 0.81, "latency_in_seconds": 101.6}
   ]
model_names = ["Model 1", "Model 2", "Model 3", "Model 4"]

plot = radar_plot(data=data, model_names=model_names)

在这里插入图片描述

三、改造预训练代码

改造前面篇章《AI大模型探索之路-训练篇8:大语言模型Transformer库-预训练流程编码体验》中的第七步中的评估的代码;原来代码如下:

def evaluate():
    ## 将模型设置为评估模式
    model.eval()
    acc_num=0
    #将训练模型转化为推理模型,模型将使用转换后的推理模式进行评估
    with torch.inference_mode():
        for batch in validloader:
            ## 检查是否有可用的GPU,如果有,则将数据批次转移到GPU上进行加速
            if torch.cuda.is_available():
                batch = {k: v.cuda() for k,v in batch.items()}
            ##对数据批次进行前向传播,得到模型的输出
            output = model(**batch)
            ## 对模型输出进行预测,通过torch.argmax选择概率最高的类别。
            pred = torch.argmax(output.logits,dim=-1)
            ## 计算正确预测的数量,将预测值与标签进行比较,并使用.float()将比较结果转换为浮点数,使用.sum()进行求和操作
            acc_num += (pred.long() == batch["labels"].long()).float().sum()
    ## 返回正确预测数量与验证集样本数量的比值,这表示模型在验证集上的准确率
    return acc_num / len(validset)

原来是通过计算正确预测的数量与验证集样本数量的比值来得到模型在验证集上的准确率
改造后代码如下:改造后的代码使用了evaluate库来组合多个评估指标,并计算它们的值。

import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1"])

def evaluate():
    model.eval()
    with torch.inference_mode():
        for batch in validloader:
            if torch.cuda.is_available():
                batch = {k: v.cuda() for k, v in batch.items()}
            output = model(**batch)
            pred = torch.argmax(output.logits, dim=-1)
            clf_metrics.add_batch(predictions=pred.long(), references=batch["labels"].long())
    return clf_metrics.compute()

总结

在本文中,我们详细介绍了大语言模型Transformer库中的Evaluate组件。Evaluate组件的出现填补了评估模型性能的空白,使得评估过程更加直接、高效和标准化。它为研究人员和工程师们提供了一套系统的评价方案,从而能够更加深入地理解和改进他们的模型。
Evaluate组件在大语言模型Transformer库中扮演着核心角色。它与训练和推理(Inference)组件紧密集成,确保了从模型开发到部署的每一个环节都能被有效监控和评价

在这里插入图片描述

🎯🔖更多专栏系列文章:AIGC-AI大模型探索之路

如果文章内容对您有所触动,别忘了点赞、⭐关注,收藏!加入我,让我们携手同行AI的探索之旅,一起开启智能时代的大门!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1642789.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2021-11-10 51蛋骗鸡单数码0-9按键条件速度逢9蜂鸣

缘由用Proteus设计单片机控制1只数码管&#xff0c;循环显示0~9。循环显示间隔可手动按钮调整。每当到9时&#xff0c;发出声音提示-编程语言-CSDN问答 #include "REG52.h" #include<intrins.h> sbit K1 P3^0; sbit K2 P3^1; sbit K3 P3^2; sbit K4 P3^3;…

论文笔记:(Security 22) 关于“二进制函数相似性检测”的调研

个人博客链接 注&#xff1a;部分内容参考自GPT生成的内容 [Security 22] 关于”二进制函数相似性检测“的调研&#xff08;个人阅读笔记&#xff09; 论文&#xff1a;《How Machine Learning Is Solving the Binary Function Similarity Problem》&#xff08;Usenix Securi…

【面试经典 150 | 数组】文本左右对齐

文章目录 写在前面Tag题目来源解题思路方法一&#xff1a;模拟 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法&#xff0c;两到三天更新一篇文章&#xff0c;欢迎催更…… 专栏内容以分析题目为主&#xff0c;并附带一些对于本题涉及到的数据结构等内容进行回顾…

Git系列:config 配置

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

自定义数据上的YOLOv9分割训练

原文地址&#xff1a;yolov9-segmentation-training-on-custom-data 2024 年 4 月 16 日 在飞速发展的计算机视觉领域&#xff0c;物体分割在从图像中提取有意义的信息方面起着举足轻重的作用。在众多分割算法中&#xff0c;YOLOv9 是一种稳健且适应性强的解决方案&#xff0…

车载电子电器架构 —— 关于bus off汇总

车载电子电器架构 —— 关于bus off汇总 我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证明…

【C++题解】1660. 今天要上课吗

问题&#xff1a;1660. 今天要上课吗 类型&#xff1a;分支结构 题目描述&#xff1a; 暑假来了&#xff0c;晶晶报了自己心仪已久的游泳课&#xff0c;非常开心&#xff0c;老师告诉晶晶每周周一、周三、周五、周六四天都要上课的&#xff0c;晶晶担心自己会忘记&#xff0c…

AI-数学-高中52-离散型随机变量概念及其分布列、两点分布

原作者视频&#xff1a;【随机变量】【一数辞典】2离散型随机变量及其分布列_哔哩哔哩_bilibili 离散型随机变量分布列&#xff1a;X表示离散型随机变量可能在取值&#xff0c;P:对应分布在概率&#xff0c;P括号里X1表示事件的名称。 示例&#xff1a;

近似消息传递算法(AMP)单测量模型(SMV)

1、算法解决问题 很多人致力于解决SLM模型的求逆问题&#xff0c;即知道观测值和测量矩阵&#xff08;字典之类的&#xff09;&#xff0c;要求未知变量的值。SLM又叫做标准线性模型&#xff0c;后续又在此基础上进行升级变为广义线性模型。即SLM是yAxe&#xff0c;这里是线性…

循环神经网络完整实现(Pytorch 13)

一 循环神经网络的从零开始实现 从头开始基于循环神经网络实现字符级语言模型。 %matplotlib inline import math import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2lbatch_size, num_steps 32, 35 train_iter, vocab …

【ARM Cortex-M3指南】3:Cortex-M3基础

文章目录 三、Cortex-M3基础3.1 寄存器3.1.1 通用目的寄存器 R0~R73.1.2 通用目的寄存器 R8~R123.1.3 栈指针 R133.1.4 链接寄存器 R143.1.5 程序计数器 R15 3.2 特殊寄存器3.2.1 程序状态寄存器3.2.2 PRIMASK、FAULTMASK和BASEPRI寄存器3.2.3 控制寄存器 3.3 操作模式3.4 异常…

缓冲流,BufferReader,BufferWriter,案例

IO流的体系 字节缓冲流的作用 提高字节流读取数据的性能 *原理&#xff1a;字节缓冲输入流自带了8Kb的缓冲池&#xff0c;字节缓冲输出流也自带了8kb的缓冲池 构造器说明public BufferedInputStream(InputStream is)把低级的字节输入流包装成一个高级的缓冲字节输入流&#…

RabbitMQ之顺序消费

什么是顺序消费 例如&#xff1a;业务上产生者发送三条消息&#xff0c; 分别是对同一条数据的增加、修改、删除操作&#xff0c; 如果没有保证顺序消费&#xff0c;执行顺序可能变成删除、修改、增加&#xff0c;这就乱了。 如何保证顺序性 一般我们讨论如何保证消息的顺序性&…

【Linux】进程exec函数族以及守护进程

一.exec函数族 1.exec函数族的应用 在shell下敲shell的命令都是在创建shell的子进程。而我们之前学的创建父进程和子进程代码内容以及通过pid与0的关系来让父子进程执行不同的代码内容都是在一个代码文件里面&#xff0c;而shell是如何做到不在一个文件里面写代码使之成为子进…

4- 29

五六月安排 5.12江苏CPC 6.2、6.16、6.30三场百度之星省赛 6月蓝桥杯国赛 7.15 睿抗编程赛道省赛 5 6月两个科创需要申请完软著。 网络技术挑战赛过了资格赛&#xff0c;下面不知道怎么搞&#xff0c;如果参加需要花费很多的时间。 1.100个英语单词一篇阅读&#xff0c;讲了文…

Docker Compose 部署若依前后端分离版

准备一台服务器 本次使用虚拟机&#xff0c;虚拟机系统 Ubuntu20.04&#xff0c;内存 4G&#xff0c;4核。 确保虚拟机能连接互联网。 Ubuntu20.04 安装 Docker 添加 Docker 的官方 GPG key&#xff1a; sudo apt-get update sudo apt-get install ca-certificates curl su…

Hibernate的QBC与HQL查询

目录 1、Hibernate的QBC查询 2、Hibernate的HQL查询 3、NatvieSQL原生查询 1、Hibernate的QBC查询 Hibernate具有一个直观的、可扩展的条件查询API public class Test { /** * param args */ public static void main(String[] args) { Session sessio…

【八股】AQS,ReentrantLock实现原理

AQS 概念 AQS 的全称是 AbstractQueuedSynchronized &#xff08;抽象队列同步器&#xff09;&#xff0c;在java.util.concurrent.locks包下面。 AQS是一个抽象类&#xff0c;主要用来构建锁和同步器&#xff0c;比如ReentrantLock, Semaphore, CountDownLatch&#xff0c;里…

安卓LayoutParams浅析

目录 前言一、使用 LayoutParams 设置宽高二、不设置 LayoutParams2.1 TextView 的 LayoutParams2.2 LinearLayout 的 LayoutParams 三、getLayoutParams 的使用四、setLayoutParams 的作用五、使用 setWidth/setHeight 设置宽高 前言 先来看一个简单的布局&#xff0c;先用 x…

Jackson-jr 对比 Jackson

关于Jackson-jr 对比 Jackson 的内容&#xff0c;有人在做了一张下面的图。 简单点来说就 Jackson-jr 是Jackson 的轻量级应用&#xff0c;因为我们在很多时候都用不到 Jackson 的很多复杂功能。 对很多应用来说&#xff0c;我们可能只需要使用简单的 JSON 读写即可。 如我们…