手搓多模态-03 顶层和嵌入层的搭建

news2025/4/8 2:18:49
声明:本代码非原创,是博主跟着国外大佬的视频教程编写的,本博客主要为记录学习成果所用。

我们首先开始编写视觉模型这一部分这一部分主要功能接收一个batch图像将其转化上下文相关嵌入向量这一阶段我们需要做的事情以下这些

  • 编写一个全局通用视觉配置
  • 编写用户模型调用
  • 输入图像嵌入向量
  • 通过transformer编码图像嵌入向量进行编码使其上下文相关

我们一个个实现这些代码

视觉模型配置

视觉模型配置主要如下

class SiglipVisionConfig:

	def __init__(
			self,
			hidden_size=768,
			num_hidden_layers=12,
			num_attention_heads=12,
			intermediate_size=3072,
			num_channels=3,
			image_size=224,
			patch_size=16,
			attention_dropout=0.0,
			layer_norm_eps=1e-6,
			num_image_tokens: int = None,
			**kwargs
		):
		super().__init__(**kwargs)
		self.hidden_size = hidden_size ## embedding 的维度
		self.num_hidden_layers = num_hidden_layers ## 隐藏层的数量
		self.num_attention_heads = num_attention_heads ## 注意力头数量
		self.intermediate_size = intermediate_size ## 线性层的维度 
		self.num_channels = num_channels ##图像的RGB通道
		self.image_size = image_size ## 图像尺寸,任何图像size都会被缩放到这个尺寸
		self.patch_size = patch_size ## 每个patch的尺寸
		self.attention_dropout = attention_dropout ## 注意力层dropout
		self.layer_norm_eps = layer_norm_eps ## 层归一化epsilon
		self.num_image_tokens = num_image_tokens ## 图像token数量,它实际上是一个固定值

为了各个变量涵义更加浅显易懂博主增加中文注释

顶层模型

随后用户模型用户只需要一个batch图像传入调用forward函数即可返回这些图像上下文相关embeddings

代码如下

class SiglipVisionModel(nn.Module): ## 最顶层的视觉模型,它负责顶层的传入和编码的输出
	def __init__(self, config:SiglipVisionConfig):
		super().__init__()
		self.config = config
		self.vision_model = SiglipVisionTransformer(config)

	def forward(self, pixel_values) -> Tuple:
		# [Batch_size,Channels,Height,Width] ->  [Batch_size,Num_Patches(num_image_token),Embedding_size(Hidden_size)]
		return self.vision_model(pixel_values = pixel_values)	

其中输入形状 [ Batch_size, Channels, Height, Width ],对应一个图像batch各自RGB通道像素输出 [ Batch_size, Num_Patches, Embedding_size ], 对应各个图像每个分割Patch嵌入结果

模型拆分

我们需要内部一次视觉模型调用拆分两个模型各自的调用也就是拆分嵌入模型Transformer编码这里我们创建一个SiglipVisionTransformer将其分成两个模型调用

代码如下

class SiglipVisionTransformer(nn.Module): ##视觉模型的第二层,将模型的调用分为了图像嵌入模型和transformer编码器模型的调用
	def __init__(self, config:SiglipVisionConfig):
		super().__init__()
		self.config = config
		self.embed_dim = config.hidden_size
		self.embeddings = SiglipVisionEmbeddings(config) ## 负责将图像嵌入成向量
		self.encoder = SiglipEncoder(config) ## 负责将向量编码成注意力相关的向量
		self.post_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) ## 层归一化

	def forward(self, pixel_values:torch.Tensor) -> torch.Tensor:
		"""
		pixel_values: [Batch_size,Channels,Height,Width]
		"""
		## [ Batch_size,Channels,Height,Width] -> [Batch_size,Num_Patches,Embedding_size] 
		hidden_states = self.embeddings(pixel_values) ## 将图像嵌入成向量

		# [Batch_size,Num_Patches,Embedding_size] -> [Batch_size,Num_Patches,Embedding_size]
		last_hidden_state = self.encoder(hidden_states) ## 将向量编码成注意力相关的向量

		# [Batch_size,Num_Patches,Embedding_size] -> [Batch_size,Num_Patches,Embedding_size]
		last_hidden_state = self.post_layer_norm(last_hidden_state)

		return last_hidden_state

嵌入模型

嵌入模型初始图像像素初步转换patch编码向量list, 同时阶段我们使用位置编码位置编码形式多种这里我们采用自学习嵌入向量每个位置创建一个可以学习参数向量形成位置矩阵使用时候根据indices位置矩阵抽取对应位置向量即可

代码

