torchmetrics,一个无敌的 Python 库!

news2025/2/25 17:22:31

21380339fdd5986086353c6c0e81b880.png

更多Python学习内容:ipengtao.com

大家好,今天为大家分享一个无敌的 Python 库 - torchmetrics。

Github地址:https://github.com/Lightning-AI/torchmetrics


在深度学习和机器学习项目中,模型评估是一个至关重要的环节。为了准确地评估模型的性能,开发者通常需要计算各种指标(metrics),如准确率、精确率、召回率、F1 分数等。torchmetrics 是一个用于 PyTorch 的开源库,提供了一组方便且高效的评估指标计算工具。本文将详细介绍 torchmetrics 库,包括其安装方法、主要特性、基本和高级功能,以及实际应用场景,帮助全面了解并掌握该库的使用。

安装

要使用 torchmetrics 库,首先需要安装它。可以通过 pip 工具方便地进行安装。

以下是安装步骤:

pip install torchmetrics

安装完成后,可以通过导入 torchmetrics 库来验证是否安装成功:

import torchmetrics
print("torchmetrics 库安装成功!")

特性

  1. 广泛的指标支持:提供多种评估指标,包括分类、回归、图像处理和生成模型等领域的常用指标。

  2. 模块化设计:指标可以像模块一样轻松集成到 PyTorch Lightning 或任何 PyTorch 项目中。

  3. GPU 加速:支持 GPU 加速,能够高效处理大规模数据。

  4. 易于扩展:用户可以自定义指标并轻松集成到现有项目中。

  5. 高效计算:优化的计算方法,确保在训练过程中实时计算指标,性能开销最小。

基本功能

计算准确率

使用 torchmetrics 库,可以方便地计算分类任务的准确率。

import torch
import torchmetrics

# 创建 Accuracy 指标
accuracy = torchmetrics.Accuracy()

# 模拟预测和真实标签
preds = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])

# 计算准确率
acc = accuracy(preds, target)
print(f"准确率:{acc}")

计算精确率和召回率

torchmetrics 库可以计算分类任务的精确率和召回率。

import torch
import torchmetrics

# 创建 Precision 和 Recall 指标
precision = torchmetrics.Precision(num_classes=4)
recall = torchmetrics.Recall(num_classes=4)

# 模拟预测和真实标签
preds = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])

# 计算精确率和召回率
prec = precision(preds, target)
rec = recall(preds, target)
print(f"精确率:{prec}")
print(f"召回率:{rec}")

计算 F1 分数

torchmetrics 库还可以计算分类任务的 F1 分数。

import torch
import torchmetrics

# 创建 F1 指标
f1 = torchmetrics.F1(num_classes=4)

# 模拟预测和真实标签
preds = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])

# 计算 F1 分数
f1_score = f1(preds, target)
print(f"F1 分数:{f1_score}")

高级功能

自定义指标

torchmetrics 库允许用户自定义指标,以满足特定需求。

import torch
import torchmetrics

class CustomMetric(torchmetrics.Metric):
    def __init__(self):
        super().__init__()
        self.add_state("sum", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        self.sum += torch.sum(preds == target)
        self.count += target.numel()

    def compute(self):
        return self.sum.float() / self.count

# 创建自定义指标
custom_metric = CustomMetric()

# 模拟预测和真实标签
preds = torch.tensor([0, 2, 1, 3])
target = torch.tensor([0, 1, 2, 3])

# 计算自定义指标
result = custom_metric(preds, target)
print(f"自定义指标结果:{result}")

与 PyTorch Lightning 集成

torchmetrics 库可以无缝集成到 PyTorch Lightning 中,简化指标计算流程。

import torch
import torchmetrics
import pytorch_lightning as pl
from torch import nn

class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Linear(10, 4)
        self.accuracy = torchmetrics.Accuracy()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = nn.functional.cross_entropy(preds, y)
        acc = self.accuracy(preds, y)
        self.log('train_acc', acc)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

# 示例数据
train_data = torch.utils.data.TensorDataset(torch.randn(100, 10), torch.randint(0, 4, (100,)))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32)

# 训练模型
model = LitModel()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_loader)

GPU 加速

