【bug】Transformer输出张量的值全部相同?!

news2024/9/29 18:14:01

【bug】Transformer输出张量的值全部相同?!

  • 现象
  • 原因
  • 解决

现象

输入经过TransformerEncoderLayer之后,基本所有输出都相同了。
核心代码如下,

from torch.nn import TransformerEncoderLayer
self.trans = TransformerEncoderLayer(d_model=2,
								     nhead=2,
								     batch_first=True,
								     norm_first=True)
...
x = torch.randn(2, 8, 2)
print("x before transformer", x, x.shape)
x = self.trans(x)		# Transformer Encoder Layers
print("x after transformer", x, x.shape)

输出:

x before transformer tensor([[[ 0.2244, -1.9497],
         [ 0.4710, -0.7532],
         [-1.4016,  0.5266],
         [-1.1386, -2.5170],
         [-0.0733,  0.0240],
         [-0.9647, -0.9760],
         [ 2.4195, -0.0135],
         [-0.3929,  1.2231]],

        [[ 0.1451, -1.2050],
         [-1.1139, -1.7213],
         [ 0.5105,  0.4111],
         [ 2.1308,  2.5476],
         [ 1.2611, -0.7307],
         [-2.0910,  0.1941],
         [-0.3903,  1.3022],
         [-0.2442,  0.5787]]]) torch.Size([2, 8, 2])
x after transformer tensor([[[ 1.0000, -1.0000],
         [ 1.0000, -1.0000],
         [-1.0000,  1.0000],
         [ 1.0000, -1.0000],
         [-1.0000,  1.0000],
         [ 1.0000, -1.0000],
         [ 1.0000, -1.0000],
         [-1.0000,  1.0000]],

        [[ 1.0000, -1.0000],
         [ 1.0000, -1.0000],
         [ 1.0000, -1.0000],
         [-1.0000,  1.0000],
         [ 1.0000, -1.0000],
         [-1.0000,  1.0000],
         [-1.0000,  1.0000],
         [-1.0000,  1.0000]]], grad_fn=<NativeLayerNormBackward0>) torch.Size([2, 8, 2])

原因

在询问过全知全能的New Bing之后,找到一篇文章。

简化Transformer模型训练技术简介

Understand the difficulty of training transformer
时间:2020
引用:124
期刊会议:EMNLP 2020
代码:https://github.com/LiyuanLucasLiu/Transformer-Clinic

在这里插入图片描述

Transformer的Layer Norm的位置很关键。

如果我们使用Post-LN,模型可能对参数不稳定,导致训练的失败。 而Pre-LN却不会。

原始Transformer论文中为Post-LN。一般来说,Post-LN会比Pre-LN的效果好。

针对这点,Understand the difficulty of training transformer文中提出使用Admin初始化。在训练稳定的前提下,拥有Post-LN的性能。

在这里插入图片描述

解决

这里我们使用Pre-LN。

torch.nn.TransformerEncodelayer就提供了norm_frist的选项。

self.trans = TransformerEncoderLayer(d_model=2,
								     nhead=2,
								     batch_first=True,
								     norm_first=True)

修改后,输出:

x before transformer tensor([[[ 0.5373,  0.9244],
         [ 0.6239, -1.0643],
         [-0.5129, -1.1713],
         [ 0.5635, -0.7778],
         [ 0.4507, -0.0937],
         [ 0.2720,  0.7870],
         [-0.5518,  0.8583],
         [ 1.5244,  0.5447]],

        [[ 0.3450, -1.9995],
         [ 0.0530, -0.9778],
         [ 0.8687, -0.6834],
         [-1.6290,  1.6586],
         [ 1.2630,  0.4155],
         [-2.0108,  0.9131],
         [-0.0511, -0.8622],
         [ 1.5726, -0.7042]]]) torch.Size([2, 8, 2])
x after transformer tensor([[[ 0.5587,  0.9392],
         [ 0.5943, -1.0631],
         [-0.5196, -1.1681],
         [ 0.5635, -0.7765],
         [ 0.4341, -0.0819],
         [ 0.2943,  0.7998],
         [-0.5329,  0.8661],
         [ 1.5166,  0.5528]],

        [[ 0.3450, -1.9860],
         [ 0.0273, -0.9603],
         [ 0.8415, -0.6682],
         [-1.6297,  1.6686],
         [ 1.2261,  0.4175],
         [-2.0205,  0.9314],
         [-0.0595, -0.8421],
         [ 1.5567, -0.6847]]], grad_fn=<AddBackward0>) torch.Size([2, 8, 2])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/370488.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SPARC体系下硬浮点编译故障分析

