LeMeViT:具有可学习元令牌的高效ViT

news2024/11/25 23:18:46

本文提出使用可学习的元令牌来制定稀疏令牌,这有效地学习了关键信息,同时提高了推理速度。从技术上讲,主题标记首先通过交叉关注从图像标记中初始化。提出了双交叉注意(DCA)来促进图像令牌和元令牌之间的信息交换,其中它们在双分支结构中交替充当查询和密钥(值)令牌,与自注意相比,显著降低了计算复杂度。通过在具有密集视觉标记的早期阶段使用DCA,获得了不同大小的分层结构LeMeViT。在分类和密集预测任务中的实验结果表明,与baseline相比,LeMeViT具有1.7倍的显著加速、更少的参数和有竞争力的性能,并在效率和性能之间实现了更好的权衡。

        现有方法通常使用下采样或 clus tering 来减少当前块内的图像标记数量,这依赖于强先验或对并行计算不友好。而通过学习元标记稀疏地表示密集的图像标记。元代币通过计算高效的双交叉注意力块以端到端的方式与图像代币交换信息,促进信息分阶段流动。


 LeMeViT总结构:LeMeViT由三个不同的注意力块组成,从左到右排列为交叉注意力块、双交叉注意力块和标准注意力块。

通过代码来实现:

def scaled_dot_product_attention(q, k, v, scale=None):
    """Custom Scaled-Dot Product Attention
        dim (B h N d)
    """
    _,_,_,dim = q.shape
    scale = scale or dim**(-0.5)
    attn = q @ k.transpose(-1,-2) * scale
    attn = attn.softmax(dim=-1)
    x = attn @ v
    return x


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x):
        """
        x: NHWC tensor
        """
        x = x.permute(0, 3, 1, 2) #NCHW
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) #NHWC

        return x

