【Pytorch】模型的可复现性

news2025/1/23 14:50:06

背景

在做研究的时候,通常我们希望同样的样本,同样的代码能够得到同样的实验效果,但由于代码中存在一些随机性,导致虽然是同样的样本和程序,但是得到的结果不一致。在pytorch的官方文档中为此提供了一些建议,原文档:REPRODUCIBILITY。下面我们来看看看具体的内容。

程序包的随机性

pytorch中的随机性

pytorch在一些操作具有随机性,如:torch.svd_lowrank(),我们可以使用torch.manual_seed()设置随机数种子来使得所有的设备(CPU和GPU)的随机性一致(本质来说现有的随机函数都是伪随机,都是通过随机数种子确定)。如:

import torch
torch.manual_seed(0)

python中的随机性

当然有时候我们的程序中可能还会使用python内建函数random,在程序中设置对应的随机数种子也是需要做的。即:

import random
random.seed(0)

python常用的第三方包-numpy

在数据处理中我们可能也会使用numpy,numpy中也存在响应的随机函数,使用方法如下:

import numpy as np
np.random.seed(0)

需要补充的是,可能在实际的编程中还会使用其他的具有随机操作的包,我们需要根据对应的包设置响应的随机数种子,以确保随机数种子的固定。源头就是设置随机数生成器(random number generator,rng)的随机数种子.

cuda的卷积操作

CUDA卷积操作所使用的cuDNN库,可能是一个应用程序多次执行的非确定性的来源。当用一组新的尺寸参数调用cuDNN卷积时,一个可选的功能可以运行多种卷积算法,对它们进行基准测试以找到最快的算法。然后,在剩下的过程中,最快的算法将被持续用于相应的尺寸参数集。由于基准测试的噪音和不同的硬件,该基准在随后的运行中可能会选择不同的算法,即使是在同一台机器上。

torch.backends.cudnn.benchmark = False 禁用基准测试功能会导致 cuDNN 确定性地选择一种算法,可能以降低性能为代价。

避免非决定性的算法

cuda卷积基准

torch.use_deterministic_algorithms()可以让您将PyTorch配置为使用确定性的算法,而不是使用非确定性的算法,如果已知某个操作是非确定性的(并且没有确定性的替代方法),则会抛出一个错误。这一点我们需要查看 torch.use_deterministic_algorithms()文档了解受影响操作的完整列表。
此外,如果你使用CUDA张量,并且你的CUDA版本是10.2或更高,你应该根据CUDA文档设置环境变量CUBLAS_WORKSPACE_CONFIG:https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
.我们可以使用如下代码进行设置:

torch.backends.cudnn.deterministic = True
# 或
torch.use_deterministic_algorithms(True) 

两者的区别是:后者可以使pytorch操作变为确定性,而前者只是控制函数的行为,例如CUDA convolution benchmarking,后者将其处理成确定性的算法,而前者只是让其每次选择同样的基准算法,而选择基准算法的算法依然是一个非确定行的算法。

cuda中的rnn和lstm

在CUDA的某些版本中,RNN和LSTM网络可能有非确定性的行为。例如在torch.nn.RNN的文档中介绍:
在这里插入图片描述
具体的处理方法上面也进行了介绍。

数据加载的随机性

DataLoader中有设置多线程处理数据的参数,由于多进程处理数据也会有先后完成的随机性可能会导致模型训练使用的batch语料不同影响着最终的模型结果,我们可以使用worker_init_fn()和生成器来保持可重复性。例如:

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    worker_init_fn=seed_worker,
    generator=g,
)

如果不想麻烦,在训练语料较少的时候可以考虑使用一个线程去处理数据也是可行的。

其他

当然网上也有博主设置hash操作的环境变量,固定hash的随机性。

os.environ['PYTHONHASHSEED'] = str(seed)

总结

总得来说,在pytorch中存在着一些随机性的操作这个隐藏的比较深,当然官方文档也进行了介绍,根据实际情况进行调整应该可以确保pytorch中的代码能够确定下来,其他的就可能是cuda的情况、数据加载以及使用的第三方包导致的随机性,需要根据情况意义排查了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/59077.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

复现 MMDetection

文章目录MMDetection 复现一、环境配置服务器信息安装CUDA下载并安装CUDA配置环境变量多个Cuda版本切换 (可选)安装CUDNN安装Anaconda搭建虚拟环境新建虚拟环境安装pytorchPycharm 远程连接代码同步配置服务器解释器二、训练和推理自制COCO格式数据集训练修改数据集相关参数修改…

Problem C: 算法10-10,10-11:堆排序

Problem Description 堆排序是一种利用堆结构进行排序的方法,它只需要一个记录大小的辅助空间,每个待排序的记录仅需要占用一个存储空间。 首先建立小根堆或大根堆,然后通过利用堆的性质即堆顶的元素是最小或最大值,从而依次得出…

TMS FixInsight代码评估工具

TMS FixInsight代码评估工具 TMS Fix Insight被认为是Delphi程序员的代码评估工具,它也能够在Delphi的源代码中发现问题。它被认为是一个代码分析工具,用于划分过程以及问题的位置以及Delphi的应用。TMS Fix Insight基本上是一个静态的代码列表&#xff…

Spring - SmartInstantiationAwareBeanPostProcessor扩展接口

文章目录Preorg.springframework.beans.factory.config.SmartInstantiationAwareBeanPostProcessor类关系SmartInstantiationAwareBeanPostProcessor接口方法扩展示例Pre Spring Boot - 扩展接口一览 org.springframework.beans.factory.config.SmartInstantiationAwareBeanPo…

HTML5期末大作业:基于HTML+CSS+JavaScript仿蘑菇街购物商城设计毕业论文源码

