一、背景
今天对Hybrid Transformer Demucs代码进行解读,目标:对图c中各个模块对应的代码进行解读(时域编码、频域编码、时域解码、频域解码、Cross-Domain Transformer Encoder模块、STFT模块、ISTFT模块)。解读的代码是开源工程中的htdemucs.py、hdemucs.py、 transformer.py、spec.py。
具体的paper解读看这篇文章。
关于Hybrid Transformer Demucs算法各个模块对应的代码具体在工程的哪个地方看这篇文章。
二、解读
2.1 时域编码和频域编码
时域编码和频域编码复用同一个类,HEncLayer。(要像看图文博客一样看代码,先整体再局部,主观感受就会好很多)。此部分对应工程中的hdemucs.py文件。
总体上来看,时域和频域的编码模块主要由二维卷积【self.conv(x)】+ 归一化【self.norm1】+激活函数构成【F.gelu()】构成。
2.2 Cross-Domain Transformer Encoder模块
在解读Cross-Domain Transformer Encoder模块前,先贴论文原理图。此部分对应工程中的transformer.py文件。
可以看到时域和频域进行了相同的位置编码和归一化处理,只是在细节处略有差异。
Transformer Encoder还是哪个大家常见的形式:多头注意力+前馈。
2.3 时域解码和频域解码
总的来说,解码模块和编码模块类似,主要由二维转置卷积【self.conv_tr()】+归一化【 self.norm2()】组成。
2.4 STFT和ISTFT模块
此部分对应工程中的spec.py文件。
STFT模块最底层是调用了torch.stft模块。
ISTFT模块最底层是调用了torch.istft模块。
感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)