如何使用Pytorch-Metric-Learning?

news2024/9/20 12:36:26

文章目录

  • 如何使用Pytorch-Metric-Learning?
    • 1.Pytorch-Metric-Learning库9个模块的功能
      • 1.1 Sampler模块
      • 1.2 Miner模块
      • 1.3 Loss模块
      • 1.4 Reducer模块
      • 1.5 Distance模块
      • 1.6 Regularizer模块
      • 1.7 Trainer模块
      • 1.8 Tester模块
      • 1.9 Utils模块
    • 2.如何使用PyTorch Metric Learning库中的Loss?
      • 2.1 使用方法
      • 2.2 调用原理
    • 3.根据PyTorch Metric Learning库自定义Loss
      • 3.1 通用Loss模板(简单版本)
      • 3.2 通用Loss模板(使用distances和reducers版本)
      • 3.3 使用`indices_tuple`注意项
    • 参考

如何使用Pytorch-Metric-Learning?

PyTorch-Metric-Learning 是一个基于 PyTorch 的开源库,专门用于度量学习(Metric Learning)的实现和研究。度量学习是一类机器学习任务,旨在学习一个距离函数,使得相似的样本在特征空间中靠得更近,而不相似的样本更远。该库包含9个模块(可用模块概览,点击查看),每个模块都可在现有的代码库中独立使用,或者组合起来完成完整的训练和测试工作流。

1.Pytorch-Metric-Learning库9个模块的功能

high_level_module_overview

通常流程:Your Data --> Sampler --> Miner --> Loss --> Reducer --> Final loss value

1.1 Sampler模块

采样器只是torch.utils.data.Sampler类的扩展,即它们被传递给PyTorch数据加载器(特别是作为采样器参数,除非另有说明)。用于确定批次应如何形成

1.2 Miner模块

挖掘函数(Mining functions)接收一个包含 n 个嵌入的批次,并返回 k pairs/triplets,用于计算损失

  • Pair miners输出一个大小为 4 的元组:(anchors, positives, anchors, negatives)。
  • Triplet miners输出一个大小为 3 的元组:(anchors, positives, negatives)。
  • 如果没有使用元组挖掘函数,损失函数将默认使用批次中的所有可能的对/三元组。
from pytorch_metric_learning import miners, losses
miner_func = miners.SomeMiner()
loss_func = losses.SomeLoss()
miner_output = miner_func(embeddings, labels)
losses = loss_func(embeddings, labels, miner_output)

1.3 Loss模块

Loss模块中,包含了许多的loss函数。loss函数的用法如下:

from pytorch_metric_learning import losses
loss_func = losses.SomeLoss()  # 实例化想要的loss
loss = loss_func(embeddings, labels)  # 根据loss的compute_loss传入相对应的参数来计算损失

loss与miner结合使用的用法如下:

from pytorch_metric_learning import miners
miner_func = miners.SomeMiner()  # 实例化miner
loss_func = losses.SomeLoss()  # 实例化loss
miner_output = miner_func(embeddings, labels)  # 计算损失
loss = loss_func(embeddings, labels, miner_output)  # 计算损失

对于某些损失,如果已经传入了pair/triplet索引,则不需要传入标签:

loss = loss_func(embeddings, indices_tuple=pairs)
# 也适用于ref_emb
loss = loss_func(embeddings, indices_tuple=pairs, ref_emb=ref_emb)

也可以使用reducer指定如何将loss减少到单个值:

from pytorch_metric_learning import reducers
reducer = reducers.SomeReducer()
loss_func = losses.SomeLoss(reducer=reducer)
loss = loss_func(embeddings, labels)

对于元组损失,可以将锚(anchors)的来源和正/负(positives/negatives)分开:

loss_func = losses.SomeLoss()
# anchors will come from embeddings
# positives/negatives will come from ref_emb
loss = loss_func(embeddings, labels, ref_emb=ref_emb, ref_labels=ref_labels)

1.4 Reducer模块

