神经网络学习小记录74——Pytorch 设置随机种子Seed来保证训练结果唯一

news2024/11/16 11:29:13

神经网络学习小记录74——Pytorch 设置随机种子Seed来保证训练结果唯一

  • 学习前言
  • 为什么每次训练结果不同
  • 什么是随机种子
  • 训练中设置随机种子

学习前言

好多同学每次训练结果不同,最大的指标可能会差到3-4%这样,这是因为随机种子没有设定导致的,我们一起看看怎么设定吧。
在这里插入图片描述

为什么每次训练结果不同

模型训练中存在很多随机值,最常见的有:
1、随机权重,网络有些部分的权重没有预训练,它的值则是随机初始化的,每次随机初始化不同会导致结果不同。
2、随机数据增强,一般来讲网络训练会进行数据增强,特别是少量数据的情况下,数据增强一般会随机变化光照、对比度、扭曲等,也会导致结果不同。
3、随机数据读取,喂入训练数据的顺序也会影响结果。
……
应该还有别的随机值,这里不一一列出,这些随机都很容易影响网络的训练结果。

如果能够固定权重、固定数据增强情况、固定数据读取顺序,网络理论上每一次独立训练的结果都是一样的。

什么是随机种子

随机种子(Random Seed)是计算机专业术语。一般计算机的随机数都是伪随机数,以一个真随机数(种子)作为初始条件,然后用一定的算法不停迭代产生随机数。

按照这个理解,我们如果可以设置最初的 真随机数(种子),那么后面出现的随机数将会是固定序列。

以random库为例,我们使用如下的代码,前两次为随机生成,后两次为设置随机数生成器种子后生成。

import random

# 生成随机整数
print("第一次随机生成")
print(random.randint(1,100))
print(random.randint(1,100))

# 生成随机整数
print("第二次随机生成")
print(random.randint(1,100))
print(random.randint(1,100))

# 设置随机数生成器种子
random.seed(11)

# 生成随机整数
print("第一次设定种子后随机生成")
print(random.randint(1,100))
print(random.randint(1,100))

# 重置随机数生成器种子
random.seed(11)

# 生成随机整数
print("第二次设定种子后随机生成")
print(random.randint(1,100))
print(random.randint(1,100))

结果如下,前两次随机生成的序列不同,后两次设定种子后随机生成的序列相同:

第一次随机生成
66
37
第二次随机生成
93
56
第一次设定种子后随机生成
58
72
第二次设定种子后随机生成
58
72

训练中设置随机种子

一般训练会用到多个库包含有关random的内容。

在pytorch构建的网络中,一般都是使用下面三个库来获得随机数,我们需要对三个库都设置随机种子:
1、torch库;
2、numpy库;
3、random库。

在这里写了一个函数:

#---------------------------------------------------#
#   设置种子
#---------------------------------------------------#
def seed_everything(seed=11):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

这里面写到了cuda、cudnn这类gpu才会用到的东西,实测发现cpu版本的pytorch也可以正常运行。

torch.backends.cudnn.deterministic=True用于保证CUDA 卷积运算的结果确定。
torch.backends.cudnn.benchmark=False是用于保证数据变化的情况下,减少网络效率的变化。为True的话容易降低网络效率。

只需要在所有初始化前,调用该seed初始化函数即可。

另外,Pytorch一般使用Dataloader来加载数据,Dataloader一般会使用多worker加载多进程来加载数据,此时我们需要使用Dataloader自带的worker_init_fn函数初始化Dataloader启动的多进程,这样才能保证多进程数据加载时数据的确定性。

#---------------------------------------------------#
#   设置Dataloader的种子
#---------------------------------------------------#
def worker_init_fn(worker_id, rank, seed):
    worker_seed = rank + seed
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/734840.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电源频率检测器/采用555时基电路的过流检测器电路设计

电源频率检测器 对于某些电子仪器和电气设备,对见六电源的频率有着一定的要求,电源频率高于或低于 50Hz,都会影响设备的正常工作,甚至造成仪器和设备的损坏。因此,对于此类设备需要装设电源频率检测装置,当…

Linux开发工具之【vim】

Linux开发工具之【vim】 文章目录: Linux开发工具之【vim】1. Linux软件包管理器yum1.1 查看软件1.2. 下载软件1.3 卸载软件 2. vim编辑器的使用2.1 vim常用模式2.2 vim基本操作2.3 vim命令模式命令集2.3.1 移动光标2.3.2 删除文字2.3.3 复制文本内容2.3.4 替换文本…

Openpose原理总结

Openpose是一种开源的实时多人姿态估计库,由卡耐基梅隆大学开发。它通过分析图像或视频中的人体关键点来估计人体的姿态,识别身体的各个部分,并推断出人体的姿势信息。 Openpose能够同时检测和跟踪多个人的姿态,可以用于人机交互、…

二.《泽诺尼亚》明文CALL

了解发包函数 1.一款网络游戏,必定是会发包的,对于PC端而言,想要进行网络通讯,就拿最简单的CS架构而言势必会调用win32 API函数或底层函数 2.这里列举出常用的API如:send sendto wsasend wspsend 3.有一些正向开发经验的同学,肯定是见过这几个函数的 4.接下来我们来看看这款…

JavaEE——介绍并简单使用线程池

文章目录 一、 什么是线程池二、Java中线程池的运用1. 创建线程池中的问题2. 标准库中线程池的使用 三、自主实现一个简单的线程池 一、 什么是线程池 所谓线程池,其实和字符串常量池,数据库连接池十分相似,就是设定一块区域,提前…

打印机常见故障解决参考方法

