用ResNet50+Qwen2-VL-2B-Instruct+LoRA模仿Diffusion-VLA的论文思路,在3090显卡上训练和测试成功

news2025/1/8 15:02:35

想一步步的实现Diffusion VLA论文的思路,不过论文的图像的输入用DINOv2进行特征提取的,我先把这个部分换成ResNet50。

老铁们,直接上代码:

from PIL import Image
import torch
import torchvision.models as models
from torch import nn
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.transformers import SwanLabCallback
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
    Qwen2VLForConditionalGeneration,
    AutoProcessor,
)
import swanlab
import json
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.models as models

class CustomResNet(nn.Module):
    def __init__(self, output_size=(256, 1176)):
        super(CustomResNet, self).__init__()
        
        # 预训练的 ResNet 模型
        resnet = models.resnet50(pretrained=True)
        
        # 去掉 ResNet 的最后全连接层和池化层
        self.features = nn.Sequential(*list(resnet.children())[:-2])  # 去掉最后的FC层和AvgPool层
        
        # 自定义的卷积层,调整步幅和padding来控制尺寸
        self.conv1 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1)  # 保持大小
        self.conv2 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1)  # 保持大小
        self.conv3 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1)  # 保持大小
        
        # 上采样层,用于增加特征图的尺寸
        self.upconv1 = nn.ConvTranspose2d(2048, 2048, kernel_size=4, stride=4, padding=0)  # 上采样
        self.upconv2 = nn.ConvTranspose2d(2048, 2048, kernel_size=4, stride=4, padding=0)  # 上采样
        
        # 最终卷积层将特征图变为单通道输出(灰度图)
        self.final_conv = nn.Conv2d(2048, 1, kernel_size=1)  # 输出单通道

    def forward(self, x):
        # 获取ResNet的特征图
        x = self.features(x)
        
        # 经过卷积层
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        # 上采样阶段:增加特征图的尺寸
        x = self.upconv1(x)  # 上采样1
        x = self.upconv2(x)  # 上采样2
        
        # 使用插值进行微调输出尺寸
        x = F.interpolate(x, size=(256, 1176), mode='bilinear', align_corners=False)
        
        # 通过最后的卷积层输出(单通道)
        x = self.final_conv(x)  # 通过最后的卷积层输出
        
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# 创建模型并移动到设备上
model_ResNet = CustomResNet(output_size=(256, 1176)).to(device)

# 定义图像预处理过程
image_transform = transforms.Compose([
    transforms.Resize((800, 800)),  # 确保图像大小一致(通常为224x224)
    transforms.ToTensor(),  # 转换为Tensor并标准化
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])

def extract_resnet_features(image_path):
    """
    使用ResNet提取图像特征
    """
    image = Image.open(image_path).convert("RGB")  # 加载图像并转换为RGB
    image_tensor = image_transform(image).unsqueeze(0).to('cuda')  # 添加batch维度并转换为cuda Tensor
    # features = resnet_extractor(image_tensor)  # 从ResNet提取特征    
    features = model_ResNet(image_tensor)

    return features