reducer指定如何从多个loss值变为单个loss值。例如,ContrastiveLoss计算批次中每个正对和负对的损失。reducer将获取所有这些每对loss,并将它们减少到单个值。reducer的使用是将其传入到损失函数中,如下所示:

from pytorch_metric_learning import losses, reducers
reducer = reducers.SomeReducer()
loss_func = losses.SomeLoss(reducer=reducer)
loss = loss_func(embeddings, labels) 

原理:在内部,loss函数创建一个包含loss和其他信息的字典。reducer接受这个字典,执行reducer,并返回一个可以调用.backward()的值。

1.5 Distance模块

Distance用于计算输入嵌入之间的成对距离/相似性。下面以TripletMarginLoss损失为例,解释其功能与用途:

from pytorch_metric_learning.losses import TripletMarginLoss
loss_func = TripletMarginLoss(margin=0.2)

该损失函数试图最小化 [ d a p − d a n + m a r g i n ] + [\mathrm{d_{ap}-d_{an}+margin}]_{+} [dapdan+margin]+。通常, d a p d_{ap} dap d a n d_{an} dan表示欧几里得或L2距离。但是如果我们想使用平方L2距离,或者非归一化的L1距离,或者像信噪比这样完全不同的距离度量呢?使用Distance模块,可以轻松尝试这些想法:

### TripletMarginLoss with squared L2 distance ###
from pytorch_metric_learning.distances import LpDistance
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(power=2))

### TripletMarginLoss with unnormalized L1 distance ###
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(normalize_embeddings=False, p=1))

### TripletMarginLoss with signal-to-noise ratio###
from pytorch_metric_learning.distances import SNRDistance
loss_func = TripletMarginLoss(margin=0.2, distance=SNRDistance())

### TripletMarginLoss with cosine similarity##
from pytorch_metric_learning.distances import CosineSimilarity
loss_func = TripletMarginLoss(margin=0.2, distance=CosineSimilarity())

所有losses, miners, 和 regularizers都接受Distance参数。

1.6 Regularizer模块

Regularizer应用于权重和嵌入,而不需要标签或元组。 下面是一个将权重正则化器传递给损失函数的示例。

from pytorch_metric_learning import losses, regularizers
R = regularizers.RegularFaceRegularizer()
loss = losses.ArcFaceLoss(margin=30, num_classes=100, embedding_size=128, weight_regularizer=R)

1.7 Trainer模块

Trainer存在于这个库中,因为一些度量学习算法不仅仅是损失或挖掘函数。一些算法需要额外的网络、数据扩充、学习速率计划等。Trainer模块的目标是提供对这些类型的度量学习算法的访问。Trainer的使用如下:

from pytorch_metric_learning import trainers
t = trainers.SomeTrainingFunction(*args, **kwargs)
t.train(num_epochs=10)

1.8 Tester模块

Tester采用你的模型和数据集,并计算基于最近邻的准确性指标。请注意,Tester需要faiss包。Tester的使用如下:

from pytorch_metric_learning import testers
t = testers.SomeTestingFunction(*args, **kwargs)
dataset_dict = {"train": train_dataset, "val": val_dataset}
all_accuracies = tester.test(dataset_dict, epoch, model)

1.9 Utils模块

utils模块中包含了许多的工具包,具体请查阅此处。

2.如何使用PyTorch Metric Learning库中的Loss?

2.1 使用方法

以TripletMarginLoss为例进行讲解,以下是TripletMarginLoss的源代码:

import torch

from ..reducers import AvgNonZeroReducer
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction


