深度学习竞赛进阶技巧 - BLIP使用说明与实战

news2024/9/20 1:17:18

BLIP-2: 图像到文本的生成器

请添加图片描述

  • BLIP-2: Scalable Pre-training of Multimodal Foundation Models for the World’s First Open-source Multimodal Chatbot

1论文摘要

由于大规模模型的端到端的训练,视觉与语言的预训练模型的成本越来越高。本文提出了BLIP-2,这是一种通用的有效的预训练策略,它从现成的冷冻预训练图像编码器与大型的语言模型中引导视觉语言预训练。BLIP-2通过一个轻量级的查询transformer弥补了模态差距,该transformer分为两个阶段进行预训练:第一个阶段从冷冻的图像编码器中引导视觉语言representation learning。第二阶段从一个固定的语言模型中引导视觉到语言的生成学习。
优点:BLIP-2 achieves state-of-the-art performance on various vision-language tasks, despite having significantly fewer trainable parameters than existing methods. For example, our model outperforms Flamingo80B by 8.7% on zero-shot VQAv2 with 54x fewer trainable parameters. We also demonstrate the model’s emerging capabilities of zero-shot image-to-text generation that can follow natural language instructions

  • paper:BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models
  • huggingface :huggingface
  • github:salesforce/LAVIS
  • blip-2

2一些任务的总结对应支持的模型

请添加图片描述

3BLIP- example

在这里插入图片描述

4How BLIP-2 works

llm要理解视觉内容,关键是要弥合视觉语言的情态鸿沟。由于llm在自然语言预训练期间没有见过任何图像,因此弥合模态差距具有挑战性,特别是当llm仍然处于冻结状态时。为此,我们提出了一个用新的两阶段预训练策略预训练的查询转换器(Q-Former)。如下图所示,Q-Former经过预训练后,可以有效地充当冻结的图像编码器和冻结的LLM之间的桥梁,从而缩小了模态差距。
Overview of BLIP-2 two-stage pre-training strategy
Overview of BLIP-2 two-stage pre-training strategy

  • 第一个阶段是视觉和语言表征学习。在这个阶段,我们将Q-Former连接到一个冻结的图像编码器,并用图像-文本对进行预训练。Q-Former学习提取与相应文本最相关的图像特征。我们从BLIP (https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)中重新设计了用于视觉和语言表示学习的预训练目标。

Overview of Q-Former and the first stage of vision-language representation learning in BLIP-2
Overview of Q-Former and the first stage of vision-language representation learning in BLIP-2

  • 第二阶段是视觉-语言生成学习。在这一阶段,我们将Q-Former的输出连接到冻结的LLM。我们预先训练Q-Former,这样它的输出特征就可以被LLM解释,从而生成相应的文本。我们实验了基于解码器的LLMs(例如OPT)和基于编码器-解码器的LLMs(例如FlanT5)
    verview of the second stage of vision-to-language generative learning in BLIP-2
    overview of the second stage of vision-to-language generative learning in BLIP-2

在推理过程中,我们只需将文本指令附加在Q-Former的输出之后,作为LLM的输入。我们已经对各种图像编码器和LLM进行了实验,并得出了一个有希望的观察结果:更强的图像编码器和更强的LLM都会导致BLIP-2的更好性能。这一观察结果表明,BLIP-2是一种通用的视觉语言预训练方法,可以有效地收集视觉和自然语言社区的快速进展。BLIP-2是构建多模态对话AI代理的重要突破性技术。

5 BLIP demo

install

pip install salesforce-lavis

library

import torch
from PIL import Image
import requests
from lavis.models import load_model_and_preprocess

示例图像展示

img_url = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png' 
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')   
display(raw_image.resize((596, 437)))

请添加图片描述

device


# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

Load pretrained/finetuned BLIP2 captioning model

# we associate a model with its preprocessors to make it easier for inference.
model, vis_processors, _ = load_model_and_preprocess(
    name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=device
)
vis_processors.keys()
# dict_keys(['train', 'eval'])

将图像处理为模型输入的格式

image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)

生成标题 (using beam search)

model.generate({"image": image})

输出 :# ‘singapore’

generate multiple captions using nucleus sampling

## 由于核采样的不确定性,你可能会得到不同的标题。
model.generate({"image": image}, use_nucleus_sampling=True, num_captions=3)

instructed zero-shot vision-to-language generation

Ask the model to explain its answer.