class Attention(nn.Module):
    """Patch-to-Cluster Attention Layer"""
    
    def __init__(
        self,
        dim,
        num_heads,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads

        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0

        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.attn_drop = attn_drop

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 

    def forward(self, x):
        if self.use_xformers:
            q = self.q(x)  # B N C
            k = self.k(x)  # B N C
            v = self.v(x)
            q = rearrange(q, "B N (h d) -> B N h d", h=self.num_heads)
            k = rearrange(k, "B N (h d) -> B N h d", h=self.num_heads)
            v = rearrange(v, "B N (h d) -> B N h d", h=self.num_heads)

            x = xops.memory_efficient_attention(q, k, v)  # B N h d
            x = rearrange(x, "B N h d -> B N (h d)")

            x = self.proj(x)
        else:
            x = rearrange(x, "B N C -> N B C")

            x, attn = F.multi_head_attention_forward(
                query=x,
                key=x,
                value=x,
                embed_dim_to_check=x.shape[-1],
                num_heads=self.num_heads,
                q_proj_weight=self.q.weight,
                k_proj_weight=self.k.weight,
                v_proj_weight=self.v.weight,
                in_proj_weight=None,
                in_proj_bias=torch.cat([self.q.bias, self.k.bias, self.v.bias]),
                bias_k=None,
                bias_v=None,
                add_zero_attn=False,
                dropout_p=self.attn_drop,
                out_proj_weight=self.proj.weight,
                out_proj_bias=self.proj.bias,
                use_separate_proj_weight=True,
                training=self.training,
                need_weights=not self.training,  # for visualization
                average_attn_weights=False,
            )

            x = rearrange(x, "N B C -> B N C")

            if not self.training:
                attn = self.attn_viz(attn)

        x = self.proj_drop(x)

        return x

class StandardAttention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads
        
        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc

        self.qkv = nn.Linear(dim, 3 * dim)
        self.attn_drop = attn_drop

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 
        self.scale = dim**0.5
        
    # @get_local('attn_map')
    def forward(self, x):
        if self.use_flash_attn:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> B N x h d", x=3, h=self.num_heads).contiguous()
            x = flash_attn_qkvpacked_func(qkv)
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj(x)
        elif self.use_xformers:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> x B N h d", x=3, h=self.num_heads).contiguous()
            q, k, v = qkv[0], qkv[1], qkv[2]
            x = xops.memory_efficient_attention(q, k, v)  # B N h d
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj(x)
        elif self.use_torchfunc:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            q, k, v = qkv[0], qkv[1], qkv[2]
            x = F.scaled_dot_product_attention(q, k, v)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj(x)
        else:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            q, k, v = qkv[0], qkv[1], qkv[2]
            x = scaled_dot_product_attention(q, k, v)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj(x)
        # with torch.no_grad():
        #     attn = (q @ k.transpose(-2, -1)) * self.scale
        #     attn_map = attn.softmax(dim=-1)
        #     # print("Standard:", attn_map)
        return x


class DualCrossAttention(nn.Module):
    
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads
        self.scale = scale or dim**(-0.5)

        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc

        self.qkv1 = nn.Linear(dim, 3 * dim)
        self.qkv2 = nn.Linear(dim, 3 * dim)
        self.attn_drop = nn.Dropout(attn_drop)

        self.proj_x = nn.Linear(dim, dim)
        self.proj_c = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 
        
    # @get_local('attn_map')
    def forward(self, x, c):
        B, N, C = x.shape        
        B, M, _ = c.shape 
        scale_x = math.log(M, N) * self.scale
        scale_c = math.log(N, N) * self.scale
        
        if self.use_flash_attn:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> B N x h d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> B M x h d", x=3, h=self.num_heads).contiguous()
            
            q1, kv1 = qkv1[:,:,0], qkv1[:,:,1:]
            q2, kv2 = qkv2[:,:,0], qkv2[:,:,1:]
            
            x = flash_attn_kvpacked_func(q1, kv2, softmax_scale=scale_x)
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = flash_attn_kvpacked_func(q2, kv1, softmax_scale=scale_c)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_xformers:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> x B N h d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> x B M h d", x=3, h=self.num_heads).contiguous()
            
            q1, k1, v1 = qkv1[0], qkv1[1], qkv1[2]
            q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
            
            x = xops.memory_efficient_attention(q1, k2, v2, scale=scale_x)  # B N h d
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = xops.memory_efficient_attention(q2, k1, v1, scale=scale_c)  # B N h d
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_torchfunc:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> x B h M d", x=3, h=self.num_heads).contiguous()
            
            q1, k1, v1 = qkv1[0], qkv1[1], qkv1[2]
            q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
            
            x = F.scaled_dot_product_attention(q1, k2, v2)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = F.scaled_dot_product_attention(q2, k1, v1)  # B N h d
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        else:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> x B h M d", x=3, h=self.num_heads).contiguous()
            
            q1, k1, v1 = qkv1[0], qkv1[1], qkv1[2]
            q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
            
            x = scaled_dot_product_attention(q1, k2, v2, scale=scale_x)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = scaled_dot_product_attention(q2, k1, v1, scale=scale_c)  # B N h d
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        # with torch.no_grad():   
        #     # q1 = rearrange(q1, "B h M d -> B M (h d)").contiguous()
        #     # k2 = rearrange(k2, "B h M d -> B M (h d)").contiguous()
        #     attn = (q1 @ k2.transpose(-2, -1)) * scale_x
        #     attn_map = attn.softmax(dim=-1)
        #     # print("Mix:", attn_map)
        return x, c
    
class DualCrossAttention_v2(nn.Module):
    
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads
        self.scale = scale or dim**(-0.5)

        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc
        
        self.qv1 = nn.Linear(dim, 2 * dim)
        self.kv2 = nn.Linear(dim, 2 * dim)
        self.attn_drop = nn.Dropout(attn_drop)

        self.proj_x = nn.Linear(dim, dim)
        self.proj_c = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 

    def forward(self, x, c):
        B, N, C = x.shape        
        B, M, _ = c.shape 
        scale_x = math.log(M, N) * self.scale
        scale_c = math.log(N, N) * self.scale
        
        if self.use_flash_attn:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> B N x h d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> B M x h d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[:,:,0], qv1[:,:,1]
            k, v2 = kv2[:,:,0], kv2[:,:,1]
            
            x = flash_attn_func(q, k, v2, softmax_scale=scale_x)
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = flash_attn_func(k, q, v1, softmax_scale=scale_c)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_xformers:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> x B h M d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[0], qv1[1]
            k, v2 = kv2[0], kv2[1]
            
            x = xops.memory_efficient_attention(q, k, v2, scale=scale_x)
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = xops.memory_efficient_attention(k, q, v1, scale=scale_c)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_torchfunc:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> x B h M d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[0], qv1[1]
            k, v2 = kv2[0], kv2[1]

            x = F.scaled_dot_product_attention(q, k, v2)
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = F.scaled_dot_product_attention(k, q, v1)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        else:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> x B h M d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[0], qv1[1]
            k, v2 = kv2[0], kv2[1]

            x = scaled_dot_product_attention(q, k, v2)
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = scaled_dot_product_attention(k, q, v1)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        return x, c

class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads

        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc

        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, 2 * dim)
        self.attn_drop = attn_drop

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 


    def forward(self, x, c):
        B, N, C = x.shape        
        B, M, _ = c.shape 
        
        if self.use_flash_attn:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B M h d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> B N x h d", x=2, h=self.num_heads).contiguous()
            
            c = flash_attn_kvpacked_func(q, kv)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj(c)
        elif self.use_xformers:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B M h d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> x B N h d", x=2, h=self.num_heads).contiguous()
            k, v = kv[0], kv[1]
            
            c = xops.memory_efficient_attention(q, k, v)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj(c)
        elif self.use_torchfunc:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B h M d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            k, v = kv[0], kv[1]
            
            c = F.scaled_dot_product_attention(q, k, v)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj(c)
        else:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B h M d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            k, v = kv[0], kv[1]
            
            c = scaled_dot_product_attention(q, k, v)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj(c)
        return c