class SiglipVisionEmbeddings(nn.Module):

	def __init__(self, config:SiglipVisionConfig):
		self.config = config
		self.patch_size = config.patch_size
		self.image_size = config.image_size
		self.embed_dim = config.hidden_size

		self.patch_embedding = nn.Conv2d(
			in_channels = config.num_channels,
			out_channels = self.embed_dim,
			kernel_size = self.patch_size,
			stride = self.patch_size,
			padding = 'valid', ##不加padding
		)

		self.num_patches = (self.image_size // self.patch_size) ** 2 ## 图像的patch数量 (224 // 16) ** 2 = 196
		self.num_positions = self.num_patches

		self.position_embeddings = nn.Embedding(self.num_positions, self.embed_dim)

		self.register_buffer(
			"position_ids",
			torch.arange(self.num_positions).expand((1, -1)), ## 这里expand是为了保持和patch_embeds的维度一致,以便可以直接与之相加
			persistent=False,
		)

	def forward(self, pixel_values:torch.FloatTensor) -> torch.Tensor:
		"""
		pixel_values: [Batch_size,Channels,Height,Width]
		"""
		_ , _ , height, width = pixel_values.shape

		## 卷积,3通道转embedding_size通道
		patch_embeds = torch.FloatTensor(self.patch_embedding(pixel_values)) ## [Batch_size,Channel,Height,Width] -> [Batch_size,Embedding_size,Num_Patches_Height,Num_Patches_Width]

		## flatten
		patch_embeds = patch_embeds.flatten(2) # [Batch_size,Embedding_size,Num_Patches_Height,Num_Patches_Width] -> [Batch_size,Embedding_size,Num_Patches]

		## transpose
		patch_embeds = patch_embeds.transpose(1,2) ## [Batch_size,Embedding_size,Num_Patches] -> [Batch_size,Num_Patches,Embedding_size] 

		## positon_encoding
		patch_embeds = patch_embeds + self.position_embeddings(self.position_ids) ## [Batch_size,Num_Patches,Embedding_size]  自学习的位置编码

		return patch_embeds

上面卷积配置表示我们希望卷积结果patch_size * embedding_size维度为了方便大家理解

这里简单介绍一下torch卷积

pytorch2D卷积

卷积通过卷积图像特征提取出来卷积操作可以如图所示

卷积操作本质输入区域展平向量同时卷积核展平向量做一次内积得到输出位置

torch卷积公式:

这里N_i 第i个batchCout_j 是指j输出通道输出星号代表二维区域weight权重input区域做一次卷积

这里可以看到多出一个通道概念其实对于图像来说输入通道就是RGB通道输出通道你希望一个卷积的图像区域多少特征

用图展示如下

这里彩色方块是1*1的卷积核我们希望一个三个输入通道输入卷积得到三个输出通道输出这样对于每个通道conv2D都会为其生成三个卷积核每个通道卷积结果卷积核顺序对应相加比如第一个输出通道的结果等于三个输入通道各自第一个卷积核卷积结果进行相加得到

由此再来这个公式

j输出通道结果等于所有输入通道j输出通道卷积卷积结果相加加上一个偏置矩阵得到

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2330164.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【经验分享】将qt的ui文件转换为py文件

🌟 嗨,我是命运之光! 🌍 2024,每日百字,记录时光,感谢有你一路同行。 🚀 携手启航,探索未知,激发潜能,每一步都意义非凡。 首先简单的设计一个U…

探秘JVM内部

在我们编写Java代码,点击运行后,会发生什么事呢? 首先,Java源代码会经过Java编译器将其编译成字节码,放在.class文件中 然后这些字节码文件就会被加载到jvm中,然后jvm会读取这些文件,调用相关…

在HarmonyOS NEXT 开发中,如何指定一个号码,拉起系统拨号页面

大家好,我是 V 哥。 《鸿蒙 HarmonyOS 开发之路 卷1 ArkTS篇》已经出版上市了哈,有需要的朋友可以关注一下,卷2应用开发篇也马上要出版了,V 哥正在紧锣密鼓的写鸿蒙开发实战卷3的教材,卷3主要以项目实战为主&#xff0…

利用空间-运动-回波稀疏性进行5D图像重建,以实现自由呼吸状态下肝脏定量磁共振成像(MRI)的加速采集|文献速递--深度学习医疗AI最新文献

Title 题目 5D image reconstruction exploiting space-motion-echo sparsity foraccelerated free-breathing quantitative liver MRI 利用空间-运动-回波稀疏性进行5D图像重建,以实现自由呼吸状态下肝脏定量磁共振成像(MRI)的加速采集 …

ZKmall开源商城B2B2C电商用户隐私信息保护策略:数据脱敏全链路实践

随着业务的不断拓展和用户规模的持续扩大,用户隐私信息的保护也面临着前所未有的挑战。下面将深入探讨ZKmall开源商城在数据脱敏方面的实践,以及针对B2B2C电商用户隐私信息的具体保护策略。 数据脱敏,又称数据去标识化或数据匿名化&#xff0…

Pgvector的安装

Pgvector的安装 向量化数据的存储,可以为 PostgreSQL 安装 vector 扩展来存储向量化数据 注意:在安装vector扩展之前,请先安装Postgres数据库 vector 扩展的步骤 1、下载vs_BuildTools 下载地址: https://visualstudio.microso…

Django接入 免费的 AI 大模型——讯飞星火(2025年4月最新!!!)

上文有介绍deepseek接入,但是需要 付费,虽然 sliconflow 可以白嫖 token,但是毕竟是有限的,本文将介绍一款完全免费的 API——讯飞星火 目录 接入讯飞星火(免费) 测试对话 接入Django 扩展建议 接入讯飞星火…

路由器学习

路由器原理 可以理解成把不同的网络打通,实现通信的设备。比如家里的路由器,他就是把家里的内网和互联网(外网)打通。 分类 1.(按应用场景分类) 路由器分为家用的,企业级的,运营…

UE5学习记录part14

第17节 enemy behavior 173 making enemies move: AI Pawn Navigation 按P查看体积 So its very important that our nav mesh bounds volume encompasses all of the area that wed like our 因此,我们的导航网格边界体积必须包含我们希望 AI to navigate in and …

Docker的备份与恢复

一、两种基本方式 docker export / import 在服务器上导出容器docker export container_name > container_backup.tar这里使用 > 重定向时默认保存路径为当前运行命令的路径,可以自行指定绝对路径来保存,后续加载时也使用对应的路径即可。 恢复为…

DAPP实战篇:规划下我们的开发线路

前言 在DApp实战篇:先用前端起个项目一文中我们起了一个前端项目,在后续开发中笔者将带领大家一步步完成这个DAPP,为了方便后续讲解,本篇将完整说明后续我们要进行的开发和思路。 主打前端 实际上一个完整的DAPP是由前端和智能…

【Elasticsearch】开启大数据分析的探索与预处理之旅

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/literature?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,…

状态机思想编程练习

状态机实现LED流水灯 本次实验,我们将利用状态机的思想来进行Verilog编程实现一个LED流水灯,并通过Modelsim来进行模拟仿真,再到DE2-115开发板上进行验证。 ​ 首先进行主要代码的编写。 module led (input sys_clk,input sys_…

前端新增数据,但数据库里没有新增的数据

先看情况: 1.前端,可以进行删查改,但是新增数据之后,显示保存成功,也增加了空白的一行,但是数据没有显示出来。 2.后端接收到了数据,但返回结果的列表里面是空的;同时数据库里面没…

httpx模块的使用

在使用requests模块发起请求时,报以下错误,表示服务器有可能使用的是http2.0协议版本,导致requests无法爬取。 此时就可以使用httpx模块爬取。 先下载httpx模块: pip install httpx[http2]然后用httpx发起请求: impo…

论文阅读10——解开碳排放与碳足迹之间的关系:文献回顾和可持续交通框架

原文地址: Unraveling the relation between carbon emission and carbon footprint: A literature review and framework for sustainable transportation | npj Sustainable Mobility and TransportTransportation decarbonization has drawn enormous attention globally,…

新一代AI架构实践:数字大脑AI+智能调度MCP+领域执行APP的黄金金字塔体系

新一代AI架构实践:数字大脑智能调度领域执行的黄金金字塔体系 一、架构本质的三层穿透性认知 1.1 核心范式转变(CPS理论升级) 传统算法架构:数据驱动 → 特征工程 → 模型训练 → 业务应用 新一代AI架构:物理规律建…

Winform MQTT客户端连接方式

项目中使用到Winform的数据转发服务,所以记录下使用到的方法。 一.创建单例模板 using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace ConsoleApp.Scripts {public class SingleTon&…

Linux Bash 脚本实战:自动监控域名证书过期并发送邮件告警

在日常运维工作中,SSL 证书的管理是一个非常重要的环节,尤其对于线上业务来说,证书到期会直接导致服务不可用。为了避免证书到期带来的风险,我们可以编写一个 Bash 脚本来自动检测域名的 SSL 证书过期时间,并在证书即将到期时发送告警邮件。 目录 脚本功能概述 代码实现…

【模型量化】GPTQ 与 AutoGPTQ

GPTQ是一种用于类GPT线性最小二乘法的量化方法,它使用基于近似二阶信息的一次加权量化。 本文中也展示了如何使用量化模型以及如何量化自己的模型AutoGPTQ。 AutoGPTQ:一个易于使用的LLM量化包,带有用户友好的API,基于GPTQ算法(仅…