model.generate({"image": image, "prompt": "Question: which city is this? Answer:"})

[‘singapore’]

model.generate({
    "image": image,
    "prompt": "Question: which city is this? Answer: singapore. Question: why?"})

[‘it has a statue of a merlion’]

model.generate({
    "image": image,
    "prompt": "Question: which city is this? Answer: singapore. Question: why?"})
# 'it has a statue of a merlion'  
context = [
    ("which city is this?", "singapore"),
    ("why?", "it has a statue of a merlion"),
]
question = "where is the name merlion coming from?"
template = "Question: {} Answer: {}."


prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + question + " Answer:"

print(prompt)
# generate model's response
model.generate({"image": image,"prompt": prompt})
# 'merlion is a portmanteau of mermaid and lion'

用于比赛中的实战

比赛介绍kaggle竞赛-Stable Diffusion数据分析与baseline

BLIP2 models are very large to load, so I use some techniques such as init_empty_weights.
And in order to submit within 9 hours, a beam width of beam search in decoder is reduced to 3.
代码参照👆github链接

环境安装

# locally downloaded salesforce-lavis
!pip install salesforce-lavis --no-index --find-links=file:///kaggle/input/lavis-pip/

# in order to load local weights files, modified version of salesforce-lavis is required. so firstly uninstall.
!pip uninstall -y salesforce-lavis

# and install modified salesforce-lavis
!pip install salesforce-lavis --no-index --find-links=file:///kaggle/input/lavis-mod-wheel/salesforce_lavis-1.0.0.dev1-py3-none-any.whl

数据在这里插入图片描述

  • 数据链接data

在kaggle环境下很难使用BLIP一类的大模型,主要原因是我们在加载权重的时候使用了两倍的显存,于是我改进为使用一倍显存。

import os
import gc
import cv2
import sys
import torch

import numpy as np
import torch.nn as nn
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt

from PIL import Image
from lavis.models import load_model, load_preprocess, load_model_and_preprocess
from lavis.processors import load_processor
from lavis.models.blip2_models.blip2_opt import Blip2OPT
from typing import Dict
from sklearn.metrics.pairwise import cosine_similarity 
from pathlib import Path
from accelerate import init_empty_weights

sys.path.append('/kaggle/input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

节省显存使用

显存不够用?一种大模型加载时节约一半显存的方法

# these helper functions are based on the following repository. 
# https://github.com/FrancescoSaverioZuppichini/Loading-huge-PyTorch-models-with-linear-memory-consumption/blob/main/README.md
def get_keys_to_submodule(model: nn.Module) -> Dict[str, nn.Module]:
    keys_to_submodule = {}
    for submodule_name, submodule in model.named_modules():
        for param_name, param in submodule.named_parameters():
            splitted_param_name = param_name.split('.')
            is_leaf_param = len(splitted_param_name) == 1
            if is_leaf_param:
                if submodule_name != '':
                    key = f"{submodule_name}.{param_name}"
                else:
                    key = param_name
                keys_to_submodule[key] = submodule                
    return keys_to_submodule


def load_state_dict_with_low_memory(model: nn.Module, state_dict: Dict[str, torch.Tensor]):
    model.to(torch.device("meta"))
    keys_to_submodule = get_keys_to_submodule(model)
    for key, submodule in keys_to_submodule.items():
        val = state_dict.get(key)
        
        if val is not None:
            param_name = key.split('.')[-1]
            param_dtype = getattr(submodule, param_name).dtype
            val = val.to(param_dtype)
            new_val = torch.nn.Parameter(val, requires_grad=False)
            setattr(submodule, param_name, new_val)

推断部分

comp_path = Path('/kaggle/input/stable-diffusion-image-to-prompts/')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with init_empty_weights():
    my_model = Blip2OPT(opt_model="facebook/opt-2.7b")
class DictWrapper:
    def __init__(self, d):
        self.dict = d
    
    def __getattr__(self, name):
        return self.dict[name]

    def get(self, name, default_val=None):
        return self.dict.get(name, default_val)

dict_tr = {
    "name": "blip_image_train",
    "image_size": 224
}
dict_ev = {
    "name": "blip_image_eval",
    "image_size": 224
}
dict_t = {
    "name": "blip_caption"
}
config = {
    "vis_processor":{
        "train":DictWrapper(dict_tr),
        "eval":DictWrapper(dict_ev),
    },
    "text_processor":{
        "train":DictWrapper(dict_t),
        "eval":DictWrapper(dict_t)
    }
}


vis_processors = load_preprocess(config)[0]

加载模型

#低显存加载模型权重
load_state_dict_with_low_memory(my_model, torch.load("/kaggle/input/blip2-pretrained-opt27b-sdpth/blip2_pretrained_opt2.7b_sd.pth"))
my_model.eval()
gc.collect()

代码核心在此

images = os.listdir(comp_path / 'images')
pred_prompt_list = []
for image_name in images:
    image = Image.open(comp_path / 'images' / image_name).convert('RGB')
    #图像处理为模型输入格式
    image = vis_processors["eval"](image).unsqueeze(0).to(device)
    #产生标题(num_beans = 3)将可能的标题都给产出
    pred_prompt = my_model.generate({"image": image}, num_beams=3)
    #将生成的结果添加到pred_prompt_list中
    pred_prompt_list.append(pred_prompt[0])

后续将pred_prompt_list使用官方的all-MiniLM-L6-v2映射到384维度上进行提交

del my_model
gc.collect()
st_model = SentenceTransformer('/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2')
prompt_embeddings = st_model.encode(pred_prompt_list, batch_size=256).flatten()
imgIds = [i.split('.')[0] for i in images]

EMBEDDING_LENGTH = 384
eIds = list(range(EMBEDDING_LENGTH))

imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, EMBEDDING_LENGTH),
        np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]