class TripletMarginLoss(BaseMetricLossFunction):
    """
    Args:
        margin: The desired difference between the anchor-positive distance and the
                anchor-negative distance.
        swap: Use the positive-negative distance instead of anchor-negative distance,
              if it violates the margin more.
        smooth_loss: Use the log-exp version of the triplet loss
    """

    def __init__(
        self,
        margin=0.05,
        swap=False,
        smooth_loss=False,
        triplets_per_anchor="all",
        **kwargs
    ):
        super().__init__(**kwargs)
        self.margin = margin
        self.swap = swap
        self.smooth_loss = smooth_loss
        self.triplets_per_anchor = triplets_per_anchor
        self.add_to_recordable_attributes(list_of_names=["margin"], is_stat=False)

    def compute_loss(self, embeddings, labels, indices_tuple):
        indices_tuple = lmu.convert_to_triplets(
            indices_tuple, labels, t_per_anchor=self.triplets_per_anchor
        )
        anchor_idx, positive_idx, negative_idx = indices_tuple
        if len(anchor_idx) == 0:
            return self.zero_losses()
        mat = self.distance(embeddings)
        ap_dists = mat[anchor_idx, positive_idx]
        an_dists = mat[anchor_idx, negative_idx]
        if self.swap:
            pn_dists = mat[positive_idx, negative_idx]
            an_dists = self.distance.smallest_dist(an_dists, pn_dists)

        current_margins = self.distance.margin(ap_dists, an_dists)
        violation = current_margins + self.margin
        if self.smooth_loss:
            loss = torch.nn.functional.softplus(violation)
        else:
            loss = torch.nn.functional.relu(violation)

        return {
            "loss": {
                "losses": loss,
                "indices": indices_tuple,
                "reduction_type": "triplet",
            }
        }

    def get_default_reducer(self):
        return AvgNonZeroReducer()

当我们想要使用TripletMarginLoss损失时,首先要初始化TripletMarginLoss。

from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss()

要计算训练循环中的损失,请传入模型计算的嵌入(embeddings)、相应的标签(labels)与索引元组(indices_tuple)。嵌入应该具有大小(N,embedding_size),标签应该具有大小(N),其中N是批量大小。索引元组为3元组(anchors, positives, negatives)或4元组(anchors, positives, anchors, negatives),该案例传入的是3元组,因为源码中的compute_loss函数。具体使用如下:

"""
自己构建三元组的示例:
"""
for i, (data, labels) in enumerate(dataloader):
    optimizer.zero_grad()
    embeddings = model(data)
    indices_tuple = (anchor_idx, positive_idx, negative_idx)  # 自己构建
    loss = loss_func(embeddings, labels, indices_tuple)  # indices_tuple可以是自己构建的,也可以是通过Miner得到的,根据具体情况对待
    loss.backward()
    optimizer.step()

"""
通过Miner得到三元组的示例:
"""
from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner()
loss_func = losses.TripletMarginLoss()

for i, (data, labels) in enumerate(dataloader):
    optimizer.zero_grad()
    embeddings = model(data)
    hard_pairs = miner(embeddings, labels)  # 得到三元组
    loss = loss_func(embeddings, labels, hard_pairs)
    loss.backward()
    optimizer.step()

在上面的代码中,Miner找到了它认为特别困难的正负对。请注意,即使TripletMarginLoss对三元组(triplets)进行操作,仍然可以成对(pairs)传递。这是因为在必要时,库会自动将对转换为三元组,并将三元组转换为对。

2.2 调用原理

在使用库中的TripletMarginLoss函数时,我们首先需要初始化TripletMarginLoss,然后在计算TripletMarginLoss的时候,我们传入的参数与源码中的compute_loss函数一致。感觉有点像PyTorch模型的forward函数。一般这种方法的调用是通过python的特殊方法__call__函数实现的,比如:

def __call__(self, embeddings, labels, indices_tuple=None):
    return self.compute_loss(embeddings, labels, indices_tuple)

但是源码中并未定义__call__方法。如果该TripletMarginLoss继承自torch.nn.Module,并且定义了forward方法,如下。那么尽管没有显式定义 __call__ 方法,但是我们依旧可以这样使用。

class TripletMarginLoss(BaseMetricLossFunction):
    # ... 其他代码 ...

    def forward(self, embeddings, labels, indices_tuple=None):
        return self.compute_loss(embeddings, labels, indices_tuple)

但是即没有定义 __call__ 方法,也没有定义 forward 方法,那么为什么还可以loss = loss_func(embeddings, labels, hard_pairs)直接使用呢?

