基于huggingface加载openai/clip-vit-large-patch14-336视觉模型demo

news2024/12/28 5:14:40

文章目录

  • 引言
  • 一、模型加载
  • 二、huggingface梯度更新使用
  • 三、图像处理
  • 四、模型推理
  • 五、整体代码
  • 总结


引言

本文介绍如何使用huggingface加载视觉模型openai/clip-vit-large-patch14-336,我之所以记录此方法源于现有大模型基本采用huggingface库来加载视觉模型和大语言模型,我也是在做LLava模型等模型。基于此,本节将介绍如何huggingface如何加载vit视觉模型。

一、模型加载

使用huggingface模型加载是非常简单,其代码如下:

from transformers import CLIPVisionModel, CLIPImageProcessor


if __name__ == '__main__':
    vit_path='D:/clip-vit-large-patch14-336'
    img_path='dogs.jpg'

    image_processor = CLIPImageProcessor.from_pretrained(vit_path)  # 加载图像预处理
    vision_tower = CLIPVisionModel.from_pretrained(vit_path)  # 加载图像模型
    vision_tower.requires_grad_(False)  # 模型冻结

    for name, param in vision_tower.named_parameters():
        print(name, param.requires_grad)

二、huggingface梯度更新使用

一般视觉模型需要冻结,使用lora训练,那么我们需要如何关闭视觉模型梯度。为此,我继续探讨梯度设置方法,其代码如下:

    vision_tower.requires_grad_(False)  # 模型冻结

    for name, param in vision_tower.named_parameters():
        print(name, param.requires_grad)

以上代码第一句话是视觉模型梯度冻结方法,下面2句是验证梯度是否冻结。如果设置’‘vision_tower.requires_grad_(False)’'表示冻结梯度,如果不设置表示需要梯度传播。我将不在介绍了,若你想详细了解,只要执行以上
代码便可知晓。

三、图像处理

