MINICPM-V2_6之图像embedding的resampler-代码解读

news2025/1/12 20:02:19

目的

基于上一篇MINICPM-V2_6图像得到embedding-代码解读将图片patch找到对应的embedding(包括位置embedding和像素embedding),embedding经过多层attention后会得到vision_embedding,vision_embedding的长度对应的是patch的个数,这个长度是不固定的,有长有短,那要怎么做到长度统一呢?
这篇从vision_embedding入手,了解如何解决patch长度不统一的问题

Perceiver Resampler

Flamingo
resampler
Perceiver Resampler解决的事将变长的patch信息转化为固定大小长度的特征,否则过长的patch会大大加大后续LLM的计算负担(毕竟每个token都占用计算)。
Perceiver Resampler采用一个可学习的queries作为交叉注意力中的Q,而将patch进行特征提取后的表示x_f,和Q拼接起来作为交叉注意力中的K和V,通过这种方法将变长的patch特征就规整为了固定大小的特征,方便了后续的处理。

代码

基本变量

import torch
from torch import nn
from torch import Tensor
from functools import partial
import numpy as np

num_queries = 64# query数量
embed_dim = 3584# 向量维度
num_heads = 3584//128# attention头数28
adaptive = True# 
max_size = (70,70)# 最大尺寸
kv_dim = 1152# kv向量维度

query = nn.Parameter(torch.zeros(num_queries, embed_dim))# 64,3584
kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)# 1152,3584

norm_layer=partial(nn.LayerNorm, eps=1e-6)
ln_q = norm_layer(embed_dim)# LayerNorm((3584,), eps=1e-06, elementwise_affine=True)
ln_kv = norm_layer(embed_dim)# LayerNorm((3584,), eps=1e-06, elementwise_affine=True)

ln_post = norm_layer(embed_dim)# LayerNorm((3584,), eps=1e-06, elementwise_affine=True)
proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))# 3584,3584

这里定义了几个LN,都是在attention前后使用的

函数定义

既然是attention,那其中必然有位置embedding,这里使用的是ROPE,只是因为是2D,所以这里也要处理一下得到2D的位置embedding