这是因为TripletMarginLoss继承自 BaseMetricLossFunctionTripletMarginLoss 会继承 BaseMetricLossFunction 中的所有方法和属性。而在BaseMetricLossFunction中实现了 forward 方法。所以即使 TripletMarginLoss 自己没有显式定义这些方法,作为子类,它会自动继承父类的行为,并使用父类的方法来实现损失计算逻辑。

3.根据PyTorch Metric Learning库自定义Loss

3.1 通用Loss模板(简单版本)

from pytorch_metric_learning.losses import BaseMetricLossFunction
import torch

class BarebonesLoss(BaseMetricLossFunction):
    def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
        # perform some calculation #
        some_loss = torch.mean(embeddings)

        # put into dictionary #
        return {
            "loss": {
                "losses": some_loss,
                "indices": None,
                "reduction_type": "already_reduced",
            }
        }

3.2 通用Loss模板(使用distances和reducers版本)

通过添加distances和reducers来增强损失函数的功能。

from pytorch_metric_learning.losses import BaseMetricLossFunction
from pytorch_metric_learning.reducers import AvgNonZeroReducer
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
import torch

class FullFeaturedLoss(BaseMetricLossFunction):
    def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
        indices_tuple = lmu.convert_to_triplets(indices_tuple, labels)
        anchors, positives, negatives = indices_tuple
        if len(anchors) == 0:
            return self.zero_losses()

        mat = self.distance(embeddings)
        ap_dists = mat[anchors, positives]
        an_dists = mat[anchors, negatives]

        # perform some calculations #
        losses1 = ap_dists - an_dists
        losses2 = ap_dists * 5
        losses3 = torch.mean(embeddings)

        # put into dictionary #
        return {
            "loss1": {
                "losses": losses1,
                "indices": indices_tuple,
                "reduction_type": "triplet",
            },
            "loss2": {
                "losses": losses2,
                "indices": (anchors, positives),
                "reduction_type": "pos_pair",
            },
            "loss3": {
                "losses": losses3,
                "indices": None,
                "reduction_type": "already_reduced",
            },
        }

    def get_default_reducer(self):
        return AvgNonZeroReducer()

    def get_default_distance(self):
        return CosineSimilarity()

    def _sub_loss_names(self):
        return ["loss1", "loss2", "loss3"]
  • 该损失函数对三元组进行操作,因此convert_to_triplets用于将indices_tuple转换为三元组形式。
  • self.distance返回成对的距离矩阵。
  • 损失函数的输出是一个包含多个子损失的字典。这就是重写_sub_loss_names函数的原因。
  • 默认情况下,get_default_reducer被覆盖以使用AvgNonZeroReducer,而不是MeanReducer
  • 默认情况下,get_default_distance被覆盖以使用CosineSimilarity,而不是LpDistances(p=2)

3.3 使用indices_tuple注意项

使用indices_tuple,需要使用适当的转换函数,这样我们不需要知道将传入什么类型的indices_tuple,因为转换函数会自动处理。indices_tuple三种可能的形式:

  • None
  • 大小为4的元组,表示miner pairs 的索引(anchors, positives, anchors, negatives)
  • 大小为3的元组,表示miner triplets 的索引(anchors, positives, negatives)
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu

# For a pair based loss
# After conversion, indices_tuple will be a tuple of size 4
indices_tuple = lmu.convert_to_pairs(indices_tuple, labels)

# For a triplet based loss
# After conversion, indices_tuple will be a tuple of size 3
indices_tuple = lmu.convert_to_triplets(indices_tuple, labels)

# For a classification based loss
# miner_weights.shape == labels.shape
# You can use these to weight your loss
miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=torch.float32)

参考

  • PyTorch Metric Learning官方库
  • PyTorch Metric Learning中文库
  • PyTorch Metric Learning官方文档
  • pytorch-metric-learning度量学习工具

😃😃😃

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2116531.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

传统CV算法——基于harris检测算法实现角点检测