def process_func(example):
    """
    将数据集进行预处理,加入ResNet特征提取
    """
    MAX_LENGTH = 8192
    input_ids, attention_mask, labels = [], [], []
    conversation = example["conversations"]
    input_content = conversation[0]["value"]
    output_content = conversation[1]["value"]
    file_path = input_content.split("<|vision_start|>")[1].split("<|vision_end|>")[0]  # 获取图像路径
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": f"{file_path}",
                    "resized_height": 224,  # 确保图像尺寸为224x224
                    "resized_width": 224,
                },
                {"type": "text", "text": "COCO Yes:"},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )  # 获取文本
    image_inputs, video_inputs = process_vision_info(messages)  # 获取数据数据(预处理过)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )

    # print("inputs['pixel_values'] shape: ", inputs['pixel_values'].shape)

    # 提取图像特征
    image_tensor = extract_resnet_features(file_path)  # 从图像路径提取特征
    # print("image_tensor shape: ", image_tensor.shape)
    inputs['pixel_values'] = image_tensor[0,0,:,:]  # 替换图像特征为ResNet特征

    inputs = {key: value.tolist() for key, value in inputs.items()}  # tensor -> list,为了方便拼接
    instruction = inputs

    response = tokenizer(f"{output_content}", add_special_tokens=False)


    input_ids = (
            instruction["input_ids"][0] + response["input_ids"] + [tokenizer.pad_token_id]
    )

    attention_mask = instruction["attention_mask"][0] + response["attention_mask"] + [1]
    labels = (
            [-100] * len(instruction["input_ids"][0])
            + response["input_ids"]
            + [tokenizer.pad_token_id]
    )
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]

    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    labels = torch.tensor(labels)
    inputs['pixel_values'] = torch.tensor(inputs['pixel_values'])
    inputs['image_grid_thw'] = torch.tensor(inputs['image_grid_thw']).squeeze(0)  # 由(1,h,w)变换为(h,w)
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels,
            "pixel_values": inputs['pixel_values'], "image_grid_thw": inputs['image_grid_thw']}


def predict(messages, model):
    # 准备推理
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # 生成输出
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    return output_text[0]


# 在modelscope上下载Qwen2-VL模型到本地目录下
model_dir = snapshot_download("Qwen/Qwen2-VL-2B-Instruct", cache_dir="./", revision="master")

# 使用Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", use_fast=False, trust_remote_code=True)
processor = AutoProcessor.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct")

# 加载模型
model = Qwen2VLForConditionalGeneration.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True,)
model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法
model.config.use_cache = False

# 处理数据集:读取json文件
# 拆分成训练集和测试集,保存为data_vl_train.json和data_vl_test.json
train_json_path = "data_vl.json"
with open(train_json_path, 'r') as f:
    data = json.load(f)
    train_data = data[:-4]
    test_data = data[-4:]

with open("data_vl_train.json", "w") as f:
    json.dump(train_data, f)

with open("data_vl_test.json", "w") as f:
    json.dump(test_data, f)

train_ds = Dataset.from_json("data_vl_train.json")
train_dataset = train_ds.map(process_func)

# 配置LoRA
config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False,  # 训练模式
    r=4, #64,  # Lora 秩
    lora_alpha= 1, #16,  # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.05,  # Dropout 比例
    bias="none",
)

# 获取LoRA模型
peft_model = get_peft_model(model, config)

# 配置训练参数
args = TrainingArguments(
    output_dir="./output/Qwen2-VL-2B",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    logging_steps=10,
    logging_first_step=5,
    num_train_epochs=2,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True,
    report_to="none",
)

# 配置Trainer
trainer = Trainer(
    model=peft_model,
    args=args,
    train_dataset=train_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)

# 开启模型训练
trainer.train()


# ====================测试模式===================
# 配置测试参数
val_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=True,  # 训练模式
    r=4,#64,  # Lora 秩
    lora_alpha=1,#16,  # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.05,  # Dropout 比例
    bias="none",
)

# 获取测试模型
val_peft_model = PeftModel.from_pretrained(model, model_id="./output/Qwen2-VL-2B/checkpoint-992", config=val_config)

# 读取测试数据
with open("data_vl_test.json", "r") as f:
    test_dataset = json.load(f)

test_image_list = []
for item in test_dataset:
    input_image_prompt = item["conversations"][0]["value"]
    # 去掉前后的<|vision_start|>和<|vision_end|>
    origin_image_path = input_image_prompt.split("<|vision_start|>")[1].split("<|vision_end|>")[0]
    
    messages = [{
        "role": "user", 
        "content": [
            {
            "type": "image", 
            "image": origin_image_path
            },
            {
            "type": "text",
            "text": "COCO Yes:"
            }
        ]}]
    
    response = predict(messages, val_peft_model)
    messages.append({"role": "assistant", "content": f"{response}"})
    print(messages[-1])

    test_image_list.append(swanlab.Image(origin_image_path, caption=response))

我在3090显卡(24G显存)运行的结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2272616.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

