(表征学习论文阅读)FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE

news2024/11/18 18:48:35

1. 前言

向量量化(Vector Quantization)或称为矢量量化最早在1984年由Gray提出,主要应用于数据压缩、检索领域,具体的阐述可以参考我写的另一篇关于VQ算法的文章。随着基于神经网络的离散表征学习模型的兴起,VQ技术也开始重新被重视。它在图像、音频等表征学习中体现出了优秀的性能,并且有希望成为多模态大语言模型的重要组件。

在AI领域,最为知名应该是VQ-VAE(Vector Quantized-Variational Autoencoder)了,它的思想是将图像 x x x映射为表征 z k × d z^{k \times d} zk×d,其中 z k × d z^{k \times d} zk×d由一组维度为 d d d的特征向量构成,VQ-VAE引入了一个codebook记为 C n × d C^{n \times d} Cn×d z k × d z^{k \times d} zk×d会和 C n × d C^{n \times d} Cn×d中的向量进行距离计算,可以是欧式距离也可以是余弦相似度,用 C n × d C^{n \times d} Cn×d中距离最近或者最相似的向量来表示 z k × d z^{k \times d} zk×d中的向量。这种量化操作往往不可微,因此VQ-VAE使用了一个非常简单的技巧straight through estimator (STE)来解决,具体的实现可以看代码。

VQ-VAE的损失函数主要由三个部分组成,以确保模型能够有效地学习到有用的离散表征,并同时保持输入数据的重建质量:
L = L recon + α L quant + β L commit L = L_{\text{recon}} + \alpha L_{\text{quant}} + \beta L_{\text{commit}} L=Lrecon+αLquant+βLcommit

  • 重建损失(Reconstruction
    Loss):这部分的损失计算了模型重建的输出与原始输入之间的差异。目标是最小化这一差异,以确保重建的数据尽可能接近原数据。常见的重建损失包括均方误差(MSE)或交叉熵损失,具体取决于输入数据的类型。
  • 量化损失(Quantization Loss)或 码本损失(Codebook Loss):在训练过程中,当输入数据通过编码器被编码到潜在空间后,每个潜在表示会被量化为最近的码本向量。量化损失计算潜在表示与其对应的最近码本向量之间的距离。通过最小化量化损失,模型优化码本向量的位置,使其更好地代表输入数据的潜在表示。这有助于模型更准确地量化潜在空间,并提高重建质量。
  • 提交损失(Commitment Loss):提交损失主要用于稳定训练过程,它鼓励编码器生成的潜在表示靠近选中的码本向量。这样做可以防止码本向量在训练过程中出现较大的变动,从而确保模型的稳定性。提交损失通过计算编码器输出的潜在表示与选中的码本向量之间的距离来实现其目标。因此,提交损失主要影响编码器的参数更新,帮助编码器学习生成与码本向量更接近的潜在表示。

虽然VQ-VAE的效果比传统的VAE要好,但是它使用的codebook中的大部分向量并未被利用到,造成了存储和计算的大量浪费,此外,它额外引入的两项损失即codebook loss和commitment loss也带来些许复杂性。

FSQ(FINITE SCALAR QUANTIZATION: VQ-VAE MADE SIMPLE)这篇文章的目的就是优化以上两个问题。

2. 方法

作者发现,传统的编码器所得到的表征向量 z z z中的每一个元素(标量)的值并没有一个明确的边界,也就是说 z z z在特征空间中不受任何约束。那么,作者就想到了为 z z z中的每个标量都设定好取值的范围和能够取值的个数。
在这里插入图片描述
假设有一个d维特征向量 z z z,将每个标量 z i z_i zi都限制只能取 L L L个值,将 z i → ⌊ L / 2 ⌋ t a n h ( z i ) z_i \rightarrow \left\lfloor L/2 \right\rfloor tanh(z_i) ziL/2tanh(zi)然后四舍五入为一个整数值。例如图中所示,取d=3,L=3,代表codebook C = { ( − 1 , − 1 , − 1 ) , ( − 1 , − 1 , 0 ) , . . . , ( 1 , 1 , 1 ) } C=\left\{(-1, -1, -1), (-1, -1, 0), ..., (1, 1, 1)\right\} C={(1,1,1),(1,1,0),...,(1,1,1)},一共有27种组合,即一个3维向量的每个标量都有三种值的取法。值得一提的是,FSQ中的codebook不像VQ-VAE那样是显式存在的,而是隐式的,编码器直接输出量化后的特征向量 z ^ \hat{z} z^。因此,FSQ也就没有了VQ-VAE损失的后两项了。
在这里插入图片描述

