参考官方文档。
官方文档中只给了第一种用法。根据条件condition,从input,other中选择元素f返回。如果满足条件,则返回input元素。若不满足,返回other元素。
还有一种用法是通过where返回张量中满足条件condition的坐标,以二维张量为例。
代码如下:
import torch
nums = torch.tensor([
[1, 2, 3], [4, 5, 6], [7, 8, 9]
])
x_loc, y_loc = torch.where(nums>5)
print('x_loc: ‘, x_loc)
print('y_loc: ‘, y_loc)
z = torch.where(nums>5, 10, 1)
print('z: ‘, z)
输出结果如下:
x_loc: tensor([1, 2, 2, 2])
y_loc: tensor([2, 0, 1, 2])
z: tensor([[ 1, 1, 1],
[ 1, 1, 10],
[10, 10, 10]])