torch.where
是 PyTorch 中用于条件选择的函数。它可以根据一个布尔条件在两个张量中选择元素,从而生成一个新的张量。
函数定义
torch.where(condition, input, other)
参数说明:
condition
- 一个布尔张量,表示条件判断结果。
- 形状可以与
input
和other
相同,或者可以广播到相同的形状。
input
- 满足条件时的值来源张量。
other
- 不满足条件时的值来源张量。
返回值:
- 返回一个与
condition
、input
和other
形状兼容的张量。 - 如果
condition
的某个位置为True
,返回input
中对应位置的值;否则返回other
中对应位置的值。
示例
1. 基本用法
import torch
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([10, 20, 30, 40, 50])
condition = x > 3
print(condition) # tensor([False, False, False, True, True])
result = torch.where(condition, x, y)
print(result) # 输出: tensor([10, 20, 30, 4, 5])
2. 多维张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[10, 20], [30, 40]])
condition = a < 3
result = torch.where(condition, a, b)
print(result)
# 输出:
# tensor([[ 1, 2],
# [30, 40]])
3. 标量支持
input
和 other
可以是标量,而不是张量。
x = torch.tensor([1, 2, 3, 4, 5])
condition = x > 3
result = torch.where(condition, x, 0)
print(result) # 输出: tensor([0, 0, 0, 4, 5])
解释:
- 如果满足条件
x > 3
,取x
的值;否则取标量0
。
4. 广播机制
例1:
如果 condition
、input
和 other
的形状不同,PyTorch 会自动广播使其兼容。
x = torch.tensor([[1, 2], [3, 4]])
condition = torch.tensor([True, False])
result = torch.where(condition, x, 0)
print(result)
# 输出:
# tensor([[1, 2],
# [0, 0]])
解释:
condition
只有两列,通过广播扩展为形状[2, 2]
。
例2:(不同维度的广播机制)
1 .
import torch
x = torch.tensor([[1, 2, 3], [3, 4, 5]])
print(x)
condition = torch.tensor([True, False, True])
result = torch.where(condition, x, 0)
print(result)
x:
tensor([[1, 2, 3],
[3, 4, 5]])
result:
tensor([[1, 0, 3],
[3, 0, 5]])
2 .
import torch
x = torch.tensor([[1, 2, 3], [3, 4, 5]])
print(x)
condition = torch.tensor([[True], [False]])
print(condition)
result = torch.where(condition, x, 0)
print(result)
x:
tensor([[1, 2, 3],
[3, 4, 5]])
condition:
tensor([[ True],
[False]])
result:
tensor([[1, 2, 3],
[0, 0, 0]])
常见用途
1. 替换特定值
将张量中大于某值的元素替换为某个固定值:
x = torch.tensor([1, 5, 10, 15])
x_clipped = torch.where(x > 10, 10, x)
print(x_clipped) # 输出: tensor([ 1, 5, 10, 10])
2. 创建条件张量
使用条件逻辑生成一个新张量:
x = torch.linspace(-1, 1, 5)
y = torch.where(x > 0, 1, -1)
print(x) # tensor([-1.0000, -0.5000, 0.0000, 0.5000, 1.0000])
print(y) # 输出: tensor([-1, -1, -1, 1, 1])
注意事项
-
数据类型一致性
input
和other
必须具有相同的数据类型,否则会抛出错误。x = torch.tensor([1.0, 2.0]) y = torch.tensor([1, 2]) torch.where(x > 1, x, y) # 会报错,因为 x 是浮点型,y 是整型
2. 广播机制
当使用广播时,确保张量可以广播到相同的形状。