多模态大语言模型(MLLM)-Blip3/xGen-MM

news2024/10/24 18:18:52

论文链接:https://www.arxiv.org/abs/2408.08872
代码链接:https://github.com/salesforce/LAVIS/tree/xgen-mm

本次解读xGen-MM (BLIP-3): A Family of Open Large Multimodal Models
可以看作是
[1] Blip: Bootstrapping language-image pre-training for unified vision-language understanding and generation
[2] BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models
的后继版本

前言

在这里插入图片描述
没看到Blip和Blip2的一作Junnan Li,不知道为啥不参与Blip3
整体pipeline服从工业界的一贯做法,加数据,加显卡,模型、训练方式简单,疯狂scale up

创新点

  • 开源模型在模型权重、训练数据、训练方法上做的不好
  • Blip2用的数据不够多、质量不够高;Blip2用的Q-Former、训练Loss不方便scale up;Blip2仅支持单图输入,不支持多图输入
  • Blip3收集超大规模数据集,并且用相对简单的训练方式,实现多图、文本的交互。
  • 开放两个数据集:BLIP3-OCR-200M(大规模OCR标注数据集),BLIP3-GROUNDING-50M(大规模visual grounding数据集)

具体细节

模型结构

在这里插入图片描述
整体结构非常简单

  • 图像经过ViT得到patch embedding,再经过token sampler得到vision token。(先经过Token Sampler,得到视觉embedding,而后经过VL connector,得到vision token)
  • 文本通过tokenizer获得text token
  • 文本、图像输入均送到LLM中,并且仅对本文加next prediction loss
  • 注意:ViT参数冻结,其他参数可训练
  • 注意:支持图像和文本交替输入,支持多图,任意分辨率图像
  • ViT:所用模型有DFN、SigLIP,在不同任务上,效果不同,如下:
    在这里插入图片描述
  • LLM:所用模型为phi3-mini
  • 模型结构代码见https://github.com/salesforce/LAVIS/blob/xgen-mm/open_flamingo/src/factory.py
  • token Sampler代码见https://github.com/salesforce/LAVIS/blob/xgen-mm/open_flamingo/src/vlm.py
  • VL connector代码见https://github.com/salesforce/LAVIS/blob/xgen-mm/open_flamingo/src/helpers.py

Token Sampler

详见博客https://blog.csdn.net/weixin_40779727/article/details/142019977,就不赘述了

VL Connector

整体结构如下:

class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents, vision_attn_masks=None):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, n1, D)
            latent (torch.Tensor): latent features
                shape (b, T, n2, D)
        """
        x = self.norm_media(x)
        latents = self.norm_latents(latents)

        h = self.heads

        q = self.to_q(latents)
        kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.
        if vision_attn_masks is not None:
            vision_attn_masks = torch.cat((vision_attn_masks, 
                                            torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),
                                            dim=-1)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
        q = q * self.scale

        # attention
        sim = einsum("... i d, ... j d  -> ... i j", q, k)
        # Apply vision attention mask here.
        # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
        if vision_attn_masks is not None:
            attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)
            vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))
            attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
            sim += attn_bias

        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
        

        out = einsum("... i j, ... j d -> ... i d", attn, v)
        out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
        return self.to_out(out)


class PerceiverResampler(VisionTokenizer):
    def __init__(
        self,
        *,
        dim,
        dim_inner=None,
        depth=6,
        dim_head=96,
        heads=16,
        num_latents=128,
        max_num_media=None,
        max_num_frames=None,
        ff_mult=4,
    ):
        """
        Perceiver module which takes in image features and outputs image tokens.
        Args:
            dim (int): dimension of the incoming image features
            dim_inner (int, optional): final dimension to project the incoming image features to;
                also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
            depth (int, optional): number of layers. Defaults to 6.
            dim_head (int, optional): dimension of each head. Defaults to 64.
            heads (int, optional): number of heads. Defaults to 8.
            num_latents (int, optional): number of latent tokens to use in the Perceiver;
                also corresponds to number of tokens per sequence to output. Defaults to 64.
            max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
                and keep positional embeddings for. If None, no positional embeddings are used.
            max_num_frames (int, optional): maximum number of frames to input into the Perceiver
                and keep positional embeddings for. If None, no positional embeddings are used.
            ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
        """
        if dim_inner is not None:
            projection = nn.Linear(dim, dim_inner)
        else:
            projection = None
            dim_inner = dim
        super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
        self.projection = projection
        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        # positional embeddings
        self.frame_embs = (
            nn.Parameter(torch.randn(max_num_frames, dim))
            if exists(max_num_frames)
            else None
        )
        self.media_time_embs = (
            nn.Parameter(torch.randn(max_num_media, 1, dim))
            if exists(max_num_media)
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(
                            dim=dim, dim_head=dim_head, heads=heads
                        ),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

        self.norm = nn.LayerNorm(dim)

    def forward(self, x, vision_attn_masks):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, F, v, D)
            vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
                shape (b, v)
        Returns:
            shape (b, T, n, D) where n is self.num_latents
        """
        b, T, F, v = x.shape[:4]

        # frame and media time embeddings
        if exists(self.frame_embs):
            frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
            x = x + frame_embs
        x = rearrange(
            x, "b T F v d -> b T (F v) d"
        )  # flatten the frame and spatial dimensions
        if exists(self.media_time_embs):
            x = x + self.media_time_embs[:T]

        # blocks
        latents = self.latents
        latents = repeat(latents, "n d -> b T n d", b=b, T=T)
        for attn, ff in self.layers:
            latents = attn(x, latents, vision_attn_masks) + latents
            latents = ff(latents) + latents
        
        if exists(self.projection):
            return self.projection(self.norm(latents)) 
        else:
            return self.norm(latents)

