unet中的attn_processor的修改(用于设计新的注意力模块)

news2024/11/4 21:45:39

参考资料

文章目录

  • unet中的一些变量的数据情况
    • attn_processor
    • unet.config
    • unet_sd
  • 自己定义自己的attn Processor ,对原始的attn Processor进行修改

IP-adapter中设置attn的方法
参考的代码: 腾讯ailabipadapter 的官方训练代码

unet中的一些变量的数据情况

# init adapter modules
	#用来存储自己重构后的注意力处理器字典
    attn_procs = {}
    unet_sd = unet.state_dict()
    for name in unet.attn_processors.keys():
    	#如果是自注意力注意力attn1,那么设置为空,否则设置为交叉注意力的维度
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        #这里记录此时这个快的通道式
        if name.startswith("mid_block"):
        #'block_out_channels', [320, 640, 1280, 1280]
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
        #name中的,up_block.的后一个位置就是表示是第几个上块
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor()
        else:
            layer_name = name.split(".processor")[0]
            weights = {
            #这里是从unet_sd当中把这个交叉注意力层的原始kv权重拷贝一份出来
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }
            #然后这里将新构建的字典里面的attn_processor给替换为自己定义的IPAttnProcessor
            attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
            #这里将新构建的attn模型的权重初始化为原来的SD的uent中的crossattn的权重
            attn_procs[name].load_state_dict(weights)
    #最后这里将unet的注意力处理器设置为自己重构后的注意力字典
    unet.set_attn_processor(attn_procs)

attn_processor

unet中的unet.state_dict()存储了所有attn_processor的字典
我们要做修改的话,重构一个类似的字典,然后把其中我们需要修改的模块的attn_processor的类型进行替换

我们来看一下unet.attn_processors是什么样子的
在这里插入图片描述
unet.attn_processors是一个字典,包含32个元素
它的 key 是每个处理类所在位置,并结合unet的结构以及其中中crossattn块的个数(总共2,2,2,1,3,3,3(16个块)(每个块分别有一个自注意力和一个交叉注意力模块,所以总共有32个注意力块)),
我们知道了每块的名称的命名的含义:
比如:

'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor'
down_blocks.0.(可以是0,12,)(有3个下块)
代表第一个下块

attentions.0.(可以是0,1)(每个下块有2个transformer块)
代表第一个下块中的第一个transformer块

transformer_blocks.0.
这里都是0

attn1.processor(每个transformer块有2和注意快,一个交叉注意力,一个自注意力)
代表是自注意力还是交叉注意力(attn2.代表交叉注意力层,attn1代表自注意力层)

unet.config

unet.config 是unet配置的参数

