DETR纯代码分享(八)position_encoding.py(models)

news2024/12/23 14:28:17

一、导入一些Python库和模块

import math
import torch
from torch import nn

from util.misc import NestedTensor

上面的代码段主要是Python代码,用于导入一些Python库和模块,以下是对每行代码的详细解释:

  1. import math: 这一行代码导入了Python的math模块,该模块提供了各种数学函数和常数,例如三角函数(sincostan)、对数函数(loglog10)以及数学常数如圆周率(math.pi)。您可以使用这些函数和常数来进行各种数学计算。

  2. import torch: 这一行代码导入了PyTorch库,PyTorch是一种流行的深度学习框架。PyTorch通常用于开发神经网络和进行机器学习研究。它提供了创建和训练神经网络、处理张量(多维数组)等功能。

  3. from torch import nn: 这一行代码从PyTorch中导入了nn模块。nn模块提供了各种神经网络层和操作,用于构建神经网络架构。例如,您可以使用nn.Linear创建一个全连接层,nn.Conv2d创建一个卷积层,以及nn.ReLU来应用修正线性单元(ReLU)激活函数。

  4. from util.misc import NestedTensor: 这一行代码从自定义模块util.misc中导入了NestedTensor类。NestedTensor不是标准的PyTorch类,它的功能取决于在util.misc模块中如何定义。

二、 PositionEmbeddingSine

class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

这段代码定义了一个名为 PositionEmbeddingSine 的PyTorch模块,用于计算位置嵌入(Position Embedding)。位置嵌入通常用于将位置信息引入神经网络模型中,特别是在处理序列数据或图像数据时。

1、初始化__init__()
class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats #位置特征的数量
        self.temperature = temperature #温度参数,控制嵌入的缩放
        self.normalize = normalize #是否进行位置嵌入的归一化
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi #默认的嵌入缩放尺度
        self.scale = scale

这个构造函数的主要目的是为 PositionEmbeddingSine 类的实例设置初始属性值

  • num_pos_feats: 这是一个整数,默认为64,表示要生成的位置特征的数量。位置特征是用来表示输入数据中的位置信息的向量。

  • temperature: 这是一个浮点数,默认为10000,用于控制位置嵌入的缩放。较高的温度值会导致更大的嵌入值,而较低的温度值会导致更小的嵌入值。

  • normalize: 这是一个布尔值,默认为False。如果设置为True,位置嵌入将被归一化,以确保它们在一定范围内,通常是[0, 2π]。如果设置为False,则不进行归一化。

  • scale: 这是一个浮点数,默认为None。如果未提供scale参数,它将被设置为2π。scale用于控制位置嵌入的缩放范围。如果normalize为True,那么scale将用于归一化位置嵌入的范围。

最后,如果用户提供了scale参数但未将normalize设置为True,代码会引发ValueError,以防止不一致的参数设置。

2、前向传播方法forward()
    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors #输入张量
        mask = tensor_list.mask #掩码,用于指示输入张量中哪些位置是有效的
        assert mask is not None
        not_mask = ~mask #掩码取反,用于标记哪些位置是无效的
        #计算行方向和列方向上的累计位置信息
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        #如果设置了归一化标志,对位置信息进行归一化处理        
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        #计算位置嵌入
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        #使用正弦和余弦函数来计算位置嵌入的x分量和y分量
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        #拼接位置嵌入的x分量和y分量,并将通道维度移动到正确的位置
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

这个前向传播方法接受一个名为 tensor_list 的输入参数,其中包含了输入张量 x 和掩码 mask

(1)计算掩码mask
    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors # 输入张量
        mask = tensor_list.mask # 掩码,用于指示输入张量中哪些位置是有效的
        assert mask is not None
        not_mask = ~mask # 掩码取反,用于标记哪些位置是无效的
  1. x = tensor_list.tensors: 获取输入参数 tensor_list 中的张量数据,通常是图像数据

  2. mask = tensor_list.mask: 获取 tensor_list 中的掩码信息,掩码指示了哪些位置是有效的(True)和哪些位置是无效的(False)。

  3. assert mask is not None: 确保掩码信息存在。mask 是必需的,因为它用于确定哪些位置需要计算位置嵌入。

  4. not_mask = ~mask: 使用~操作符对掩码取反,创建一个not_mask,用于标记哪些位置是无效的。not_mask 中,True 表示无效的位置,False 表示有效的位置。