class LeMeBlock(nn.Module):
    def __init__(self, dim, 
                 attn_drop, proj_drop, drop_path=0., attn_type=None,
                 layer_scale_init_value=-1, num_heads=8, qk_dim=None, mlp_ratio=4, mlp_dwconv=False,
                 cpe_ks=3, pre_norm=True):
        super().__init__()
        qk_dim = qk_dim or dim

        # modules
        if cpe_ks > 0:
            self.pos_embed = nn.Conv2d(dim, dim,  kernel_size=cpe_ks, padding=1, groups=dim)
        else:
            self.pos_embed = lambda x: 0
        self.norm1 = nn.LayerNorm(dim, eps=1e-6) # important to avoid attention collapsing
        
        self.attn_type = attn_type
        if attn_type == "D":
            self.attn = DualCrossAttention(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        elif attn_type == "D2":
            self.attn = DualCrossAttention_v2(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        elif attn_type == "S" or attn_type == None:
            self.attn = StandardAttention(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        elif attn_type == "C":
            self.attn = CrossAttention(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
            
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = nn.Sequential(nn.Linear(dim, int(mlp_ratio*dim)),
                                 DWConv(int(mlp_ratio*dim)) if mlp_dwconv else nn.Identity(),
                                 nn.GELU(),
                                 nn.Linear(int(mlp_ratio*dim), dim)
                                )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # tricks: layer scale & pre_norm/post_norm
        if layer_scale_init_value > 0:
            self.use_layer_scale = True
            self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones((1,1,dim)), requires_grad=True)
            self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((1,1,dim)), requires_grad=True)
        else:
            self.use_layer_scale = False
        self.pre_norm = pre_norm
            
    def forward_with_xc(self, x, c):

        _, C, H, W = x.shape
        # conv pos embedding
        x = x + self.pos_embed(x)
        # permute to NHWC tensor for attention & mlp
        x = rearrange(x, "N C H W -> N (H W) C")
        # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

        # attention & mlp
        if self.pre_norm:
            if self.use_layer_scale:
                _x, _c = self.attn(self.norm1(x), self.norm1(c))
                x = x + self.drop_path(self.gamma1 * _x) # (N, H, W, C)
                x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma1 * _c) # (N, H, W, C)
                c = c + self.drop_path(self.gamma2 * self.mlp(self.norm2(c))) # (N, H, W, C)
            else:
                _x, _c = self.attn(self.norm1(x), self.norm1(c))
                x = x + self.drop_path(_x) # (N, H, W, C)
                x = x + self.drop_path(self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(_c) # (N, H, W, C)
                c = c + self.drop_path(self.mlp(self.norm2(c))) # (N, H, W, C)
        else: # https://kexue.fm/archives/9009
            if self.use_layer_scale:
                _x, _c = self.attn(x,c)
                x = self.norm1(x + self.drop_path(self.gamma1 * _x)) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.gamma2 * self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(self.gamma1 * _c)) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.gamma2 * self.mlp(c))) # (N, H, W, C)
            else:
                _x, _c = self.attn(x,c)
                x = self.norm1(x + self.drop_path(_x)) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(_c)) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.mlp(c))) # (N, H, W, C)
                
        x = rearrange(x, "N (H W) C -> N C H W",H=H,W=W)
        # permute back
        # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return x, c

    def forward_with_c(self, x, c):
        
        _, C, H, W = x.shape
        _x = x
        # conv pos embedding
        x = x + self.pos_embed(x)
        # permute to NHWC tensor for attention & mlp
        x = rearrange(x, "N C H W -> N (H W) C")
        # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

        # attention & mlp
        if self.pre_norm:
            if self.use_layer_scale:
                c = c + self.drop_path(self.gamma1 * self.attn(self.norm1(x), self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma2 * self.mlp(self.norm2(c))) # (N, H, W, C)
            else:
                c = c + self.drop_path(self.attn(self.norm1(x),self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.mlp(self.norm2(c))) # (N, H, W, C)
        else: # https://kexue.fm/archives/9009
            if self.use_layer_scale:
                c = self.norm1(c + self.drop_path(self.gamma1 * self.attn(x,c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.gamma2 * self.mlp(c))) # (N, H, W, C)
            else:
                c = self.norm1(c + self.drop_path(self.attn(x,c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.mlp(c))) # (N, H, W, C)

        x = _x
        # permute back
        # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return x, c

    def forward_with_x(self, x, c):
        
        _, C, H, W = x.shape
        # conv pos embedding
        x = x + self.pos_embed(x)
        # permute to NHWC tensor for attention & mlp
        x = rearrange(x, "N C H W -> N (H W) C")
        # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

        # attention & mlp
        if self.pre_norm:
            if self.use_layer_scale:
                x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) # (N, H, W, C)
                x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma1 * self.attn(self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma2 * self.mlp(self.norm2(c))) # (N, H, W, C)
            else:
                x = x + self.drop_path(self.attn(self.norm1(x))) # (N, H, W, C)
                x = x + self.drop_path(self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(self.attn(self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.mlp(self.norm2(c))) # (N, H, W, C)
        else: # https://kexue.fm/archives/9009
            if self.use_layer_scale:
                x = self.norm1(x + self.drop_path(self.gamma1 * self.attn(x))) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.gamma2 * self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(self.gamma1 * self.attn(c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.gamma2 * self.mlp(c))) # (N, H, W, C)
            else:
                x = self.norm1(x + self.drop_path(self.attn(x))) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(self.attn(c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.mlp(c))) # (N, H, W, C)
        x = rearrange(x, "N (H W) C -> N C H W",H=H,W=W)
        # permute back
        # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return x, c
    
    def forward(self, x, c):
        if self.attn_type == "D" or self.attn_type == "D2":
            return self.forward_with_xc(x,c)
        elif self.attn_type == "S":
            return self.forward_with_x(x,c)
        elif self.attn_type == "C":
            return self.forward_with_c(x,c)
        else:
            raise NotImplementedError("Attention type does not exit")


class LeMeViT(nn.Module):
    def __init__(self, 
                 depth=[2, 3, 4, 8, 3], 
                 in_chans=3, 
                 num_classes=1000, 
                 embed_dim=[64, 64, 128, 320, 512], 
                 head_dim=64, 
                 mlp_ratios=[4, 4, 4, 4, 4], 
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop=0., 
                 drop_path_rate=0.,
                 # <<<------
                 attn_type=["C","D","D","S","S"],
                 queries_len=128,
                 qk_dims=None,
                 cpe_ks=3,
                 pre_norm=True,
                 mlp_dwconv=False,
                 representation_size=None,
                 layer_scale_init_value=-1,
                 use_checkpoint_stages=[],
                 # ------>>>
                 ):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        qk_dims = qk_dims or embed_dim
        
        self.num_stages = len(attn_type)
        
        ############ downsample layers (patch embeddings) ######################
        self.downsample_layers = nn.ModuleList()
        # NOTE: uniformer uses two 3*3 conv, while in many other transformers this is one 7*7 conv 
        stem = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim[0] // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(embed_dim[0] // 2),
            nn.GELU(),
            nn.Conv2d(embed_dim[0] // 2, embed_dim[0], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(embed_dim[0]),
        )

        if use_checkpoint_stages:
            stem = checkpoint_wrapper(stem)
        self.downsample_layers.append(stem)

        for i in range(self.num_stages-1):
            if attn_type[i] == "C":
                downsample_layer = nn.Identity()
            else:
                downsample_layer = nn.Sequential(
                    nn.Conv2d(embed_dim[i], embed_dim[i+1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
                    nn.BatchNorm2d(embed_dim[i+1])
                )
            if use_checkpoint_stages:
                downsample_layer = checkpoint_wrapper(downsample_layer)
            self.downsample_layers.append(downsample_layer)
        ##########################################################################


        #TODO: maybe remove last LN
        self.queries_len = queries_len
        self.meta_tokens = nn.Parameter(torch.randn(self.queries_len ,embed_dim[0]), requires_grad=True) 
        
        self.meta_token_downsample = nn.ModuleList()
        meta_token_downsample = nn.Sequential(
            nn.Linear(embed_dim[0], embed_dim[0] * 4),
            nn.LayerNorm(embed_dim[0] * 4),
            nn.GELU(),
            nn.Linear(embed_dim[0] * 4, embed_dim[0]),
            nn.LayerNorm(embed_dim[0])
        )
        self.meta_token_downsample.append(meta_token_downsample)
        for i in range(self.num_stages-1):
            meta_token_downsample = nn.Sequential(
                nn.Linear(embed_dim[i], embed_dim[i] * 4),
                nn.LayerNorm(embed_dim[i] * 4),
                nn.GELU(),
                nn.Linear(embed_dim[i] * 4, embed_dim[i+1]),
                nn.LayerNorm(embed_dim[i+1])
            )
            self.meta_token_downsample.append(meta_token_downsample)

        
        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
        nheads= [dim // head_dim for dim in qk_dims]
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] 
        cur = 0
        for i in range(self.num_stages):
            stage = nn.ModuleList(
                [LeMeBlock(dim=embed_dim[i], 
                           attn_drop=attn_drop, proj_drop=drop_rate,
                           drop_path=dp_rates[cur + j],
                           attn_type=attn_type[i],
                           layer_scale_init_value=layer_scale_init_value,
                           num_heads=nheads[i],
                           qk_dim=qk_dims[i],
                           mlp_ratio=mlp_ratios[i],
                           mlp_dwconv=mlp_dwconv,
                           cpe_ks=cpe_ks,
                           pre_norm=pre_norm
                    ) for j in range(depth[i])],
            )
            if i in use_checkpoint_stages:
                stage = checkpoint_wrapper(stage)
            self.stages.append(stage)
            cur += depth[i]

        ##########################################################################
        self.norm = nn.BatchNorm2d(embed_dim[-1])
        self.norm_c = nn.LayerNorm(embed_dim[-1])
        # Representation layer
        if representation_size:
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc', nn.Linear(embed_dim, representation_size)),
                ('act', nn.Tanh())
            ]))
        else:
            self.pre_logits = nn.Identity()

        # Classifier head
        self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x, c):
        for i in range(self.num_stages): 
            x = self.downsample_layers[i](x)
            c = self.meta_token_downsample[i](c)
            for j, block in enumerate(self.stages[i]):
                x, c = block(x, c)
        x = self.norm(x)
        x = self.pre_logits(x)
        
        c = self.norm_c(c)
        c = self.pre_logits(c)

        # x = x.flatten(2).mean(-1,keepdim=True)
        # c = c.transpose(-2,-1).contiguous().mean(-1,keepdim=True)
        # x = torch.concat([x,c],dim=-1).mean(-1)

        x = x.flatten(2).mean(-1)
        c = c.transpose(-2,-1).contiguous().mean(-1)
        x = x + c

        return x

    def forward(self, x):
        B, _, H, W = x.shape 
        c = self.meta_tokens.repeat(B,1,1)
        x = self.forward_features(x, c)
        x = self.head(x)
        return x

同时,在整体架构的基础上,作者设计了三种不同大小的模型,即 Tiny、Small 和 Base。通过调整每个阶段的块数和特征的尺寸来定制这些尺寸,其他配置在所有变体之间共享。我们将每个关注的头尺寸设置为 32,MLP 扩展率为 4,条件位置编码核大小为 3。元标记的长度设置为 16。

@register_model
def lemevit_tiny(pretrained=False, pretrained_cfg=None,
                  pretrained_cfg_overlay=None, **kwargs):
    model = LeMeViT(
        depth=[1, 2, 2, 8, 2],
        embed_dim=[64, 64, 128, 192, 320], 
        head_dim=32,
        mlp_ratios=[4, 4, 4, 4, 4],
        attn_type=["C","D","D","S","S"],
        queries_len=16,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.,
        qk_dims=None,
        cpe_ks=3,
        pre_norm=True,
        mlp_dwconv=False,
        representation_size=None,
        layer_scale_init_value=-1,
        use_checkpoint_stages=[],
        **kwargs)
    model.default_cfg = _cfg()

    if pretrained:
        checkpoint = torch.load(pretrained, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"])

    return model


@register_model
def lemevit_small(pretrained=False, pretrained_cfg=None,
                  pretrained_cfg_overlay=None, **kwargs):
    model = LeMeViT(
        depth=[1, 2, 2, 6, 2],
        embed_dim=[96, 96, 192, 320, 384], 
        head_dim=32,
        mlp_ratios=[4, 4, 4, 4, 4],
        attn_type=["C","D","D","S","S"],
        queries_len=16,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.,
        qk_dims=None,
        cpe_ks=3,
        pre_norm=True,
        mlp_dwconv=False,
        representation_size=None,
        layer_scale_init_value=-1,
        use_checkpoint_stages=[],
        **kwargs)
    model.default_cfg = _cfg()

    if pretrained:
        checkpoint = torch.load(pretrained, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"])

    return model


@register_model
def lemevit_base(pretrained=False, pretrained_cfg=None,
                  pretrained_cfg_overlay=None, **kwargs):
    model = LeMeViT(
        depth=[2, 4, 4, 18, 4],
        embed_dim=[96, 96, 192, 384, 512], 
        head_dim=32,
        mlp_ratios=[4, 4, 4, 4, 4],
        attn_type=["C","D","D","S","S"],
        queries_len=16,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.,
        qk_dims=None,
        cpe_ks=3,
        pre_norm=True,
        mlp_dwconv=False,
        representation_size=None,
        layer_scale_init_value=-1,
        use_checkpoint_stages=[],
        **kwargs)
    model.default_cfg = _cfg()

    if pretrained:
        checkpoint = torch.load(pretrained, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"])

    return model

性能测试

现在浅浅试一下模型在图像分类上的表现,我选择其中的tiny和small两个版本。使用自制的葡萄多光条件数据集:包括4个类别和3种光照条件共12种数据集,分成8:1=训练集:测试集。

我的训练代码:

import json
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.utils import accuracy, AverageMeter, ModelEma
from sklearn.metrics import classification_report
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from models.lemevit import lemevit_small_v2
from torch.autograd import Variable
from torchvision import datasets
torch.backends.cudnn.benchmark = False
import warnings
warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES']="0,1"
import pandas as pd
from torchvision.transforms import RandAugment
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch,model_ema):
    model.train()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device, non_blocking=True), Variable(target).to(device,non_blocking=True)
        samples, targets = mixup_fn(data, target)
        output = model(data)
        optimizer.zero_grad()
        if use_amp:
            with torch.cuda.amp.autocast():
                loss = torch.nan_to_num(criterion_train(output, targets))
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss = criterion_train(output, targets)
            loss.backward()
            optimizer.step()

        if model_ema is not None:
            model_ema.update(model)
        torch.cuda.synchronize()
        lr = optimizer.state_dict()['param_groups'][0]['lr']
        loss_meter.update(loss.item(), target.size(0))
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))
    ave_loss =loss_meter.avg
    acc = acc1_meter.avg
    print('epoch:{}\tloss:{:.2f}\tacc:{:.2f}'.format(epoch, ave_loss, acc))
    return ave_loss, acc


# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
    global Best_ACC
    model.eval()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    val_list = []
    pred_list = []

    for data, target in test_loader:
        for t in target:
            val_list.append(t.data.item())
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        output = model(data)
        loss = criterion_val(output, target)
        _, pred = torch.max(output.data, 1)
        for p in pred:
            pred_list.append(p.data.item())
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))
    acc = acc1_meter.avg
    print('\nVal set: Average loss: {:.4f}\tAcc1:{:.3f}%\tAcc5:{:.3f}%\n'.format(
        loss_meter.avg, acc, acc5_meter.avg))
    if acc > Best_ACC:
        if isinstance(model, torch.nn.DataParallel):
            torch.save(model.module, file_dir + '/' + 'best.pth')
        else:
            torch.save(model, file_dir + '/' + 'best.pth')
        Best_ACC = acc
    if isinstance(model, torch.nn.DataParallel):
        state = {

            'epoch': epoch,
            'state_dict': model.module.state_dict(),
            'Best_ACC': Best_ACC
        }
        if use_ema:
            state['state_dict_ema'] = model.module.state_dict()
        torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
    else:
        state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'Best_ACC': Best_ACC
        }
        if use_ema:
            state['state_dict_ema'] = model.state_dict()
        torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
    return val_list, pred_list, loss_meter.avg, acc