问题说明 之前extension版的app工程都是用的软浮点编译的&#xff0c;在增加姿控算法库后&#xff0c;统一改用硬浮点运行&#xff0c;发现之前一个浮点数解析不对了&#xff0c;排查发现和工程编译选项有关&#xff0c;为软浮点时正常&#xff0c;硬浮点时异常。该问题脱离业…

【华为OD机试模拟题】用 C++ 实现 - VLAN 资源池(2023.Q1)

最近更新的博客 华为OD机试 - 入栈出栈(C++) | 附带编码思路 【2023】 华为OD机试 - 箱子之形摆放(C++) | 附带编码思路 【2023】 华为OD机试 - 简易内存池 2(C++) | 附带编码思路 【2023】 华为OD机试 - 第 N 个排列(C++) | 附带编码思路 【2023】 华为OD机试 - 考古…

Stream操作流 练习

基础数据&#xff1a;Data AllArgsConstructor NoArgsConstructor public class User {private String name;private int age;private String sex;private String city;private Integer money; static List<User> users new ArrayList<>();public static void m…

【计算机三级网络技术】 第一篇 网络系统结构与系统设计的基本原则

网络系统结构与系统设计的基本原则 文章目录网络系统结构与系统设计的基本原则一、计算机网络的基本结构二、计算机网络分类及其互联方式1.局域网2.城域网3.广域网4.计算机网络的互联方式三、局域网技术四、城域网技术1.城域网的概念2.宽带城域网建设产生的影响3.推动城域网快速…

HTML - 扫盲

文章目录1. 前言2. HTML2.1 下载 vscode3 HTML 常见标签3.1 注释标签3.2 标题标签3.3 段落标签3.4 换行标签3.5 格式化标签1. 加粗2. 倾斜3. 下划线3.6 图片标签3.7 超链接标签3.8 表格标签3.9 列表标签4. 表单标签4.1 from 标签4.2 input 标签4.3 select 标签4.4 textarea标签…

webgl渲染优化——深度缓冲区、多边形缓冲机制

文章目录前言深度缓冲区多边形缓冲机制总结前言 webgl在渲染三维场景时&#xff0c;按照Z坐标的值决定前后关系&#xff0c;但是在默认状态下它并未开启深度检测&#xff0c;而是将后绘制的物体放在前面&#xff1b;当两个物体Z坐标相差无几时&#xff0c;会产生深度冲突&…

【Redis】线程模型:Redis是单线程还是多线程?

【Redis】线程模型&#xff1a;Redis是单线程还是多线程&#xff1f; 文章目录【Redis】线程模型&#xff1a;Redis是单线程还是多线程&#xff1f;Redis 是单线程吗&#xff1f;Redis 单线程模式是怎样的&#xff1f;Redis 采用单线程为什么还这么快&#xff1f;Redis 6.0 之前…

高端装备的AC主轴头结构

加工机器人的AC主轴头和位置相关动力学特性1. 位置依赖动态特性及其复杂性2. AC主轴头2.1 常见主轴头摆角结构2.2 摆动机构3. 加装AC主轴头的作用和局限性4. 切削机器人的减速器类型5. 其他并联结构形式参考文献资料1. 位置依赖动态特性及其复杂性 However, FRF measurements …

JS学习第3天——Web APIs之DOM(什么是DOM,相关API)

目录一、Web APIs介绍1、API2、Web API二、DOM1、DOM树2、获取元素3、事件基础4、操作元素属性5、节点&#xff08;node&#xff09;操作三、以上内容总结四、小案例一、Web APIs介绍 JS的组成&#xff1a;ECMAScript&#xff08;基础语法&#xff09;、DOM&#xff08;页面文…

CTFer成长之路之反序列化漏洞

反序列化漏洞CTF 1.访问url&#xff1a; http://91a5ef16-ff14-4e0d-a687-32bdb4f61ecf.node3.buuoj.cn/ 点击下载源码 本地搭建环境并访问url&#xff1a; http://127.0.0.1/www/public/ 构造payload&#xff1a; ?sindex/index/hello&ethanwhoamiPOST的参数&#…