(2)计算行方向/列方向上的累计位置信息
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
  1. y_embed = not_mask.cumsum(1, dtype=torch.float32): 计算行方向上的累积位置信息 y_embed。这是通过对not_mask在维度1上进行累积操作实现的,数据类型为torch.float32

  2. x_embed = not_mask.cumsum(2, dtype=torch.float32): 计算列方向上的累积位置信息 x_embed。这是通过对not_mask在维度2上进行累积操作实现的,数据类型为torch.float32

示例:用来计算列方向上的累积位置信息 x_embed,并且使用 dtype=torch.float32 指定数据类型为 32 位浮点数。让我们通过一个简单的例子来说明它的实现。假设我们有一个输入张量 x,它是一个3x4的二维张量,同时有一个掩码 mask 用来指示哪些位置是有效的(True)和哪些位置是无效的(False):

import torch

x = torch.tensor([[1, 2, 3, 4],
                           [5, 6, 7, 8],
                           [9, 10, 11, 12]], dtype=torch.float32)

mask = torch.tensor([[True, True, False, True],
                                  [False, True, True, False],
                                  [True, False, True, True]], dtype=torch.bool)

现在,我们来解释如何使用 not_mask.cumsum(2, dtype=torch.float32) 来计算列方向上的累积位置信息:

  1. not_maskmask 的取反,即标记了哪些位置是无效的(False)。not_mask 现在如下所示:

not_mask = torch.tensor([[False, False,  True, False],
                                         [ True, False, False,  True],
                                         [False,  True, False, False]], dtype=torch.bool)

  1. cumsum(2) 表示在维度2上进行累积操作。维度2是列的维度,所以我们将在每一列上执行累积操作。cumsum 是累积求和的函数。当你在一个张量上应用cumsum时,它会计算该张量中每个元素在指定维度上的累积和。在这个情况下,指定的维度是维度2,也就是列方向。

  2. cumsum 操作会计算每个位置的累积和,从左到右依次累积。得到的 x_embed 张量如下所示:

x_embed = torch.tensor( [ 0.,  0.,  1.,  0.],
                                        [ 1.,  0.,  1.,  1.],
                                        [ 1.,  1.,  1.,  1.], dtype=torch.float32)

在这个示例中,x_embed 是一个与输入张量 x 相同大小的张量,其中每个位置的值表示从该位置的列开始的累积和。

(3)归一化
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  1. eps = 1e-6: 这一行定义了一个小的正数 eps,它是一个极小的值,通常用于数值稳定性。在计算中,它将被添加到分母中,以防止除以零的情况。

  2. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale: 这一行对行方向上的累积位置信息 y_embed 进行归一化操作。具体步骤如下:

    • y_embed[:, -1:, :] 选择每个批次中的最后一行的累积位置信息。结果形状为 (batch_size, 1, num_columns)
    • (y_embed[:, -1:, :] + eps) 在分母中将最后一行的累积位置信息与小的正数 eps 相加,以防止零除法
    • y_embed / (y_embed[:, -1:, :] + eps) 执行元素级除法,将每个位置的值除以最后一行的值(加上 eps)进行归一化。
    • * self.scale 乘以缩放因子 self.scale,以将归一化后的位置信息缩放到所需的范围。
  3. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale: 这一行对列方向上的累积位置信息 x_embed 进行类似的归一化操作,步骤与上述操作相似:

    • x_embed[:, :, -1:] 选择每个批次中的最后一列的累积位置信息。结果形状为 (batch_size, num_rows, 1)
    • (x_embed[:, :, -1:] + eps) 在分母中将最后一列的累积位置信息与小的正数 eps 相加,以防止零除法。
    • x_embed / (x_embed[:, :, -1:] + eps) 执行元素级除法,将每个位置的值除以最后一列的值(加上 eps)进行归一化。
    • * self.scale 乘以缩放因子 self.scale,以将归一化后的位置信息缩放到所需的范围。