FrozenDict([('sample_size', 64),
 ('in_channels', 4), 
('out_channels', 4), 
('center_input_sample', False), ('flip_sin_to_cos', True), ('freq_shift', 0), ('down_block_types', ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D']), ('mid_block_type', 'UNetMidBlock2DCrossAttn'), ('up_block_types', ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D']), 
('only_cross_attention', False), 
('block_out_channels', [320, 640, 1280, 1280]),
 ('layers_per_block', 2), ('downsample_padding', 1), ('mid_block_scale_factor', 1), ('dropout', 0.0), ('act_fn', 'silu'), ('norm_num_groups', 32), 
 ('norm_eps', 1e-05), ('cross_attention_dim', 768), ('transformer_layers_per_block', 1), ('reverse_transformer_layers_per_block', None),
  ('encoder_hid_dim', None), ('encoder_hid_dim_type', None), ('attention_head_dim', 8), ('num_attention_heads', None), ('dual_cross_attention', False), ('use_linear_projection', False), ('class_embed_type', None), ('addition_embed_type', None), ('addition_time_embed_dim', None), ('num_class_embeds', None), ('upcast_attention', False), ('resnet_time_scale_shift', 'default'), ('resnet_skip_time_act', False), ('resnet_out_scale_factor', 1.0), ('time_embedding_type', 'positional'), ('time_embedding_dim', None), ('time_embedding_act_fn', None), ('timestep_post_act', None), ('time_cond_proj_dim', None), ('conv_in_kernel', 3), ('conv_out_kernel', 3), ('projection_class_embeddings_input_dim', None), ('attention_type', 'default'), ('class_embeddings_concat', False), ('mid_block_only_cross_attention', None), ('cross_attention_norm', None), ('addition_embed_type_num_heads', 64), ('_use_default_values', ['addition_embed_type', 'encoder_hid_dim', 'transformer_layers_per_block', 'addition_embed_type_num_heads', 'upcast_attention', 'conv_in_kernel', 'attention_type', 'resnet_out_scale_factor', 'time_embedding_dim', 'time_embedding_act_fn', 'conv_out_kernel', 'reverse_transformer_layers_per_block', 'mid_block_type', 'class_embeddings_concat', 'time_embedding_type', 'use_linear_projection', 'class_embed_type', 'only_cross_attention', 'resnet_time_scale_shift', 'encoder_hid_dim_type', 'projection_class_embeddings_input_dim', 'dual_cross_attention', 'addition_time_embed_dim', 'cross_attention_norm', 'dropout', 'timestep_post_act', 'resnet_skip_time_act', 'num_attention_heads', 'time_cond_proj_dim', 'mid_block_only_cross_attention', 'num_class_embeds']), ('_class_name', 'UNet2DConditionModel'), ('_diffusers_version', '0.6.0'), ('_name_or_path', '/media/dell/DATA/RK/pretrained_model/stable-diffusion-v1-5')])

unet_sd

这里面是一个字典,包含了所有层的各个小模块的权重
在这里插入图片描述
这里是从unet_sd当中把这个交叉注意力层的原始kv权重拷贝一份出来,用于初始化自己设计的注意力处理器

 weights = {
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }

查看修改后unet的attn_processors

这里将unet.attn_processors的所有values()转化为list

adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

这里是IPadapter替换后的attn processor 的情况

ModuleList(
  (0): AttnProcessor2_0()
  (1): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (2): AttnProcessor2_0()
  (3): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (4): AttnProcessor2_0()
  (5): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (6): AttnProcessor2_0()
  (7): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (8): AttnProcessor2_0()
  (9): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (10): AttnProcessor2_0()
  (11): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (12): AttnProcessor2_0()
  (13): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (14): AttnProcessor2_0()
  (15): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (16): AttnProcessor2_0()
  (17): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (18): AttnProcessor2_0()
  (19): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (20): AttnProcessor2_0()
  (21): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (22): AttnProcessor2_0()
  (23): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (24): AttnProcessor2_0()
  (25): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (26): AttnProcessor2_0()
  (27): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (28): AttnProcessor2_0()
  (29): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (30): AttnProcessor2_0()
  (31): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
)

自己定义自己的attn Processor ,对原始的attn Processor进行修改

在原始的attention_processor.py 文件中定义新的attn processor类

原始的attention_processor中的attn processor

class AttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

