LoRA 引领多模态模型革命,大模型的微调方案

news2024/9/21 22:52:47


基于LoRA微调多模态大模型

d0165a19269636e3b2c942a72b2fc319.jpeg  

随着 ChatGPT 的火爆,大模型时代降临,普通人难以进行全量微调。

参数高效微调技术应运而生,为科研人员和开发者提供了微调大模型的机会。

用 LoRA 微调 2.7B 参数的 blip2-opt 模型,提升图生文能力,带来高效的视觉语言对齐。

数据集和模型准备


该虚拟数据集提供 6 位足球运动员的图像,附有详细文字说明,可用于微调图像描述模型。
访问数据集:huggingface.co/datasets/ybelkada/football-dataset

72fc28761147fb2587ae09ffd1af7f57.jpeg

由 OPT-2.7B 训练的 BLIP-2 模型,包含三个强大组件:
* 视觉 Transformer:提取图像特征
* 语言模型:生成丰富描述
* 联合嵌入器:关联视觉和语言
此模型已在 Hugging Face 上提供,可通过以下链接下载:/huggingface.co/Salesforce/blip2-opt-2.7b

BLIP-2 简介

BLIP-2 是一种多模态 AI 模型,凭借预训练优势,在视觉和语言任务上表现卓越。
该模型的架构包括一个图像编码器(提取视觉特征),一个大型语言模型(生成语言)和一个可学习的 Q-Former(融合视觉和语言表征)。
BLIP-2 融合了视觉和语言理解,提供强大的多模态能力。它的预训练机制降低了训练成本,提升了模型效果,使其在各种任务中大显身手。

8cb0e085e8d52d39fb9786be0c511d8b.jpeg

  • Image Encoder:负责从输入图片中提取视觉特征。
  • Large Language Model:负责文本生成。
  • Q-Former巧妙地融合了视觉和语言模态,通过共享自注意力层,Image Transformer 和 Text Transformer两个子模块实现了跨模态交互。这种创新架构有效缩小了两种模态之间的鸿沟,实现了跨模态理解的显著提升。
    • 借助图像编码器,Image Transformer 提取视觉特征。可学习 Query 通过自注意力层交互,并通过交叉注意力层与冻结图像特征融合。它还可以通过共享的自注意力层与文本交互,实现图像和文本的深度融合,从而提升图像理解能力。
    • Text Transformer兼具文本编码和解码功能,其自注意力层与Image Transformer共享。根据预训练任务的不同,采用不同的自注意力掩码,精准控制Query与文本的交互方式。

1dc1b669d6b4056e05f0ea30aa79210c.jpeg


BLIP-2 使用了一种两阶段预训练方法,称为 Q-Former,以克服冻结预训练模型导致视觉和语言特征不一致的问题。在表示学习阶段,模型学习对齐视觉和文本特征的表示。接续的生成学习阶段利用这些表示来生成文本描述,从而加强跨模态对齐。

表示学习阶段

Q-Former模型创新性地将冻结的图像编码器连接到学习阶段,利用图像-文本对进行训练。通过优化预训练目标,该模型采用差异化的注意力掩码策略,控制图像和文本 Transformer 之间的交互。这种方法使 Q-Former 能够更有效地捕捉图像和文本之间的联系,从而在学习阶段取得卓越的性能。

生成学习阶段

Q-Former 通过连接到 LLM,利用其语言生成能力。预训练阶段中,全连接层将视觉表示投影到与 LLM 文本嵌入相同的维度,并将其添加到输入文本嵌入中。Q-Former 的预训练使其能够过滤视觉信息,提取与语言相关的关键特征,充当信息瓶颈。这简化了 LLM 学习视觉语言对齐,提高了模型效率。

39dfab43e72057e6b91f7e439dfc8e06.jpeg

先预先准备Processor、模型和图像输入。

from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
) # doctest: +IGNORE_RESULT

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

对于图像描述生成任务示例如下:

inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

# two cats laying on a couch

对于视觉问答任务(VQA)示例如下:

prompt = "Question: how many cats are there? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float16)

generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)
# two

LoRA 简介

LoRA 技术巧妙利用低秩分解模拟模型参数变更,以极少参数量间接训练大模型。详细技术原理和实战教程请参考过往文章。

  • 大模型参数高效微调技术原理综述(五)-LoRA、AdaLoRA、QLoRA
  • 大模型参数高效微调技术实战(五)-LoRA

模型微调

优化代码,改进精度!
我们的微调代码现可从 GitHub 的 llm-action 项目获取,具体位于 blip2_lora_int8_fine_tune.py 文件中。关键步骤包括:
- 将模型精度调整为 8 位整数。
- 微调模型以提升性能。
访问 GitHub 即可获取详细代码并提升您的模型表现。

第一步,加载预训练Blip-2模型以及processor。

from transformers import AutoModelForVision2Seq, AutoProcessor

# We load our model and processor using `transformers`
model = AutoModelForVision2Seq.from_pretrained(pretrain_model_path, load_in_8bit=True)
processor = AutoProcessor.from_pretrained(pretrain_model_path)