常见中间件漏洞(tomcat,weblogic,jboss,apache)

先准备好今天要使用的木马文件 使用哥斯拉生成木马 压缩成zip文件 改名为war后缀 一&#xff1a;Tomcat 1.1CVE-2017-12615 环境搭建 cd vulhub-master/tomcat/CVE-2017-12615 docker-compose up -d 1.首页抓包&#xff0c;修改为 PUT 方式提交 发送shell.jsp 和木马内容 …

嵌入式科普(26)为什么heap通常8字节对齐

目录 一、概述 二、newlibc heap 2.1 stm32cubeide .ld heap 2.2 e2studio .ld heap 三、glibc源码 3.1 Ubuntu c heap 四、总结 一、概述 结论&#xff1a;在嵌入式c语言中&#xff0c;heap通常8字节对齐 本文主要分析这个问题的分析过程 二、newlibc heap newlibc…

大数据-240 离线数仓 - 广告业务 测试 ADS层数据加载 DataX数据导出到 MySQL

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; Java篇开始了&#xff01; 目前开始更新 MyBatis&#xff0c;一起深入浅出&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff0…

CTF杂项——[NSSRound#4 SWPU]Pixel_Signin

得到一个像素点 提取像素点 脚本 from PIL import Image result open(1.txt,w) img Image.open("Pixel_Signin.png") img img.convert(RGB) for i in range(31):for j in range(31):r,g,b img.getpixel((j,i))print(r,g,b,end ,fileresult) 运行之后得到 把它…

Harmony开发【笔记1】报错解决(字段名写错了。。)

在利用axios从网络接收请求时&#xff0c;发现返回obj的code为“-1”&#xff0c;非常不解&#xff0c;利用console.log测试&#xff0c;更加不解&#xff0c;可知抛出错误是 “ E 其他错误: userName required”。但是我在测试时&#xff0c;它并没有体现为空&#xff0c;…

springCloudGateWay使用总结

1、什么是网关 功能: ①身份认证、权限验证 ②服务器路由、负载均衡 ③请求限流 2、gateway搭建 2.1、创建一个空项目 2.2、引入依赖 2.3、加配置 3、断言工厂 4、过滤工厂 5、全局过滤器 6、跨域问题

韩国机场WebGIS可视化集合Google遥感影像分析

目录 前言 一、相关基础数据介绍 1、韩国的机场信息 2、空间数据准备 二、Leaflet叠加Google地图 1、叠加google地图 2、空间点的标记及展示 3、韩国机场空间分布 三、相关成果展示 1、务安国际机场 2、有同类问题的机场 四、总结 前言 12月29日8时57分左右务安国际机…

电子电气架构 --- 设计车载充电机的关键考虑因素

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 简单,单纯,喜欢独处,独来独往,不易合同频过着接地气的生活,除了生存温饱问题之外,没有什么过多的欲望,表面看起来很高冷,内心热情,如果你身…

python进阶06:MySQL

课后大总结 Day1 一、数据库命令总结 1.连接数据库 连接数据库进入mysql安装目录打开bin文件夹&#xff0c;输入cmd(此命令后无分号)mysql.exe -u root -ppassword命令后输入密码:root 设置密码set passwordpassword("root123"); 查看所有数据库show databases; …

php反序列化原生态 ctfshow练习 字符串逃逸

web262 拿着题审计一下 <?php error_reporting(0); class message{public $from;public $msg;public $to;public $tokenuser;public function __construct($f,$m,$t){$this->from $f;$this->msg $m;$this->to $t;} }$f $_GET[f]; $m $_GET[m]; $t $_GET[t…

探秘前沿科技:RFID 与 NFC,开启智能识别新篇

RFID&#xff08;射频识别&#xff09;与NFC&#xff08;近场通信&#xff09;作为两种基于射频技术的无线通信方式&#xff0c;在现代社会中发挥着越来越重要的作用。尽管它们都具备非接触式识别和通信的能力&#xff0c;但在工作原理、应用场景、技术细节等方面存在着显著的差…

【04】优雅草央千澈详解关于APP签名以及分发-上架完整流程-第四篇安卓APP上架之vivo商店-小米商店,oppo商店,应用宝

【04】优雅草央千澈详解关于APP签名以及分发-上架完整流程-第四篇安卓APP上架之vivo商店-小米商店&#xff0c;oppo商店&#xff0c;应用宝 背景介绍 接第三篇上架华为&#xff0c;由于华为商店较为细致&#xff0c;本篇幅介绍其他4类商店相对简要一点&#xff0c;剩下其他更…

OpenCV计算机视觉 06 图像轮廓检测(轮廓的查找、绘制、特征、近似及轮廓的最小外接圆外接矩形)

目录 图像轮廓检测 轮廓的查找 轮廓的绘制 轮廓的特征 面积 周长 根据面积显示特定轮廓 轮廓的近似 给定轮廓的最小外接圆、外接矩形 外接圆 外接矩形 图像轮廓检测 轮廓的查找 API函数 image, contours, hierarchy cv2.findContours(img, mode, method) 代入参…

ROS2 跨机话题通信问题(同一个校园网账号)

文章目录 写在前面的话校园网模式&#xff08;失败&#xff09;手机热点模式&#xff08;成功&#xff09; 我的实验细节实验验证1、ssh 用户名IP地址 终端控制2、互相 ping 通 IP3、ros2 run turtlesim turtlesim_node/turtle_teleop_key4、ros2 multicast send/receive5、从机…

web3与AI结合-Sahara AI 项目介绍

背景介绍 Sahara AI 于 2023 年创立&#xff0c;是一个 "区块链AI" 领域的项目。其项目愿景是&#xff0c;利用区块链和隐私技术将现有的 AI 商业模式去中心化&#xff0c;打造公平、透明、低门槛的 “协作 AI 经济” 体系&#xff0c;旨在重构新的利益分配机制以及…

【C++】你了解异常的用法吗?

文章目录 Ⅰ. C语言传统的处理错误的方式Ⅱ. C异常概念Ⅲ. 异常的使用1、异常的抛出和匹配原则2、在函数调用链中异常栈展开匹配原则3、异常的重新抛出4、异常安全5、异常规范 Ⅳ. 自定义异常体系Ⅴ. C标准库的异常体系Ⅵ. 异常的优缺点1、异常的优点2、异常的缺点3、总结 Ⅰ. …

Matlab仿真径向受压圆盘光弹图像

Matlab仿真径向受压圆盘光弹图像-十步相移法 主要参数 % 定义圆盘参数 R 15; % 圆盘半径&#xff0c;单位&#xff1a;mm h 5; % 圆盘厚度&#xff0c;单位&#xff1a;mm P 300; % 径向受压载荷大小&#xff0c;单位&#xff…

游戏引擎学习第75天

仓库:https://gitee.com/mrxiao_com/2d_game_2 Blackboard: 处理楼梯通行 为了实现楼梯的平滑过渡和角色的移动控制&#xff0c;需要对楼梯区域的碰撞与玩家的运动方式进行优化。具体的处理方式和遇到的问题如下&#xff1a; 楼梯区域的过渡&#xff1a; 在三维空间中&#x…

算法的学习笔记—不用常规控制语句求 1 到 n 的和

&#x1f600;前言 在算法编程中&#xff0c;有时我们会遇到一些特殊的限制条件&#xff0c;这些限制会迫使我们跳出常规思维。本文讨论的问题就是一个典型案例&#xff1a;在不能使用基本控制语句的情况下&#xff0c;如何求解 1 到 n 的和。这个问题不仅考验编程技巧&#xf…

网络协议安全的攻击手法

1.使用SYN Flood泛洪攻击&#xff1a; SYN Flood(半开放攻击)是最经典的ddos攻击之一&#xff0c;他利用了TCP协议的三次握手机制&#xff0c;攻击者通常利用工具或控制僵尸主机向服务器发送海量的变源端口的TCP SYN报文&#xff0c;服务器响应了这些报文后就会生成大量的半连…