常见网页设计作业题材有 个人、 美食、 公司、 学校、 旅游、 电商、 宠物、 电器、 茶叶、 家居、 酒店、 舞蹈、 动漫、 服装、 体育、 化妆品、 物流、 环保、 书籍、 婚纱、 游戏、 节日、 戒烟、 电影、 摄影、 文化、 家乡、 鲜花、 礼品、 汽车、 其他等网页设计题目, A…

jdk11新特性——官方的更新列表

目录一、官方的更新列表二、JEP (JDK Enhancement Proposal 特性增强提议)一、官方的更新列表 二、JEP (JDK Enhancement Proposal 特性增强提议) JShell——(java9开始支持)Dynamic Class-File Constants类文件新添的一种结构局部变量类型推断(var关键字&#xff…

开荒手册3——构思一篇小论文

0 写在前面 又过了一个gap week,总算想清楚了之前遇到的一些问题,现在需要把之前画的大饼们一个一个消化掉。跳出来就会知道,总有一些something is wrong的人喜欢散播点焦虑,你要做的不是惩戒他们,而是赶紧远离&#…

windows下安装ubuntu linux子系统

windows下安装ubuntu linux子系统一、win10下安装ubuntu linux子系统二、下载ubuntu子系统三、启动ubuntu子系统四、配置ubuntu子系统一、win10下安装ubuntu linux子系统 但我们现在自己的主机上跑linux时,有几种选择 同时安装多个操作系统,每次重启电…

js 代码的运行机制

前言: 自己从一开始学习 javaScript 的时候,踩过很多很多坑,初学之路上也问过很多大佬许多为什么...现在回过头感叹,当时问的某些问题确实是有一丢丢幼稚。但是作为一个过来者,我深知这些问题的对于很多“后来者”来说…

tensorflow的模型持久化

参考 tensorflow的模型持久化 - 云社区 - 腾讯云 目录 1、持久化代码实现 2、持久化原理及数据格式 1、meta_info_def属性 2、graph_def属性 3、saver_def属性 4、collection_def属性 1、持久化代码实现 tensorflow提供了一个非常简单的API来保存和还原一个神经网络模型…

自主式模块化无人机设计

目 录 摘 要 I Abstract II 1 绪论 1 1.1 研究背景与意义 1 1.2 国内外研究现状 1 1.3 主要研究内容 2 2自主式模块化无人机的总体结构设计 3 2.1结构形式 3 2.2工作原理 3 2.3机架及桨叶的选择 5 2.3.1 单个桨叶空气动力分析及桨叶的选择 5 2.3.2材料的选择 6 2.3.3机架结构分…

【教学类-20-01】20221203《世界杯16强国旗》(大班)

展示效果: 单人使用样式: 多页打印样式 ​ 背景需求: 做《蒙德里安》格子画时,我把A4纸分割为正方形画框和长条纸支撑。活动中幼儿询问:为什么我的画站不起来?(底边剪的不平整、提手太重、画…

知识直播:时代乐见搜狐的长期主义选择

国内著名商业咨询顾问刘润说:“所有伟大的机会都源自于巨大的结构性改变。大成就背后,一定有涌动的、因商业逻辑巨变而释放出来的红利。” 这话用在当前的互联网行业身上再好不过。面对重重不确定性,如何拨开迷雾,看懂市场趋势&a…

HTTP到底是什么?

文章目录HTTP简介HTTP协议的特点1) 简单快速2) 灵活3) 无连接4) 无状态HTTP协议的发展历程1) HTTP/0.92) HTTP/1.03) HTTP/1.14) HTTP/2.0HTTP的工作流程HTTP简介 HTTP 全称为 Hypertext Transfer Protocol,翻译为中文是“超文本传输协议”的意思,它是互…

Java并发编程—volatile

文章目录volatile的应用volatile的定义与实现原理专业术语:volatile是如何来保证可见性的呢?volatile的原理:volatile的两条实现原则:(物理上如何实施)volatile的内存语义volatile的特性例:下面…

SpringBoot -集成Druid

文章目录Druid概述使用问题解决Spring监控不生效方式1:修改yml的配置写法方式2:参考DruidSpringAopConfiguration自行注入Bean,灵活(更建议)Druid 概述 官网: https://github.com/alibaba/druid   文档&a…

校园论坛(Java)—— 用户管理系统模块

校园论坛(Java)—— 用户管理系统模块 文章目录校园论坛(Java)—— 用户管理系统模块[toc]1、写在前面2、系统结构设计2.1 各个页面之间的调用关系2.2. 用户管理系统模块各层的设计3、管理员管理用户功能3.1 管理员查看普通用户的…

微服务框架 SpringCloud微服务架构 10 使用Docker 10.1 镜像命令

微服务框架 【SpringCloudRabbitMQDockerRedis搜索分布式,系统详解springcloud微服务技术栈课程|黑马程序员Java微服务】 SpringCloud微服务架构 文章目录微服务框架SpringCloud微服务架构10 使用Docker10.1 镜像命令10.1.1 镜像相关命令10.1.2 镜像操作命令10.1.…

SpringBoot_整合PageHelper

分页插件/PageHelper插件 我们在正常的查询业务之中,只需要加上一行代码就可以实现分页的数据的封装处理 实现原理 PageHelper方法使用了静态的ThreadLocal参数,分页参数和线程是绑定的。内部流程是ThreadLocal中设置了分页参数(pageIndex&#xff0c…

TypeScript21(装饰器Decorator)

Decorator 装饰器是一项实验性特性,在未来的版本中可能会发生改变 不仅增加了代码的可读性,清晰地表达了意图,而且提供一种方便的手段,增加或修改类的功能; 若要启用实验性的装饰器特性,你必须在命令行或…