角点 角点是图像中的一个特征点,指的是两条边缘交叉的点,这样的点在图像中通常表示一个显著的几角。在计算机视觉和图像处理中,角点是重要的特征,因为它们通常是图像中信息丰富的区域,可以用于图像分析、对象识别、3D…

JavaScript 循环控制语句-break和continue

break循环 首先i0&#xff0c;判断i是否<5,满足条件&#xff0c;判断i是否等于3&#xff0c;i不等于3&#xff0c;输出i0&#xff0c;i的值加1&#xff0c;判断i是否<5&#xff0c;判断i是否等于3&#xff0c;i不等于3&#xff0c;输出i1&#xff0c;i的值加1&#xff0c…

【H2O2|全栈】关于HTML(6)HTML基础(五 · 完结篇)

HTML基础知识 目录 HTML基础知识 前言 准备工作 标签的具体分类&#xff08;五&#xff09; 本文中的标签在什么位置中使用&#xff1f; 表单&#xff08;二&#xff09; 下拉选择菜单 文本域 案例 拓展标签 iframe框架 案例 预告和回顾 后话 前言 本系列博客介…

EasyExcel模板导出与公式计算(下)

目录 环境要求 功能预览 需求分析 导入依赖 制作模板 编写代码 格式优化 最终效果 总结 在上一篇 EasyExcel模板导出与公式计算&#xff08;上&#xff09;-CSDN博客 文章中我们知道了在若依中使用自带的Excel注解来实现表格数据的导出&#xff0c;并且通过重写相关接…

C++复习day07

一、继承 1.什么是继承&#xff1f;继承的意义是什么&#xff1f; 继承的概念 继承(inheritance)机制是面向对象程序设计使代码可以复用的最重要的手段&#xff0c;它允许程序员在保持原有类特性的基础上进行扩展&#xff0c;增加功能&#xff0c;这样产生新的类&#xff0c…

C++ STL 适配器

系列文章目录 模板特例化&#xff0c;偏特化&#xff0c;左右值引用 https://blog.csdn.net/surfaceyan/article/details/126794013 C STL 关联容器 https://blog.csdn.net/surfaceyan/article/details/127414434 C STL 序列式容器(二) https://blog.csdn.net/surfaceyan/arti…

项目实战系列三: 家居购项目 第四部分

购物车 &#x1f333;购物车&#x1f346;显示购物车&#x1f346;更改商品数量&#x1f346;清空购物车&&删除商品 &#x1f333;生成订单 &#x1f333;购物车 需求分析 1.会员登陆后, 可以添加家居到购物车 2.完成购物车的设计和实现 3.每添加一个家居,购物车的数量…

比较顺序3s1,3s2,4s1之间的关系

(A,B)---6*30*2---(0,1)(1,0) 分类A和B&#xff0c;让B全是0。当收敛误差为7e-4&#xff0c;收敛199次取迭代次数平均值&#xff0c;3s1为 3s2为 4s1为 3s1&#xff0c;3s2&#xff0c;4s1这3个顺序之间是否有什么联系 &#xff0c; 因为4s1可以按照结构加法 变换成与4s1内在…

Linux相关概念和重要知识点(2)(用户、文件和目录、inode、权限)

1.root和普通用户 在Windows里面&#xff0c;管理员Administrator是所有用户里面权限最高的&#xff0c;很多文件都会提示请使用管理员打开等。但在整个Windows系统中&#xff0c;管理员的权限并不是最大的&#xff0c;System优先级更高&#xff0c;因此我们系统中的某些文件是…

谈谈ES搜索引擎

一 ES的定义 ES 它的全称是 Elasticsearch&#xff0c;是一个建立在全文搜索引擎库Lucene基础上的一个开源搜索和分析引擎。ES 它本身具备分布式存储&#xff0c;检索速度快的特性&#xff0c;所以我们经常用它来实现全文检索功能。目前在 Elastic 官网对 ES 的定义&#xff0c…

模拟实现vector中的常见接口