【渗透测试学习】—记录一次自测试渗透实战

写在前面 本文是作者入门web安全后的第一次完整的授权渗透测试实战&#xff0c;因为最近在总结自己学习与挖掘到的漏&#xff0c;无意中翻到了这篇渗透测试报告&#xff0c;想当初我的这篇渗透测试报告是被评为优秀渗透测试报告的&#xff0c;故在此重新整了一下&#xff0c;分…

创客匠人直播:构建公域到私域的用户增长模型

进入知识付费直播带货时代&#xff0c;很多拥有知识技能经验的老师和培训机构吃到了流量红利。通过知识付费直播&#xff0c;老师们可以轻松实现引流、变现&#xff0c;还可以突破时间、地域的限制&#xff0c;为全国各地的学员带来优质的教学服务&#xff0c;因此越来越受到教…

【Linux】-- 多线程安全

目录 进程互斥 计算 -> 时序问题 加锁保护 pthread_mutex_lock pthread_mutex_unlock 使用init与destory pthread_mutex_init phtread_mutex_destory 锁的实现原理 图 可重入VS线程安全 死锁 Linux线程同步 条件变量 系统调用 进程互斥 进程线程间的互斥相关…

【C语言经典例题】打印菱形

目录 一、题目要求 二、解题思路 上半部分三角形 打印空格 打印星号* 下半部分三角形 打印空格 打印星号* 三、完整代码 代码 运行截图&#xff1a; 一、题目要求 输入一个整数n&#xff08;n为奇数&#xff09;&#xff0c;n为菱形的高&#xff0c;打印出该菱形 例&a…

【模拟集成电路】鉴频鉴相器设计(Phase Frequency Detector,PFD)

鉴频鉴相器设计&#xff08;Phase Frequency Detector&#xff0c;PFD&#xff09;前言一、 PFD的工作原理二、 PFD电路设计&#xff08;1&#xff09;PFD电路图&#xff08;2&#xff09;D触发器电路图&#xff08;3&#xff09;与非门&#xff08;NAND&#xff09;电路图&…

【死磕数据库专栏】MySQL对数据库增删改查的基本操作

前言 本文是专栏【死磕数据库专栏】的第二篇文章&#xff0c;主要讲解MySQL语句最常用的增删改查操作。我一直觉得这个世界就是个程序&#xff0c;每天都在执行增删改查。 MySQL 中我们最常用的增删改查&#xff0c;对应SQL语句就是 insert 、delete、update、select&#xf…

亚马逊侵权了怎么办?不要恐慌,这套申诉方法教你解决

侵权&#xff0c;在亚马逊可是大忌&#xff01;在亚马逊平台上&#xff0c;卖家侵权行为被认为是极为严重的违规行为。亚马逊采取的对待侵权的措施通常相当严厉&#xff0c;从轻者的产品下架到重者直接被禁售。所以如果你的产品涉嫌侵犯知识产权&#xff0c;那么想要在亚马逊上…

软件质量保证与测试(测试部分)

第九章、软件测试过程 9.1 计算机软件的可靠性要素 9.2 软件测试的目的和原则 9.3 软件测试过程 9.4 软件测试与软件开发的关系 9.7 测试工具选择 9.7.1 白盒测试工具 9.7.2 黑盒测试工具 第十章、黑盒测试 10.1 黑盒测试的基本概念 10.2 等价类划分 10.2.2 划分等价类的方法…

MinGW编译log4cpp

log4cpp的官网和下载地址 https://log4cpp.sourceforge.net/ https://sourceforge.net/projects/log4cpp/files/ 使用MinGW编译log4cpp 进入到log4cpp的源码目录 cd F:\3rdParty\Log\log4cpp\log4cpp-1.1.3\log4cpp 创建文件夹 mkdir build && mkdir outcd build …

死磕Spring,什么是SPI机制,对SpringBoot自动装配有什么帮助

文章目录如果没时间看的话&#xff0c;在这里直接看总结一、Java SPI的概念和术语二、看看Java SPI是如何诞生的三、Java SPI应该如何应用四、从0开始&#xff0c;手撸一个SPI的应用实例五、SpringBoot自动装配六、Spring SPI机制与Spring Factories机制做对比七、这里是给我自…