def seed_everything(seed=0):
    os.environ['PYHTONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


if __name__ == '__main__':
    file_dir = 'checkpoints/LEMEVIT-small/'
    if os.path.exists(file_dir):
        print('true')
        os.makedirs(file_dir,exist_ok=True)
    else:
        os.makedirs(file_dir)

    model_lr = 1e-3
    BATCH_SIZE = 16
    EPOCHS = 50
    DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    use_amp =True
    use_dp = True
    classes = 4
    resume =None
    CLIP_GRAD = 5.0
    Best_ACC = 0
    use_ema=False
    model_ema_decay=0.9995
    start_epoch=1
    seed=1
    seed_everything(seed)
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std= [0.5, 0.5, 0.5])

    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std= [0.5, 0.5, 0.5])
    ])
    mixup_fn = Mixup(
        mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
        prob=0.1, switch_prob=0.5, mode='batch',
        label_smoothing=0.1, num_classes=classes)

    dataset_train = datasets.ImageFolder('dataset/train', transform=transform)
    dataset_test = datasets.ImageFolder("dataset/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True,drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

    criterion_train = SoftTargetCrossEntropy()
    criterion_val = torch.nn.CrossEntropyLoss()
    model_ft = lemevit_small_v2(pretrained=False)
    num_fr=model_ft.head.in_features
    model_ft.head =nn.Linear(num_fr,classes)
    print(model_ft)
    if resume:
        model=torch.load(resume)
        print(model['state_dict'].keys())
        model_ft.load_state_dict(model['state_dict'],strict = False)
        Best_ACC=model['Best_ACC']
        start_epoch=model['epoch']+1
    model_ft.to(DEVICE)
    optimizer = optim.AdamW(model_ft.parameters(),lr=model_lr)
    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=40, eta_min=5e-8)
    if use_amp:
        scaler = torch.cuda.amp.GradScaler()
    if torch.cuda.device_count() > 1 and use_dp:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model_ft = torch.nn.DataParallel(model_ft)
    if use_ema:
        model_ema = ModelEma(
            model_ft,
            decay=model_ema_decay,
            device=DEVICE,
            resume=resume)
    else:
        model_ema=None

    # 训练与验证
    is_set_lr = False
    log_dir = {}
    train_loss_list, val_loss_list, train_acc_list, val_acc_list, epoch_list = [], [], [], [], []
    epoch_info = []
    if resume and os.path.isfile(file_dir+"result.json"):
        with open(file_dir+'result.json', 'r', encoding='utf-8') as file:
            logs = json.load(file)
            train_acc_list = logs['train_acc']
            train_loss_list = logs['train_loss']
            val_acc_list = logs['val_acc']
            val_loss_list = logs['val_loss']
            epoch_list = logs['epoch_list']
    for epoch in range(start_epoch, EPOCHS + 1):
        epoch_list.append(epoch)
        log_dir['epoch_list'] = epoch_list
        train_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema)
        train_loss_list.append(train_loss)
        train_acc_list.append(train_acc)
        log_dir['train_acc'] = train_acc_list
        log_dir['train_loss'] = train_loss_list
        if use_ema:
            val_list, pred_list, val_loss, val_acc = val(model_ema.ema, DEVICE, test_loader)
        else:
            val_list, pred_list, val_loss, val_acc = val(model_ft, DEVICE, test_loader)
        val_loss_list.append(val_loss)
        val_acc_list.append(val_acc)
        log_dir['val_acc'] = val_acc_list
        log_dir['val_loss'] = val_loss_list
        log_dir['best_acc'] = Best_ACC
        with open(file_dir + '/result.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(log_dir))
        print(classification_report(val_list, pred_list, target_names=dataset_train.class_to_idx))
        epoch_info.append({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc
        })
        df = pd.DataFrame(epoch_info)
        df.to_excel(file_dir + "/epoch_info.xlsx",index=False)
        with open('epoch_info.txt', 'w') as f:
            for epoch_data in epoch_info:
                f.write(f"Epoch: {epoch_data['epoch']}\n")
                f.write(f"Train Loss: {epoch_data['train_loss']}\n")
                f.write(f"Train Acc: {epoch_data['train_acc']}\n")
                f.write(f"Val Loss: {epoch_data['val_loss']}\n")
                f.write(f"Val Acc: {epoch_data['val_acc']}\n")
                f.write("\n")
        if epoch < 600:
            cosine_schedule.step()
        else:
            if not is_set_lr:
                for param_group in optimizer.param_groups:
                    param_group["lr"] = 1e-6
                    is_set_lr = True
        fig = plt.figure(1)
        plt.plot(epoch_list, train_loss_list, 'r-', label=u'Train Loss')
        # 显示图例
        plt.plot(epoch_list, val_loss_list, 'b-', label=u'Val Loss')
        plt.legend(["Train Loss", "Val Loss"], loc="upper right")
        plt.xlabel(u'epoch')
        plt.ylabel(u'loss')
        plt.title('Model Loss ')
        plt.savefig(file_dir + "/loss.png")
        plt.close(1)
        fig2 = plt.figure(2)
        plt.plot(epoch_list, train_acc_list, 'g-', label=u'Train Acc')
        plt.plot(epoch_list, val_acc_list, 'y-', label=u'Val Acc')
        plt.legend(["Train Acc", "Val Acc"], loc="lower right")
        plt.title("Model Acc")
        plt.ylabel("acc")
        plt.xlabel("epoch")
        plt.savefig(file_dir + "/acc.png")
        plt.close(2)

测试代码:

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.metrics import recall_score, precision_score, f1_score, accuracy_score

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义自定义数据集
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.image_paths = []
        self.labels = []

        for idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                self.image_paths.append(img_path)
                self.labels.append(idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# 图像转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 加载数据集
dataset_root = 'dataset/sunlight'
dataset = CustomDataset(root_dir=dataset_root, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

# 加载模型
model_path = 'checkpoints/LEMEVIT-small/best.pth'
model = torch.load(model_path)
model.to(device)
model.eval()

# 预测函数
def predict(model, dataloader, device):
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return np.array(all_preds), np.array(all_labels)

# 获取预测结果
predictions, true_labels = predict(model, dataloader, device)

# 计算各个指标
recall = recall_score(true_labels, predictions, average='macro')
precision = precision_score(true_labels, predictions, average='macro')
f1 = f1_score(true_labels, predictions, average='macro')
accuracy = accuracy_score(true_labels, predictions)

# 获取类别名称
class_names = sorted(os.listdir(dataset_root))

print(f'Class names: {class_names}')
print(f'Recall: {recall:.4f}')
print(f'Precision: {precision:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'Accuracy: {accuracy:.4f}')

 结果如下:

可以看到,根据实验结果,LEMEVIT-tiny和LEMEVIT-small在不同光照条件下的表现差异明显。LEMEVIT-tiny在所有条件下的指标均略高于LEMEVIT-small,尤其在阴影和正常光照下,显示出更稳定的性能。而在阳光下,两个模型的性能均有所下降,但LEMEVIT-tiny依然保持较高的精度和召回率。总体而言,LEMEVIT-tiny在光照变化下具有更优越的适应性和稳定性,表现出更高的鲁棒性。当然本结果只针对我的数据集起作用,各位可以自己去实验。

列出模型的地址:https://github.com/ViTAE-Transformer/LeMeViT

以上为全部内容!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1805825.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JS】理解闭包及其应用

历史小剧场 明朝灭亡&#xff0c;并非是简单的政治问题&#xff0c;事实上&#xff0c;这是世界经济史上的一个重要案例。 所谓没钱&#xff0c;就是没有白银。----《明朝那些事儿》 什么是闭包&#xff1f; 闭包就是指有权访问另一个函数作用域中变量的函数 闭包变量存储位置&…

数据结构【堆排序】

前言 在上一篇文章主要讲解了二叉树的基本概念和堆的概念以及接口的实现&#xff08;点此处跳转&#xff09; 我们简回顾下堆的基本概念&#xff1a; 1.堆分为大堆和小堆 大堆&#xff1a;父亲结点比左右孩子都大&#xff0c;根结点是最大的小堆&#xff1a;父亲结点比左右孩…

关于CodeCombat(沙漠)布朗噪声的攻略

关于CodeCombat(沙漠)//布朗噪声的攻略 总的来说怎么猥琐怎么来 1.走到墙角骷髅看不到的位置&#xff0c;让宠物制造噪音&#xff0c;然后英雄走过去&#xff0c;就是这样没错&#xff08;坐标之类能明白) 最后看看运行结果吧 Rec 0002 希望天天开心

CAN协议简介

协议简介 can协议是一种用于控制网络的通信协议。它是一种基于广播的多主机总线网络协议&#xff0c;常用于工业自动化和控制领域。can协议具有高可靠性、实时性强和抗干扰能力强的特点&#xff0c;被广泛应用于汽车、机械、航空等领域。 can协议采用了先进的冲突检测和错误检测…

C++系统编程篇——linux软件包管理器yum

Linux 软件包管理器yum (1)linux系统&#xff08;centos生态&#xff09; 安装方式有三种&#xff1a;源代码安装、rpm安装、yum安装&#xff08;最简单&#xff09; ls /etc/yum.repos.d/ 查看该路径下的文件 包含了用于配置 YUM 软件包管理器的仓库配置文件。这些配置文件…

QT-轻量级的笔记软件MyNote

MyNote v2.0 一个轻量级的笔记软件&#x1f4d4; Github项目地址: https://github.com/chandlerye/MyNote/tree/main 应用简介 MyNote v2.0 是一款个人笔记管理软件&#xff0c;没有复杂的功能&#xff0c;旨在提供便捷的笔记记录、管理以及云同步功能。基于Qt 6.6.3 个人开…

ASUS华硕ROG幻14Air笔记本GA403UI(UI UV UU UJ)工厂模式原厂Windows11系统安装包,带MyASUS in WinRE重置还原

适用型号&#xff1a;GA403UI、GA403UV、GA403UU、GA403UJ 链接&#xff1a;https://pan.baidu.com/s/1tz8PZbYKakfvUoXafQPLIg?pwd1mtc 提取码&#xff1a;1mtc 华硕原装WIN11系统工厂包带有ASUS RECOVERY恢复功能、自带面部识别,声卡,显卡,网卡,蓝牙等所有驱动、出厂主题…

大模型的演进之路:从萌芽到ChatGPT的辉煌

文章目录 ChatGPT&#xff1a;大模型进化史与未来展望引言&#xff1a;大模型的黎明统计模型的奠基深度学习的破晓 GPT系列&#xff1a;预训练革命GPT的诞生&#xff1a;预训练微调的范式转换GPT-2&#xff1a;规模与能力的双重飞跃GPT-3&#xff1a;千亿美元参数的奇迹 ChatGP…

(三)React事件

1. React基础事件绑定 语法&#xff1a; on 事件名称 { 事件处理程序 }&#xff0c;整体上遵循驼峰命名法 App.js //项目根组件 //App -> index.js -> public/index.html(root)function App() {const handleClick () > {console.log(button被点击了)}return (<…

Data Mining2 复习笔记6 - Optimization Hyperparameter Tuning

6. Optimization & Hyperparameter Tuning Why Hyperparameter Tuning? Many learning algorithms for classification, regression, … Many of those have hyperparameters: k and distance function for k nearest neighbors, splitting and pruning options in decis…

【JS】立即执行函数IIFE 和闭包到底是什么关系?

历史小剧场 ”我希望认您作父亲&#xff0c;但又怕您觉得我年纪大&#xff0c;不愿意&#xff0c;索性让我的儿子给您作孙子吧&#xff01;“ ----《明朝那些事儿》 什么是立即执行函数&#xff1f; 特点&#xff1a; 声明一个匿名函数马上调用这个匿名函数销毁这个匿名函数 …

湖南(品牌控价)源点调研 手机价格管理对品牌的影响分析

前言&#xff1a;手机自发明以来&#xff0c;过去一直是国际品牌占主导地位&#xff0c;从最初的爱立信、摩托罗拉&#xff0c;到后来的诺基亚、三星&#xff0c;苹果在这个手机行业里&#xff0c;竞争激励&#xff0c;没有百年企业&#xff0c;每个品牌的盛衰都有背后的历史背…

transformer中对于QKV的个人理解

目录 1、向量点乘 2、相似度计算举例 3、QKV分析 4、整体流程 (1) 首先从词向量到Q、K、V (2) 计算Q*&#xff08;K的转置&#xff09;&#xff0c;并归一化之后进行softmax (3) 使用刚得到的权重矩阵&#xff0c;与V相乘&#xff0c;计算加权求和。 5、多头注意力 上面…

VMware Fusion 如何增加linux硬盘空间并成功挂载

文章目录 0. 前言1. 增加硬盘空间2. 硬盘分区2.1 查看硬盘2.2 分区2.3 格式化2.4 挂载 3. 参考 0. 前言 如果发现虚拟机分配的硬盘不足&#xff0c;需要增加硬盘空间。本文教给大家如何增加硬盘空间并成功挂载。 查看当前硬盘使用情况&#xff1a; df -h可以看到&#xff0c…

使用 GPT-4 创作高考作文 2024年

使用 GPT-4 创作高考作文 2024年 使用 GPT-4 创作高考作文&#xff1a;技术博客指南 &#x1f914;✨摘要引言正文内容&#xff08;详细介绍&#xff09; &#x1f4da;&#x1f4a1;什么是 GPT-4&#xff1f;高考作文题目分析 ✍️&#x1f9d0;新课标I卷 人类智慧的进步&…

二次规划问题(Quadratic Programming, QP)原理例子

二次规划(Quadratic Programming, QP) 二次规划(Quadratic Programming, QP)是优化问题中的一个重要类别,它涉及目标函数为二次函数并且线性约束条件的优化问题。二次规划在控制系统、金融优化、机器学习等领域有广泛应用。下面详细介绍二次规划问题的原理和求解过程 二…

k8s学习--kubernetes服务自动伸缩之垂直伸缩(资源伸缩)VPA详细解释与安装

文章目录 前言VPA简介简单理解详细解释VPA的优缺点优点1.自动化资源管理2.资源优化3.性能和稳定性提升5.成本节约6.集成性和灵活性 缺点1.Pod 重启影响可用性2.与 HPA 冲突3.资源监控和推荐滞后&#xff1a;4.实现复杂度&#xff1a; 核心概念Resource Requests 和 Limits自动调…

多曝光融合算法(三)cv2.createAlignMTB()多曝光图像融合的像素匹配问题

文章目录 1.cv2.createAlignMTB() 主要是计算2张图像的位移&#xff0c;假设位移移动不大2.多曝光图像的aline算法&#xff1a;median thresold bitmap原理讲解3.图像拼接算法stitch4.多曝光融合工具箱 1.cv2.createAlignMTB() 主要是计算2张图像的位移&#xff0c;假设位移移动…

开发做前端好还是后端好?

目录 一、引言 二、两者的对比分析 技能要求和专业知识&#xff1a; 职责和工作内容&#xff1a; 项目类型和应用领域&#xff1a; 就业前景和市场需求&#xff1a; 三、技能转换和跨领域工作 评估当前技能&#xff1a; 确定目标领域&#xff1a; 掌握相关框架和库&a…

端午节大家都放假了吗

端午节作为中国四大传统节日之一&#xff0c;具有深厚的文化内涵和广泛的群众基础&#xff0c;因此&#xff0c;在这个节日里发布软文&#xff0c;可以围绕其传统习俗、美食文化、家庭团聚等方面展开&#xff0c;以吸引读者的兴趣。 首先&#xff0c;可以从端午节的起源和传统习…