#3 ipadapter 新定义的
class IPAttnProcessor(nn.Module):
    r"""
    Attention processor for IP-Adapater.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
        super().__init__()

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            # get encoder_hidden_states, ip_hidden_states
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states, ip_hidden_states = (
                encoder_hidden_states[:, :end_pos, :],
                encoder_hidden_states[:, end_pos:, :],
            )
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # for ip-adapter
        ip_key = self.to_k_ip(ip_hidden_states)
        ip_value = self.to_v_ip(ip_hidden_states)

        ip_key = attn.head_to_batch_dim(ip_key)
        ip_value = attn.head_to_batch_dim(ip_value)

        ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
        self.attn_map = ip_attention_probs
        ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)

        hidden_states = hidden_states + self.scale * ip_hidden_states

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

自己定义两个新的,然后也放如这个文件里面

class StyleAttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states
class LayoutAttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

然后导入这两个attn processor

from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor, \
    LayoutAttnProcessor, StyleAttnProcessor

替换后的结果如下

这里是将第三个下块,和第1个上块分别替换为layout attn 和 style attn

    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor()
        # 第三个下块的名称开头是这个
        elif name.startswith("down_blocks.2.attentions"):
            attn_procs[name] = LayoutAttnProcessor()
        #第一个上块的名称开头是这个
        elif name.startswith("up_blocks.1.attentions"):
            attn_procs[name] = StyleAttnProcessor()
        else:
            layer_name = name.split(".processor")[0]
            weights = {
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }
            attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
            attn_procs[name].load_state_dict(weights)

修改后 attn_processors 如下

{
'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
),

##  可以看到,这里的attn替换为了我们自己定义的layout  attn
 'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor': LayoutAttnProcessor(), 'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor': LayoutAttnProcessor(), 

## 可以看到,这里的attn替换为了我们自己定义的style  attn
'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 



'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'mid_block.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'mid_block.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
)
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2231224.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

客户端时间 与 服务器时间

对客户端时间和服务器有概念,但从来没有这么直观地观察过。直到有一天打开了长久未使用的mac,第一次对时间有了直观的概念: 打开之后就有了上面这样的提示“您的时钟慢了”… 我看了下电脑的时间,然后打开F12获取了下时间&#x…

VLAN高级特性:VLAN聚合

一、VLAN聚合的概述 在一般的三层交换机中,通常是采用一个VLAN对应一个VLANIF接口实现广播域之间的互通,这导致了在一些情况下造成了IP地址的浪费。 因为一个VLAN对应的子网中,子网号,子网广播地址、子网网关地址不能用作VLAN内…

Rust 力扣 - 2653. 滑动子数组的美丽值

文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 我们遍历长度为k的的窗口 因为数据范围比较小,所以我们可以通过计数排序找到窗口中第k小的数 如果小于0,则该窗口的美丽值为第k小的数如果大于等于0,则该窗口的美丽值为0 题…

2024网鼎杯青龙组wp:Crypto1

题目 附件内容如下 from Crypto.Util.number import * from secret import flag from Cryptodome.PublicKey import RSAp getPrime(512) q getPrime(512) n p * q d getPrime(299) e inverse(d,(p-1)*(q-1)) m bytes_to_long(flag) c pow(m,e,n) hint1 p >> (51…

《JVM第2课》类加载子系统(类加载器、双亲委派)

类加载系统加载类时分为三个步骤,加载、链接、初始化,下面展开介绍。 文章目录 1 类加载器1.1 引导类加载器(BootStrapClassLoader)1.2 拓展类加载器(ExtClassLoader)1.3 应用类加载器(AppClas…

记住电机原理及几个重要公式,搞清楚电机so easy

电机作为电力转换设备,在现代工业、交通以及生活中发挥着无处不在的作用。无论是微型电动机还是大型发电机,它们的工作原理均基于一定的物理学和电磁学原理。 一、电机的基本原理 电机的基本原理可以概括为电能与机械能之间的相互转换。电动机通过电流在…

软件(2)

操作系统 windows、unix、linux、dos都属于操作系统 操作系统的核心部分的主要特点是【常驻内存】 【多用户分时系统】是当今计算机操作系统中最普遍使用的一类操作系统 操作系统的主要功能是【调度】、【监控】和【维护】计算机系统 负责管理计算机中各种独立的硬件&#xff0…

深度学习常用开源数据集介绍【持续更新】

DIV2K 介绍:DIV2K是一个专为 图像超分辨率(SR) 任务设计的高质量数据集,广泛应用于计算机视觉领域的研究和开发。它包含800张高分辨率(HR)训练图像和100张高分辨率验证图像,每张图像都具有极高…

计算机图形学中向量相关知识chuizhi

一、向量加法 平行四边形法则 两个向量统一起点,构成平行四边形,对角线为向量加和的结果 三角形法则 两个向量尾首相连,从a起点连接到b终点,为向量加法的结果 多向量首尾相连的加法结果为第一个向量的起点到最后一个向量的终点…

[LitCTF 2023]只需要nc一下~-好久不见6

先nc一下,连接上 ls打开查看里面有什么文件 cat 查看里面有什么内容 这个 Dockerfile 构建了一个基于 Python 3.11 的镜像,将当前目录的文件复制到镜像的 /app 目录,设置了一个环境变量 FLAG,并将其值写入 /flag.txt 文件。工作目…

软考高级之系统架构师之安全攻防技术

攻防包括攻击和防御两部分。 攻击 安全威胁 信息系统的安全威胁来自于: 物理环境:对系统所用设备的威胁,如:自然灾害,电源故障,数据库故障,设备被盗等造成数据丢失或者信息泄露通信链路&…

VLAN间通信以及ospf配置

目录 1.基础知识介绍 1.1 什么是VLAN? 1.2 VLAN有什么用? 1.3 不同VLAN如何实现通信? 1.4 什么是路由汇总? 1.4.1 路由汇总的好处: 2. 实验 2.1 网络拓扑设计 2.2 实验配置要求 2.2.1 三层交换配置&#xff…

ChatGPT变AI搜索引擎!以后还需要谷歌吗?

前言 在北京时间11月1日凌晨,正值ChatGPT两岁生日之际,OpenAI宣布推出最新的人工智能搜索体验!具备实时网络功能!与 Google 展开直接竞争。 ChatGPT搜索的推出标志着ChatGPT成功消除了即时信息这一最后的短板。 这项新功能可供 …

使用python画一颗圣诞树

具体效果: 完整代码: import random def print_christmas_tree(height): # 打印圣诞树的顶部 for i in range(height): # 打印空格,使树居中 for j in range(height - i - 1): print(" ", end"") # 打印星号&…

省级-碳排放相关数据(1990-2022年)

关键指标: 地区:数据涵盖了中国各省级行政区,为我们提供了一个全面的视角来观察不同地区的碳排放情况。年份:数据跨越了1990年至2022年,这为我们提供了一个长期的时间序列,以观察碳排放的变化趋势。总碳排…

评估 机器学习 回归模型 的性能和准确度

回归 是一种常用的预测模型,用于预测一个连续因变量和一个或多个自变量之间的关系。 那么,最后评估 回归模型 的性能和准确度非常重要,可以帮助我们判断模型是否有效并进行改进。 接下来,和大家分享如何评估 回归模型 的性能和准…

WPF+MVVM案例实战(二十)- 制作一个雷达辐射效果的按钮

文章目录 1、案例效果2、文件创建与代码实现1、创建文件2、图标资源文件3、源代码获取1、案例效果 2、文件创建与代码实现 1、创建文件 打开 Wpf_Examples 项目,在 Views 文件夹下创建窗体界面 RadarEffactWindow.xaml 。代码功能分两个部分完成,一个是样式,一个是动画。页…

小程序配置消息推送

配置以上信息后,点击提交时, 服务器需要配置GET请求,同时验证签名,签名通过后,返回参数echo_str, 切忌: 一定转化为int类型; python fastapi实现代码如下: async def callback_file(request: …

【大模型开发指南】llamaindex配置deepseek、jina embedding及chromadb实现本地RAG及知识库(win系统、CPU适配)

说一些坑,本来之前准备用milvus,但是发现win搞不了(docker都配好了)。然后转头搞chromadb。这里面还有就是embedding一般都是本地部署,但我电脑是cpu的没法玩,我就选了jina的embedding性能较优(…

C++ STL 学习指南:带你快速掌握标准模板库

🌟快来参与讨论💬,点赞👍、收藏⭐、分享📤,共创活力社区。 🌟 大家好呀!🤗 今天我们来聊一聊 C 程序员的必备神器——STL(Standard Template Library&#xf…