torchmetrics 库支持 GPU 加速,可以在 GPU 上高效地计算指标。

import torch
import torchmetrics

# 创建 Accuracy 指标并移动到 GPU
accuracy = torchmetrics.Accuracy().cuda()

# 模拟预测和真实标签并移动到 GPU
preds = torch.tensor([0, 2, 1, 3]).cuda()
target = torch.tensor([0, 1, 2, 3]).cuda()

# 计算准确率
acc = accuracy(preds, target)
print(f"准确率:{acc}")

实际应用场景

图像分类任务中的指标计算

在图像分类任务中,需要计算各种评估指标,如准确率、精确率、召回率等。

import torch
import torchmetrics
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# 加载数据
transform = transforms.Compose([transforms.ToTensor()])
train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

# 创建模型和指标
model = models.resnet18(num_classes=10)
accuracy = torchmetrics.Accuracy()

# 训练模型并计算准确率
for inputs, targets in train_loader:
    outputs = model(inputs)
    acc = accuracy(outputs, targets)
    print(f"批次准确率:{acc}")

文本分类任务中的指标计算

在文本分类任务中,需要计算评估指标,如 F1 分数。

import torch
import torchmetrics
from transformers import BertTokenizer, BertForSequenceClassification

# 加载模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# 示例数据
texts = ["I love this!", "This is bad."]
labels = torch.tensor([1, 0])

# 预处理数据
inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
outputs = model(**inputs)

# 创建 F1 指标
f1 = torchmetrics.F1(num_classes=2)

# 计算 F1 分数
preds = torch.argmax(outputs.logits, dim=1)
f1_score = f1(preds, labels)
print(f"F1 分数:{f1_score}")

生成对抗网络(GAN)中的指标计算

在生成对抗网络(GAN)的训练中,需要计算生成图片的质量指标,如 Frechet Inception Distance(FID)。

import torch
import torchmetrics
from torchvision.models import inception_v3
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, TensorDataset

# 创建生成对抗网络(GAN)的生成器模型
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = torch.nn.Linear(100, 128 * 7 * 7)
        self.deconv = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1),
            torch.nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x).view(-1, 128, 7, 7)
        return self.deconv(x)

# 创建生成器模型
generator = Generator()

# 创建 FID 指标
fid = torchmetrics.image.fid.FrechetInceptionDistance(feature=64)

# 模拟生成图片和真实图片
latent_vectors = torch.randn(100, 100)
generated_images = generator(latent_vectors)
real_images = torch.randn(100, 1, 28, 28)

# 转换图片为 Inception V3 输入格式
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.Normalize(mean=[0.5], std=[0.5])
])
generated_images = transform(generated_images)
real_images = transform(real_images)

# 创建 DataLoader
generated_loader = DataLoader(TensorDataset(generated_images), batch_size=32)
real_loader = DataLoader(TensorDataset(real_images), batch_size=32)

# 计算 FID
for gen_batch, real_batch in zip(generated_loader, real_loader):
    fid.update(real_batch[0], gen_batch[0])

fid_value = fid.compute()
print(f"FID 分数:{fid_value}")

总结

torchmetrics 库是一个功能强大且易于使用的评估指标计算工具,能够帮助开发者在深度学习和机器学习项目中高效地计算各种评估指标。通过支持广泛的指标、多种计算模式、GPU 加速和自定义扩展,torchmetrics 库能够满足各种复杂的评估需求。本文详细介绍了 torchmetrics 库的安装方法、主要特性、基本和高级功能,以及实际应用场景。希望本文能帮助大家全面掌握 torchmetrics 库的使用,并在实际项目中发挥其优势。

如果你觉得文章还不错,请大家 点赞、分享、留言 下,因为这将是我持续输出更多优质文章的最强动力!


如果想要系统学习Python、Python问题咨询,或者考虑做一些工作以外的副业,都可以扫描二维码添加微信,围观朋友圈一起交流学习。

b6e9de34fef3fa667b741f36ce203968.gif

我们还为大家准备了Python资料和副业项目合集,感兴趣的小伙伴快来找我领取一起交流学习哦!

a49ac0374b3473b26fe3180dd7267e58.jpeg

