【随机种子初始化】一个神经网络模型初始化的大坑

news2024/10/5 17:43:03

1 问题起因和经过

半年前写了一个模型,取得了不错的效果(简称项目文件1),于是整理了一番代码,保存为了一个新的项目(简称项目文件2)。半年后的今天,我重新训练这个整理过的模型,即项目文件2,没有修改任何的超参数,并且保持完全一致的随机种子,但是始终无法完全复现出半年前项目文件1跑出来的结果(按道理来说,随机种子控制好后,整个训练过程都应该能够复现,第一个epoch的accuracy就应该对上)。我找到项目文件1,跑了跑,能复现之前的训练结果。并且,分别训练项目文件1和2的模型,都能重复自己的训练结果,而两个项目文件的结果无法对上。
花了半天的时间反复仔细检查数据集和训练超参数的设置后,也没能看出项目文件2有什么毛病。非常奇怪,为什么同样的模型、配置、服务器、随机种子,在不同的项目文件中出现不同的结果?!实在想不通。
于是,决定动手debug看看问题出在了哪里。
我发现第一个epoch的结果就对不上,所以我猜测问题出在了模型的初始化上,那初始化与什么相关呢?很自然地,我把问题聚焦在了随机种子上,是不是没有有效固定住随机性?所以我将两个项目文件构建model后的参数打印出来看了看,发现,完全不同!

在这里插入图片描述
在这里插入图片描述

项目文件1中的部分打印结果:
在这里插入图片描述

项目文件2中的部分打印结果:
在这里插入图片描述

很明显地说明了一件事:同样的随机种子,在这两个项目文件中,产生了完全不一样的初始化值! 这个结果是违背我的常识的,为什么会出现这样的情况?
于是,我 猜测是不是因为两个项目在同一服务器上发生了未知的冲突,所以我copy了一份项目文件1为项目文件3,然后跑项目文件3的初始化结果,发现和项目文件1的初始化结果一致,居然没问题!?那这个项目文件2怎么回事,凭空出现了不同的初始化值?
排除了项目冲突这个猜想后,我把视野放在了模型本身上。我试着print(model)进行观察,发现项目文件2相比项目文件1的模型架构多了一些参数,这些参数是我当初在整理代码并补充新算法时补充定义的(比如:self.gamma = Parameter(torch.randn((1, self.num_heads, 1, 1)))),但是后面并没有真正用上这个参数。
于是,我又有了一个新的猜想:是不是因为多出来的这些新定义的参数,导致在同一随机种子的设置下,仍然出现不一致的初始化行为?
顺着这个思路,我给项目文件1做了如下简单的尝试:直接给模型架构多定义一个模块,但无需使用它,看看初始化是否受影响。 这里我简单加了个nn.Linear()进去。
代码解释如下,原本的架构为:

class Attention(Module):
    """
    Obtained from timm: github.com:rwightman/pytorch-image-models
    """

    def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // self.num_heads
        self.scale = head_dim ** -0.5

        self.qkv = Linear(dim, dim * 3, bias=False)
        self.attn_drop = Dropout(attention_dropout)
        self.proj = Linear(dim, dim)
        self.proj_drop = Dropout(projection_dropout)
 
        self.relu = ReLU()
        self.eps = 1e-8
        self.alpha = Parameter(torch.ones(1, self.num_heads, 1, 1), requires_grad=False)
        self.alpha.data.fill_(1.0)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        ...

我在init函数中多定义一行线性层后为:

class Attention(Module):
    """
    Obtained from timm: github.com:rwightman/pytorch-image-models
    """

    def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // self.num_heads
        self.scale = head_dim ** -0.5

        self.qkv = Linear(dim, dim * 3, bias=False)
        self.attn_drop = Dropout(attention_dropout)
        self.proj = Linear(dim, dim)
        self.proj_drop = Dropout(projection_dropout)
 
        self.relu = ReLU()
        self.eps = 1e-8
        self.alpha = Parameter(torch.ones(1, self.num_heads, 1, 1), requires_grad=False)
        self.alpha.data.fill_(1.0)
        
        self.linear = Linear(dim, dim)  # 这里是新加的模块,但是无需使用

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        ...

于是,打印初始化参数得到:
在这里插入图片描述

初始化的结果仍然改变了!!! 看来果然是模型架构的不一致性导致了随机种子“失效”。

2 总结上述问题