submission = pd.DataFrame(
    index=imgId_eId,
    data=prompt_embeddings,
    columns=['val']
).rename_axis('imgId_eId')
submission.to_csv('submission.csv')

文章参考

SDIP BLIP2 baseline public

相关文章

图像分类竞赛进阶技能:OpenAI-CLIP使用范例

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/432792.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络原理(IP协议)

目录IP协议IP地址IP 地址由网络和主机两部分标识组成IP 地址的分类广播地址IP多播子网掩码路由IP 地址与路由控制IP 分包与组包什么是IP分片为什么要进行IP分片IP分片是如何组装的路径 MTU 发现IP协议 IP(IPv4、IPv6)相当于 OSI 参考模型中的第3层——网…

玩转ChatGPT:Auto-GPT项目部署与测评

一、Auto-GPT简介 最近,以ChatGPT为代表的超大规模语言模型火出了圈,各种二次开发项目也是层出不穷。 这周在AI圈炸街的是Auto-ChatGPT,在GitHub上已经61.4K的点赞了。 项目地址:https://github.com/Torantulino/Auto-GPT 用项…

命令设计模式(Command Pattern)[论点:概念、组成角色、相关图示、示例代码、框架中的运用、适用场景]

文章目录概念组成角色相关图示示例代码框架中的运用适用场景概念 命令设计模式(Command Pattern)是一种行为设计模式,它将请求的操作封装为一个对象,从而实现请求者和执行者之间的解耦。这样,请求者只需要知道如何发送…

Darknet19详细原理(含tensorflow版源码)

Darknet19原理 Darknet19是一个轻量级的卷积神经网络,用于图像分类和检测任务。 它是YOLOv2目标检测算法的主干网络,它的优点在于具有较少的参数和计算量,在计算速度和精度之间取得了良好的平衡,同时在训练过程中也具有较高的准确…

MobileNetV1详细原理(含torch源码)

目录 MobileNetV1原理 MobileNet V1的网络结构如下: 为什么要设计MobileNet: MobileNetV1的主要特点如下: MobileNetV1的创新点: MobileNetV1源码(torch版) 训练10个epoch的效果 MobileNetV1原理 Mo…

玩转ChatGPT:中科院ChatGPT Academic项目部署与测评

一、ChatGPT Academic简介 最近,以ChatGPT为代表的超大规模语言模型火出了圈,各种二次开发项目也是层出不穷。 比如说今天我们玩弄的这个“ChatGPT Academic”,在GitHub上已经13.7K的点赞了。 项目地址:https://github.com/bina…

因为这5大工具,同事直呼我时间管理小王子

写在前面 关于时间管理、如何做计划、如何提高执行力等等相关话题其实很早之前我就想写了,但一直拖着迟迟没有动笔。 在之前的一篇文章里我曾详细聊过自己对于时间管理,如何提高执行力,以及如何摆脱那种没有灵魂的任务计划的一些思考和做法…

【C语言】深度理解指针(中)

