代码解读 | Hybrid Transformers for Music Source Separation[05]

news2025/1/8 12:45:15


        0、Hybrid Transformer 论文解读

        1、代码复现|Demucs Music Source Separation_demucs架构原理-CSDN博客

        2、Hybrid Transformer 各个模块对应的代码具体在工程的哪个地方

        3、Hybrid Transformer 各个模块的底层到底是个啥(初步感受)?

        4、Hybrid Transformer 各个模块处理后,数据的维度大小是咋变换的?

        5、Hybrid Transformer 拆解STFT模块

        从模块上划分,Hybrid Transformer Demucs 共包含 (STFT模块、时域编码模块、频域编码模块、Cross-Domain Transformer Encoder模块、时域解码模块、频域解码模块、ISTFT模块)7个模块。




class HEncLayer(nn.Module):
    def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
                 freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
        """Encoder layer. This used both by the time and the frequency branch.

            chin: number of input channels.
            chout: number of output channels.
            norm_groups: number of groups for group norm.
            empty: used to make a layer with just the first conv. this is used
                before merging the time and freq. branches.
            freq: this is acting on frequencies.
            dconv: insert DConv residual branches.
            norm: use GroupNorm.
            context: context size for the 1x1 conv.
            dconv_kw: list of kwargs for the DConv class.
            pad: pad the input. Padding is done so that the output size is
                always the input size / stride.
            rewrite: add 1x1 conv at the end of the layer.
        norm_fn = lambda d: nn.Identity()  # noqa
        if norm:
            norm_fn = lambda d: nn.GroupNorm(norm_groups, d)  # noqa
        if pad:
            pad = kernel_size // 4
            pad = 0
        klass = nn.Conv1d
        self.freq = freq
        self.kernel_size = kernel_size
        self.stride = stride
        self.empty = empty
        self.norm = norm
        self.pad = pad
        if freq:
            kernel_size = [kernel_size, 1]
            stride = [stride, 1]
            pad = [pad, 0]
            klass = nn.Conv2d
        self.conv = klass(chin, chout, kernel_size, stride, pad)
        if self.empty:
        self.norm1 = norm_fn(chout)
        self.rewrite = None
        if rewrite:
            self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
            self.norm2 = norm_fn(2 * chout)

        self.dconv = None
        if dconv:
            self.dconv = DConv(chout, **dconv_kw)

    def forward(self, x, inject=None):
        `inject` is used to inject the result from the time branch into the frequency branch,
        when both have the same stride.
        if not self.freq and x.dim() == 4:
            B, C, Fr, T = x.shape
            x = x.view(B, -1, T)

        if not self.freq:
            le = x.shape[-1]
            if not le % self.stride == 0:
                x = F.pad(x, (0, self.stride - (le % self.stride)))
        y = self.conv(x)
        if self.empty:
            return y
        if inject is not None:
            assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
            if inject.dim() == 3 and y.dim() == 4:
                inject = inject[:, :, None]
            y = y + inject
        y = F.gelu(self.norm1(y))
        if self.dconv:
            if self.freq:
                B, C, Fr, T = y.shape
                y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
            y = self.dconv(y)
            if self.freq:
                y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
        if self.rewrite:
            z = self.norm2(self.rewrite(y))
            z = F.glu(z, dim=1)
            z = y
        return z



        编码层:Conv2d+Norm1+GELU,  Norm1:Identity()

        残差连接:(Conv1d+GroupNorm+GELU +Conv1d+GroupNorm+GLU+LayerScale())

        +(Conv2d+Norm2+GLU),Norm2:Identity() ,备注:Identity可以理解成直通

Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(48, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(96, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(192, 384, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 6, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 96, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    (1): Sequential(
      (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 6, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 96, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))

  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 12, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 192, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    (1): Sequential(
      (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 12, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 192, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))

  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 24, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 384, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    (1): Sequential(
      (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 24, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 384, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))

  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 48, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 768, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    (1): Sequential(
      (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 48, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 768, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))


        没有所谓天生的大佬,如果有那么我愿称他/她为圣人。我相信,能读到这儿的都会成为大佬~。Believe yourself,one day,you will be somebody.








