深度学习_GPT2Block详解(casual attention)

news2024/11/8 14:39:17

一、GTP2Block 整体结构

1.1 block准备

import torch 
from torch import nn
from transformers import GPT2Model, GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Block

cfg = GPT2Config()
print(cfg.add_cross_attention)
blk = GPT2Block(cfg, layer_idx=0)
hidden_states = torch.randn(10, 1024, 768)

1.2 block架构

经典的preNorm TFDecoder架构

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D()
    (c_proj): Conv1D()
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D()
    (c_proj): Conv1D()
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

1.3 forward-preNorm

y = attn(ln_1(x)) + x
O = mlp(ln_2(y)) + y

在这里插入图片描述
在这里插入图片描述

二、GPT2Attention

  1. hidden 拆分成 q k v: query, key, value = gpt2_att.c_attn(hidden_states).split(split_size, dim=2)
  2. q k v 拆分成多头
query = gpt2_att._split_heads(query, gpt2_att.num_heads, gpt2_att.head_dim)
key = gpt2_att._split_heads(key, gpt2_att.num_heads, gpt2_att.head_dim)
value = gpt2_att._split_heads(value, gpt2_att.num_heads, gpt2_att.head_dim)
print(f'{query.shape=}') # [batch, n_head, len, head_emb] 
  1. 计算attention
    1. A ^ = Q K T K d i m \hat{A}=\frac{QK^T}{\sqrt{K_{dim}}} A^=Kdim QKT 代码中用的是 V d i m \sqrt{V_{dim}} Vdim
    2. casual attention: 对原始attn进行mask
    3. 计算mask后的attention: A = s o f t m a x ( A ^ , d i m = − 1 ) A=softmax(\hat{A}, dim=-1) A=softmax(A^,dim=1)
    4. O = A V O=AV O=AV
# 3- attention 
#  3.1 A = QK^T
attn_weights = torch.matmul(query, key.transpose(-1, -2)) / torch.full([], value.size(-1) ** 0.5)
#  3.2 mask 
max_positions = 1024
causal_mask = torch.tril(
    torch.ones((max_positions, max_positions), dtype=torch.bool)
).view(1, 1, max_positions, max_positions)
mask_value = torch.finfo(attn_weights.dtype).min
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
# where mask
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
#  3.3 A = softmax(A)
attn_weights = nn.functional.softmax(attn_weights, dim=-1) # [batch, n_head, len, len] 
#  3.4  O = AV
attn_output = torch.matmul(attn_weights, value)            # [batch, n_head, len, head_emb] 
# 4- q k v -> merge head -> attn_out # [batch, len, head_emb*n_head] 
attn_output = gpt2_att._merge_heads(attn_output, gpt2_att.num_heads, gpt2_att.head_dim)
  1. 多头合并 [batch, n_head, len, head_emb] =>> [batch, len, head_emb*n_head]
    1. attn_output = gpt2_att._merge_heads(attn_output, gpt2_att.num_heads, gpt2_att.head_dim)

pic-attn_weights mask前后

三、GPT2MLP

结构比较简单 O = d r o p O u t ( σ ( X W 1 ) W 2 ) O=dropOut(\sigma (XW_1)W_2) O=dropOut(σ(XW1)W2),主要是激活函数 NewGELU

GPT2MLP(
  (c_fc): Conv1D()
  (c_proj): Conv1D()
  (act): NewGELUActivation()
  (dropout): Dropout(p=0.1, inplace=False)
)

