一、前言
在实现DQN的过程中,torch.gather()这个方法引起了我的注意,原因有二:1)这个函数在我硕士期间很少遇见,用到的次数更是少之又少;2)torch.gather()这个方法是如何使用的呢,以为它的逻辑是怎样的?带着这个疑问,我查阅了PyTorch的官方文档,并在此进行记录,以备不时之需!同时,也希望能够帮助到更多的同学!
二、方法解析
首先,先要从字面意思进行理解,gather的英文释义为:搜集、收集等等;按照逻辑,我们自己可能会这么想:torch.gather()方法应该是在按照某个方式来进行数据的搜集。
其次,我们看一下torch.gather()方法的官方解释(https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather):
可以发现,torch.gather()方法有三个必须的参数:
- 🥇input:类型为Tensor,很好理解,就是gather方法要操作的对象
- 🥈dim:int类型,可以理解为按照dim所指定的维度,进行搜集数据的操作
- 🥉index:类型为LongTensor,可以理解为gather方法搜集数据的方式
值得注意的是:
- torch.gather()的返回值为Tensor类型
- Input和index必须有相同的维度
三、案例分析
为对比dim=0和dim=1的区别,我们使用相同的Input tensor和Index tensor,仅仅dim不同。计算示意图如下:
下面我们来分解一下torch.gather()方法的详细过程,因为是二维矩阵,所以dim只能有2个取值,下列分情况进行讨论:
3.1 dim=0
dim=0指:Index tensor中的元素的含义为第几行,如下图:例如Index tensor中的(0,0)位置的元素为2,代表第2行,2所在的列在第0列,对应Input tensor中(2,0)位置的元素,即Output中(0,0)位置的元素的值为7!
3.2 dim=1
dim=1指:Index tensor中的元素的含义为第几列。如下图:例如Index tensor中的(0,0)位置的元素为2,代表第2列,2所在的行在第0行,对应Input tensor中(0,2)位置的元素,即Output中(0,0)位置的元素的值为3!
🥇🥇🥇那么,你学会了吗?
参考文献
- https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather
- https://blog.csdn.net/weixin_42899627/article/details/122816250