往期推荐

历时一个月整理的 Python 爬虫学习手册全集PDF(免费开放下载)

Python基础学习常见的100个问题.pdf(附答案)

学习 数据结构与算法,这是我见过最友好的教程!(PDF免费下载)

Python办公自动化完全指南(免费PDF)

Python Web 开发常见的100个问题.PDF

肝了一周,整理了Python 从0到1学习路线(附思维导图和PDF下载)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1810492.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Springboot结合redis实现关注推送

关注推送 Feed流的模式 Timeline:不做内容筛选,简单的按照内容发布时间排序。常用于好友与关注。例如朋友圈的时间发布排序。 优点:信息全面,不会有缺失。并且实现也相对简单 缺点:信息噪音较多,用户不一定感兴趣,内容获取效率…

打造精细化运维新玩法(三)

实践SLO,概括下就是在相对标准、统一的框架下指导和推动服务质量的数字化建设,形成对组织有价值的数据资产和流程规范。借用在人工智能和机器学习领域的观点,算法的上限受限于数据质量的好坏,所以从源头上建设高质量的数据非常重要…

【电赛】STM32-PID直流减速电机小车【寻迹+避障+跟随】【更新ing】

一.需求分析 1.主控:STM32C8T6(没什么好说的哈哈) 2.电机:JAG25-370电机 【问】为什么要用直流减速电机?? PID控制器需要依靠精确的反馈信号来调整其输出,确保电机按照预定的速度和位置运行…

独立游戏之路:Tap篇 -- Unity 集成 TapTap 广告详细步骤

Unity 集成 TapADN 广告详细步骤 前言一、TapTap 广告介绍二、集成 TapTap 广告的步骤2.1 进入广告后台2.2 创建广告计划2.3 选择广告类型三、代码集成3.1 下载SDK3.2 工程配置3.3 源码分享四、常见问题4.1 有展现量没有预估收益 /eCPM 波动大?4.2 新建正式媒体找不到预约游戏…

介绍Linux

目录 1.什么是操作系统 2.现实生活中的操作系统 3.操作系统的发展史 4.操作系统的发展 Linux的不同版本以及应用领域 1.Linux内核及发行版介绍 <1>Linux内核版本 <2>Linux发行版本 2.应用领域 个⼈桌⾯领域的应⽤ 服务器领域 嵌⼊式领域 3.文件和目录 …

HDFS 读写数据流程

优质博文&#xff1a;IT-BLOG-CN 一、HDFS 写数据流程 HDFS 文件写入流程图如下&#xff1a;三个模块&#xff08;客户端、NameNode、DataNode&#xff09; 【1】校验&#xff1a; 客户端通过 DistributedFileSystem 模块向 NameNode 请求上传文件&#xff0c;NameNode 会检…

Vue 面试通杀秘籍

理论篇&#xff1a; 1. 说说对 Vue 渐进式框架的理解&#xff08;腾讯医典&#xff09; a) 渐进式的含义&#xff1a; 主张最少, 没有多做职责之外的事 b) Vue 有些方面是不如 React&#xff0c;不如 Angular.但它是渐进的&#xff0c;没有强主张&#xff0c; 你可以在原有…

Java面向对象-Object类的toString方法、equals方法

Java面向对象-Object类的toString方法、equals方法 一、toString二、equals三、总结 一、toString Object的toString方法。 方法的原理&#xff1a; 现在使用toString方法的时候&#xff0c;打印出来的内容不友好。 现在想要知道对象的信息。 出现的问题&#xff1a;子类Stu…

SAP Build 2 PDF数据提取与决策树(未完成)

0. 安装desktop agent 在后续过程中发现要预先安装desktop agent&#xff0c;否则没法运行自动化流程… 0.1 agent下载 参考官方文档说明 https://help.sap.com/docs/build-process-automation/sap-build-process-automation/create-user-in-rbsc-download-repository?loca…

AI办公自动化:用Kimi批量在Excel文件名中加入日期

工作任务&#xff1a;在一个文件夹中所有的Excel文件后面加上一个日期 在Kimi中输入提示词&#xff1a; 你是一个Python编程专家&#xff0c;写一个Python脚本&#xff0c;具体步骤如下&#xff1a; 打开文件夹&#xff1a;F:\AI自媒体内容\AI行业数据分析\投融资 读取里面所…