训练及数据

预训练
  • 训练数据
    在这里插入图片描述
    用了0.1T token的多模态数据训练,和一些知名的MLLM相比,例如Qwen2VL 0.6T,还是不太够
  • 训练方式:针对文本的next token prediction方式训练,图像输入为384x384
有监督微调(SFT)
  • 训练数据:从不同领域(multi-modal conversation、 image captioning、chart/document understanding、science、math),收集一堆开源数据。从中采样1百万,包括图文指令+文本指令数据。
    训练1epoch
  • 训练方式:针对文本的next token prediction方式训练
交互式多图有监督微调(Interleaved Multi-Image Supervised Fine-tuning)
  • 训练数据:首先,收集多图指令微调数据(MANTIS和Mmdu)。为避免模型过拟合到多图数据,选择上一阶段的单图指令微调数据子集,与收集的多图指令微调数据合并,构成新的训练集合。
  • 训练方式:针对文本的next token prediction方式训练
后训练(Post-training)
DPO提升Truthfulness
part1
  • 训练数据:利用开源的VLFeedback数据集。VLFeedback数据集构造方式:输入指令,让多个VLM模型做生成,随后GPT4-v从helpfulness, visual faithfulness, ethics三个维度对生成结果打分。分值高的输出作为preferred responses,分值低的输出作为dispreferred responses。BLIP3进一步过滤掉一部分样本,最终得到62.6K数据。
  • 训练方式:DPO为训练目标,用LoRA微调LLM 2.5%参数,总共训练1 epoch
part2
  • 训练数据:根据该工作,生成一组额外responses。该responses能够捕捉LLM的内在幻觉,作为额外dispreferred responses,采用DPO训练。
  • 训练方式:同part1,再次训练1 epoch
Safety微调(Safety Fine-tuning)提升Harmlessness
  • 训练数据:用2k的VLGua数据集+随机5K SFT数据集。VLGuard包括两个部分:
    这段话可以翻译为:
    (1) 恶心图配上安全指示及安全回应
    (2) 安全图配上安全回应及不安全回应
  • 训练方式:用上述7k数据,训练目标为next token prediction,用LoRA微调LLM 2.5%参数,总共训练1 epoch

实验效果

预训练

对比类似于预训练任务的VQA、Captioning任务,效果在使用小参数量LLM的MLLM里,效果不错。
在这里插入图片描述

有监督微调(SFT)

在这里插入图片描述

交互式多图有监督微调(Interleaved Multi-Image Supervised Fine-tuning)

在这里插入图片描述

后训练(Post-training)

在这里插入图片描述

消融实验

预训练
预训练数据量

