BiFormer 实验记录

news2024/10/7 4:35:00




1、Path Embedding

2、BiFormer Block




1、Path Embedding


2、BiFormer Block


x = x + self.pos_embed(x)


x = x + self.drop_path(self.attn(self.norm1(x)))




        else:  # True
            N, H, W, C = x.size()  # 1,56,56,64
            assert H%self.n_win == 0 and W%self.n_win == 0

其中的self.n_win就是 论文中的 S,论文中

 以及论文中的 Algorithm1

# patchify input (H, W, C) -> (Sˆ2, HW/Sˆ2, C)
x = patchify(input, patch_size=H//S)


# patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size
x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)  # (1,49,8,8,64)



# linear projection of query, key, value
query, key, value = linear_qkv(x).chunk(3, dim=-1)


q, kv = self.qkv(x)  # to 103  q (1,49,8,8,64) kv (1,49,8,8,128)

# pixel-wise qkv
# q_pix: (n, p^2, w^2, c_qk)
# kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)
q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')  # (1,49,64,64)
kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))  # (49,128,8,8)
kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)  # (1,49,64,128)

只不过这里将 k v 放在一起了。



q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3])  # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk) q_win (1,49,64) k_win (1,49,64)  这里是k_min所以截至到0:self.dim

这个不是按通道进行平均,而是按 每个区域的8x8个vector进行平均 


lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())  # (1,64,56,56)
lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)  # (1,56,56,64)


 文本的公式4 和公式5


 r_weight, r_idx = self.router(q_win, k_win)  # both are (n, p^2, topk) tensors to 51 (1,49,1) (1,49,1)


    def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]:  # q (1,49,64), k (1,49,64)
            q, k: (n, p^2, c) tensor
            r_weight, topk_index: (n, p^2, topk) tensor
        if not self.diff_routing:  # True
            query, key = query.detach(), key.detach()
        query_hat, key_hat = self.emb(query), self.emb(key)  # per-window pooling -> (n, p^2, c) (1,49,64)
        attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1)  # (n, p^2, p^2)  (1,49,49)
        topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1)  # (n, p^2, k), (n, p^2, k) (1,49,1) (1,49,1)
        r_weight = self.routing_act(topk_attn_logit)  # (n, p^2, k) (1,49,1)
        return r_weight, topk_index

 self.emb 为 Identity 恒等函数,然后进行公式4 attention,拿出top k 个,self.routing_act为softmax激活函数。这个过程,可以看到,经过attention的输出为 (1,49,49),第一个49表是区域数,第二个49表示每一个区域与其它区域的 affifinity graph 分数,按通道 -1拿出top k 个 最大分数和索引。



kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix)  #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)  (1,49,1,64,128)
k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)  # (1,49,1,64,64)  (1,49,1,64,64)


然后进行多头 self attention,对应论文中的公式7

k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads)  # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?  (49,2,32,64)
v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)  (49,2,64,32)
q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)  (49,2,64,32)

# param-free multihead attention
attn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)  (49,2,64,64)
attn_weight = self.attn_act(attn_weight)  # (49,2,64,64)
out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)  (49,2,64,32)
out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,
                        h=H//self.n_win, w=W//self.n_win)  # (1,56,56,64)

out = out + lepe  # (1,56,56,64)

 这里将 n 和SxS 合并在一起了,也就是batch和 区域数,最后输出重新reshape回 H W。


接下来进入 mlp模块,没什么好说的。



  (downsample_layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU()
      (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (stages): ModuleList(
    (0): Sequential(
      (0): Block(
        (pos_embed): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
        (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=64)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=64, out_features=192, bias=True)
          (wo): Linear(in_features=64, out_features=64, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=64, out_features=192, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=192, out_features=64, bias=True)
        (drop_path): Identity()
      (1): Block(
        (pos_embed): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
        (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=64)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=64, out_features=192, bias=True)
          (wo): Linear(in_features=64, out_features=64, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=64, out_features=192, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=192, out_features=64, bias=True)
        (drop_path): Identity()
    (1): Sequential(
      (0): Block(
        (pos_embed): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
        (norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=128)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=128, out_features=384, bias=True)
          (wo): Linear(in_features=128, out_features=128, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=128, out_features=384, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=384, out_features=128, bias=True)
        (drop_path): Identity()
      (1): Block(
        (pos_embed): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
        (norm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=128)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=128, out_features=384, bias=True)
          (wo): Linear(in_features=128, out_features=128, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=128, out_features=384, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=384, out_features=128, bias=True)
        (drop_path): Identity()
    (2): Sequential(
      (0): Block(
        (pos_embed): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
        (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=256)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=256, out_features=768, bias=True)
          (wo): Linear(in_features=256, out_features=256, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=256, out_features=768, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=768, out_features=256, bias=True)
        (drop_path): Identity()
      (1): Block(
        (pos_embed): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
        (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=256)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=256, out_features=768, bias=True)
          (wo): Linear(in_features=256, out_features=256, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=256, out_features=768, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=768, out_features=256, bias=True)
        (drop_path): Identity()
      (2): Block(
        (pos_embed): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
        (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=256)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=256, out_features=768, bias=True)
          (wo): Linear(in_features=256, out_features=256, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=256, out_features=768, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=768, out_features=256, bias=True)
        (drop_path): Identity()
      (3): Block(
        (pos_embed): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
        (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=256)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=256, out_features=768, bias=True)
          (wo): Linear(in_features=256, out_features=256, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=256, out_features=768, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=768, out_features=256, bias=True)
        (drop_path): Identity()
      (4): Block(
        (pos_embed): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
        (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=256)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=256, out_features=768, bias=True)
          (wo): Linear(in_features=256, out_features=256, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=256, out_features=768, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=768, out_features=256, bias=True)
        (drop_path): Identity()
      (5): Block(
        (pos_embed): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
        (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=256)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=256, out_features=768, bias=True)
          (wo): Linear(in_features=256, out_features=256, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=256, out_features=768, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=768, out_features=256, bias=True)
        (drop_path): Identity()
      (6): Block(
        (pos_embed): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
        (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=256)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=256, out_features=768, bias=True)
          (wo): Linear(in_features=256, out_features=256, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=256, out_features=768, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=768, out_features=256, bias=True)
        (drop_path): Identity()
      (7): Block(
        (pos_embed): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
        (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (attn): BiLevelRoutingAttention(
          (lepe): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=256)
          (router): TopkRouting(
            (emb): Identity()
            (routing_act): Softmax(dim=-1)
          (kv_gather): KVGather()
          (qkv): QKVLinear(
            (qkv): Linear(in_features=256, out_features=768, bias=True)
          (wo): Linear(in_features=256, out_features=256, bias=True)
          (kv_down): Identity()
          (attn_act): Softmax(dim=-1)
        (norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=256, out_features=768, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=768, out_features=256, bias=True)
        (drop_path): Identity()
    (3): Sequential(
      (0): Block(
        (pos_embed): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)
        (norm1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (attn): AttentionLePE(
          (qkv): Linear(in_features=512, out_features=1536, bias=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (lepe): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512)
        (norm2): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=1536, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=1536, out_features=512, bias=True)
        (drop_path): Identity()
      (1): Block(
        (pos_embed): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)
        (norm1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (attn): AttentionLePE(
          (qkv): Linear(in_features=512, out_features=1536, bias=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (lepe): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512)
        (norm2): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=1536, bias=True)
          (1): Identity()
          (2): GELU()
          (3): Linear(in_features=1536, out_features=512, bias=True)
        (drop_path): Identity()
  (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pre_logits): Identity()
  (head): Linear(in_features=512, out_features=1000, bias=True)





