All Things ViTs系列讲座从ViT视觉模型注意力机制出发,本文给出mean attention distance可视化部分阅读学习体会.
课程视频与课件: https://all-things-vits.github.io/atv/
代码: https://colab.research.google.com/github/all-things-vits/code-samples/blob/main/probing/mean_attention_distance.ipynb
文献:A N I MAGE IS W ORTH 16 X 16 W ORDS :
T RANSFORMERS FOR I MAGE R ECOGNITION AT S CALE
1.总述
之前在阅读ViT论文的时候对MAD这部分没有十分理解,及MAD究竟是什么,如下图所示.将该部分代码进行调试理解,能够比较深入理解ViT的注意力机制.
Fig 1 vit-base-patch16-224 MAD可视化
2.关键代码讲解
2.1 注意力分数获得
def perform_inference(image: Image, model: torch.nn.Module, processor):
"""Performs inference given an image, a model, and its processor."""
inputs = processor(image, return_tensors="pt")#[1, 3, 224, 224]
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
print(type(outputs))
# model predicts one of the 1000 ImageNet classes
predicted_label = outputs.logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
return outputs.attentions #[[1, 12, 197, 197]*12]
这部分代码将图像输入ViT网络,并得到输出的logits,类别以及ViT中每个block(如图Fig2)中每个head的注意力分数(outputs.attentions).ViT可以看作是transformer的一个encoder,如下:
Fig 2 ViT的一个block
此外,outputs.attentions是一个tuple,其中包括12个维度为[1, 12, 197, 197]的tensor.这个tensor可理解如下,其中12为head的数量,197是token的数量.197*197表示每个token之间的注意力分数.197包含196个图像token与一个cls token.其中MAD是图像token之间的距离
2.2 计算MAD
def gather_mads(attention_scores, patch_size: int = 16):
all_mean_distances = {
f"block_{i}_mean_dist": compute_mean_attention_dist(
patch_size=patch_size, attention_weights=attention_weight.numpy()
)
for i, attention_weight in enumerate(attention_scores)
}
return all_mean_distances
这段代码是遍历计算每一个block中的MAD
def compute_mean_attention_dist(patch_size, attention_weights):
# The attention_weights shape = (batch, num_heads, num_patches, num_patches)
attention_weights = attention_weights[
..., num_cls_tokens:, num_cls_tokens:
] # Removing the CLS token, [1, 12, 196, 196]
num_patches = attention_weights.shape[-1]
length = int(np.sqrt(num_patches))
assert length**2 == num_patches, "Num patches is not perfect square"
distance_matrix = compute_distance_matrix(patch_size, num_patches, length)#[196, 196]
h, w = distance_matrix.shape
distance_matrix = distance_matrix.reshape((1, 1, h, w))#[1, 1, 196, 196], space distance between batch in the image
# The attention_weights along the last axis adds to 1
# this is due to the fact that they are softmax of the raw logits
# summation of the (attention_weights * distance_matrix)
# should result in an average distance per token
mean_distances = attention_weights * distance_matrix#[1, 12, 196, 196]
mean_distances = np.sum(
mean_distances, axis=-1
) # sum along last axis to get average distance per token, [1, 12, 196]
mean_distances = np.mean(
mean_distances, axis=-1
) # now average across all the tokens
return mean_distances
这段代码则是具体计算MAD.首先计算patch(Fig 1中阐述了什么是patch)之间的距离,ViT中的token可以理解为对每个patch的编码,patch之间的距离计算方法如下:
def compute_distance_matrix(patch_size, num_patches, length):
"""Helper function to compute distance matrix."""
distance_matrix = np.zeros((num_patches, num_patches))
for i in range(num_patches):
for j in range(num_patches):
if i == j: # zero distance
continue
xi, yi = (int(i / length)), (i % length)
xj, yj = (int(j / length)), (j % length)
distance_matrix[i, j] = patch_size * np.linalg.norm([xi - xj, yi - yj])
return distance_matrix
patch之间的距离即patch之间的空间距离.而MAD的核心计算代码为:
mean_distances = attention_weights * distance_matrix
之后在求每个head中所有token的距离均值.MAD是衡量每个patch与其他patch之间的综合距离,这个距离既考虑了它与其他patch的实际物理距离,又将注意力分数作为物理距离的加权.我对MAD的理解是,它是经过学习,对离散图像patch的一种建模.这种建模既考虑了patch与patch之间的空间关系,又考虑了patch之间实际的联系(注意力分数).这个距离可以用来探究每个head关注的范围,类似CNN中的感受野.
3.总述
接下来再回到Fig 2,我们再来理解这张图的含义.这张图横轴为block的编号,包含12个block,纵轴为每个head的MAD. 可以看到,ViT在浅层中就有的head开始关注全局(MAD大的head),有的关注局部(MAD小的head),这与CNN有所不同,CNN在浅层多关注局部,深层关注全局.因此说明.随着层数的加深,ViT逐步过渡到关注全局.相比于CNN来说,ViT是对图像的更一般的一种建模,这有利于表达更复杂的空间关系,但也更加难学习,因此一般认为在数据量比较大的情况下,ViT才能展现出其优势.