通过 LoRA 微调策略创建自定义配置,并利用 get_peft_model 方法扩展基础 Transformer 模型,提升微调性能,量化模型参数,优化模型大小。

from peft import LoraConfig, get_peft_model

# Let's define the LoraConfig
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
)

# Get our peft model and print the number of trainable parameters
model = get_peft_model(model, config)
model.print_trainable_parameters()

第三步,进行模型微调。

# 设置优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.train()
for epoch in range(11):
print("Epoch:", epoch)
for idx, batch in enumerate(train_dataloader):
input_ids = batch.pop("input_ids").to(device)
pixel_values = batch.pop("pixel_values").to(device, torch.float16)

outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)
loss = outputs.loss
print("Loss:", loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()

if idx % 10 == 0:
# 根据图像生成文本
generated_output = model.generate(pixel_values=pixel_values)
# 解码
print(processor.batch_decode(generated_output, skip_special_tokens=True))

最后,保存训练的Adapter模型权重及配置文件。

model.save_pretrained(peft_model_id)

模型推理

使用 LLM-Action,轻松实现图生文!
只需运行 CUDA_VISIBLE_DEVICES=0 python blip2_lora_inference.py,即可使用先进的 LLM-Action 模型进行图生文,为您提供准确而有创意的文本描述。
代码详情请访问 GitHub 上的 llm-action 项目,了解 blip2_lora_inference.py 文件中的详细说明。

结语

基于 LoRA 微调,优化 BLIP-2 多模态大模型的文本生成能力,提升文本流畅性和信息丰富度,助力文本创作提质增效。

 

-对此,您有什么看法见解?-

-欢迎在评论区留言探讨和分享。-

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1631090.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式学习——C语言基础——day11

1. 字符型数组和字符串的传参 1.1 常量和变量的区别&#xff08;难点&#xff09; 一般常量不能被修改&#xff0c;变量才能被修改 #include <stdio.h> int main(void) { char str[] {"hello world"};//定义数组&#xff0c;数组名为指针常量 char …

数据结构和算法:贪心

贪心算法 贪心算法是一种常见的解决优化问题的算法&#xff0c;其基本思想是在问题的每个决策阶段&#xff0c;都选择当前看起来最优的选择&#xff0c;即贪心地做出局部最优的决策&#xff0c;以期获得全局最优解。 贪心算法和动态规划都常用于解决优化问题。它们之间存在一…

干货整理:好用的文件加密软件有哪些

说到文件加密&#xff0c;想必大家都很熟悉&#xff0c;文件加密已经普遍应用&#xff0c;文件加密是一种重要的安全措施&#xff0c;可以确保数据的机密性、完整性和可用性&#xff0c;降低因数据泄露或丢失带来的风险。 下面小编给大家分享几款常用的加密软件&#xff0c;大…

纯血鸿蒙APP实战开发——评论组件案例实现

介绍 评论组件在目前市面上的短视频app中是一种很常见的场景&#xff0c;本案例使用全局状态保留能力弹窗来实现评论组件。点击评论按钮弹出评论组件&#xff0c;点击空白处隐藏该组件&#xff0c;再次点击评论按钮则会恢复上一次浏览的组件状态。 效果图预览 使用说明 点击…

DDP示例

https://zhuanlan.zhihu.com/p/602305591 https://zhuanlan.zhihu.com/p/178402798 关于模型保存与加载 &#xff1a; 其实分为保存 有module和无module2种 &#xff1b; &#xff08;上面知乎这篇文章说带时带module) 关于2种带与不带的说明&#xff1a; https://blog.csdn.…

Oracle中rman使用记录

最近在项目中&#xff0c;遇到使用RMAN的操作来恢复数据库中某个时间归档日志&#xff0c;RMAN的原理和理解&#xff0c;网友们百度了解一下。我重点将实操部分了。直接上实验环节&#xff0c;让网友更懂。&#xff08;特别提醒&#xff1a;我是1:1用VMware克隆数据库进行RMAN还…

构建高效智能的理赔业务系统:保险科技的未来

随着保险行业的发展和科技的不断进步&#xff0c;理赔业务作为保险服务的重要环节&#xff0c;也在不断演进和改进。传统的理赔流程可能存在效率低下、信息不透明等问题&#xff0c;而现代化的理赔业务系统则能够通过数字化、智能化等手段提升理赔服务的质量和效率&#xff0c;…

Java集成结巴中文分词器、Springboot项目整合jieba分词,实现语句最精确的切分、自定义拆词

文章目录 一、jieba介绍二、集成三、原理四、自定义拆词4.1、方式一&#xff1a;在源码的dict.txt中修改然后重新打包(推荐)4.2、新建文件自定义拆词 五、其他问题 一、jieba介绍 jieba是一个分词器&#xff0c;可以实现智能拆词&#xff0c;最早是提供了python包&#xff0c;…

Qt | 窗口的显示及可见性|标题、透明度、启用/禁用|窗口标志、设置其他属性|获取窗口部件、设置父部件|鼠标光标

​显示事件:QEvent::show,处理函数为 showEvent(QShowEvent*) 隐藏事件:QEvent::hide,处理函数为 hideEvent(QHideEvent* ) 01 QWidget 类中与可见性有关的属性 visible:bool 访问函数: bool isVisible() const; virtual void setVisible(bool visible); 02 QWid…

高频面试题:在浏览器搜索框中输入一个URL的完整请求过程?

相信很多小伙伴在校招或者社招面试中都遇到过这个问题 面试官&#xff1a;小伙子&#xff0c;了解 在浏览器搜索框中输入一个URL的完整请求过程吗&#xff1f;详细说说我&#xff1a;eeemm&#xff0c;不太清出具体的过程。整体过程应该是HTTP请求的过程。 如果在面试中不能很…

【C++】---STL容器适配器之底层deque浅析

【C】---STL容器适配器之底层deque浅析 一、deque的使用二、deque的原理1、deque的结构2、deque的底层结构&#xff08;1&#xff09;deque的底层空间&#xff08;2&#xff09;deque如何支持随机访问、deque迭代器 3、deque的优缺点&#xff08;1&#xff09;deque的优势&…

【golang学习之旅】报错:a declared but not used

目录 报错原因解决方法参考 报错 代码很简单&#xff0c;如下所示。可以发现a和b都飙红了&#xff1a; 运行后就会出现报错&#xff1a; 报错翻译过来就是a已经声明但未使用。当时我很疑惑&#xff0c;在其他语言中从来没有这种情况。况且这里的b不是赋值了吗&#xff0c;怎…

Sarcasm detection论文解析 | 通过阅读进行讽刺推理-Reasoning with sarcasm by reading in-between

论文地址 论文地址&#xff1a;[1805.02856] Reasoning with Sarcasm by Reading In-between (arxiv.org) 论文首页 笔记大纲 通过阅读进行讽刺推理论文笔记 &#x1f4c5;出版年份:2018&#x1f4d6;出版期刊:&#x1f4c8;影响因子:&#x1f9d1;文章作者:Tay Yi,Luu Anh…

制作一个RISC-V的操作系统十六-系统调用

文章目录 用户态和内核态mstatus设置模式切换核心流程封装代码背景解释代码示例解析解释目的 用户态和内核态 mstatus设置 此时UIE设置为1和MPIE为1&#xff0c;MPP设置为0 代表当前权限允许UIE中断发生&#xff0c;并且在第一个mret后将权限恢复为用户态&#xff0c;同时MIE也…

17 大数据定制篇-shell编程

第 17 章大数据定制篇-Shell 编程 17.1 为什么要学习 Shell 编程 Linux 运维工程师在进行服务器集群管理时&#xff0c;需要编写 Shell 程序来进行服务器管理。 对于 JavaEE 和 Python 程序员来说&#xff0c;工作的需要&#xff0c;你的老大会要求你编写一些 Shell 脚本进行…

ERP系统和SRM系统有什么关系?

一、什么是ERP系统和SRM系统&#xff1f; ERP系统是一种集成化的管理软件&#xff0c;能够帮助企业实现资源的优化配置&#xff0c;提高运营效率。ERP系统涵盖了企业的各个方面&#xff0c;包括财务、采购、库存、生产、销售、人力资源等&#xff0c;通过对这些方面的管理&…

MMSeg搭建自己的网络

配置结构 首先&#xff0c;我们知道MMSeg矿机的配置文件很多&#xff0c;主要结构如下图所示。 在configs/_base_下是模型配置、数据集配置、以及一些其他的常规配置和运行配置&#xff0c;四类。 configs/all_config目录下存放&#xff0c;即是将四种配置聚合在一起的一个总…

Android优化RecyclerView图片展示:Glide成堆加载批量Bitmap在RecyclerView成片绘制Canvas,Kotlin(b)

Android优化RecyclerView图片展示&#xff1a;Glide成堆加载批量Bitmap在RecyclerView成片绘制Canvas&#xff0c;Kotlin&#xff08;b&#xff09; 对 Android GridLayoutManager Glide批量加载Bitmap绘制Canvas画在RecyclerView&#xff0c;Kotlin&#xff08;a&#xff09;-…

【调研分析】目标在不同焦距和距离下与画面的比例(2.8-3.6-4.0)

之前在做项目中需要极度优化效果和代码运行速度 为此测试了同一个目标在不同焦距和距离下与画面的比例&#xff0c;从而可以方便在指定大小情况下搜索目标 NOTE: 这是早期滑窗检测做目标检测下的工作

分布式与一致性协议之Raft算法(一)

Raft算法 概述 Raft算法属于Multi-Paxos算法&#xff0c;它在兰伯特Multi-Paxos思想的基础上做了一些简化和限制&#xff0c;比如日志必须是连续的&#xff0c;只支持领导者(Leader)、跟随者(Follwer)和候选人(Candidate)3种状态。在理解和算法实现上&#xff0c;Raft算法相对…