1. 论文基本信息
发布于: TMLR 2022
2. 创新点
- 使用元学习将编码时间减少了两个数量级以上,将编码共享结构进行编码,并对该网络应用调制来编码实例特定信息。
- 量化和熵编码调制。虽然我们的方法在压缩和速度方面都大大超过了 COIN,但它仅部分缩小了与 SOTA 编解码器在经过充分研究的模式(例如图像)上的差距。然而,COIN++ 适用于传统方法难以使用的广泛数据模式,使其成为非标准域中神经压缩的一种有前途的工具。
3. 背景
COIN++通过优化将广泛的数据模式转换为神经网络,然后将这些神经网络的参数存储为数据的压缩代码。可以通过简单地改变神经网络的输入和输出维度来压缩不同的数据模式。
确定了 COIN 的以下问题:
- 编码很慢:压缩单个图像可能需要长达一个小时
- 缺乏共享结构:由于每个图像都是独立压缩的,网络之间没有共享信息
- 性能远低于最先进的 (SOTA) 图像编解码器。
4. Pipeline
数据表示
数据以坐标集(X)和特征集(Y)的形式表示。例如:
- 图像:二维平面中的像素位置 (x, y) 及其RGB颜色值 (r, g, b) 。
- MRI扫描:三维空间中的位置 (x, y, z) 和强度值。
每个数据点是坐标和特征对的集合,表示为
。
4.1. 元学习的应用
- 基础网络与调制机制:在这种方法中,首先训练一个所有数据共享的基础网络,通常是一个多层感知机(MLP)。这个基础网络不直接对每个独立的数据进行特定的学习,而是提供了一个通用的特征提取框架。
- FiLM层的应用:为了使基础网络能够适应每个具体的数据点,引入了 FiLM(Feature-wise Linear Modulation,特征线性调制)层。FiLM 层的核心思想是通过调制操作来修改网络中的隐藏特征。具体来说,FiLM 层会对隐藏层特征 h∈Rd 应用逐元素(elementwise)的缩放(scales, γ∈Rd)和平移(shifts, β∈Rd)操作:FiLM(h)=γ⋅h+β这里的 γ 和 β 是调制参数,它们可以根据每个数据点的具体需求进行调整,从而使得基础网络能够灵活地适应不同的数据特性。
COIN++架构。潜在调制 φ(绿色)通过超网络映射到调制(蓝色),这些调制被添加到基础网络 fθ(白色)的激活中,以参数化可以在坐标 x 处评估的单个函数以获得特征 y。
在 COIN++ 架构中,不同的图像或数据点对应的调制参数 φ 各不相同,这是因为每个数据或图像可能有其独特的结构和特征需求,通过使用不同的调制参数,可以使得同一个基础网络 fθ 能够适应和生成这些不同的数据。此外,这种方法允许我们在基础网络中存储共享信息,并在调制中存储实例特定信息。例如,对于自然图像,基础网络编码自然图像中常见的结构,而调制存储重建单个图像所需的信息。
4.2. 元学习调制
- 问题描述:
- 传统的COIN方法需要在训练时学习模型的权重(θ),这是一个耗时的过程。
- 通过COIN的编码方法(公式6),对单个数据点进行编码也很慢,需要进行许多次梯度下降迭代。
- 解决方案 - MAML:
- MAML是一种元学习方法,它的目标是找到一个好的模型初始化参数θ∗,这样我们可以很快地在新数据上进行少量梯度下降迭代来学习。
- COIN++的元学习应用:
- 内循环(Inner Loop):对于单个数据点,我们使用它来快速更新调制参数φ,保持θ不变。
- 外循环(Outer Loop):在整个数据集上,我们同时更新θ和φ。
- 优点:
- 快速适应:COIN++可以快速地适应新的数据点,只需几次梯度下降迭代。
- 共享参数:COIN++只需要存储每个数据点的调制参数φ,而不是整个模型的权重,这减少了存储和计算成本。
- 简化初始化:只需要元学习共享参数θ的初始化,而不需要单独元学习φ。
4.3. 针对调制的量化和熵编码
(左)从随机初始化θ开始,元学习基础网络的参数θ *(训练进度显示为实线),这样调制φ可以很容易地适应几个梯度步骤(拟合如虚线所示)。(右)在训练期间,随机抽取补丁,而在测试时,将数据点划分为补丁并将调制拟合到每个补丁中。
5. 💎实验成果展示
(左)CIFAR10 上的速率失真图。 COIN++ 优于 COIN、JPEG 和 JPEG2000,同时部分缩小了与最先进的编解码器的差距。(右)具有相似重建质量的模型压缩伪影的定性比较。COIN++在3.29 bpp处达到32.4dB,而BPG在1.88 bppp处达到31.9dB。
(左)柯达的速率失真图。虽然 COIN++ 的性能略好于 COIN,但补丁的使用降低了压缩性能。(右)柯达上的 COIN++ 压缩伪影。
6. 源码环境配置:
源码地址:GitHub - EmilienDupont/coinpp: Pytorch implementation of COIN++ 🍁