在多类别分类任务中生成的热力图(例如语义分割或图像分类任务中的类别概率图),通常每个像素会对应多个类别的概率。假设已经经过 softmax
层处理,那么每个像素点的值是某个类别的概率分布,取值范围在 [0, 1]
之间,并且所有类别的概率之和为 1。热力图颜色深浅的含义取决于你选择显示的类别的概率。
如何理解多类别的热力图
-
单类别热力图:
- 如果你只想观察某个类别的分布,可以从
softmax
输出中选择该类别的概率矩阵生成热力图。 - 在这种情况下,颜色越深(假设使用的是深色调的颜色映射),表示该类别的概率越高,即模型越“确信”该像素属于该类别。
- 比如,选择类别 A 的概率图生成热力图,则颜色深浅代表类别 A 的概率强弱,越深说明该像素越可能属于类别 A。
- 如果你只想观察某个类别的分布,可以从
-
多类别综合热力图:
- 在某些情况下,可能希望在一张热力图中展示多类别信息。这种情况下通常会采用不同颜色表示不同类别,然后颜色的深浅表示该类别的概率强度。
- 例如,类别 A 用红色,类别 B 用蓝色,类别 C 用绿色,那么每个像素点的颜色会是各类别颜色按其概率强度混合的结果。颜色越深的部分,表示模型在该类别上的预测越“确定”。
颜色越深的含义
- 颜色越深表示更高的概率值:在单类别热力图中,经过
softmax
后,颜色越深的区域表示该类别的概率越接近 1,说明模型认为这些像素很可能属于该类别。 - 低概率表示浅色:对于较浅的颜色(接近白色或透明),表示模型认为该类别的可能性较低,或更倾向于其他类别。
示例说明
假设我们有一个经过 softmax
的输出张量,尺寸为 [3, 256, 256]
,代表三个类别,每个通道对应一个类别的概率图。
- 类别 1 热力图:可以取张量的第一个通道
[256, 256]
来绘制类别 1 的概率分布热力图,颜色越深表示越高的概率。 - 类别 2 热力图:同样,可以取张量的第二个通道
[256, 256]
来绘制类别 2 的概率分布热力图。
代码示例
import matplotlib.pyplot as plt
import torch
def plot_heatmap(data, category, cmap='viridis'):
"""
绘制指定类别的概率热力图
参数:
- data: Softmax 后的多类别概率张量,例如形状为 [3, 256, 256]
- category: 要绘制的类别索引,例如 0 表示第一个类别
- cmap: 颜色映射,默认 'viridis'
"""
if isinstance(data, torch.Tensor):
data = data[category].numpy()
elif isinstance(data, np.ndarray):
data = data[category]
else:
raise TypeError("data must be a torch Tensor or numpy array")
plt.imshow(data, cmap=cmap)
plt.colorbar()
plt.title(f"Category {category} Probability Heatmap")
plt.show()
# 假设 softmax 后的张量,形状为 [3, 256, 256]
softmax_output = torch.rand(3, 256, 256) # 示例数据
# 绘制类别 0 的概率热力图
plot_heatmap(softmax_output, category=0, cmap='hot')
总结
- 颜色越深:表示该类别的概率越高,即模型对该类别的置信度越高。
- 不同类别可以分别绘制热力图:每个类别的概率图可以单独绘制成一张热力图,以便观察该类别在图像中的分布情况。
- 多类别混合显示:可以通过颜色映射的方式,将多个类别的概率分布叠加在一起,以不同颜色表示不同类别的置信度分布。