vqvae简单实战,利用vqvae来提升模型向量表达

news2025/1/16 7:45:35

最近CV领域各种大模型在图像生成领域大发异彩,比如这两年大火的dalle系列模型。在这些模型中用到一个基础模型vqvae,今天我们写个简单实现来了解一下vqvae的工作原理。vqvae原始论文连接https://arxiv.org/pdf/1711.00937.pdf

1,代码

首先我们直接来看代码实现,完整代码GitHub - Pillars-Creation/vqvae: 使用vqvae 进行用户和物品冷启动

    def vector_quantizer(self, z):
        # 将 z 的形状更改为 [batch_size, embedding_dim, 1]
        z_flat = z.view(-1, self.codebook_dim, 1)

        # 计算 z_flat 两两相乘的结果
        z_flat = torch.matmul(z_flat, z_flat.transpose(1, 2))
        z_flat = torch.sqrt(z_flat)

        # 计算z_flat中每个潜在向量与码本中所有向量之间的欧几里得距离
        distances = torch.cdist(z_flat, self.codebook)

        # 计算与每个潜在向量z最接近的码本向量的索引
        codebook_indices = torch.argmin(distances, dim=-1)

        # 使用codebook_indices从码本中检索与原始潜在向量z最接近的离散潜在向量z_q
        one_hot = F.one_hot(codebook_indices, self.codebook_size).type(z_flat.dtype)
        z_q = torch.matmul(one_hot, self.codebook)

        # 提取 z_q 的对角线元素并将它们相加以还原为形状为 [batch_size, emb] 的张量
        z_q = torch.diagonal(z_q, dim1=1, dim2=2)


        # 计算VQ损失,vq_loss为标量
        vq_loss = torch.mean(torch.square(z_q.detach() - z))
        commit_loss = torch.mean(torch.square(z.detach() - z_q))
        vq_loss += self.commitment_cost * commit_loss

        # Apply the Straight-Through Estimator (STE) trick
        z_q = z + (z_q - z).detach()

        # 计算困惑度
        avg_probs = torch.mean(one_hot, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # VQ-VAE Decoder
        z_q = z_q.view(z.shape)

        return z_q, vq_loss, perplexity

2,为什么是vqvae,

要回答这个问题,我们看看vqvae论文里作者认为和传统vae模型的关键差异点。从论文可以看到作者认为关键差异点有两个一个是使用了离散编码,一个是动态的学习先验分布

离散编码

  • VAE通过在编码器中引入隐变量(通常是高斯分布的样本)来建模数据的潜在分布。这种连续性的隐空间使得VAE在生成新样本时更加灵活。通过在隐空间中进行插值或随机采样,可以生成具有连续变化的新样本。但是VAE模型存在一个问题是后验奔溃
  • 后验奔溃是指在训练过程中,编码器学到的潜在表示几乎没有包含输入数据的任何有用信息,而解码器主要依赖于其自身来生成数据。这种情况下,VAE 的生成性能会受到影响,因为潜在空间没有学到有效的数据表示。
  • 在VQ-VAE中,编码器将输入数据映射到一个离散的隐藏,将编码器的输出与一个称为码本(codebook)的离散向量集进行匹配来实现的。使用一个离散编码表来表达连续分布。这种离散的隐藏表示具有一些优势,例如更高的表示能力和更好的泛化性能。

动态的学习先验分布

这块比较直观,在传统的 VAE 中,先验分布通常是一个固定的分布,例如标准正态分布。这意味着潜在变量应该遵循这个固定的分布,这是一个静态的约束。然而,在 VQ-VAE 中,先验分布是从数据中学习的,这意味着它可以根据数据的特点自适应地改变。这个学习的先验分布是通过优化码本中的离散向量来实现的。

在训练过程中,码本中的向量会根据输入数据和重构误差进行更新,从而学习到一个更适合表示数据的离散潜在空间。因此,当我们说 VQ-VAE 中的先验是学习的而不是静态的,潜在空间(即码本)可以根据数据自适应地调整。

3,代码里几个注意点

1,Straight-Through Estimator (STE) trick 

vqvae因为要和codebook 取argmin,由于argmin不可导。所以要用STE技术。

STE是一种用于训练离散变量(例如二值变量)的神经网络的技巧。源于Benjio的论文《Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation》

Straight-Through的思想分两个部分

        前向传播的时候可以用想要的变量(哪怕不可导),

        而反向传播的时候,用你自己为它所设计的梯度。

根据这个思想,我们设计的目标函数是:

 
 

其中detach()是stop gradient的意思。这样一来,前向传播计算(求loss)的时候,就直接等价于decoder(z+zq−z)=decoder(zq),然后反向传播(求梯度)的时候,由于zq−z不提供梯度,所以它也等价于decoder(z),这个就允许我们对encoder进行优化了。

2,codebook

在cv里码本对应的encoder是卷积完的三维机构,如果我们是优化ID向量只有一维,需要做个转换把一维变成二维,这里可以用卷积,也可以把向量两两相乘变成二维结构,这样的好处是一方面方便我们把每一行当作一个向量和codebook求对应,另一方面两两相乘也可以理解为一种特征交叉,提升了向量的表达。如代码中实现

3,提取对角线元素,

因为刚刚encode的时候我们做了两两相乘生成了个二维矩阵,所以从codebook中取得映射后,也是个二维矩阵,我们对应的取对角线值,把向量还原为一维,对应代码

4.实验效果

在movilen的数据集上对物品ID做了增强,可以看到效果还是不错的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1096416.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习——奇异值分解二(特征分解+SVD纯理解)

矩阵的特征分解 特征值和特征向量的定义 抄来的:奇异值分解 困惑1:特征值和特征向量,和原矩阵是怎样的关系,需要一个栗子进行更具象的认识 困惑2:为什么多个特征向量组合成的矩阵,可以构成矩阵A的特征分解…

项目管理之实施关键步骤

项目管理已成为当代企业运营和发展过程中不可或缺的重要环节。如何实现高效、有序和可控的项目管理,一直是企业领导和项目团队追求的目标。本文将结合项目管理七招制胜内容,详细阐述项目管理实战中的具体做法。 如何分析项目 了解项目的背景和目的&…

网工记背配置命令(3)----POE配置示例

POE 供电就是通过以太网供电,这种方式仅凭借那根连接通信终端的网线就可完成为它们供电。POE提供的是-53V~0v 的直流电,供电距离最长可达 100m。PoE 款型的交换机的软件大包天然支持 POE,无需 license,通过执行 poe-enable 命令使…

【力扣1844】将所有数字用字符替换

👑专栏内容:力扣刷题⛪个人主页:子夜的星的主页💕座右铭:前路未远,步履不停 目录 一、题目描述二、题目分析 一、题目描述 给你一个下标从 0 开始的字符串 s ,它的偶数下标处为小写英文字母&am…

系列七、Redis持久化

一、是什么 将内存中的数据写入到硬盘的过程。 二、持久化方式 RDB、AOF 2.1、RDB(Redis Database) 2.1.1、概述 在指定的时间间隔,执行数据集的时间点快照。实现类似照片记录效果的方式,就是把某一时刻的数据和状态以文件的形…

vue3后台管理框架之收获

前端⼯程化概念 在学VUE和webpack打包的时候,了解到前端⼯程的基本概念,如下: 实际的前端开发,遵循四个现代化: 1.模块化(js的模块化、css的模块化、其他资源的模块化) 2.组件化(复⽤…

05-React组件的组合使用

05-React组件的组合使用 1.TodoList案例 需求:TodoList组件化实现此功能 显示所有todo列表输入文本, 点击按钮显示到列表的首位, 并清除输入的文本 1).实现: 完成TodoList组件的静态页面以及拆解组件 动态初始化列表 //App.jsx export default class …

[初始java]——java为什么这么火,java如何实现跨平台、什么是JDK/JRE/JVM

java的名言: ”一次编译、到处运行“ 一、编译语言与解释语言 编译: 是将整份源代码转换成机器码再进行下面的操作,最终形成可执行文件 解释: 是将源代码逐行转换成机器码并直接执行的过程,不需要生成目标文件 jav…

10.14~10.15verilog操作流程与Block Design

后面的那个是延时精度 verilog文件结构 文件名称与写的模板没有关系,这个文件名为P1,但模板名为andgate 但是如果是仿真文件,就需要开头的模板名和仿真文件名相同 .v是源文件,设计文件 .v在设计与sim里都有,静态共享&#xff0…

卡顿分析与布局优化

卡顿分析与布局优化 大多数用户感知到的卡顿等性能问题的最主要根源都是因为渲染性能。Android系统每隔大概16.6ms发出VSYNC信 号,触发对UI进行渲染,如果每次渲染都成功,这样就能够达到流畅的画面所需要的60fps,为了能够实现60fp…

CISA 彻底改变了恶意软件信息共享:网络安全的突破

在现代网络安全中,战术技术和程序(TTP)的共享对于防范网络事件至关重要。 因此,了解攻击向量和攻击类型之间的关联如今是让其他公司从其他公司遭受的 IT 事件中受益(吸取经验教训)的重要一步。 美国主要网…

多模态大模型升级:LLaVA→LLaVA-1.5,MiniGPT4→MiniGPT5

Overview LLaVA-1.5总览摘要1.引言2.背景3.LLaVA的改进4.讨论附录 LLaVA-1.5 总览 题目: Improved Baselines with Visual Instruction Tuning 机构:威斯康星大学麦迪逊分校,微软 论文: https://arxiv.org/pdf/2310.03744.pdf 代码: https://llava-vl.…

无法解析符号 ‘SpringBootApplication’

刚打开一个项目出现"SpringBootApplication"无法解析: 通过以下步骤,修改maven路径即可: 文件---->设置(File--->Settings) 构建、执行、部署--->构建工具--->Maven--->Maven主路经&#xf…

07-ConfigurationClassPostProces的解析

文章目录 如何解析Component,Service,Configurationd,Bean,Import等注解1. 源码描述2. 类继承结构图3. 解析流程4. 具体的注解解析 如何解析Component,Service,Configurationd,Bean,Import等注解 1. 源码描述 BeanFactoryPostProcessor used for bootstrapping processing of…

论文笔记[156]PARAFAC. tutorial and applications

原文下载:https://www.sciencedirect.com/science/article/abs/pii/S0169743997000324 摘要 本文介绍了PARAFAC的多维分解方法及其在化学计量学中的应用。PARAFAC是PCA对高阶数组的推广,但该方法的一些特性与普通的二维情况截然不同。例如,…

O2O优惠券预测

O2O优惠券预测 赛题理解赛题类型解题思路 数据探索理论知识数据可视化分布 特征工程赛题特征工程思路 模型训练与验证 赛题理解 赛题类型 本赛题要求提交的结果是预测15 天内用券的概率,这是一个连续值,但是因为用券只有用与不用两种情况,而…

IDEA中为Maven配置使用vpn工具连接的网络

IDEA中为Maven配置使用vpn工具连接的网络 在电脑上使用VPN工具连接上特定网络后,发现idea中使用maven工具还是无法访问相关的网络,这时需要在idea中进行相关配置,开启ipv4代理 -Djava.net.preferIPv4Stacktrue maven配置 这时,…

JavaScript基础知识13——运算符:一元运算符,二元运算符

哈喽,大家好,我是雷工。 JavaScript的运算符可以根据所需表达式的个数,分为一元运算符、二元运算符、三元运算符。 一、一元运算符 1、一元运算符:只需要一个表达式就可以运算的运算符。 示例:正负号 一元运算符有两…

Golang 协程 与 Java 线程池的联系

Golang 协程 与 Java 线程池的联系 引言Java 线程池缺陷Golang 协程实现思路0.x 版本1.0 版本1.1 版本Goroutine 抢占式执行基于信号的抢占式调度 队列轮转系统调用工作量窃取GOMAXPROCS设置对性能的影响 小结 引言 如何理解Golang的协程,我觉得可以用一句话概括: …

【大数据】Hive SQL语言(学习笔记)

一、DDL数据定义语言 1、建库 1)数据库结构 默认的数据库叫做default,存储于HDFS的:/user/hive/warehouse 用户自己创建的数据库存储位置:/user/hive/warehouse/database_name.db 2)创建数据库 create (database|…