在这里插入图片描述

预训练数据配比

在这里插入图片描述

视觉backbone

在这里插入图片描述

有监督微调(SFT)
视觉Token Sampler对比

在这里插入图片描述
base resolution:直接把图片resize到目标大小
anyres-fixed-sampling (ntok=128):把所有图像patch的表征concat起来,经过perceiver resampler,得到128个vision token
anyres-fixed-sampling (ntok=256):把所有图像patch的表征concat起来,经过perceiver resampler,得到256个vision token
anyres-patch-sampling:本文采用的方法

Instruction-Aware Vision Token Sampling.

在这里插入图片描述
XGen-MM:输入图像,获取vision token
XGen-MM(instruction-aware):同时输入图像+指令,获取vision token

Quality of the Text-only Instruction Data.

在这里插入图片描述仅利用文本指令数据,训练SFT模型,对比效果


https://blog.csdn.net/weixin_40779727/article/details/142019977

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2222604.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

uniapp:uni.createSelectorQuery函数结合vue的watch函数使用实例

提醒 本文实例是使用uniapp进行开发演示的。 一、需求场景 在开发详情页面时,不同产品描述文案不同,有的文案比较长,需求上要求描述文案最多展示4行文案,少于4行文案,全部显示,此UI高度自动适配&#xff0c…

智慧城管综合管理系统源码,微服务架构,基于springboot、vue+element+uniapp技术开发,支持二次开发

智慧城管源码,智慧城管执法办案系统源码 智慧城管综合执法办案平台是智慧城市框架下,依托物联网、云计算、多网融合等现代化技术,运用数字基础资源、多维信息感知、协同工作处置、智能化辅助决策分析等手段,形成具备高度感知、互联…

pikachu靶场-Cross-Site Scripting(XSS)

sqli-labs靶场安装以及刷题记录-dockerpikachu靶场-Cross-Site Scripting pikachu靶场的安装刷题记录反射型xss(get)反射型xss(post)存储型xssDOM型xssDOM型xss-xxss盲打xss之过滤xss之htmlspecialcharsxss之href输出xss之js输出 pikachu靶场的安装 刷题记录 反射型xss(get) …

《什么是大模型、超大模型和 Foundation Model?》

前言 大模型旨在解决人类面临的各种问题,提高人类的生产力和生活质量。是一门涉及计算机科学、数学、哲学、心理学等多个领域的交叉学科,旨在研究如何使计算机能够像人类一样思考、学习、推理和创造。大模型的出现,让很多产业人士认为这项技术会改变信息产业格局,即基于数…

解码专业术语——应用系统开发项目中的专业词汇解读

文章目录 引言站点设置管理具体要求包括: Footer管理基于URL的权限控制利用数据连接池优化数据库操作什么是数据连接池?优化的优势 利用反射改造后端代码,AJAX反射的作用及其在后端代码中的应用AJAX 实现前后端无刷新交互 引言 创新实践项目二…

ThingsBoard规则链节点:Delete Attributes节点详解

引言 删除属性节点简介 用法 含义 应用场景 实际项目运用示例 智能家居安全系统 物流跟踪解决方案 工业自动化生产线 结论 引言 ThingsBoard是一个开源的物联网平台,它提供了设备管理、数据收集与处理以及实时监控等功能。其中,规则引擎是其核心…

Clickhouse 笔记(一) 单机版安装并将clickhouse-server定义成服务

ClickHouse 是一个高性能的列式数据库管理系统(DBMS),主要用于在线分析处理(OLAP)场景。它由俄罗斯搜索引擎公司 Yandex 开发,并在 2016 年开源。ClickHouse 以其卓越的查询性能和灵活的扩展性而闻名&#…

模拟信号采集显示器+GPS同步信号发生器制作全过程(焊接、问题、代码、电路)

1、制作最小系统板 在制作最小系统板的时候,要用USB转TTL给板子供电,留了一个电源输入的四个接口,同时又用排针引出来VCC和GND用于后续其他外设的电源供应,电源配有电源指示灯和保护电容, 当时在焊接的时候把接口处的…

云计算实验1——基于VirtualBox的Ubuntu安装和配置