18.2 HTTP服务器-处理函数、响应404错误

1. 处理函数 处理来自客户端的请求&#xff0c;并回之以特定的响应&#xff0c;这是处理函数的主要任务。在处理函数中&#xff0c;我们通常会完成如下工作&#xff1a; 验证请求路径 http.Request.URL.Pathhttp.NotFound(...) 当请求没有对应的处理函数时&#xff0c;返回4…

机器学习笔记:label smoothing

在传统的分类任务中&#xff0c;我们通常使用硬标签&#xff08;hard labels&#xff09; 即如果一个样本属于某个类别&#xff0c;其对应的标签就是一个全0的向量&#xff0c;除了表示这个类别的位置为1。例如&#xff0c;在一个3类分类任务中&#xff0c;某个样本的标签可能是…

【Vue】购物车案例-构建项目

脚手架新建项目 (注意&#xff1a;勾选vuex) 版本说明&#xff1a; vue2 vue-router3 vuex3 vue3 vue-router4 vuex4/pinia vue create vue-cart-demo需要勾选上vuex&#xff0c;由于这个项目只有一个页面&#xff0c;vuex可勾可不勾 将原本src内容清空&#xff0c;替换成教学…

缓存更新策略中级总结

背景 看到好些人在写更新缓存数据代码时&#xff0c;先删除缓存&#xff0c;然后再更新数据库&#xff0c;而后续的操作会把数据再装载的缓存中。然而&#xff0c;这个是逻辑是错误的。试想&#xff0c;两个并发操作&#xff0c;一个是更新操作&#xff0c;另一个是查询操作…

数据结构(常见的排序算法)

1.插入排序 1.1直接插入排序 在[0 end]区间上有序&#xff0c;然后将&#xff08;end1&#xff09;的数据与前面有序的数据进行比较&#xff0c;将&#xff08;end1&#xff09;的数据插入&#xff0c;这样[0 end1]区间上就是有序的&#xff0c;然后再向后进行比较。 例如&a…

VXLAN技术

VXLAN技术 一、VXLAN简介 1、定义 VXLAN&#xff08;Virtual eXtensible Local Area Network&#xff09;&#xff1a;采用MAC in UDP&#xff08;User Datagram Protocol&#xff09;封装方式&#xff0c;是NVO3&#xff08;Network Virtualization over Layer 3&#xff09…

机器学习算法 —— 贝叶斯分类之模拟离散数据集

&#x1f31f;欢迎来到 我的博客 —— 探索技术的无限可能&#xff01; &#x1f31f;博客的简介&#xff08;文章目录&#xff09; 目录 实战&#xff08;贝叶斯分类&#xff09;莺尾花数据模拟离散数据集库函数导入数据导入和分析模型训练和预测 总结 实战&#xff08;贝叶斯…

C语言 | Leetcode C语言题解之第144题二叉树的前序遍历

题目&#xff1a; 题解&#xff1a; int* preorderTraversal(struct TreeNode* root, int* returnSize) {int* res malloc(sizeof(int) * 2000);*returnSize 0;if (root NULL) {return res;}struct TreeNode *p1 root, *p2 NULL;while (p1 ! NULL) {p2 p1->left;if (…

一道Delphi的For循环题目

起因 事情是这样的&#xff1a; 俺在一个Delphi交流QQ群&#xff0c;有点冷场&#xff0c;俺想热一下场子就发了下面这个段子。其实这是之前俺带新人时的一道题目。 第一个回答 第一个网友给的答案是 i:i-1; 俺说这个答案是不对的&#xff0c;因为 Delphi在编译时是不允许…

探索智慧机场运营中心解决方案的价值与应用

随着全球航空业的不断发展&#xff0c;机场运营中心的作用日益凸显。智慧机场运营中心解决方案以其高效的管理和智能化的运营模式&#xff0c;成为优化机场运营、提升服务水平的重要工具。本文将深入探讨智慧机场运营中心解决方案的价值与应用&#xff0c;揭示其在机场管理中的…