两个项目文件,看似模型、超参数、数据集、随机种子、服务器完全一致时,发现训练时两者无法保持完全一致,并且进一步发现两个项目文件在初始化模型参数时就不一致了。
而单独地重复训练每一个项目文件,都能重复自身的结果。
直观感觉就是随机种子并没有有效地作用到另一个项目上,是一个很奇怪的问题,有点违背常识。

3 总结解决方案

训练过程不一致是表象,实际上是模型初始化就不一致了。
如果想要用随机种子控制模型初始化参数完全一致,就必须保证模型的架构完全一致! 但凡在model类的init函数里多定义一个无用参数比如Linear,都会改变整个初始化结果,从而影响后面的训练进程(应该是很微小的影响,但是对于我们复现项目时扣细节来说,会放大这个影响)。

4 可能的解释

咨询了一些遇到过这个问题的同学,大概有如下的可能的解释:在模型的定义中加入了新的模块后,不管是否真正使用,都会影响初始化(已控制了随机种子)。因为加入了新的模块后,整个初始化的顺序会发生改变,于是就乱套了。随机种子只能保证你调用后生成的随机数列是一样的,而在构建模型时的调用顺序,是会随着模型架构的改变而改变的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/661102.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C数据结构】带头双向循环链表_HDList

目录 带头双向循环链表_HDList 【1】链表概念 【2】链表分类 【3】带头双向循环链表 【3.1】带头双向循环链表数据结构与接口定义 【3.2】带头双向循环链表初始化 【3.3】带头双向循环链表开辟节点空间 【3.4】带头双向循环链表销毁 【3.5】带头双向循环链表头插 【3…

【C数据结构】带头单向非循环链表_HList

目录 带头单向非循环链表_HList 【1】链表概念 【2】链表分类 【3】有头单向非循环链表 【3.1】非循环链表数据结构与接口定义 【3.2】带头单向非循环链表初始化 【3.3】带头单向非循环链表释放空间 【3.4】带头单向非循环链表创建节点 【3.5】带头单向非循环链表头插…

HTML学习(二)

视频 <video width"320" height"240" controls> <source src"movie.mp4" type"video/mp4"> <source src"movie.ogg" type"video/ogg"> </video> 音频 <audio controls> <…

C++【AVL树】

✨个人主页&#xff1a; 北 海 &#x1f389;所属专栏&#xff1a; C修行之路 &#x1f383;操作环境&#xff1a; Visual Studio 2019 版本 16.11.17 文章目录 &#x1f307;前言&#x1f3d9;️正文1、认识AVL树1.1、AVL树的定义 2、AVL树的插入操作2.1、抽象图2.2、插入流程…

控制层调用接口的http请求封装

目录 0.碎碎念1.controller层2.util层3.测试3.1中间层调用GET请求3.2中间层调用POST请求 0.碎碎念 因为只是为了写这个帮助类&#xff0c;解耦&#xff0c;不敢拿已经写了一堆的代码改&#xff0c;就单独拆了个项目出来&#xff0c;持久层全是mybatisplus生成的。     所以…

Kafka源码解析之索引

Kafka源码解析之索引 索引结构 Kafka有两种类型的索引&#xff1a; TimeIndex: 根据时间戳索引&#xff0c;可以通过时间查找偏移量所在位置&#xff0c;目录下以.timeindex结尾Index: 根据偏移量索引&#xff0c;.index结尾 构建索引时机 由log.index.interval.bytes 参…

3. redis cluster集群运维与核心原理剖析

分布式缓存技术Redis 1. Redis集群方案比较2. Redis高可用集群搭建 本文是按照自己的理解进行笔记总结&#xff0c;如有不正确的地方&#xff0c;还望大佬多多指点纠正&#xff0c;勿喷。 课程内容&#xff1a; 1、哨兵集群与Redis Cluster架构异同 2、Redis高可用集群快速实…

2023/6/18总结

JS 在document.querySelectorAll(CSS选择器) 选到的集合并没有pop()和push()等数组的方法。是一个伪数组。 如果想要得到里面的每一个对象&#xff0c;需要用for遍历获得 document.getElementById(id名称) 根据id获取一个元素 document.getElementsByTagName(标签名字) 根…

Css面试题:css文字隐藏

文章目录 文字隐藏单行文字隐藏多行文字隐藏基于高度设置多行文字隐藏基于行数设置多行文字隐藏 文字隐藏 单行文字隐藏 主要是通过overflow&#xff0c;text-overflow&#xff0c;white-space三个属性实现。 overflow&#xff1a;visible|hidden|auto|scroll|inherit&#…