insert void insert(iterator pos, const T& x) {if (_finish _endofstorage){int n pos - _start;size_t newcapacity capacity() 0 ? 2 : capacity() * 2;reserve(newcapacity);pos _start n;//防止迭代器失效}int end _finish-1;while (end > pos){*(end 1…

PMBOK® 第六版 规划进度管理

目录 读后感—PMBOK第六版 目录 规划进度管理主要关注为整个项目期间的进度管理提供指南和方向。以下是两个案例&#xff0c;展示了进度管理中的复杂性和潜在的冲突&#xff1a; 案例一&#xff1a;近期&#xff0c;一个长期合作的客户因政策要求&#xff0c;急需我们为多家医…

SQL的增删改查CRUD练习知识点(day27)

1 学习目标 重点掌握插入单条记录的语法了解全表插入记录的语法重点掌握修改记录的语法重点掌握删除记录的语法重点掌握主键约束、外键约束了解检查约束、非空约束、唯一约束 2 数据类型 MySQL支持多种数据类型&#xff0c;大致可以分类三类&#xff1a;数值、日期和字符串。…

【JavaEE初阶】多线程(3)

欢迎关注个人主页&#xff1a;逸狼 创造不易&#xff0c;可以点点赞吗~ 如有错误&#xff0c;欢迎指出~ 目录 线程状态 线程安全 代码示例 解释 总结原因 解决方案-->加锁 t1和t2都加锁 且 同一个锁对象 t1和t2中只有一个加锁了 t1和t2都加锁,但锁对象不同 加锁 与线程等待…

我给孩子请了个AI老师,省掉了1999元的报名费

大家好&#xff0c;我是凡人。 最近老婆想给儿子在线报个书法班&#xff0c;要价1999元&#xff0c;本来是个好事情&#xff0c;但一向勤俭持家的我&#xff0c;怎能让她花这个冤枉钱&#xff0c;经过我三七二十一个小时的上网&#xff0c;还真让我找出一套利用AI学习的万能命…

图片无损放大编辑PhotoZoom Pro 9.0.4多版本软件2024年最新安装包下载含安装教程

PhotoZoom Pro 9.0.4是一款非常流行的图像放大软件&#xff0c;它可以让你将低分辨率的图像放大到高分辨率的尺寸&#xff0c;同时保持高质量的图像细节和清晰度。 PhotoZoom Pro 9.0.4采用了一种称为S-Spline技术的算法&#xff0c;这是一种能够保持图像细节的高级插值算法。…

Web3 详解

1. 使用 Web3 库 Web3 是一个 JavaScript 库&#xff0c;可用于通过 RPC 通信与以太坊节点通信。 Web3 的工作方式是&#xff0c;公开已通过 RPC 启用的方法&#xff0c;这允许开发利用 Web3 库的用户界面&#xff0c;以便与部署在区块链上的合约进行交互。 一旦 Geth JavaScri…

25届计算机专业选题推荐-基于python的线上拍卖会管理系统【python-爬虫-大数据定制】

&#x1f496;&#x1f525;作者主页&#xff1a;毕设木哥 精彩专栏推荐订阅&#xff1a;在 下方专栏&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; 实战项目 文章目录 实战项目 一、基于python的线上拍卖会管理…

Window下编译OpenJDK17

本文详细介绍Window下如何编译OpenJDK17&#xff0c;包含源码路径&#xff0c;各工具下载地址&#xff0c;严格按照文章中的步骤来操作&#xff0c;你将获得一个由自己亲手编译出的jdk。 一、下载OpenJDK17源码 下载地址&#xff1a;GitHub - openjdk/jdk at jdk-1735 说明&a…

碰撞检测 | 详解矩形AABB与OBB碰撞检测算法(附ROS C++可视化)

引言 在复杂的人工智能系统和机器人应用中,碰撞检测(Collision Detection)作为一项基础技术,扮演着至关重要的角色。无论是在自动驾驶车辆中防止车祸的发生,还是在机器人导航中避免障碍物,碰撞检测的精度和效率都直接决定了系统的可靠性和安全性。在游戏开发、虚拟现实、…