def get_2d_sincos_pos_embed(embed_dim, image_size):
    """
    输入:
        embed_dim: 向量维度
        image_size:(H,W)
    输出:
        pos_embed:(H, W, embed_dim)
    demo:
        embed_dim = 8
        image_size = (3,5)
        pos_embed = get_2d_sincos_pos_embed(embed_dim, image_size)# 3,5,8
    """
    if isinstance(image_size, int):
        grid_h_size, grid_w_size = image_size, image_size
    else:
        grid_h_size, grid_w_size = image_size[0], image_size[1]# 70,70
    grid_h = np.arange(grid_h_size, dtype=np.float32)# 0,1,2,..,69
    grid_w = np.arange(grid_w_size, dtype=np.float32)# 0,1,2,..,69
    grid = np.meshgrid(grid_w, grid_h)  # 生成网格,但是这里是w在前;torch.meshgrid是h在前
    # [[ 0.,  1.,  2., ..., 67., 68., 69.], [ 0.,  1.,  2., ..., 67., 68., 69.],
    # [[ 0.,  0.,  0., ...,  0.,  0.,  0.], [ 1.,  1.,  1., ...,  1.,  1.,  1.],
    grid = np.stack(grid, axis=0)# 在第0维拼接
    # [[[ 0.,  1.,  2., ..., 67., 68., 69.], [ 0.,  1.,  2., ..., 67., 68., 69.],
    # [[ 0.,  0.,  0., ...,  0.,  0.,  0.], [ 1.,  1.,  1., ...,  1.,  1.,  1.], ]]   
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)# 70,70,3584
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    """
    输入:
        embed_dim: 向量维度
        grid:行位置和列位置[(H,W),(H,W)]
    输出:
        emb:(H, W, embed_dim)
    demo:
        embed_dim = 8
        grid = [[np.arange(70, dtype=np.float32)],[np.arange(70, dtype=np.float32)]]
        emb = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)# 1,70,8
    """
    assert embed_dim % 2 == 0
    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[0])  # (H, W, D/2) H维度是一样的 W维度是不一样的 对应的是左半部分
    emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[1])  # (H, W, D/2) W维度是一样的 H维度是不一样的 对应的是右半部分
    emb = np.concatenate([emb_h, emb_w], axis=-1)  # (H, W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
    """
    输入:
        embed_dim: 向量维度
        pos: 位置 (H, W)
    输出:
        emb:得到POS对应的ROPE位置向量 (H, W, D)
    demo:
    embed_dim = 8
    pos = [np.arange(70, dtype=np.float32)]# [[0,1,2,..,69]]
    emb = get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos)# 1,70,8
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float32)# 0,1,2,3
    omega /= embed_dim / 2.# [0,1,2,3]/4=[0,1/4,2/4,3/4]
    omega = 1. / 10000 ** omega  # (D/2,) ROPE位置编码中的\theta = 1/10000 **(2i/d)
    out = np.einsum('hw,d->hwd', pos, omega)  # (H, W, D/2), outer product m*\theta
    emb_sin = np.sin(out)  # (H, W, D/2) sin(m*\theta)
    emb_cos = np.cos(out)  # (H, W, D/2) cos(m*\theta)
    emb = np.concatenate([emb_sin, emb_cos], axis=-1)  # (H, W, D)
    return emb

embed_dim = 8
max_size = (3,5)
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, max_size)).float()

tgt_sizes = (2,3)
tgt_h, tgt_w = tgt_sizes
a = pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1))# tgt_h * tgt_w, 8

我一直想象不出来这里得到的pos_embed和a是什么样子的,简单举个例子,看着清楚明白
h=3,w=5,embed_dim=8,最后形成的pos_embed的维度是(3,5,8),相同颜色的是一样的值
position_embedding
当一个图片patch对应的tgt_sizes=(2,3)从pos_embed中找到的位置embedding是什么样子的呢?
将h对应的坐标当作i,将w对应的坐标当作j,从pos_embed找到pos_embed[:h,:w,:]就是下面这个样子的了
result

def MultiheadAttention(embed_dim, num_heads, query ,key ,value, key_padding_mask):
    """
    输入:
        embed_dim: 向量维度
        num_heads: 头数
    输出:
        output:输出向量 (query_num, batch_size, embed_dim)
    demo:
    这里和attention是一样的,后面找个机会写
    """    
    output = torch.randn(64, 3, embed_dim)
    return output

函数调用

x = torch.randn(3,1036,1152)# 上一步得到的图片patch对应的embedding
tgt_sizes = torch.tensor([[28, 37],
        [39, 26],
        [39, 26]])# 上一步得到的图片patch对应的尺寸
assert x.shape[0] == tgt_sizes.shape[0]
bs = x.shape[0]

embed_dim = 3584# 向量维度
max_size = (70,70)# 最大尺寸
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, max_size)).float()# 70,70,3584

patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]# [1036, 1014, 1014]
max_patch_len = torch.max(patch_len)# 1036
key_padding_mask = torch.zeros((bs, max_patch_len))

position_embed = []# 通过从pos_embed中根据图片patch大小得到对应的位置向量,且要更新key_padding_mask,后面的attention要用的
for i in range(bs):
    tgt_h, tgt_w = tgt_sizes[i]
    position_embed.append(pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)))  # [[1036, 3584],[1014, 3584],[1014, 3584]]
    key_padding_mask[i, patch_len[i]:] = True

position_embed = torch.nn.utils.rnn.pad_sequence(
    position_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2)  # BLD => L * B * D 1036*3*3584

x = kv_proj(x)  # 将x映射得到kv的过程 [3, 1036, 3584]
x = ln_kv(x).permute(1, 0, 2)  # 将x经过ln [1036, 3, 3584]

q = ln_q(query)  # 将query经过ln 64*3584

def repeat(query, N: int):
    return query.unsqueeze(1).repeat(1, N, 1)


out = MultiheadAttention(embed_dim, num_heads,
    repeat(q, bs),  # Q * B * D query
    x + position_embed,  # L * B * D +  L * B * D key
    x,# L * B * D  value
    key_padding_mask=key_padding_mask)
#  out: Q * B * D 在这一步完成了映射
# attn就是正常的attention操作了
# key、value的值和原始论文flamingo中有些不一样,论文中是将q和x拼接在一起作为key、value进行attention,这里只使用了x
# 位置编码的使用和论文也是不一样的,论文中是将位置编码直接加在了原始输入上,所以value、key都是有位置编码的
# 这里是只将位置编码加在了key上


x = out.permute(1, 0, 2)  # B * Q * D

x = ln_post(x)# ln
x = x @ proj# 映射
return x

额外说几句

# 模型MINICPM-V2_6中的resampler部分结构
  (resampler): Resampler(
    (kv_proj): Linear(in_features=1152, out_features=3584, bias=False)
    (attn): MultiheadAttention(
      (out_proj): Linear(in_features=3584, out_features=3584, bias=True)
    )
    (ln_q): LayerNorm((3584,), eps=1e-06, elementwise_affine=True)
    (ln_kv): LayerNorm((3584,), eps=1e-06, elementwise_affine=True)
    (ln_post): LayerNorm((3584,), eps=1e-06, elementwise_affine=True)
  )

这里主要是为了将变长的vision embedding转化为固定长度,代码结构也比较简单
代码比较简单,能想通2D position embeeding就好。
挖了一个坑,要看看attention的代码了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2131091.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

超链接/列表/多媒体/表格标记

1.超链接标记&#xff0c; 要将两个前端网页连接起来用什么标记呢&#xff1f; 答案是a标记&#xff0c;也就是超链接 下图就是两个html建立了超链接 效果是点击我是1号会跳转到我是2号那里 2.列表标记分为有序列表ol和无序列表ul, 每一列用li标签 <hr color"yell…

华为项目管理认证HCIA-PM认证 |课程大纲

大家想要往上升的&#xff0c;或多或少都要懂点技术&#xff0c;但这并不是让你们对技术的认知层面要做到面面俱到&#xff0c;每个细节都清楚&#xff0c;而是只要知道产品开发的流程和研发所需的资源设备就可以了。 如何才能在短时间内掌握这些技术&#xff1f;最直接有效的…

PHP智能收银精准管理收银服务系统小程序源码

智能收银&#xff0c;精准管理 —— 解锁收银服务新境界 &#x1f389; 开篇&#xff1a;告别传统&#xff0c;迎接智能收银新时代 在快节奏的现代生活中&#xff0c;每一次购物体验都值得我们追求更加高效与便捷。传统的收银方式已难以满足商家与顾客日益增长的需求&#xff…

【车载以太网】【SOME/IP】Wireshark 解析

Wireshark 下载链接:Wireshark Go DeepSOMEIP插件介绍:https://www.wireshark.org/docs/dfref/s/someip.html官方插件 Wireshark从3.2版本开始支持SOME/IP,启用相应的插件即可以使用Wireshark解析解析并查看SOME/IP数据。 相关代码: 代码仓库:https://github.com/wiresh…

UTF-8与UTF-8MB4编码的异同与应用场景

前言 想象一下&#xff0c;你正在网上冲浪&#xff0c;突然看到一个超有趣的表情符号&#xff0c;或者是一个外国朋友发来了一条消息&#xff0c;里面包含了一些特殊字符。这时候&#xff0c;如果你的电脑或者手机使用的编码方式不够强大&#xff0c;那些酷炫的表情或者特殊文字…

OpenCV结构分析与形状描述符(12)椭圆拟合函数fitEllipseAMS()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 围绕一组2D点拟合一个椭圆。 该函数计算出一个椭圆&#xff0c;该椭圆拟合一组2D点。它返回一个内切于该椭圆的旋转矩形。使用了由[260]提出的近…

微调大模型:提高其代码修复能力的尝试

目录 一、作品背景&#xff1a; 二、作品目标&#xff1a; 三、作品技术方案&#xff1a; (1)标记化 (2)量化 (3) LoRA&#xff08;低秩自适应&#xff09;配置 (4)训练配置 (6)模型保存 四、作品效果&#xff1a; 一、作品背景&#xff1a; 随着大型模型技术的日益成…

Java毕业设计之基于SSM框架的正安县吉他线上销售系统

&#x1f6a9;毕设中如何选题&#xff1f; 对于项目设计中如何选题、让题目的难度在可控范围&#xff0c;以及如何在选题过程以及整个毕设过程中如何与老师沟通&#xff0c;有疑问不清晰的可以联系我&#xff0c;详细为你解答 &#x1f6ad;如何快速熟悉一个项目&#xff1f;这…

java异步发送邮件:如何实现高效邮件发送?

java异步发送邮件性能调优&#xff1f;如何设计java异步发邮件&#xff1f; 传统的同步邮件发送方式在处理大量邮件时可能会导致系统响应变慢&#xff0c;甚至阻塞其他关键业务流程。AokSend将深入探讨如何通过Java异步发送邮件来实现高效邮件发送&#xff0c;从而提升应用程序…

裸土检测算法样本标注、裸土检测、裸土算法识别

在当今快速发展的科技时代&#xff0c;裸土检测算法作为一种前沿技术&#xff0c;正逐步改变我们对土壤检测的传统观念。随着环境保护和资源管理的日益重要&#xff0c;裸土检测算法不仅在农业领域展现了巨大的潜力&#xff0c;也在环境监测、灾后恢复等多个领域发挥了至关重要…

基于SpringBoot+Vue的高考志愿智能推荐系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于JavaSpringBootVueMySQL的…

使用 nvm 管理 node 版本:如何在 macOS 和 Windows 上安装使用nvm

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 一、引言二、nvm的安装与基本使用2.1 macOS安装nvm2.1.1 使用 curl 安装2.1.2 使用 Homebrew 安装 2.2 Windows安装nvm2.2.1 下载 nvm-windows2.2.2 安装 nvm-windows 2.3 安装node2.4 切换node版本 三、常见问题及解决方案…

语音克隆神器GPT-Sovits-V2 Mac版整合包!

语音克隆神器GPT-Sovits-V2 Mac版整合包&#xff01; Mac M1/M2/M3芯片福音&#xff01;语音克隆神器GPT-Sovits-V2整合包来了&#xff01; AI语音克隆黑科技&#xff0c;Mac也能轻松玩转&#xff01; 还在羡慕别人用AI语音克隆技术&#xff1f;还在苦恼Mac配置环境的复杂&am…

Linux - iptables防火墙

目录 一、iptables概述 二、规则表与规则链结构&#xff08;四表五链&#xff09; 1.简述 2.四表&#xff08;规则表&#xff09; 3.五链&#xff08;规则链&#xff09; 三、数据链过滤的匹配流程 四、iptables命令行配置方法 1.命令格式 2.基本匹配条件 3.隐含匹配 …

Python什么时候打折?

Python收费&#xff1f; 今天有一个刚学习编程的网友&#xff0c;他找到了我&#xff0c;问了我这样一个问题&#xff0c;“我看你的文章也有一段时间了&#xff0c;对上面提到的python的强大功能非常感兴趣。现在想自己安装一个亲自体验一下。我发现&#xff0c;python的售价…

开发一款通过蓝牙连接控制水电表的微信小程序

增强软硬件交互 为了更好的解决师生生活中的实际问题&#xff0c;开发蓝牙小程序加强了和校区硬件的交互。 比如通过蓝牙连接控制水电表&#xff0c;减少实体卡片的使用。添加人脸活体检测功能&#xff0c;提高本人认证效率&#xff0c;减少师生等待时间。 蓝牙水电控展示 蓝…

go-map系统学习

map底层结构 Goland的map的底层结构使用hash实现&#xff0c;一个hash表里有多个hash表节点&#xff0c;即bucket&#xff0c;每个bucket保存了map中的一个或者一组键值对。 map结构定义&#xff1a; runtime/map.go:hmap type hmap struct {// Note: the format of the hma…

Linux系统进程的优先级

一、进程优先级的概念 进程优先级就是进程被CPU执行的先后顺序&#xff0c;优先级值越小&#xff0c;优先级别越高。 使用ps -al命令查看当前系统所有进程的优先级&#xff1a; PRI是进程的基准优先级&#xff0c;NI&#xff08;nice值&#xff09;是进程优先级修正数据&…

【C++】理解C++中的复制、复制构造函数

十、理解C中的复制、复制构造函数 拷贝就是要复制数据&#xff0c;也就是复制内存。 当我们把一个对象或一段数据从一个地方拷贝到另一个地方&#xff0c;那这个对象或数据其实是有两个副本&#xff0c;而且这个过程还是需要时间和开销的。所以如果你只是想读取数据&#xff0…

SQL使用IN进行分组统计时如何将不存在的字段显示为0

这两天被扔过来一个脏活儿&#xff1a;做一个试点运行系统的运营指标统计。 活儿之所以称为“脏”&#xff0c;是因为要统计8家单位共12个项目的指标。而每个项目有3个用户类指标&#xff0c;以及分17个功能模块&#xff0c;每个功能模块又分5个维度的指标。也就是单个项目是1…