这个归一化操作的目的是确保位置信息的范围适应模型的需求,以便模型能够更好地理解不同位置的输入数据。归一化有助于确保不同位置的位置嵌入在相似的尺度上,并提高模型的性能和泛化能力。

(4)计算位置嵌入

        #计算位置嵌入
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
  1. dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device): 这一行创建一个名为 dim_t 的张量,用于表示位置嵌入的维度。具体解释如下:

    • self.num_pos_feats: 这是一个类的属性,它指定了要使用的位置嵌入的维度数。在这里,它代表位置编码的特征维度数。例如,如果 self.num_pos_feats 设置为 64,则将生成一个包含 64 个不同特征的位置编码。

    • torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device): 这一部分使用 PyTorch 的 torch.arange 函数创建了一个张量,它包含从 0 到 self.num_pos_feats - 1 的一系列数字。这些数字将用作位置编码的特征索引。

    • dtype=torch.float32, device=x.device: 通过指定数据类型为 torch.float32 和设备为 x.device,确保 dim_t 张量的数据类型与输入张量 x 的数据类型和设备一致。

  2. dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats): 这一行计算了用于位置嵌入的温度参数 dim_t,具体步骤如下:

    • 2 * (dim_t // 2): 这一部分首先将 dim_t 中的每个元素除以 2,然后乘以 2。这样做的目的是将 dim_t 中的所有奇数索引位置的元素都设置为零,而偶数索引位置的元素保持不变。这是因为位置编码通常采用正弦和余弦函数来构建,其中奇数索引位置的元素对应于正弦函数,而偶数索引位置的元素对应于余弦函数。

    • (2 * (dim_t // 2) / self.num_pos_feats): 接着,将上一步计算的结果除以 self.num_pos_feats。这一步将确保温度参数 dim_t 在不同的位置嵌入特征之间共享,并且其值在一个合适的范围内,以适应模型的需求。

  3. pos_x = x_embed[:, :, :, None] / dim_t: 这一行计算位置嵌入的 x 分量。具体步骤如下:

    • x_embed: 这是之前计算的列方向上的累积位置信息,它的形状为 (batch_size, num_rows, num_columns)

    • x_embed[:, :, :, None]: 通过添加一个额外的维度 None,将 x_embed 的形状从 (batch_size, num_rows, num_columns) 扩展为 (batch_size, num_rows, num_columns, 1)这是为了在接下来的操作中可以对 x_embed 的每个位置进行元素级别的除法。

    • / dim_t: 执行元素级别的除法操作,将 x_embed 的每个位置的值除以对应位置的 dim_t 值。这将对位置信息进行缩放,以适应模型的需求。这将对两个张量进行广播操作,使 dim_t 在最后一个维度上被复制以匹配 x_embed 的形状。因此,pos_x 的形状将与 x_embed 保持一致,即 (batch_size, num_rows, num_columns, num_pos_feats)

  4. pos_y = y_embed[:, :, :, None] / dim_t: 这一行计算位置嵌入的 y 分量,步骤与计算 pos_x 相似,只是使用了行方向上的累积位置信息 y_embed

总之,这两行代码将原始的位置信息 x_embedy_embed 进行了归一化和缩放,得到了位置嵌入的 x 和 y 分量。这些位置嵌入将被用于表示输入数据的位置信息,并与输入数据相结合,以帮助模型更好地理解不同位置的输入信息。

(5)将位置嵌入的 x 和 y 分量进行正弦和余弦变换
        #使用正弦和余弦函数来计算位置嵌入的x分量和y分量
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)

这两行代码的目的是将位置嵌入的 x 和 y 分量进行正弦和余弦变换,并将它们合并成一个更高维度的位置嵌入。现在逐行解释:

  1. pos_x[:, :, :, 0::2]: 这一部分使用切片操作 0::2 选择 pos_x 张量的第 0、2、4、6、... 等位置的元素。这些元素对应于位置嵌入的 x 分量的正弦部分。

  2. pos_x[:, :, :, 1::2]: 同样,这一部分使用切片操作 1::2 选择 pos_x 张量的第 1、3、5、7、... 等位置的元素。这些元素对应于位置嵌入的 x 分量的余弦部分。

  3. .sin(): 对于选定的元素,应用正弦函数,将正弦变换应用于 x 分量的部分,得到一个新的张量。

  4. .cos(): 对于另一组选定的元素,应用余弦函数,将余弦变换应用于 x 分量的部分,得到另一个新的张量。

  5. torch.stack(...): 这一部分将正弦和余弦变换的结果在一个新的维度(维度 4)上堆叠在一起,创建一个新的张量。具体来说,它将正弦和余弦部分按维度 4 进行堆叠,以便后续的处理。

  6. .flatten(3): 最后,这一部分将张量在维度 3 上展平,将正弦和余弦部分合并为一个维度,得到最终的位置嵌入。

这个过程实际上是将位置嵌入的 x 和 y 分量变换为正交的正弦和余弦分量,以更好地表示位置信息。这种正弦和余弦变换常用于位置编码,有助于模型更好地捕捉序列数据中的位置关系。同样的操作也适用于 pos_y,用于计算位置嵌入的 y 分量。

(6)最终位置嵌入
        #拼接位置嵌入的x分量和y分量,并将通道维度移动到正确的位置
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

这段代码的目的是将位置嵌入的 x 分量和 y 分量拼接在一起,并重新排列维度,以获得最终的位置嵌入。让我们逐行解释:

  1. torch.cat((pos_y, pos_x), dim=3): 这一部分使用 torch.cat 函数将 pos_ypos_x 张量在维度 3 上进行拼接。因为在前面的步骤中,pos_ypos_x 表示了位置嵌入的 y 分量和 x 分量,它们的形状都是 (batch_size, num_rows, num_columns, num_pos_feats),所以在维度 3 上拼接将它们合并成一个形状为 (batch_size, num_rows, num_columns, num_pos_feats*2) 的张量。

  2. permute(0, 3, 1, 2): 接着,使用 .permute 函数重新排列维度,将维度 0(批大小)、3(通道维度)、1(行数)和2(列数)重新排列,以获得最终的位置嵌入。这个操作确保位置嵌入的维度排列与模型的期望输入一致。

最终,pos 张量将包含位置嵌入的所有信息,其形状为 (batch_size, num_pos_feats*2, num_rows, num_columns),其中 num_pos_feats 是位置编码的特征维度数,而 num_rowsnum_columns 分别是输入数据的行数和列数。这个位置嵌入张量可以与输入数据相结合,以帮助模型更好地理解输入数据中的位置关系。

(自己理解)这里图片的feature维度为256,pos 张量的维度为128(因为分了x和y方向)

三、PositionEmbeddingLearned类

class PositionEmbeddingLearned(nn.Module):
    """
    Absolute pos embedding, learned.
    """
    def __init__(self, num_pos_feats=256):
        super().__init__()
        self.row_embed = nn.Embedding(50, num_pos_feats)
        self.col_embed = nn.Embedding(50, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        h, w = x.shape[-2:]
        i = torch.arange(w, device=x.device)
        j = torch.arange(h, device=x.device)
        x_emb = self.col_embed(i)
        y_emb = self.row_embed(j)
        pos = torch.cat([
            x_emb.unsqueeze(0).repeat(h, 1, 1),
            y_emb.unsqueeze(1).repeat(1, w, 1),
        ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
        return pos

这段代码实现了一个名为 PositionEmbeddingLearned 的类,用于生成学习到的绝对位置嵌入。PositionEmbeddingLearned提供了一种通过学习得到的绝对位置嵌入的方式,这些嵌入可以与输入数据结合使用,以帮助模型理解输入数据中的位置信息。

1、创建位置嵌入__init__()
class PositionEmbeddingLearned(nn.Module):
    """
    Absolute pos embedding, learned.
    """
    def __init__(self, num_pos_feats=256):
        super().__init__()
        self.row_embed = nn.Embedding(50, num_pos_feats)
        self.col_embed = nn.Embedding(50, num_pos_feats)
        self.reset_parameters()

这段代码实现了 PositionEmbeddingLearned 类的构造函数,用于初始化位置嵌入模块。让我们逐行详细解释代码的实现:

  1. def __init__(self, num_pos_feats=256):

    • 这是构造函数的定义,它接受一个可选参数 num_pos_feats,用于指定位置嵌入的特征维度数,默认为 256
  2. super().__init__()

    • 调用了父类 nn.Module 的构造函数,确保正确初始化了 PositionEmbeddingLearned 类。
  3. self.row_embed = nn.Embedding(50, num_pos_feats)

    • 创建了一个名为 row_embed 的属性,它是一个 Embedding 层(嵌入层)。
    • nn.Embedding(50, num_pos_feats) 创建了一个 Embedding 层,该层将 50 个离散的整数作为输入,并将它们映射到一个具有 num_pos_feats 个特征维度的连续空间中。这个层将用于表示行的位置嵌入。
  4. self.col_embed = nn.Embedding(50, num_pos_feats)

    • 创建了一个名为 col_embed 的属性,它也是一个 Embedding 层,与 row_embed 类似,但用于表示列的位置嵌入。
  5. self.reset_parameters()

    • 调用了 reset_parameters 方法,用于初始化 Embedding 层的权重。

总结: 这段代码的主要功能是创建 PositionEmbeddingLearned 类的实例,并初始化两个 Embedding 层 (row_embedcol_embed) 用于表示行和列的位置嵌入。这些位置嵌入的特征维度数由构造函数的参数 num_pos_feats 控制,默认为 256。这些 Embedding 层将在后续的 forward 方法中用于获取位置嵌入的值。

2、reset_parameters()
    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)

这段代码实现了 PositionEmbeddingLearned 类中的 reset_parameters 方法,该方法用于初始化 Embedding 层的权重。让我们逐行详细解释代码的实现:

  1. def reset_parameters(self):

    • 这是 reset_parameters 方法的定义,它属于 PositionEmbeddingLearned 类。
  2. nn.init.uniform_(self.row_embed.weight)

    • 这一行代码使用 PyTorch 的 nn.init 模块中的 uniform_ 函数来初始化 row_embed Embedding 层的权重。
    • self.row_embed.weight 是一个张量,表示 row_embed 层的权重矩阵uniform_ 函数会将这个权重矩阵的值初始化为均匀分布中的随机值。
    • 这样,每个行位置嵌入的权重都会以随机的初始值开始,模型在训练过程中会学习到适合任务的最佳权重值。
  3. nn.init.uniform_(self.col_embed.weight)

    • 这一行代码与上面的行类似,但是针对的是 col_embed Embedding 层的权重矩阵。
    • self.col_embed.weight 是表示 col_embed 层的权重矩阵的张量,它也会被初始化为均匀分布中的随机值。

总结: reset_parameters 方法的作用是在创建 PositionEmbeddingLearned 类的对象时初始化 row_embedcol_embed Embedding 层的权重,以确保它们有合适的初始值,模型可以在训练过程中逐渐调整这些权重以适应特定任务。这种初始化策略有助于模型的收敛和性能提升。

3、forward()
    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        h, w = x.shape[-2:]
        i = torch.arange(w, device=x.device)
        j = torch.arange(h, device=x.device)
        x_emb = self.col_embed(i)
        y_emb = self.row_embed(j)
        pos = torch.cat([
            x_emb.unsqueeze(0).repeat(h, 1, 1),
            y_emb.unsqueeze(1).repeat(1, w, 1),
        ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
        return pos

这段代码实现了 PositionEmbeddingLearned 类的 forward 方法,该方法用于计算学习到的绝对位置嵌入。

  1. x = tensor_list.tensors:

    • 获取输入 tensor_list 中的张量 x,这是要为其计算绝对位置嵌入的输入张量。形状为[batch_size,num_channels,height,width]
  2. h, w = x.shape[-2:]:

    • 获取输入张量 x 的高度和宽度,其中 x.shape[-2:] 表示取张量的倒数第二和倒数第一维度的尺寸。
  3. i = torch.arange(w, device=x.device)j = torch.arange(h, device=x.device):

    • 创建列索引 i 和行索引 j,它们分别包含了从 0 到 w-1 和从 0 到 h-1 的整数值。这些索引用于获取列和行的位置嵌入。
  4. x_emb = self.col_embed(i)y_emb = self.row_embed(j):

    • 使用 Embedding 层 self.col_embedself.row_embed 分别获取列和行的位置嵌入 x_emby_emb。这些位置嵌入是模型学习到的表示。[w,num_pos_feats]和[h,num_pos_feats ]
    • unsqueeze(0) 操作在维度 0 上添加一个维度,将 x_emb 的形状从 (w, num_pos_feats) 变为 (1, w, num_pos_feats)
    • x_emb.unsqueeze(0).repeat(h, 1, 1),它的形状是 (h, w, num_pos_feats)
    • unsqueeze(1) 操作在维度 1 上添加一个维度,将 y_emb 的形状从 (h, num_pos_feats) 变为 (h, 1, num_pos_feats)
    • y_emb.unsqueeze(1).repeat(1, w, 1),repeat(1, w, 1) 操作会沿着维度 1 复制 y_emb,重复 w 次。因此,形状将变为 (h, w, num_pos_feats),其中每个列都是相同的。

    torch.cat([...], dim=-1):

    • torch.cat([ x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1), ], dim=-1), 形状为,(h, w, num_pos_feats * 2)
    • torch.cat([ x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1), ], dim=-1).permute(2, 0, 1),形状为(num_pos_feats*2,h, w)
    • torch.cat([ x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1), ], dim=-1).permute(2, 0, 1).unsqueeze(0),形状为(1,num_pos_feats*2,h, w)
    • x.shape[0]是batch_size
    • torch.cat([ x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1), ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1),使用 .repeat 函数将位置嵌入复制多次,以匹配输入张量 x 的批大小。形状为(batch_size,num_pos_feats*2,h, w)
  5. 最后,返回计算得到的绝对位置嵌入张量 pos,它包含了输入张量 x 中每个位置的位置编码信息。

总结: forward 方法的主要任务是根据输入张量的高度和宽度,以及通过 Embedding 学习到的位置嵌入,计算并返回绝对位置嵌入。这些位置嵌入可以与输入数据结合使用,以帮助模型理解输入数据中的位置信息。

(自己理解)这里图片的feature维度为256,pos 张量的维度为256

四、build_position_encoding()函数

def build_position_encoding(args):
    N_steps = args.hidden_dim // 2  #N_steps = 128,输入是256维的向量
    #本文中分为了x方向上的编码和y方向上的编码(区分图像和词),前128维代表x方向的位置编码,后128维代表y方向的位置编码
    if args.position_embedding in ('v2', 'sine'):
        # TODO find a better way of exposing other arguments
        position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
    elif args.position_embedding in ('v3', 'learned'):
        position_embedding = PositionEmbeddingLearned(N_steps)
    else:
        raise ValueError(f"not supported {args.position_embedding}")

    return position_embedding

这段代码是用于构建位置编码(Position Encoding)的函数 build_position_encoding,根据输入的参数 args 中的配置选择不同的位置编码方式。让我们逐步解释代码的实现:

  1. N_steps = args.hidden_dim // 2

    • N_steps 是一个整数,表示位置编码的步数。它的值被设置为 args.hidden_dim 的一半,其中 args.hidden_dim 表示输入向量的维度,假设它为 256,因此 N_steps 将等于 128。
  2. if args.position_embedding in ('v2', 'sine'):

    • 这个条件语句检查 args.position_embedding 的值是否为 'v2''sine'
    • 如果条件成立,表示要使用正弦位置编码方式。
  3. position_embedding = PositionEmbeddingSine(N_steps, normalize=True)

    • 如果选择使用正弦位置编码,那么会创建一个 PositionEmbeddingSine 类的实例,并传递 N_steps(128)作为位置编码的特征数。
    • normalize=True 表示要对位置编码进行归一化处理。这将在位置编码中应用归一化。
  4. elif args.position_embedding in ('v3', 'learned'):

    • 如果条件不成立,即 args.position_embedding 的值为 'v3''learned',表示要使用学习到的位置编码方式。
  5. position_embedding = PositionEmbeddingLearned(N_steps)

    • 如果选择使用学习到的位置编码,那么会创建一个 PositionEmbeddingLearned 类的实例,并传递 N_steps(128)作为位置编码的特征数。
  6. else

    • 如果 args.position_embedding 的值既不是 'v2' 也不是 'v3',则会引发一个值错误(ValueError),表示不支持该位置编码方式。
  7. 最后,函数返回选定的位置编码器 position_embedding,它可以根据输入数据计算位置编码,用于模型中。

总结: 该函数根据输入参数 args 中的配置选择位置编码方式,可以是正弦位置编码或学习到的位置编码,并返回相应的位置编码器实例。位置编码用于将位置信息引入模型,以帮助模型理解输入数据的空间结构。选择合适的位置编码方式取决于具体的应用需求。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1031406.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Node.js VS Python:程序员该选择哪个作为爬虫语言?

对于程序员来说,选择合适的语言作为爬虫开发工具很重要。在这篇文章中,我们将探讨使用Node.js和Python进行爬虫开发的优势和劣势,帮助你做出明智的选择,并提供一些实际操作价值的建议。 一、Node.js的优势与劣势 1、优势&#xf…

《Playing repeated games with Large Language Models》全文翻译

《Playing repeated games with Large Language Models》- 使用大型语言模型玩重复游戏 论文信息摘要1. 介绍2. 相关工作3. 一般方法4. 分析不同游戏系列的行为5. 囚徒困境5.1 性别之战 6. 讨论 论文信息 题目:《Playing repeated games with Large Language Model…

勇立潮头!高品质SFT语音数据实现Zero-Shot语音复刻大模型

文本到语音合成(Text to Speech,TTS)作为生成式人工智能(Generative AI 或 AIGC)的重要课题,在近年来取得了飞速发展。为了实现高效合成既自然又高质量的人类语音,有不少机构及企业都进行了相关…

安防监控视频AI智能分析网关:人流量统计算法的应用场景汇总

TSINGSEE青犀人流量检测算法是内置在智能分析网关中的一种能够通过AI分析和计算人群数量以及密度的算法技术,在提升城市管理效率、改善用户体验和增加安全性方面发挥着重要作用。人流量检测算法在许多领域都有广泛的应用,如智慧城市、智慧交通、智慧景区…

计算机网络运输层网络层补充

1 CDMA是码分多路复用技术 和CMSA不是一个东西 UPD是只确保发送 但是接收端收到之后(使用检验和校验 除了检验的部分相加 对比检验和是否相等。如果不相同就丢弃。 复用和分用是发生在上层和下层的问题。通过比如时分多路复用 频分多路复用等。TCP IP 应用层的IO多路复用。网…

微软宣布推广数字助理 Copilot;GPT 应用开发和思考

🦉 AI新闻 🚀 微软宣布推广基于生成式人工智能的数字助理 Copilot 摘要:微软宣布将基于生成式人工智能的数字助理 Copilot 推广到更多软件产品中。新的 AI 助理 Microsoft Copilot 将在 Windows 中无缝可用,包括 Windows 11 桌面…

【校招VIP】专业课考点之TCP连接

考点介绍: 在TCP/IP中,TCP协议通过三次握手来建立连接,从而提供可靠的连接服务。本专题主要介绍一线互联网大厂面试关于TCP连接的相关问题。 专业课考点之TCP连接-相关题目及解析内容可点击文章末尾链接查看! 一、考…

软件测试/测试开发丨利用人工智能ChatGPT自动生成架构图

点此获取更多相关资料 简介 架构图通过图形化的表达方式,用于呈现系统、软件的结构、组件、关系和交互方式。一个明确的架构图可以更好地辅助业务分析、技术架构分析的工作。架构图的设计是一个有难度的任务,设计者必须要对业务、相关技术栈都非常清晰…

蓝桥杯打卡第14天

文章目录 最短路径最短路径 一、最短路径OJ链接 本题思路:本题是一道简单 的图论题,用floyd算法还是比较简单的,因为代码很短,这里需要用一个backup用来保存未删除边时的情况。当走完一次floyd之后,拷贝给dist数组来进行删除边的…

轻松搞定Spring集成缓存,让你的应用程序飞起来!

Spring集成缓存 缓存接口开启注解缓存注解使用CacheableCachePutCacheEvictCachingCacheConfig 缓存存储使用 ConcurrentHashMap 作为缓存使用 Ehcache 作为缓存使用 Caffeine 作为缓存 主页传送门:📀 传送 Spring 提供了对缓存的支持,允许你…

威联通NAS安装Openwrt旁路由教程

Hello大家好,有一段时间没有折腾NAS了 ,最近搞了一台威联通的TS-464C2,平时用来存储一下数据什么的,感觉有点浪费,刚好威联通自带有虚拟机的软件,直接拿来装个软路系统岂不是美滋滋。 首先说一下这个机器…

Python经典练习题(一)

文章目录 🍀第一题🍀第二题🍀第三题🍀第四题🍀第五题 🍀第一题 有四个数字:1、2、3、4,能组成多少个互不相同且无重复数字的三位数?各是多少? 这里我们使用…

【湖科大教书匠】计算机网络随堂笔记第1章(计算机网络概述)

目录 1.1、计算机网络在信息时代的作用 我国互联网发展状况 1.2、因特网概述 1、网络、互连网(互联网)和因特网 2、因特网发展的三个阶段 因特网服务提供者ISP(Internet Service Provider) 基于ISP的三层结构的因特网 3、因特网的标准化工作 4、因特网的…

基于PHP语言研发的抖音矩阵系统源代码开发部署技术文档分享

一、概述 本技术文档旨在介绍抖音SEO矩阵系统源代码的开发部署流程,以便开发者能够高效地开发、测试和部署基于PHP语言的开源系统。通过本文档的指引,您将能够掌握抖音SEO矩阵系统的开发环境和部署方案,从而快速地构建出稳定、可靠的短视频S…

如何解决 Spring Boot Actuator 的未授权访问漏洞

Spring Boot Actuator 的作用是提供了一组管理和监控端点,允许你查看应用程序的运行时信息,例如健康状态、应用程序信息、性能指标等。这些端点对于开发、测试 和运维团队来说都非常有用,可以帮助快速诊断问题、监控应用程序的性能&#xff0…

红 黑 树

文章目录 一、红黑树的概念二、红黑树的实现1. 红黑树的存储结构2. 红黑树的插入 一、红黑树的概念 在 AVL 树中删除一个结点,旋转可能要持续到根结点,此时效率较低 红黑树也是一种二叉搜索树,通过在每个结点中增加一个位置来存储红色或黑色…

软件测试缺陷报告详解

【软件测试行业现状】2023年了你还敢学软件测试?未来已寄..测试人该何去何从?【自动化测试、测试开发、性能测试】 缺陷报告是描述软件缺陷现象和重现步骤地集合。软件缺陷报告Software Bug Report(SBR)或软件问题报告Software Pr…

【开发篇】二、属性绑定与校验

文章目录 1、ConfigurationProperties自定义Bean属性绑定2、EnableConfigurationProperties注解3、ConfigurationProperties第三方Bean属性绑定4、松散绑定5、常用计量单位6、数据校验7、yaml绑定值的坑--关于进制 1、ConfigurationProperties自定义Bean属性绑定 前面读取yaml…

gateway之过滤器(Filter)详解

文章目录 什么是过滤器过滤器的种类局部过滤器代码示例全局过滤器代码示例 总结 什么是过滤器 在Spring Cloud中,过滤器(Filter)是一种关键的组件,用于在微服务架构中处理和转换传入请求以及传出响应。过滤器位于服务网关或代理中…

CRM客户管理系统主要用途

对于大多数企业而言业绩就是生命线,因此销售环节在企业管理过程中意义重大。面对愈发内卷的市场竞争企业就要借助CRM销售管理系统改善各个环节存在的漏洞,占据优势。那么,销售管理系统的用途有哪些,接下来我们从下面3个功能来介绍…