【c语言】-- 操作符汇总

&#x1f4d5;博主介绍&#xff1a;目前大一正在学习c语言&#xff0c;数据结构&#xff0c;计算机网络。 c语言学习&#xff0c;是为了更好的学习其他的编程语言&#xff0c;C语言是母体语言&#xff0c;是人机交互接近底层的桥梁。 本章来学习数组。 让我们开启c语言学习之旅…

简单认识web与http协议

文章目录 web基础域名概述DNS&#xff08;Domain Name System域名系统&#xff09; 域名空间结构 域名实际用法 2. 网页的概念2.1 网页&#xff08;HTTP/HTTPS&#xff09;HTML 概述HTML超文本标记语言 HTML文档的结构头标签中常用标签内容标签中常用标签Web概述具体组成web的主…

chatgpt赋能python:Python如何创建窗口——从入门到精通

Python如何创建窗口——从入门到精通 Python是一种高级编程语言&#xff0c;它的易读性和清晰简洁的语法使它成为许多人喜欢学习的编程语言之一。Python的一个主要特色是其丰富的库和模块。在本文中&#xff0c;我们将讨论如何使用Python创建一个窗口&#xff0c;并在其中添加…

【力扣刷题 | 第十一天】

前言&#xff1a; 我将会利用几天把树的经典例题都刷完&#xff0c;希望我可以坚持下去。 226. 翻转二叉树 - 力扣&#xff08;LeetCode&#xff09; 给你一棵二叉树的根节点 root &#xff0c;翻转这棵二叉树&#xff0c;并返回其根节点。 解题思路&#xff1a;我们交换每一…

C语言之运算符用法(补充前面运算符中的不足)

设定&#xff1a;int X20,Y10 1、算术运算符 注&#xff1a;自增和自减运算符只能用于变量&#xff0c;不可用于常量或表达式。另&#xff0c;X与X是不同的(–亦同)。以语句a[x]100;为例&#xff1a; a[X]100;执行之后得到&#xff1a;a[20] 100、X 21。//即&#xff0c;先执行…

Windows10下超详细Mysql安装

目录 0. 前言1. 下载mysql2. 开始安装3. 验证安装4. 环境变量配置 0. 前言 Mysql简介&#xff1a; MySQL是一种开源的关系型数据库管理系统&#xff08;RDBMS&#xff09;&#xff0c;它使用SQL&#xff08;结构化查询语言&#xff09;语言进行数据的存储和访问。MySQL的设计…

git版本管理入门(本地/远程仓库,常用命令)

目录 git简介 安装git 配置SSH key Linux环境下需要命令生成ssh key 本地git管理 多人协作流程 追加 重新提交 git命令 git commit本地和git push远程 git stash和git stash pop暂存 git status查看修改哪些了文件​ git diff 查看修改前后的差异 git log查看提交…

Centos7安装配置Docker

1. 什么是Docker 在开篇之前考虑到阅读人群,我觉得有必要向各位读者朋友简单介绍一下Docker是什么,它解决了什么问题&#xff1f;Docker是基于Go语言实现的云开源项目。它对此给出了一个标准化的解决方案-----系统平滑移植&#xff0c;容器虚拟化技术。让开发者可以打包他们的…

从加密到签名:如何使用Java实现高效、安全的RSA加解密算法?

目录 1. 接下来让小编给您们编写实现代码&#xff01;请躺好 ☺ 1.1 配置application.yml文件 1.2 RSA算法签名工具类 1.3 RSA算法生成签名以及效验签名测试 1.4 RSA算法生成公钥私钥、加密、解密工具类 1.5 RSA算法加解密测试 我们为什么要使用RSA算法来进行加解密&…

React之state详解

目录 执行过程 异步 React18与自动批处理 setState 推荐用法 ()>{return }&#xff0c;this.state. 生命周期 数据没改变时​不渲染 shouldComponentUpdate PureComponent自动&#xff08;推荐&#xff09; 你真的理解setState吗&#xff1f; - 掘金 组件的私有…

《Nature Aging》: 揭示皮肤衰老的分子机制

一个人衰老最直接的体现就是皮肤衰老。人体的皮肤一般从25&#xff5e;30岁以后即随着年龄的增长而逐渐衰老&#xff0c;大约在35&#xff5e;40岁后逐渐出现比较明显的衰老变化。但是&#xff0c;我们的皮肤为什么会衰老呢&#xff1f;要回答这个问题&#xff0c;我们首先要了…