class NewGELUActivation(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """

    def forward(self, input: Tensor) -> Tensor:
        return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))

在这里插入图片描述

NewGELUActivation 它是高斯误差线性单元(Gaussian Error Linear Unit,简称 GELU)的一种变体。GELU 激活函数在近年来的深度学习模型中越来越受欢迎,尤其是在自然语言处理(NLP)领域,如 BERT 和 GPT 等模型中。

GELU 激活函数的数学定义是输入值 x 乘以标准正态分布的累积分布函数(CDF)在该点的值。具体来说,GELU 的表达式为:
G E L U ( x ) = x Φ ( x ) GELU(x)=x \Phi(x) GELU(x)=xΦ(x)

其中 Φ ( x ) \Phi(x) Φ(x) 是标准正态分布的 CDF,可以通过误差函数(error function,记为 erf)来计算:
Φ ( x ) = 1 2 ( 1 + e r f ( x 2 ) ) \Phi(x)=\frac{1}{2}(1+erf(\frac{x}{\sqrt 2})) Φ(x)=21(1+erf(2 x))
GPT2中用了近似公式:
σ ( x ) = 0.5 x [ 1 + t a n h ( 2 π ( x + 0.044715 x 3 ) ) ] \sigma(x) = 0.5x [1+ tanh(\sqrt{\frac{2}{\pi}} (x + 0.044715 x^3))] σ(x)=0.5x[1+tanh(π2 (x+0.044715x3))]

GELU 激活函数的优点包括:

  • 平滑性:GELU 在整个实数域上都是平滑的,这有助于梯度的传播,减少了梯度消失或爆炸的问题
  • 非单调性:GELU 函数是非单调的,这意味着它能够捕捉数据中的更复杂模式
  • 改善性能:在某些任务中,使用 GELU 激活函数的模型性能优于使用传统的 ReLU 或其他激活函数的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2129921.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

“汉语新解” Prompt新高度,火爆的李继刚

“汉语新解” prompt 是由李继刚设计的一个用于启发人工智能模型进行创意性文本生成的指令模板。这个 prompt 的设计初衷是为了让AI能够以一种独特的方式解析和重新诠释常见的中文词汇,从而产生出具有深刻洞察力和幽默感的文本内容,仿佛是由鲁迅或林语堂…

Linux线程同步:深度解析条件变量接口

🍑个人主页:Jupiter. 🚀 所属专栏:Linux从入门到进阶 欢迎大家点赞收藏评论😊 目录 🍑Linux线程同步🐉条件变量---实现线程同步💧同步概念与竞态条件🐆条件变量接口*初始…

sqli-labs靶场自动化利用工具——第13关

文章目录 概要整体架构流程技术细节执行效果小结 概要 Sqli-Labs靶场对于网安专业的学生或正在学习网安的朋友来说并不陌生,或者说已经很熟悉。那有没有朋友想过自己开发一个测试脚本能实现自动化化测试sqli-labs呢?可能有些人会说不是有sqlmap&#…

每日OJ_牛客_马戏团(模拟最长上升子序列)

目录 牛客_马戏团(模拟最长上升子序列) 解析代码 牛客_马戏团(模拟最长上升子序列) 马戏团__牛客网 搜狐员工小王最近利用假期在外地旅游,在某个小镇碰到一个马戏团表演,精彩的表演结束后发现团长正和大…

《基于深度半监督学习的目标检测综述》泛读

基于深度半监督学习的目标检测方法分为 1、生成式方法 2、一致性正则化方法 3、基于图的方法 4、伪标记方法和混合方法 然后基于常用数据集 对典型方法进行了性能对比,最后分析了其挑战和发展趋势,旨在为相关研究提供参考 收获就是: 1…

Redis -- 全记录(面试)

目录 All : 缓存穿透 缓存击穿 互斥锁 逻辑过期 比较 : 缓存雪崩 redis怎么和数据库保持一致 双写一致性 : 延迟双删 : 保证强一致性 : 允许一定的延迟 基于mq的异步通知 基于Canal的异步通知 总结 Redis的持久化 RDB AOF 总结 Redis数据过期策略 惰性删除…

【算法专题】搜索算法

二叉树剪枝 LCR 047. 二叉树剪枝 - 力扣(LeetCode) 本题要求我们将全部为0的二叉树去掉,也就是剪枝,当我们举一个具体的例子进行模拟时,会发现,只关注于对其中一个子树的根节点进行剪枝,由于我…

Docker部署MySQL8.0.39报错解决方案

Docker部署MySQL8.0.39报错解决方案 2024-09-11T06:09:09.317582Z 0 [Warning] [MY-010139] [Server] Changed limits: max_open_files: 1024 (requested 8161) 2024-09-11T06:09:09.317586Z 0 [Warning] [MY-010142] [Server] Changed limits: table_open_cache: 431 (reques…

李彦宏内部讲话曝光,谈大模型三大认知误区:智能体还是非共识

“外界对大模型有相当多的误解,”近日据媒体报道,李彦宏的一则内部讲话曝光。在最近一次和员工交流中,李彦宏谈及三个大模型认知误区,涵盖了大模型竞争、开源模型效率、智能体趋势等热点话题。 李彦宏认为未来大模型之间的差距可…

【Axure教程】高级搜索

高级搜索可以通过使用精确的关键词或短语,帮助用户找到特定的内容。尤其在面对大量搜索结果时,通过过滤条件缩小范围,能够节省时间。他允许用户使用多个条件进行组合(例如条件匹配、模糊搜索、区间筛选等)来精准获取相…

购物车装载状态检测系统源码分享

购物车装载状态检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Comput…

瑞幸卖奶茶,霸王茶姬不慌

瑞幸和霸王茶姬,打不起来。 转载定焦(dingjiaoone)原创 作者 | 苏琦 编辑 | 魏佳 最近,瑞幸因为联名游戏大作《黑神话:悟空》无法核销套餐被骂上热搜,但业内人士更关注的,是它不久前推出的轻乳…

【网络安全】-rce漏洞-pikachu

rce漏洞包含命令执行漏洞与代码执行漏洞 文章目录 前言 什么是rce漏洞? 1.rce漏洞产生原因: 2.rce的分类: 命令执行漏洞: 命令拼接符: 常用函数: 代码执行漏洞: 常用函数: 分类&…

AI算力池化技术助力运营商打造智算生态

数字经济时代,算力已成为国民经济发展的重要基础设施。随着数字化转型的不断深入和人工智能技术的广泛应用,构建以新型智算中心为核心的智能算力生态体系正驱动着数字经济快速发展,成为人工智能赋能千行百业的重中之重。 2022年2月&#xff…

Vulnhub-RickdiculouslyEasy靶场(9个flag)

flag1 端口9090有一个flag flag2 13337端口 flag3 使用dirb进行扫描网站的80端口,发现一些敏感文件 访问80端口,没有发现有效信息 访问passwords目录 访问FLAG.txt 再返回访问passwords.html文件 查看页面源代码发现一个密码 flag4 之前扫描到了robo…

书接上文,介绍下Quartz Java体系结构

体系结构总结 JobDetail 我们创建一个实现 Job 接口的类,使用 JobBuilder 包装成 JobDetail,它可以携带 KV 的数据。 Trigger 定义任务的触发规律,Trigger,使用 TriggerBuilder 来构建。JobDetail 跟 Trigger 是 1:N 的关系。思…

智慧物流系统小程序的设计

管理员账户功能包括:系统首页,个人中心,车辆管理,商品管理,物流信息管理,论坛管理,公告信息管理 微信端账号功能包括:系统首页,商品,论坛,我的 …

磁盘无法访问:深度解析与高效数据恢复策略

在数字化时代,磁盘作为数据存储的核心载体,其稳定性和可访问性直接关系到用户数据的安全与完整性。然而,当遇到“磁盘无法访问”的突发状况时,用户往往会陷入焦虑与无助之中。本文将深入探讨磁盘无法访问的原因,并详细…

最全面IO流介绍

1.字符集介绍 标准ASCII字符集:使用1个字节存储一个字符,首尾是0,总可以表示128个字符。是美国信息交换标准代码,包含英文、符号等等。 GBK汉字编码字符集,包含2万多个汉字等字符,GBK中一个中文字符编码成…

(Java企业 / 公司项目)点赞业务系统设计-批量查询点赞状态(二)

接着上一篇文章来搞,批量查询点赞状态。这个接口提供给其他的微服务调用所以这里会用到FeignClient 直接上接口 1. 接口信息 这里是查询多个业务的点赞状态,因此请求参数自然是业务id的集合。由于是查询当前用戶的点赞状态,因此无需传递用戶信息。当前用户指的是登录用户 …