代码解读 | Hybrid Transformers for Music Source Separation[05]

news2025/1/8 12:45:15

一、背景

        0、Hybrid Transformer 论文解读

        1、代码复现|Demucs Music Source Separation_demucs架构原理-CSDN博客

        2、Hybrid Transformer 各个模块对应的代码具体在工程的哪个地方

        3、Hybrid Transformer 各个模块的底层到底是个啥(初步感受)?

        4、Hybrid Transformer 各个模块处理后,数据的维度大小是咋变换的?

        5、Hybrid Transformer 拆解STFT模块


        从模块上划分,Hybrid Transformer Demucs 共包含 (STFT模块、时域编码模块、频域编码模块、Cross-Domain Transformer Encoder模块、时域解码模块、频域解码模块、ISTFT模块)7个模块。

        本篇目标:拆解频域编码模块的底层

        时域编码和频域编码原理类似(后续不再拆解时域编码模块)。

二、频域编码模块


class HEncLayer(nn.Module):
    def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
                 freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
                 rewrite=True):
        """Encoder layer. This used both by the time and the frequency branch.

        Args:
            chin: number of input channels.
            chout: number of output channels.
            norm_groups: number of groups for group norm.
            empty: used to make a layer with just the first conv. this is used
                before merging the time and freq. branches.
            freq: this is acting on frequencies.
            dconv: insert DConv residual branches.
            norm: use GroupNorm.
            context: context size for the 1x1 conv.
            dconv_kw: list of kwargs for the DConv class.
            pad: pad the input. Padding is done so that the output size is
                always the input size / stride.
            rewrite: add 1x1 conv at the end of the layer.
        """
        super().__init__()
        norm_fn = lambda d: nn.Identity()  # noqa
        if norm:
            norm_fn = lambda d: nn.GroupNorm(norm_groups, d)  # noqa
        if pad:
            pad = kernel_size // 4
        else:
            pad = 0
        klass = nn.Conv1d
        self.freq = freq
        self.kernel_size = kernel_size
        self.stride = stride
        self.empty = empty
        self.norm = norm
        self.pad = pad
        if freq:
            kernel_size = [kernel_size, 1]
            stride = [stride, 1]
            pad = [pad, 0]
            klass = nn.Conv2d
        self.conv = klass(chin, chout, kernel_size, stride, pad)
        if self.empty:
            return
        self.norm1 = norm_fn(chout)
        self.rewrite = None
        if rewrite:
            self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
            self.norm2 = norm_fn(2 * chout)

        self.dconv = None
        if dconv:
            self.dconv = DConv(chout, **dconv_kw)

    def forward(self, x, inject=None):
        """
        `inject` is used to inject the result from the time branch into the frequency branch,
        when both have the same stride.
        """
        if not self.freq and x.dim() == 4:
            B, C, Fr, T = x.shape
            x = x.view(B, -1, T)

        if not self.freq:
            le = x.shape[-1]
            if not le % self.stride == 0:
                x = F.pad(x, (0, self.stride - (le % self.stride)))
        y = self.conv(x)
        if self.empty:
            return y
        if inject is not None:
            assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
            if inject.dim() == 3 and y.dim() == 4:
                inject = inject[:, :, None]
            y = y + inject
        y = F.gelu(self.norm1(y))
        if self.dconv:
            if self.freq:
                B, C, Fr, T = y.shape
                y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
            y = self.dconv(y)
            if self.freq:
                y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
        if self.rewrite:
            z = self.norm2(self.rewrite(y))
            z = F.glu(z, dim=1)
        else:
            z = y
        return z

        核心代码如上所示。

        使用print函数打印出各个关键节点的信息,可以得到频域编解码模块的全景图。

        编码层:Conv2d+Norm1+GELU,  Norm1:Identity()

        残差连接:(Conv1d+GroupNorm+GELU +Conv1d+GroupNorm+GLU+LayerScale())

        +(Conv2d+Norm2+GLU),Norm2:Identity() ,备注:Identity可以理解成直通

#上图均是自己读完代码绘制的。相信自己也可以。
#具体的,编码层1-4的Conv2d分别是:
Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(48, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(96, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(192, 384, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
#残差连接1
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 6, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 96, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 6, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 96, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))

#残差连接2
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 12, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 192, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 12, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 192, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))

