一、前言
在深挖ML4CO的代码过程中,遇到了torch.take_along_dim()这个方法,影响到我后续的代码阅读;加之在上网搜索资料的过程中,网络上对此函数的介绍文章少之又少,即使有,也是对torch官网文档中的解释进行英译汉或对函数的轻描淡写,看完解析后,对该函数的认识并没有因此而深刻。故在此谈一下自己的理解。
温馨提示:由于torch.take_along_dim函数与torch.gather函数的功能大同小异,因此在阅读这篇文章前,建议先阅读torch.gather()函数的解析PyTorch基础(16)-- torch.gather()。
二、方法解析
首先,我们需要看一下torch官方对于该函数的解释。白话解释为:我有一个tensor为input,沿着给定的input的维度(dim),给定取值的索引(indices),从input取出我想要的值。嗯……跟torch.gather方法不能说一模一样,只能说完全相同啊。
这个函数如何使用,直接通过案例来解释。
三、案例分析
3.1 案例1(不指定dim)
值得注意的是:不指定dim时,该方法会将input展开为一个一维tensor,然后根据indices进行取值。
import torch
t = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
index = torch.tensor([[2,0,2],[1,2,0],[1,1,1]])
t1 = torch.take_along_dim(t, index)
t1
输出为:
3.2 案例2(dim=0)
import torch
t = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
index = torch.tensor([[2,0,2],[1,2,0],[1,1,1]])
t1 = torch.take_along_dim(t, index, dim=0)
t1
3.3 案例3(dim=1)
import torch
t = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
index = torch.tensor([[2,0,2],[1,2,0],[1,1,1]])
t1 = torch.take_along_dim(t, index, dim=1)
t1
参考文献
- https://pytorch.org/docs/stable/generated/torch.take_along_dim.html#torch.take_along_dim
- https://blog.csdn.net/dongjinkun/article/details/132299818?spm=1001.2014.3001.5501