1、首先检查打印机电源线连接是否可靠或电源指示灯是否点亮,然后再次打印文件,仍不能打印,请看下一步。 2、检查打印机与计算机之间的信号传输线是否可靠连接,检查并重新连接,如果打印机仍不能打印,请看下一…

Java线程Thread类常用方法

文章目录 1. start():启动线程,使其执行run()方法中的代码。2. run():线程的执行逻辑,需要在该方法中定义线程要执行的代码。3. sleep(long millis):使当前线程暂停指定的毫秒数,进入阻塞状态。4. join()&a…

【C++】红黑树封装map和set

文章目录 一、map和set源码剖析二、红黑树的迭代器1.begin()与end()2.operator()与operator--() 三、set的模拟实现四、map的模拟实现五、完整代码实现1.RBTree.h2.set.h3.map.h5.Test.cpp 一、map和set源码剖析 我们知道,map和set的底层是红黑树,但是我…

如何用Python快速搭建一个文件传输服务

当我的朋友需要把他电脑上面的文件从他的电脑传递到我电脑上的时候,我只需要启动服务 启动服务! 他打开web界面 就能把文件传递到我电脑上(还能够实时显示进度) 文件就已经在我电脑上的uploads文件夹里面了 项目结构如下 templat…

python浮点运算不准确

1.问题 1.12.2的最后结果并不等于3.3 2. 解决方法 错误示范 引入了Decimal计算,但是计算类型是float型,依然计算不准确 正确解决方法 把类型转化为字符串引入计算

动手实战 | 使用 Transformers 包进行概率时间序列预测

最近使用深度学习进行时间序列预测而不是经典方法涌现出诸多创新。本文将为大家演示一个基于 HuggingFace Transformers 包构建的概率时间序列预测的案例。 概率预测 通常,经典方法针对数据集中的每个时间序列单独拟合。然而,当处理大量时间序列时&…

spring中的扩展点解析以及实践使用

文章目录 1、ApplicationContextInitializer2、BeanDefinitionRegistryPostProcessor3、BeanFactoryPostProcessor4、InstantiationAwareBeanPostProcessor5、SmartInstantiationAwareBeanPostProcessor6、BeanFactoryAware7、ApplicationContextAwareProcessor8、BeanNameAwar…

查找文件所在的具体位置

Linux Command 命令: find – walk a file hierarchy (遍历文件层次结构) ;DESCRIPTION 描述: 在指定目录下查找文件和目录, 可以使用不同的选项来过滤和限制查找的结果 ; Grammar Format $ find <在哪个路径下查找> <可选参数…> 常用选项 -name <pattern>…

【javaEE面试题(四)线程不安全的原因】【1. 修改共享数据 2. 操作不是原子性 3. 内存可见性 4. 代码顺序性】

4. 多线程带来的的风险-线程安全 (重点) 4.1 观察线程不安全 static class Counter {public int count 0;void increase() {count;} } public static void main(String[] args) throws InterruptedException {final Counter counter new Counter();Thread t1 new Thread(()…

VMware16.0安装教程和创建

许可证&#xff1a; ZF3R0-FHED2-M80TY-8QYGC-NPKYFYF390-0HF8P-M81RQ-2DXQE-M2UT6ZF71R-DMX85-08DQY-8YMNC-PPHV8设置网络 添加镜像 下载centos7镜像网址https://mirrors.aliyun.com/centos/7/isos/x86_64/?spma2c6h.25603864.0.0.d7724511YPrZpg win10镜像地址https://ww…

Linux+Docker+Gitlab+Jenkins+虚拟内存

最近想研究一下怎么自动化发布项目,于是找到了gitlab+jenkins这个组合,正好借机也研究一下最近很火的docker技术。本篇共分为五部分,分别为安装要求,内存虚拟化,安装docker,安装gitlab,安装jenkins。 一、 安装要求 1 Docker安装要求: 1.1 操作系统 Docker只支持64…

unittest单元测试

java的单元测试框架Junit和TestNG&#xff0c;python里面也有单元测试框架-unittest,相当于是一个python版的junit。python里面的单元测试框架除了unittest,还有一个pytest框架&#xff0c;但是用的比较少 unittest注意点&#xff1a; 导入unittest模块 类名的第一个字母大写&…

代码随想录算法学习心得 40 | 139. 单词拆分、背包问题总结...

一、单词拆分 链接&#xff1a;力扣 描述&#xff1a;给你一个字符串 s 和一个字符串列表 wordDict 作为字典。请你判断是否可以利用字典中出现的单词拼接出 s 。注意&#xff1a;不要求字典中出现的单词全部都使用&#xff0c;并且字典中的单词可以重复使用。 思路如下&…

【Linux | Shell】Linux 安全系统 —— 用户、组、文件权限 - 阅读笔记

目录 一、Linux 的安全性1.1 /etc/passwd 文件1.2 /etc/shadow 文件1.3 添加新用户 —— useradd1.4 删除用户 —— userdel1.5 修改用户 —— usermod、passwd、chpasswd 二、使用 Linux 组2.1 /etc/group 文件2.2 创建新组 —— groupadd2.3 修改组 —— groupmod 三、理解文…

Jenkins可持续集成Python自动化脚本

目录 前言 一、Jenkins搭建在Windows上 二、Jenkins搭建在Linux上 &#x1f381;更多干货 完整版文档下载方式&#xff1a; 本文讲解Jenkins如何每次定时的从SVN服务器上拉取最新的代码并执行本地库里的脚本 前言 1、本地代码库目录F:\5i5jautest内有测试文件all_tests.…