前言:近日需要用到 BLIP 微调下游任务,搜索发觉如今并无 BLIP 微调教程,下面就以 Image-Text Captioning 任务为例,演示如何完成 BLIP 模型在自己数据集上的微调。
目录
- 1. BLIP 介绍
- 2. 关键代码定位
- 3. 关键参数赋值
- 4. 模型定义&使用
1. BLIP 介绍
相关论文:BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation (ICML, 2022)
演示地址:https://huggingface.co/spaces/Salesforce/BLIP
开源代码:https://github.com/salesforce/BLIP
在开源代码的 README.md 介绍中,可以看到
可知 BLIP 可以完成 Image-Text Captioning、VQA 以及 NLVR2 这几个下游任务。
2. 关键代码定位
下图是 github 仓库中的文件构成:
首先通过 https://github.com/salesforce/BLIP/blob/main/train_caption.py 文件,了解 BLIP 如何用于 captioning 场景中。
相关参数定义在 https://github.com/salesforce/BLIP/blob/main/configs/caption_coco.yaml 文件中。
发现模型通过如下方式定义:
根据 from models.blip import blip_decoder
的得知,blip_decoder 函数定义于 models.blip 文件中,于是转到https://github.com/salesforce/BLIP/blob/main/models/blip.py 文件。微调过程中,主要使用的是该文件中的 blip_decoder() 函数以及 BLIP_Decoder 类。
-
blip_decoder() 函数定义如下,参数列表包括
pretrained
模型的地址以及 BLIP_Decoder 类的参数列表。def blip_decoder(pretrained='',**kwargs): model = BLIP_Decoder(**kwargs) if pretrained: model,msg = load_checkpoint(model,pretrained) assert(len(msg.missing_keys)==0) return model
-
BLIP_Decoder 类的初始化函数参数列表如下:
class BLIP_Decoder(nn.Module): def __init__(self, med_config = 'configs/med_config.json', image_size = 384, vit = 'base', vit_grad_ckpt = False, vit_ckpt_layer = 0, prompt = 'a picture of ', ):
其中,
med_config
对应的 json 文件路径为https://github.com/salesforce/BLIP/blob/main/configs/med_config.json
,image_size 为模型接收到的图像尺寸,vit 为 image encoder 的规模,可选 base 或 large;vit_grad_ckpt 与 vit_ckpt_layer 为初始化 vit 的相关参数,无需修改;prompt 为 BLIP 使用的提示文本,以字符串形式的自然语言文本给出。
3. 关键参数赋值
- blip_decoder() 的
pretrained
这一参数的值是在https://github.com/salesforce/BLIP/blob/main/configs/caption_coco.yaml
文件中找到的; - image_size 的值要与自己的数据集中图像尺寸对应,我这里将其修改为 224;
- prompt 修改为自己需要的自然语言提示文本,注意以字符串形式给出;
- 其他参数保持不变即可。
4. 模型定义&使用
将 https://github.com/salesforce/BLIP/blob/main/models/ 目录下的 blip.py, vit.py 以及 med.py 文件下载到自己项目的同一目录下。
在主文件中,通过使用如下命令使用 BLIP model:
BLIPModel = blip_decoder(pretrained=args['pretrained'], image_size=args['image_size'], vit=args['vit'], vit_grad_ckpt=args['vit_grad_ckpt'], vit_ckpt_layer=args['vit_ckpt_layer'], prompt=args['prompt']).to(device)
-
训练时,使用代码
loss = BLIPModel(imgs, texts)
,调用 Blip_Decoder 类中的 forward() 函数,会得到当前 batch 数据对应的 loss,然后按照关惯常操作进行反向传播; -
测试时,使用下述代码
generated_texts = BLIPModel.generate(imgs, sample=True, num_beams=3, max_length=30, min_length=5, top_p=0.95, repetition_penalty=1.0)
调用 Blip_Decoder 类中的 generate() 函数,会得到模型对当前 batch 数据生成的自然语言文本。
参考资料
- https://github.com/salesforce/BLIP