#残差连接3
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 24, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 384, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 24, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 384, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))

#残差连接4
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 48, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 768, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 48, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 768, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))

        关于,各个卷积模块输出数据的shape计算,可以读这篇文章。

        没有所谓天生的大佬,如果有那么我愿称他/她为圣人。我相信,能读到这儿的都会成为大佬~。Believe yourself,one day,you will be somebody.


         感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1820656.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

工厂环境中ESD防静电系统对静电灾害的预防与控制

静电在工厂环境中可能造成严重的危害,包括火灾、爆炸和设备损坏等。因此,对于工厂环境中的静电灾害,采取预防和控制措施是非常必要的。ESD防静电系统是一种用来预防和控制静电灾害的重要解决方案,它可以有效地降低静电危害发生的可…

MCGS仿真教学1:单个变量与博途进行通讯

目录 一、博途配置1.1、博途通讯常用配置1.2、博途测试程序 二、MCGS配置2.1、工程配置2.2、设备组态2.3、添加单个变量2.4、添加画面2.4.1、按钮2.4.2、指示灯 三、下载测试 一、博途配置 1.1、博途通讯常用配置 1.2、博途测试程序 二、MCGS配置 2.1、工程配置 选择自己所购…

ubuntu18.04离线源制作

给客户部署有时需要纯内网环境,那这样就连不了网络。 一些包就下载不下来,而大家都知道用deb离线安装是非常麻烦的,各种依赖让你装不出来。 这里教大家打包源。 我准备2台机器,42和41 42可以联网,41不能联网。我想在…

python的a[:2]、a[:] 和a [::] 的区别

一、a[:2] 数据准备 import numpy as np X np.array([[0,1],[2,3],[4,5],[6,7],[8,9],[10,11],[12,13],[14,15],[16,17],[18,19]]) print(X)形成矩阵 print (“X[: 2]:”, X[: 2]) ### :表示索引 0至1行; 二、a[:]和a [::] 在 Python 中,[:] 和 [::…

修改版的VectorDBBench更好用

原版本VectorDBBench的几个问题 在这里就不介绍VectorDBBench是干什么的了,上官网即可。 1.并发数设置的太少 2.测试时长30秒太长 3.连接milvus无用户和密码框,这个是最大的问题 4.修改了一下其它参数 由于很多网友发私信问一些milvus的相关技术问…

钓鱼小助手 —— 借助文心智能体平台打造钓鱼佬神器

前言 🚀前方高能 🚀钓鱼小助手上线了 有没有喜欢钓鱼的程序猿呀,福利来了,特意整了一个钓鱼的小助手,以后钓鱼小技巧都可以咨询了。 走过路过,不要错过,快快体验吧~ 🚀&#x1…

【Three.js】知识梳理二十三:Three.js与其他WebGL库与框架对比

在WebGL开发中,Three.js是一个非常流行的库,它简化了3D图形的创建和渲染过程。然而,市场上还有许多其他的WebGL库,如 Babylon.js、PlayCanvas、PIXI.js 和 Cesium,它们也有各自的特点和优势。本文将对Three.js 与这些常…

翻译《The Old New Thing》- The case of the exception that a catch (…) didn’t catch

The case of the exception that a catch (...) didnt catch - The Old New Thing (microsoft.com)https://devblogs.microsoft.com/oldnewthing/20240405-00/?p109621 Raymond Chen 2024年04月05日 一位客户认为他们修复了一个bug,但他们仍然因为这个bug而崩溃。…

LogicFlow 学习笔记——4. LogicFlow 基础 边 Edge

边 Edge 和节点一样&#xff0c;LogicFlow 也内置一些基础的边。LogicFlow 的内置边包括&#xff1a; 直线 - line直角折现 - polyline贝塞尔曲线 - bezier 新建 src/views/Example/LogicFlow/Example08.vue 并编写如下代码&#xff1a; <script setup lang"ts&quo…

c#中上传超过30mb的文件,接口一直报404,小于30mb的却可以上传成功

在一次前端实现上传视频文件时,超过30mb的文件上传,访问接口一直报404,但是在Swagger中直接访问接口确是正常的,且在后端控制器中添加了限制特性,如下 但是却仍然报404,在apifox中请求接口也是报404, 网上说: 在ASP.NET Core中,配置请求过来的文件上传的大小限制通常…