前言✈ 上回说到,我们学习了一些与指针相关的数据类型,如指针数组,数组指针,函数指针等等,我们还学习了转移表的基本概念,学会了如何利用转移表来实现一个简易计算器。详情请点击传送门:【C语言…

Windows 下安装和使用Redis

Redis 一般安装在Linux中, 但有时出于学习和其他目的,需要在Windows机器运行Redis, 本篇介绍如果在Windows中运行和使用Redis。 关于Redis的基本介绍可以参考: Redis介绍、安装与初体验 Windows 下Redis的下载 可…

【NestJs】日志收集

Nest 附带一个默认的内部日志记录器实现,它在实例化过程中以及在一些不同的情况下使用,比如发生异常等等(例如系统记录)。这由 nestjs/common 包中的 Logger 类实现。你可以全面控制如下的日志系统的行为: 完全禁用日…

jenkins windows安装 部署项目 前端 后端

安装 需要安装的程序: 1.下载jenkins windows版本 2.400 此版本需要jdk11 https://www.jenkins.io/ 按着提示安装即可 2.下载jdk 11 https://login.oracle.com/ 按着提示安装即可 部署pc 1.新建项目 2.源码管理 3.添加git用户 4.Build Steps 构建 初始化np…

vue2数据响应式原理(2)搭建webpack认识一下Object.defineProperty

在1中我们讲到 Object.defineProperty() 是vue2实现数据响应的关键 那么我们就来好好的看看这个方法 方法字面意思是定义属性 而他是通过Object对象调用的 所以说 他是用来控制对象的某个属性的 比较官方的解释是 object.defineProperty() 方法会直接在一个对象上定义一个新属…

单片机添加版本号的一些小技巧

平时我们写程序,通常都会备注软件版本,那么,怎么在单片机中保存版本信息呢? 方法其实有很多,但基本原理都是在指定存储区域(Flash)中写入软件版本信息。 实现方法 下面就分享一个最常用&#xf…

算法风险防控

算法风险防控是指在算法应用过程中,通过对算法应用场景、数据、模型和结果等多个方面的风险进行评估和控制,以保障算法应用的安全性、可靠性和合法性。以下是一些常见的算法风险防控措施: 数据风险防控:在算法应用中,…

【python】Python基础入门:从变量到异常处理

天池实验室代码链接:https://tianchi.aliyun.com/notebook-ai/home#notebookLabId491001 简介 Python 是一种通用编程语言,其在科学计算和机器学习领域具有广泛的应用。如果我们打算利用 Python 来执行机器学习,那么对 Python 有一些基本的了…

51单片机定时器与计数器

文章目录 51单片机定时器与计数器一、定时器与计数器的结构与功能计数功能定时功能 二、定时器与计数器的控制TMOD 工作方式寄存器TCON 定时器控制寄存器 三、仿真案例(一).8个LED 1 秒周期闪烁。(二) 产品包装生产线。 51单片机定时器与计数器 一、定时器与计数器的结构与功能…

ESP32设备驱动-BMP388气压传感器驱动

BMP388气压传感器驱动 文章目录 BMP388气压传感器驱动1、BMP388介绍2、硬件准备3、软件准备4、驱动实现1、BMP388介绍 BMP388 是一款非常小巧、低功耗和低噪声的 24 位绝对气压传感器。 它可以实现精确的高度跟踪,特别适合无人机应用。 BMP388 在 0-65C 之间的同类最佳 TCO,…

港联证券|AI概念板块无死角杀跌,主题炒作熄火后资金会流向哪些板块?

ChatGPT概念指数大跌7%,单日跌幅创历史之最。 4月10日,炒作逾月的ChatGPT概念板块团体大跌,云从科技(688327.SH)、三六零(601360.SH)、科大讯飞(002230.SZ)等热门股跌停&…

集中式版本控制工具 —— SVN

一、简介 1️⃣ SVN 是什么? 代码版本管理工具他能记住每次的修改查看所有的修改记录恢复到任何历史版本恢复已经删除的文件 2️⃣ SVN 与 Git 相比有什么优势? 使用简单、上手快目录级权限控制,企业安全必备子目录 Checkout,…

RK3568平台开发系列讲解(Linux系统篇)文件系统的读写

🚀返回专栏总目录 文章目录 一、文件IO1.1、文件 IO read()1.2、文件 IO write()二、系统调用层和虚拟文件系统层三、ext4 文件系统层沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇我们一起学习 read 和 write 调用过程。 一、文件IO 1.1、文件 IO read() rea…