在输入模型前,我们需要对图像进行预处理,然huggingface也很人性的自带了对应视觉模型的图像处理,我们只需使用PIL实现图像处理,其代码如下:

    image = Image.open(img_path).convert('RGB')  # PIL读取图像
    def expand2square(pil_img, background_color):
        width, height = pil_img.size  # 获得图像宽高
        if width == height:  # 相等直接返回不用重搞
            return pil_img
        elif width > height:  # w大构建w尺寸图
            result = Image.new(pil_img.mode, (width, width), background_color)
            result.paste(pil_img, (0, (width - height) // 2))  # w最大,以坐标x=0,y=(width - height) // 2位置粘贴原图
            return result
        else:
            result = Image.new(pil_img.mode, (height, height), background_color)
            result.paste(pil_img, ((height - width) // 2, 0))
            return result


    image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
    image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

四、模型推理

最后就是模型推理,其代码如下:

    image_forward_out = vision_tower(image.unsqueeze(0),    output_hidden_states=True)

    feature = image_forward_out['last_hidden_state']

    print(feature.shape)

但是,我想说输出包含很多内容,其中hidden_states有25个列表,表示每个block输出结果,而hidden_states的最后一层与last_hidden_states值相同,结果分别如下图:

在这里插入图片描述

hidden_states与last_hidden_states对比如下:
在这里插入图片描述

五、整体代码

from transformers import CLIPVisionModel, CLIPImageProcessor

from PIL import Image


if __name__ == '__main__':
    vit_path='E:/clip-vit-large-patch14-336'
    img_path='dogs.jpg'

    image_processor = CLIPImageProcessor.from_pretrained(vit_path)  # 加载图像预处理
    vision_tower = CLIPVisionModel.from_pretrained(vit_path)  # 加载图像模型
    vision_tower.requires_grad_(False)  # 模型冻结

    for name, param in vision_tower.named_parameters():
        print(name, param.requires_grad)

    image = Image.open(img_path).convert('RGB')  # PIL读取图像

    def expand2square(pil_img, background_color):
        width, height = pil_img.size  # 获得图像宽高
        if width == height:  # 相等直接返回不用重搞
            return pil_img
        elif width > height:  # w大构建w尺寸图
            result = Image.new(pil_img.mode, (width, width), background_color)
            result.paste(pil_img, (0, (width - height) // 2))  # w最大,以坐标x=0,y=(width - height) // 2位置粘贴原图
            return result
        else:
            result = Image.new(pil_img.mode, (height, height), background_color)
            result.paste(pil_img, ((height - width) // 2, 0))
            return result


    image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
    image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

    image_forward_out = vision_tower(image.unsqueeze(0),    output_hidden_states=True)

    feature = image_forward_out['last_hidden_state']

    print(feature.shape)

总结

本文是一个huggingface加载视觉模型的方法,另一个重点是梯度冻结。然而,我只代表VIT模型是如此使用,其它模型还未验证,不做任何说明。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1472119.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【生成式AI】ChatGPT 原理解析(2/3)- 预训练 Pre-train

Hung-yi Lee 课件整理 预训练得到的模型我们叫自监督学习模型(Self-supervised Learning),也叫基石模型(foundation modle)。 文章目录 机器是怎么学习的ChatGPT里面的监督学习GPT-2GPT-3和GPT-3.5GPTChatGPT支持多语言…

nginx(二)

nginx的验证模块 输入用户名和密码 第一步先下载httpd 这个安装包 第二步编辑子配置文件 然后去网页访问192.168.68.3/admin/ 连接之后,会出现404,404出现是因为没给网页写页面 如果要写页面,则在/opt/html,建立一个admin&#x…

算法沉淀——动态规划之路径问题(leetcode真题剖析)

算法沉淀——动态规划之路径问题 01.不同路径02.不同路径 II03.珠宝的最高价值04.下降路径最小和05.最小路径和06.地下城游戏 01.不同路径 题目链接:https://leetcode.cn/problems/unique-paths/ 一个机器人位于一个 m x n 网格的左上角 (起始点在下图…

探索C语言位段的秘密

位段 1. 什么是位段2. 位段的内存分配3. 位段的跨平台问题4. 位段的应用4. 使用位段的注意事项 1. 什么是位段 我们使用结构体实现位段,位段的声明和结构体是类似的,有两个不同: 位段的成员必须是int,unsigned int,或…

IO进程线程作业day7

信号灯集共享内存 自定义头文件 #ifndef SEM_H_ #define SEM_H_ //创建信号灯集, int creat_t(int number); //申请释放资源 int P(int semid,int semno); //申请释放资源 int V(int semid,int semno); //删除信号灯集 int del(int semid); #endif信号灯集函数集合 #include…

多维时序 | Matlab实现GRU-MATT门控循环单元融合多头注意力多变量时间序列预测模型

多维时序 | Matlab实现GRU-MATT门控循环单元融合多头注意力多变量时间序列预测模型 目录 多维时序 | Matlab实现GRU-MATT门控循环单元融合多头注意力多变量时间序列预测模型预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.多维时序 | Matlab实现GRU-MATT门控循环单元融…

wcf 简单实践 数据绑定 数据更新ui

1.概要 2.代码 2.1 xaml <Window x:Class"WpfApp3.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"http://schemas.microsoft.com/expr…

《数据治理简易速速上手小册》第2章 数据治理框架的建立(2024 最新版)

文章目录 2.1 确立数据治理目标2.1.1 基础知识2.1.2 重点案例&#xff1a;提升零售业数据质量2.1.3 拓展案例 1&#xff1a;银行的数据安全加强2.1.4 拓展案例 2&#xff1a;医疗机构的合规性数据报告 2.2 数据治理框架的关键要素2.2.1 基础知识2.2.2 重点案例&#xff1a;制药…

VM-UNet:视觉Mamba UNet用来医学图像分割 论文及代码解读

VM-UNet 期刊分析摘要贡献方法整体框架1. Vision Mamba UNet (VM-UNet)2. VSS block3. Loss function 实验1.对比实验2.消融实验 可借鉴参考代码使用 期刊分析 期刊名&#xff1a; Arxiv 论文主页&#xff1a; VM-UNet 代码&#xff1a; Code 摘要 在医学图像分割领域&#x…

JANGOW: 1.0.1

kali:192.168.223.128 主机发现 nmap -sP 192.168.223.0/24 端口扫描 nmap -p- 192.168.223.154 开启了21 80端口 web看一下&#xff0c;有个busque.php参数是buscar,但是不知道输入什么&#xff0c;尝试文件包含失败 扫描目录 dirsearch -u http://192.168.223.154 dirse…

mysql 分表实战

本文主要介绍基于range分区的相关 1、业务需求&#xff0c;每日160w数据&#xff0c;每月2000w;解决大表数据读写性能问题。 2、数据库mysql 8.0.34&#xff0c;默认innerDB;mysql自带的逻辑分表 3、分表的目的:解决大表性能差&#xff0c;小表缩小查询单位的特点(其实优化的精…

人事|人事管理系统|基于Springboot的人事管理系统设计与实现(源码+数据库+文档)

人事管理系统目录 目录 基于Springboot的人事管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、管理员登录 2、员工管理 3、公告信息管理 4、公告类型管理 5、培训管理 6、培训类型管理 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、…

2024年环境安全科学、材料工程与制造国际学术会议(ESSMEM2024)

【EI检索】2024年环境安全科学、材料工程与制造国际学术会议&#xff08;ESSMEM2024) 会议简介 我们很高兴邀请您参加将在三亚举行的2024年环境安全科学、材料工程和制造国际学术会议&#xff08;ESSMEM 2024&#xff09;。 ESSMEM2024将汇集世界各国和地区的研究人员&…

CrackRTF1

由此可知为六位密码且<100000自动退出 发现是哈希加密&#xff0c;经过网上查找后了解到 0x8004u 为 SHA1 加密&#xff0c;故从网上寻找脚本并进行爆破 密码为123321 找到一个取巧的网站 MD5免费在线解密破解_MD5在线加密-SOMD5 Flag{N0_M0re_Free_Bugs}

Mysql索引讲解及创建

索引介绍 1.什么是索引&#xff1f; 索引是存储引擎中一种数据结构&#xff0c;或者说数据的组织方式&#xff0c;又称之为键key&#xff0c;是存储引擎用于快速找到记录的 一种数据结构。 为数据建立索引就好比是为书建目录&#xff0c;或者说是为字典创建音序表&#xff0c;…

视频号视频下载教程:如何把微信视频号的视频下载下来

视频号下载相信不少人都多少有一些了解&#xff0c;但今天我们就来细说一下关于视频号视频下载的相关疑问&#xff0c;以及大家经常会问到底如何把微信视频号的视频下载下来&#xff1f; 视频号视频下载教程 视频号链接提取器详细使用指南&#xff0c;教你轻松下载号视频&…

LabVIEW燃料电池船舶电力推进监控系统

LabVIEW燃料电池船舶电力推进监控系统 随着全球经济一体化的推进&#xff0c;航运业的发展显得尤为重要&#xff0c;大约80%的世界贸易依靠海上运输实现。传统的船舶推进系统主要依赖于柴油机&#xff0c;这不仅耗能高&#xff0c;而且排放严重&#xff0c;对资源和环境的影响…

uniapp播放mp4省流方案

背景&#xff1a; 因为项目要播放一个宣传和讲解视频&#xff0c;视频文件过大&#xff0c;同时还为了节省存储流量&#xff0c;想到了一个方案&#xff0c;用m3u8切片替代mp4。 m3u8&#xff1a;切片播放&#xff0c;可以理解为一个1G的视频文件&#xff0c;自行设置文…

Antd 嵌套子表格 defaultExpandAllRows 默认展开不生效

问题描述 在使用 antd 嵌套子表格时&#xff0c;想要默认展开所有子列表&#xff0c;设置属性:defaultExpandAllRows“true”&#xff0c;但是子列表没有展开 原因分析 defaultExpandAllRows 属性官网定义&#xff1a; 从官网定义可知&#xff0c;defaultExpandAllRows 属性…

P2680 [NOIP2015 提高组] 运输计划 第一个测试点信息 || 被卡常,链式前向星应该解决的是vector的push_back频繁扩容的耗时

[NOIP2015 提高组] 运输计划 - 洛谷 目录 测试点信息 P2680_1.in P2680_1.out 图&#xff1a; 50分参考代码&#xff08;开了n^2的数组&#xff0c;MLE了&#xff09;&#xff1a; 测试点信息 Subtask #0 #1 P2680_1.in 100 1 7 97 4 89 2 0 40 91 1 70 84 1 36 92 3 …