FPGA - 全局时钟资源

全局时钟资源是指FPGA内部为实现系统时钟到达FPGA内部各 CLB、IOB&#xff0c;以及BSRAM&#xff08;Block Select RAM&#xff0c;选择性BRAM&#xff09;等基本逻辑单元的延时和抖动最小化&#xff0c;采用全铜层工艺设计和实现的专用缓冲与驱动结构。 由于全局时钟资源的布线…

电子制造业数字化整体解决方案

电子制造行业有特殊的着重点&#xff1a; 高精度要求&#xff1a;电子制造需要极高的精度和质量控制&#xff0c;因为电子组件和电路板的尺寸通常非常小&#xff0c;且对错误和缺陷非常敏感。 快速技术迭代&#xff1a;电子行业的技术迅速发展&#xff0c;产品生命周期短&…

RabbitMQ概述

RabbitMQ RabbitMQ概述 RabbitMQ是一个开源的消息代理&#xff08;message broker&#xff09;系统&#xff0c;最初由Rabbit Technologies Ltd开发&#xff0c;并在开源社区的支持下不断发展和完善。它提供了强大的消息传递机制&#xff0c;被广泛应用于构建分布式系统和应用…

vue3 proxy对象转为原始对象

https://cn.vuejs.org/api/reactivity-advanced.html#toraw import { toRaw } from "vue";const foo {} const reactiveFoo reactive(foo)console.log(toRaw(reactiveFoo) foo) // true 人工智能学习网站 https://chat.xutongbao.top

基于LangChain-Chatchat实现的RAG-本地知识库的问答应用[1]-最新版快速实践并部署(检索增强生成RAG大模型)

基于LangChain-Chatchat实现的RAG-本地知识库的问答应用[1]-最新版快速实践并部署(检索增强生成RAG大模型) 基于 ChatGLM 等大语言模型与 Langchain 等应用框架实现,开源、可离线部署的检索增强生成(RAG)大模型知识库项目。 1.介绍 一种利用 langchain思想实现的基于本地知…

『原型资源』Axure自带图标库不够用,第三方经典图标库来袭

​今天小编为大家带来第三方经典图标库&#xff0c;己确认内容可用现推荐给大家。直接上手就可不用自己画哈~ 获取原型文档请与班主任联系&#xff01; 先睹为快&#xff0c;合适再拿走不谢&#xff1a; 图标太多&#xff0c;截取部分给大家参考o(*&#xffe3;︶&#xffe3;*…

毕业了!给学计算机朋友的 10 条血泪建议

大家好&#xff0c;我是程序员鱼皮。最近高考结束了&#xff0c;也有很多同学毕业了&#xff0c;首先祝福这些朋友在人生的新阶段一帆风顺。 刚参加完高考的朋友&#xff0c;面临的最大问题就是选专业&#xff0c;这段时间也有一些家长向我咨询&#xff1a;还能不能选计算机啦…

2024 年 19 种最佳大型语言模型

大型语言模型是 2023 年生成式人工智能热潮背后的推动力。然而&#xff0c;它们已经存在了一段时间了。 LLM是黑盒 AI 系统&#xff0c;它使用深度学习对超大数据集进行处理&#xff0c;以理解和生成新文本。现代 LLM 开始成型于 2014 年&#xff0c;当时一篇题为“通过联合学…

鸿蒙轻内核A核源码分析系列七 进程管理 (1)

本文开始继续分析OpenHarmony LiteOS-A内核的源代码&#xff0c;接下来会分析进程和任务管理模块。本文中所涉及的源码&#xff0c;以OpenHarmony LiteOS-A内核为例&#xff0c;均可以在开源站点 https://gitee.com/openharmony/kernel_liteos_a 获取。如果涉及开发板&#xff…

排名前五的 Android 数据恢复软件

正在寻找数据恢复软件来从 Android 设备恢复数据&#xff1f;本指南将为您提供 5 款最佳 Android 数据恢复软件。浏览这些软件&#xff0c;然后选择您喜欢的一款来恢复 Android 数据。 ndroid 设备上的数据丢失可能是一种令人沮丧的经历&#xff0c;无论是由于意外删除、系统崩…