LISA是一个很好的Reason Segmentation的baseline, 其利用特殊的token [SEG]来微调多模态LLM和SAM的decoder来实现复杂逻辑下的prompt的推理分割. 其整体框图如下, 本篇文章精度此代码并作简单复现.
1. 推理流程
流程如下:
1.1 加载Tokenizer与模型
首先利用transformers库的AutoTokenizer从config文件中加载Tokenizer. LISA使用的Tokenizer是LLaMa相同的Tokenizer,填充方式是在序列右侧填充,最大长度是512:
tokenizer = AutoTokenizer.from_pretrained(
args.version,
cache_dir=None,
model_max_length=args.model_max_length, # 512
padding_side="right",
use_fast=False, # 如果选择true 则模型自动进行填充 选择False的流程是首先对输入进行encode 然后再填充
)
然后设定填充的unknown token以及本工作提出的代表分割掩码的[SEG] token:
tokenizer.pad_token = tokenizer.unk_token # unknown token为 <unk>
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] # 在codebook中 [SEG]的id是32000
# tokenizer返回一个dict, 有两个key, input_ids表示在词汇表中的索引, attention_mask表示是否为填充的token 用于在注意力计算的时候看哪些token应该被注意
然后根据tokenizer的设定,设置model的句子开始、结束以及填充的token id:
model.config.eos_token_id = tokenizer.eos_token_id # </s>
model.config.bos_token_id = tokenizer.bos_token_id # <s>
model.config.pad_token_id = tokenizer.pad_token_id # <unk>
随后,加载CLIP预训练的ViTal-large模型,作为LLaVA中的vision encoder:
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch_dtype)
加载CLIP的图像预处理类,包含resize、crop等,以及SAM中用到的resize类,其按照图像的最长边进行等比例resize:
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
transform = ResizeLongestSide(args.image_size)
1.2 推理主要过程
按照LLaVA规定的格式先实例化一个Conversation类,这个类是自定义的一个数据类,用以保存所有的对话历史:
conv = conversation_lib.conv_templates[args.conv_type].copy()
"""
规定一开始的system设定 角色(用户和bot) 对话历史 分隔的token等
Conversation(system="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=('USER', 'ASSISTANT'), messages=(), offset=0, sep_style=<SeparatorStyle.TWO: 2>, sep=' ', sep2='</s>', version='v1', skip_next=False)
"""
conv.messages = [] # 初始化
读取文本prompt, 例如who is the oldest person? Please output segmentation mask.
读取文本后,要在文本prompt之前加入给图像预留的token <image>
, 并且在<image>
前后加入起止符:
prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) # <im_start><image><im_end>\nwho is the oldest person? Please output segmentation mask.
随后组合成完整的对话prompt:
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <im_start><image><im_end>\nwho is the oldest person? Please output segmentation mask. ASSISTANT:
读取图片,并用CLIP的image preprocess将图像缩放为224x224
, 并遵循SAM的预处理将图像长边缩放至1024
, 并填充至1024x1024
. 两种预处理分别对应两个vision encoder.
随后将文本prompt进行tokenize, 得到input_ids:
input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
input_ids = input_ids.unsqueeze(0).cuda() # [bs, length]
"""
tensor([[ 1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116,
21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892,
322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155,
29889, 3148, 1001, 29901, 32001, -200, 32002, 1058, 338, 278,
23947, 2022, 29973, 3529, 1962, 10768, 362, 11105, 29889, 319,
1799, 9047, 13566, 29901]], device='cuda:0')
"""
将input_ids输入LISA模型, 实际上走的是LLaVA, 并产生文本输出, 类似于训练时预规定的"Sure, It is [SEG]":
with torch.no_grad():
outputs = self.generate( # transformers库的方法 采用greedy生成 即每次选logits最大的token输出
images=images_clip,
input_ids=input_ids,
max_new_tokens=max_new_tokens,
num_beams=1,
output_hidden_states=True, # 输出hidden state 是因为我们要取[SEG]对应的embedding来decode分割的mask
return_dict_in_generate=True,
)
"""
输出的句子(outputs.sequence)为:
tensor([[ 1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116,
21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892,
322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155,
29889, 3148, 1001, 29901, 32001, -200, 32002, 1058, 338, 278,
23947, 2022, 29973, 3529, 1962, 10768, 362, 11105, 29889, 319,
1799, 9047, 13566, 29901, 18585, 29892, 32000, 869, 2]],
device='cuda:0')
翻译过来就是:
<s>A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <im_start> <im_end> who is the oldest person? Please output segmentation mask. ASSISTANT: Sure, [SEG] .</s>
"""
output_hidden_states = outputs.hidden_states[-1] # 取最后一层的hidden state [bs, 313, 5120]
output_ids = outputs.sequences
之后是重要的步骤: 把[SEG]
在输出中对应的位置取出来(得到一个mask, 只有[SEG]
在的位置是1), 并在前面填充255个False
, 原因是:在LLaVA的推理过程中, vision encoder会将224x224的图像切分成14x14大小的patch, 共256个, 所以等价的输出也是256个, 因此在真实的LLaVA输出中, 长度比原本多了256 - 1(<image>符)
.
eg_token_mask = output_ids[:, 1:] == self.seg_token_idx
# hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
seg_token_mask = torch.cat(
[
torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(),
seg_token_mask,
],
dim=1,
)
取出[SEG]
对应的hidden state, 并经过一个MLP层用以对齐:
hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states))
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
pred_embeddings = last_hidden_state[seg_token_mask]
然后是SAM阶段, 对SAM预处理的image
经过encoder, 然后将[SEG]
对应的hidden state作为prompt, 也进行encode:
image_embeddings = self.get_visual_embs(images) # [1, 256, 64, 64] 其中 64 = 1024 / 16 patch大小为16 x 16
multimask_output = False
pred_masks = []
for i in range(len(pred_embeddings)):
(
sparse_embeddings, # [1, 1, 256]
dense_embeddings, # [1, 256, 64, 64]
) = self.model.visual_model.prompt_encoder(
points=None,
boxes=None,
masks=None,
text_embeds=pred_embeddings[i].unsqueeze(1),
)
随后执行常规的SAM decoder过程, 得到分割掩码. 最后, 对output_id
进行解码得到文本输出, 以及对mask进行可视化与保存即可.
low_res_masks, iou_predictions = self.model.visual_model.mask_decoder(
image_embeddings=image_embeddings[i].unsqueeze(0),
image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(), # 图像的位置编码
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output, # False, LISA只默认分割一个物体
)
pred_mask = self.model.visual_model.postprocess_masks( # 将低分辨率mask [224, 224]后处理回原本的分辨率
low_res_masks,
input_size=resize_list[i],
original_size=original_size_list[i],
)
text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
1.3 运行结果
可以看出来,它并没有很正确地分割出来最“老”的人(应该是画面最右侧的), 而是倾向于分割所有的人, 说明LISA可能存在对特定文本忽视的现象, 这在一些生成的工作中有人关注过.
2. 训练流程
2.1 数据准备与读取
LISA的训练数据包括四个任务:
- 语义分割
- 指令分割
- VQA
- 推理分割
在训练的时候, 将四种任务的若干数据集混合, 封装成统一的HybridDataset
类, 在每次迭代的时候, 都随机从四个任务中挑选一个任务, 再从挑选的任务中随机选一个数据集, 再从这个数据集中随机选一个样本
, 代码如下:
# HybridDataset:
def __getitem__(self, idx):
ind = np.random.choice(list(range(len(self.datasets))), p=self.sample_rate)
data = self.all_datasets[ind]
inference = False
return *data[0], inference
# 传到单个任务的数据集类中索引是0 但是也是随机选一个数据集之后随机选一个样本 例如对应Refferring Segment:
def __getitem__(self, idx):
ds = random.randint(0, len(self.refer_seg_ds_list) - 1)
ds = self.refer_seg_ds_list[ds]
refer_seg_ds = self.refer_seg_data[ds]
images = refer_seg_ds["images"]
annotations = refer_seg_ds["annotations"]
img2refs = refer_seg_ds["img2refs"]
idx = random.randint(0, len(images) - 1)
接下来看一下每个任务的数据集是如何构建以及读取的.
2.2.1 HyBridDataset初始化
在训练的主函数(train_ds.py
)中, 对混合数据集进行如下初始化:
train_dataset = HybridDataset(
args.dataset_dir, # 根目录 存放所有任务的所有数据集
tokenizer, # 采用LLaVA的tokenizer
args.vision_tower, # CLIP的ViT-large
samples_per_epoch=args.batch_size # 一个epoch的样本数 = bs * 梯度积累步数 * 一个epoch的步数 * 显卡数
* args.grad_accumulation_steps # 默认10 梯度累积的主要目的是在显存有限的情况下, 模拟大bs的训练效果
* args.steps_per_epoch # 默认500
* world_size,
precision=args.precision, # 推理时候的精度 fp16/fp32
image_size=args.image_size, # 默认1024 SAM的输入分辨率
num_classes_per_sample=args.num_classes_per_sample, # 一个样本的标注中最多看几个类别 默认为3
exclude_val=args.exclude_val, # 是否排除验证集
dataset=args.dataset, # 默认四个任务都进行
sample_rate=[float(x) for x in args.sample_rates.split(",")], # 对每个任务的采样频率
# 默认语义分割: 指令分割: VQA: 因果分割 = 9: 3: 3: 1 可以看出是保住分割能力 并防止复杂prompt的过拟合
sem_seg_data=args.sem_seg_data, # 具体的语义分割的数据集名称
refer_seg_data=args.refer_seg_data, # 具体的指令分割的数据集名称
vqa_data=args.vqa_data, # 具体的VQA的数据集名称
reason_seg_data=args.reason_seg_data, # 具体的推理分割的数据集名称
explanatory=args.explanatory, # 这个参数是对ReasonSeg而言的, 问题是要求解释的问题("例如Please segment.. and explain why")的比例, 默认是0.1. 加入VQA数据集训练的目的也是保障模型回答问题的能力.
)
2.2.2 语义分割
在初始化函数中, 分别定义图像预处理方法, 问题模板, 回答模板以及每个数据集的样本的class名称, 图像以及label:
self.transform = ResizeLongestSide(image_size)
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
self.short_question_list = SHORT_QUESTION_LIST # 例如 <image> + "\n" + "Can you segment the {class_name} in this image?"
self.answer_list = ANSWER_LIST # 例如 "It is [SEG]." "Sure, [SEG]."等
self.data2list = {} # key: 数据集 value: (images, labels), images 和 labels为长度为N的列表, 里面存储路径
self.data2classes = {} # key: 数据集 value: class名称的np.ndarray
# 存储每一个数据集
self.sem_seg_datas = sem_seg_data.split("||")
for ds in self.sem_seg_datas:
classes, images, labels = eval("init_{}".format(ds))(base_image_dir)
self.data2list[ds] = (images, labels)
self.data2classes[ds] = classes
在__getitem__
中, 首先随机选择数据集, 然后再从数据集中随机选择一个样本进行图像和label的读取和resize. 此外, 读取对应的类别(如果超过规定的数目, 就随机抽取args.num_classes_per_sample
个), 这部分不再赘述.
之后, 构建问题和答案:
questions = []
answers = []
class_ids = [] # 为样本中的每一个类别创建一组问答
for sampled_cls in sampled_classes:
text = sampled_cls
assert len(text.split("||")) == 1
question_template = random.choice(self.short_question_list) # 按照模板构建问题
questions.append(question_template.format(class_name=text.lower()))
answers.append(random.choice(self.answer_list)) # 随机选择答案模板
if ds in ["paco_lvis", "pascal_part"]: # 这两个数据集是single class 特殊处理
continue
class_id = self.data2classes[ds].tolist().index(sampled_cls)
class_ids.append(class_id)
# 转换为标准的prompt 即 system: A chat... Human: XXX Assistant: XXX
conversations = []
conv = conversation_lib.default_conversation.copy()
i = 0
while i < len(questions):
conv.messages = []
conv.append_message(conv.roles[0], questions[i])
conv.append_message(conv.roles[1], answers[i])
conversations.append(conv.get_prompt())
i += 1
随后读取label中的mask, 这部分省略, 返回值是如下的格式, 其余数据集也遵循:
return (
image_path, # 图像路径
image, # 用于SAM的resize图像 应该是1024x1024
image_clip, # 用于CLIP的resize图像 应该是224x224
conversations, # 真值conversation
masks, # 真值masks shape: [n, h, w] n是对应的物体类别数 [h, w]是原大小
label, # [h, w], 原始分割标签
resize, # [1024, 1024]
questions, # 问题
sampled_classes, # 抽取的类别名称 list
)
2.2.3 指令分割
基本流程和语义分割是相似的, 只不过class name需要从annotation的referring中读出来:
img2refs = refer_seg_ds["img2refs"]
refs = img2refs[image_id] # 得到图像对应的referrings
# 读取对应所有referring的文本 当然后面要根据args.num_classes_per_sample作筛选
sents = []
ann_ids = []
for ref in refs:
for sent in ref["sentences"]:
text = sent["sent"]
sents.append(text)
ann_ids.append(ref["ann_id"])
# 因此
sampled_classes = sampled_sents
# 后面读取mask也类似, 要根据抽取出的referring找到对应的mask
2.2.4 VQA
VQA比较特殊. 直接从数据集中读取数据即可构建conversation, 对于mask和label, 则将mask置为全0, label都置为ignore_label(255):
conversations = []
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]] # 直接从数据集读取
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
questions = conversations
sampled_classes = conversations
image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
masks = torch.rand(0, *ori_size) # 全0
label = torch.ones(ori_size) * self.ignore_label # 全255
2.2.5 ReasonSeg
ReasonSeg最大的不同就是要处理长问话以及解释性的问话,
首先读取当前随机抽取样本的mask, 问话以及是否为一个句子:
mask, sents, is_sentence = get_mask_from_json(json_path, image)
随后看是否为需要解释的样本, 如果是的话就构建对应的问话. 其中的choice
是在这种情况下, 进一步控制是解释性问题的比例, 其实是让解释性问题的占比进一步降低了.
if is_sentence:
question_template = random.choice(self.long_question_list)
questions.append(question_template.format(sent=text))
else:
question_template = random.choice(self.short_question_list)
questions.append(question_template.format(class_name=text.lower()))
# add explanation if applicable
img_name = image_path.split("/")[-1]
if self.explanatory != -1 and img_name in self.img_to_explanation:
if choice == 0: # [SEG] token # 最简单的回答
answers.append(random.choice(self.answer_list))
elif choice == 1: # [SEG] token + text answer # 否则加入解释性的提问
image_name = image_path.split("/")[-1]
answer = self.img_to_explanation[image_name]["outputs"]
answer = random.choice(self.answer_list) + " {}".format(answer)
questions[-1] = (
DEFAULT_IMAGE_TOKEN
+ "\n"
+ text
+ " {}".format(random.choice(self.explanatory_question_list))
)
answers.append(answer)
elif choice == 2: # vanilla text answer # 不加入
image_name = image_path.split("/")[-1]
answer = self.img_to_explanation[image_name]["outputs"]
questions[-1] = DEFAULT_IMAGE_TOKEN + "\n" + text
answers.append(answer)
else:
raise ValueError("Not implemented yet.")
else:
answers.append(random.choice(self.answer_list))
2.2 单步训练流程与损失计算
2.2.1 模型载入
模型载入和推理过程基本是相似的. 但是训练过程中需要用LoRA来微调LLaVA (要训练生成固定的答案模板), 具体做法如下:
lora_r = args.lora_r # 降到秩为多少 默认为8
if lora_r > 0:
# 查找模型中符合条件的线性层 并保存下来
def find_linear_layers(model, lora_target_modules):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
and all(
[
x not in name
for x in [
"visual_model",
"vision_tower",
"mm_projector",
"text_hidden_fcs",
]
]
)
and any([x in name for x in lora_target_modules])
):
lora_module_names.add(name)
return sorted(list(lora_module_names))
lora_alpha = args.lora_alpha # lora超参
lora_dropout = args.lora_dropout # lora超参
lora_target_modules = find_linear_layers(
model, args.lora_target_modules.split(",")
)
# 配置config 利用peft库实现lora
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
# 根据lora要改变的线性层更新模型
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
2.2.2 前向传播
在每次迭代中, 首先用SAM的image encoder得到图像特征, 以及, 得到[SEG]
在tokenize的输入中的位置:
image_embeddings = self.get_visual_embs(images) # [bs, c, h, w]
batch_size = image_embeddings.shape[0]
assert batch_size == len(offset) - 1
seg_token_mask = input_ids[:, 1:] == self.seg_token_idx # [bs, N]
seg_token_mask = torch.cat(
[
seg_token_mask,
torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(), # [bs, 1]
],
dim=1,
) # [bs, N + 1]
# hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
# 补齐255个0的理由和推理节介绍的相同
seg_token_mask = torch.cat(
[torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(), seg_token_mask],
dim=1,
)
然后输入LLaVA:
images_clip_list = []
for i in range(len(offset) - 1):
start_i, end_i = offset[i], offset[i + 1] # 该样本具有多少annotation, 即问答对的起始和最终的idx
images_clip_i = (
images_clip[i]
.unsqueeze(0)
.expand(end_i - start_i, -1, -1, -1) # 就重复这么多遍
.contiguous()
)
images_clip_list.append(images_clip_i)
images_clip = torch.cat(images_clip_list, dim=0)
output = super().forward( # 得到LLaVA的结果
images=images_clip,
attention_mask=attention_masks,
input_ids=input_ids,
labels=labels,
output_hidden_states=True,
)
output_hidden_states = output.hidden_states
得到最后一层的各个batch中[SEG]
的embedding:
hidden_states = []
assert len(self.model.text_hidden_fcs) == 1
hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1])) # 输入FC层对齐SAM和LLavA
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
pred_embeddings = last_hidden_state[seg_token_mask]
seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ]
# 得到每个batch中seg token的起始位置 并获得相应的embeddings
seg_token_offset = seg_token_counts.cumsum(-1)
seg_token_offset = torch.cat(
[torch.zeros(1).long().cuda(), seg_token_offset], dim=0
)
seg_token_offset = seg_token_offset[offset]
pred_embeddings_ = []
for i in range(len(seg_token_offset) - 1):
start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
pred_embeddings_.append(pred_embeddings[start_i:end_i])
pred_embeddings = pred_embeddings_
遍历每个embeddings, 用SAM的decoder得到mask, 这部分和推理过程相似, 不再赘述. 最后计算loss. 一个是训练LLaVA用的交叉熵损失, 用于文本输出和模板一致; 另外就是分割常用的bce和dice loss. 注意对于VQA任务的样本, mask loss理应是0.