3. 代码实现

from typing import List, Tuple, Optional
import torch
import torch.nn as nn
from torch.nn import Module
from torch import Tensor, int32
from torch.cuda.amp import autocast

from einops import rearrange, pack, unpack

# helper functions

def exists(v):
    return v is not None

def default(*args):
    for arg in args:
        if exists(arg):
            return arg
    return None

def pack_one(t, pattern):
    return pack([t], pattern)

def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# tensor helpers

def round_ste(z: Tensor) -> Tensor:
    """Round with straight through gradients."""
    zhat = z.round()  # round操作是将z中的元素四舍五入到最接近的整数
    return z + (zhat - z).detach()

class FSQ(Module):
    def __init__(
            self,
            levels: List[int],
            dim: Optional[int] = None,
            num_codebooks=1,
            keep_num_codebooks_dim: Optional[bool] = None,
            scale: Optional[float] = None,
            allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64)
    ):
        super().__init__()
        _levels = torch.tensor(levels, dtype=int32)
        self.register_buffer("_levels", _levels, persistent=False)  #persistent=False表示不会被保存到checkpoint中

        _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
        self.register_buffer("_basis", _basis, persistent=False)

        self.scale = scale

        codebook_dim = len(levels)  # codebook_dim表示每个codebook的维度
        self.codebook_dim = codebook_dim

        effective_codebook_dim = codebook_dim * num_codebooks  # effective_codebook_dim表示所有codebook的维度的总和
        self.num_codebooks = num_codebooks
        self.effective_codebook_dim = effective_codebook_dim

        keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
        assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
        self.keep_num_codebooks_dim = keep_num_codebooks_dim

        self.dim = default(dim, len(_levels) * num_codebooks)

        has_projections = self.dim != effective_codebook_dim
        self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
        self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
        self.has_projections = has_projections

        self.codebook_size = self._levels.prod().item()

        implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
        self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)

        self.allowed_dtypes = allowed_dtypes

    def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
        """Bound `z`, an array of shape (..., d)."""
        half_l = (self._levels - 1) * (1 + eps) / 2
        offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
        shift = (offset / half_l).atanh()  # atanh是双曲正切函数的反函数,能够将值映射到[-1, 1]之间
        return (z + shift).tanh() * half_l - offset

    def quantize(self, z: Tensor) -> Tensor:
        """Quantizes z, returns quantized zhat, same shape as z."""
        quantized = round_ste(self.bound(z))
        half_width = self._levels // 2  # Renormalize to [-1, 1].
        return quantized / half_width

    def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
        # 将zhat_normalized的值映射到[0, levels]之间
        half_width = self._levels // 2
        return (zhat_normalized * half_width) + half_width

    def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
        half_width = self._levels // 2
        return (zhat - half_width) / half_width

    def codes_to_indices(self, zhat: Tensor) -> Tensor:
        """Converts a `code` to an index in the codebook."""
        assert zhat.shape[-1] == self.codebook_dim
        zhat = self._scale_and_shift(zhat)
        return (zhat * self._basis).sum(dim=-1).to(int32)

    def indices_to_codes(
            self,
            indices: Tensor,
            project_out=True
    ) -> Tensor:
        """Inverse of `codes_to_indices`."""

        is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))

        indices = rearrange(indices, '... -> ... 1')
        codes_non_centered = (indices // self._basis) % self._levels
        codes = self._scale_and_shift_inverse(codes_non_centered)

        if self.keep_num_codebooks_dim:
            codes = rearrange(codes, '... c d -> ... (c d)')

        if project_out:
            codes = self.project_out(codes)

        if is_img_or_video:
            codes = rearrange(codes, 'b ... d -> b d ...')

        return codes

    @autocast(enabled=False)
    def forward(self, z: Tensor) -> Tensor:
        """
        einstein notation
        b - batch
        n - sequence (or flattened spatial dimensions)
        d - feature dimension
        c - number of codebook dim
        """

        orig_dtype = z.dtype
        is_img_or_video = z.ndim >= 4

        # make sure allowed dtype

        if z.dtype not in self.allowed_dtypes:
            z = z.float()

        # standardize image or video into (batch, seq, dimension)

        if is_img_or_video:
            # 将图片和视频的空间、时间维度展平
            z = rearrange(z, 'b d ... -> b ... d')
            z, ps = pack_one(z, 'b * d')

        assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'

        z = self.project_in(z)

        z = rearrange(z, 'b n (c d) -> b n c d', c=self.num_codebooks)

        codes = self.quantize(z)
        print(f"codes: {codes}")
        indices = self.codes_to_indices(codes)

        codes = rearrange(codes, 'b n c d -> b n (c d)')

        out = self.project_out(codes)

        # reconstitute image or video dimensions

        if is_img_or_video:
            out = unpack_one(out, ps, 'b * d')
            out = rearrange(out, 'b ... d -> b d ...')

            indices = unpack_one(indices, ps, 'b * c')

        if not self.keep_num_codebooks_dim:
            indices = rearrange(indices, '... 1 -> ...')

        # cast back to original dtype

        if out.dtype != orig_dtype:
            out = out.type(orig_dtype)

        # return quantized output and indices

        return out, indices

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1567438.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Cisco ACI Simulator 6.0(5h) - ACI 模拟器

Cisco ACI Simulator 6.0(5h) - ACI 模拟器 Application Centric Infrastructure (ACI) Simulator Software 请访问原文链接:https://sysin.org/blog/cisco-acisim-6/,查看最新版。原创作品,转载请保留出处。 作者主页:sysin.o…

最新梨花带雨网页音乐播放器

源码简介 最新梨花带雨网页音乐播放器二开优化修复美化版全开源版本源码下载 梨花带雨播放器基于thinkphp6开发的XPlayerHTML5网页播放器前台控制面板,支持多音乐平台音乐解析。二开内容:修复播放器接口问题,把接口本地化,但是集成外链播放…

AcWing刷题-约数个数

约数的个数 代码 # 计数 def f(x)->int:cnt 0i 1while i * i < x:if x % i 0:cnt 1if i * i < x:cnt 1i 1return cntn int(input()) a list(map(int,input().split())) for i in a:print(f(i))

K8s Deployment 滚动更新、金丝雀发布、自定义钩子、生命周期解析

&#x1f407;明明跟你说过&#xff1a;个人主页 &#x1f3c5;个人专栏&#xff1a;《Kubernetes航线图&#xff1a;从船长到K8s掌舵者》 &#x1f3c5; &#x1f516;行路有良友&#xff0c;便是天堂&#x1f516; 目录 一、Deployment的高级特性 1、滚动更新 2、金丝雀…

Emacs之解除comment-region绑定C-c C-c快捷键(一百三十四)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

对一个时间序列中的每个元素按照指定精度位置四舍五入

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 对一个时间序列中的每个元素 按照指定精度位置四舍五入 Series.dt.round() 选择题 以下代码的输出结果中正确的是? import pandas as pd ts pd.Series(pd.date_range("2024-04-04 08:…

16 RGB-LCD 彩条显示

RGB TFT-LCD 简介 TFT-LCD 的全称是 Thin Film Transistor-Liquid Crystal Display&#xff0c;即薄膜晶体管液晶显示屏&#xff0c;它显示的每个像素点都是由集成在液晶后面的薄膜晶体管独立驱动&#xff0c;因此 TFT-LCD 具有较高的响应速度以及较好的图像质量。液晶显示器是…

使用pip安装geopandas(24.4更新)

geopandas是我们用Python进行地理分析常用的库&#xff0c;在数据处理、分析、制图等场景中有着极为广泛的应用&#xff0c;但是在安装过程中会出现各种问题。​geopandas的安装方式有很多&#xff0c;今天我们选取较为简单的pip来进行geopandas的安装。 ​首先&#xff0c;我…

动规训练2

一、最小路径和 1、题目解析 就是一个人从左上往做下走&#xff0c;每次只能往右或者往下&#xff0c;求他到终点时&#xff0c;路径上数字和最小&#xff0c;返回最小值 2、算法原理 a状态表示方程 小技巧&#xff1a;经验题目要求 用一个二维数组表示&#xff0c;创建一个…

【WEEK6】 【DAY3】MySQL函数【中文版】

2024.4.3 Wednesday 目录 5.MySQL函数5.1.常用函数5.1.1.数据函数5.1.2.字符串函数5.1.2.1.CHAR_LENGTH(str)计算字符串str长度5.1.2.2.CONCAT(str1,str2,...)拼接字符串str1 str2 ...5.1.2.3.INSERT(str,pos,len,newstr)把原文str第pos位开始长度为len的字符串替换成newstr5.…

vue3数据库中存头像图片相对路径在前端用prop只能显示路径或无法显示图片只能显示alt中内容的问题的解决

不想看前情可以直接跳到头像部分代码 前情&#xff1a; 首先我们是在数据库中存图片相对路径&#xff0c;这里我们是在vue的src下的assets专门建一个文件夹img存头像图片。 然后我们如果用prop"avatar" label"头像"是只能显示图片路径的&#xff0c;即lo…

CEF的了解

(14 封私信 / 80 条消息) CEF和Electron的区别是什么&#xff1f; - 知乎 (zhihu.com) Electron面向的开发者&#xff1a;会用JavaScript,HTML,CSS&#xff0c;不会C CEF面向的开发者&#xff1a;会用JavaScript,HTML,CSS&#xff0c;会C (14 封私信 / 80 条消息) liulun - …

代码随想录Day28:回溯算法Part4

Leetcode 93. 复原IP地址 讲解前&#xff1a; 这道题其实在做完切割回文串之后&#xff0c;学会了使用切割的方法来找到字符串的possible 子串之后&#xff0c;思路就会很快找到&#xff0c;细想一下其实无非也就是对given string然后进行切割&#xff0c;只是深度是固定的因…

【数据结构与算法】二叉搜索树和平衡二叉树

二叉搜索树 左子树的结点都比当前结点小&#xff0c;右子树的结点都比当前结点大。 构造二叉搜索树&#xff1a; let arr [3, 4, 7, 5, 2]function Node(value) {this.value valuethis.left nullthis.right null }/*** 添加结点* param root 当前结点* param num 新的结…

50道Java经典面试题总结

1、那么请谈谈 AQS 框架是怎么回事儿&#xff1f; &#xff08;1&#xff09;AQS 是 AbstractQueuedSynchronizer 的缩写&#xff0c;它提供了一个 FIFO 队列&#xff0c;可以看成是一个实现同步锁的核心组件。 AQS 是一个抽象类&#xff0c;主要通过继承的方式来使用&#x…

AI绘图:Stable Diffusion WEB UI 详细操作介绍:基础篇

接上一篇《AI绘图体验&#xff1a;Stable Diffusion本地化部署详细步骤》本地部署完了SD后&#xff0c;大家肯定想知道怎么用&#xff0c;接下来补一篇Stable Diffusion WEB UI 详细操作&#xff0c;如果大家还没有完成SD的部署&#xff0c;请参考上一篇文章进行本地化的部署。…

抽象类与接口(3)(接口部分)

❤️❤️前言~&#x1f973;&#x1f389;&#x1f389;&#x1f389; hellohello~&#xff0c;大家好&#x1f495;&#x1f495;&#xff0c;这里是E绵绵呀✋✋ &#xff0c;如果觉得这篇文章还不错的话还请点赞❤️❤️收藏&#x1f49e; &#x1f49e; 关注&#x1f4a5;&…

Spring Boot:Web开发之视图模板技术的整合

Spring Boot 前言Spring Boot 整合 JSPSpring Boot 整合 FreeMarkerSpring Boot 整合 ThymeleafThymeleaf 常用语法 前言 在 Web 开发中&#xff0c;视图模板技术&#xff08;如 JSP 、FreeMarker 、Thymeleaf 等&#xff09;用于呈现动态内容到用户界面的工具。这些技术允许开…

【css】使用display:inline-block后,元素间存在4px的间隔

问题&#xff1a;在本地项目中使用【display: inline-block】&#xff0c;元素间存在4px间隔。打包后发布到外网又不存在这个问题了。 归根结底这是一个西文排版的问题&#xff0c;英文有空格作为词分界&#xff0c;而中文则没有。 此时的元素具有文本属性&#xff0c;只要标签…

RUST语言函数的定义与调用

1.定义函数 定义一个RUST函数使用fn关键字 函数定义语法: fn 函数名(参数名:参数类型,参数名:参数类型) -> 返回类型 { //函数体 } 定义一个没有参数,没有返回类型的参数 fn add() {println!("调用了add函数!"); } 定义有一个参数的函数 fn add(a:u32)…