实验步骤 1、VirtualBox的安装 本实验使用VirtualBox-7.0.10 进行演示。对于安装包,大家可以前往 VirtualBox官网下载页面(https :/ / www. virtualbox.org/wiki/Downloads)下载其7.0版本安装包进行安装,或者直接使用QQ群的安装包VirtualBox-7.0.10-15…

基于开源Jetlinks物联网平台协议包-MQTT自定义主题数据的编解码

目录 前言 1.下载官方协议包 2.解压 3.自定义主题 4.重写解码方法 5.以下是我解析后接收到的数据 前言 最近这段时间,一直在用开源的Jetlinks物联网平台在学习,偶尔有一次机会接触到物联网设备对接,在协议对接的时候,遇到了…

Spring面试题——第五篇

1. Spring的优点 轻量级和非侵入性:不需要引入大量的依赖和配置。面向切面编程:Spring提供了强大的面向切面编程,允许用户定义横切关注点,并将其与核心业务逻辑分离,提高了灵活性。依赖注入(DI&#xff09…

java对接钉钉发送消息(纯萌新文档解惑)

java对接钉钉(纯萌新文档解惑) 注意:不是其他直接给你个写好的钉钉工具类,但不知道它怎么来的。是以钉钉官方文档为准,流程是什么,你想要什么可以自己在文档找(所有文档都有只是萌新看着懵&…

Kafka高可用性原理深度解析

在分布式系统中,高可用(High Availability, HA)是指系统在面对硬件故障、网络分区、软件崩溃等异常情况时,仍能继续提供服务的能力。对于消息队列系统而言,高可用性尤为重要,因为它通常作为数据流通的中枢&…

SSD | (十)PCIe介绍(上)

文章目录 📚从PCIe的速度说起📚PCIe拓扑结构🐇PCI——总线型拓扑结构🐇PCIe——树形拓扑结构📚PCIe分层结构📚PCIe TLP类型📚PCIe TLP结构🐇通用结构🐇具体TLP的Header📚从PCIe的速度说起 PCIe发展至今,速度一代比一代快。 连接速度所示1、2等是指PCIe链接…

Python 打包成 EXE 的方法详解

#1024程序员节|征文# 日常开发中,python由于其便捷性成为了很多人的首选语言,但是python的环境配置也是有点麻烦的,那么我们如何让其变得更加友好呢?没错,就是打包成exe可执行文件。 一、PyInstaller 简介…

修改windows11的hosts,配置127.0.0.1域名(最清晰)

这里记录的是学习短链接项目,通过配置127.0.0.1域名,达到可以通过域名代替127.0.0.1访问127.0.0.1下的某个端口的服务,达到短链接的前缀的效果,这里展示windows11的更改过程。 一、hosts文件路径 C:\Windows\System32\drivers\e…

【Java数据结构】---哈希表

乐观学习,乐观生活,才能不断前进啊!!! 我的主页:optimistic_chen 我的专栏:c语言 ,Java 欢迎大家访问~ 创作不易,大佬们点赞鼓励下吧~ 前言 在顺序结构以及平衡树中&…

littlefs源码分析1-设计思考

1.littlefs设计目的 littlefs 最初是作为一个实验而构建的,目的是在微控制器的环境中了解文件系统设计。目的是:构建一个在不使用无限制内存的情况下对电源丢失和闪存磨损具有弹性的文件系统。 这对嵌入式文件系统littlefs提出了三个主要要求&#xff1…

【Linux】 exit 和 _exit 的区别

在Linux系统中&#xff0c;exit(int status) 和 _exit(int status) 都是用来终止进程的函数&#xff0c;都能通过参数 int status传递一个整型的退出状态码给父进程&#xff0c;但它们之间有一些重要的区别。 1. 头文件不同 exit() 函数定义在 <unistd.h> 中 _exit() 函…

【Python爬虫实战】高效解析和操作XML/HTML的实用指南

&#x1f308;个人主页&#xff1a;https://blog.csdn.net/2401_86688088?typeblog &#x1f525; 系列专栏&#xff1a;https://blog.csdn.net/2401_86688088/category_12797772.html 目录 前言 一、lxml的安装 &#xff08;一&#xff09;使用 pip 安装 &#xff08;二&…