Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

news2024/9/22 5:41:17

1 model.train() 和 model.eval()用法和区别

1.1 model.train()

model.train()的作用是启用 Batch Normalization 和 Dropout

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

1.2 model.eval()

model.eval()的作用是不启用Batch Normalization 和 Dropout
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。

在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。

1.3 分析原因

使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval。model.eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!

# 定义一个网络
class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    # 实例化这个网络
    Model = Net()
     
    # 训练模式使用.train()
    Model.train(mode=True)
    
    # 测试模型使用.eval()
    Model.eval()

为什么PyTorch会关注我们是训练还是评估模型?最大的原因是dropout和BN层(以dropout为例)。这项技术在训练中随机去除神经元。
在这里插入图片描述
想象一下,如果右边被删除的神经元(叉号)是唯一促成正确结果的神经元。一旦我们移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。

2.model.eval()和torch.no_grad()的区别

1.在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:

主要用于通知dropout层和BN层在train和validation/test模式间切换:
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); BN层会继续计算数据的mean和var等参数并更新。
在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2. 该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。

而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。

如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1132290.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

时尚女童卫衣—Get麻麻同款和宝贝一起穿

分享女儿的时尚穿搭—卫衣 不知道各位姐妹那边的天气如何 我们这边现在穿卫衣刚刚好 这件卫衣百搭圆领,经典的额版哦 加绒卫衣特别保暖 卡通鹅的图案宝贝特别喜欢 和宝贝一起穿亲子装真的很幸福!

记录一次“超出内存限制”的原因

问题: 问题的来源是力扣的这一条题目:LCR 048. 二叉树的序列化与反序列化 - 力扣(LeetCode) 我寻思着也没啥,就前序遍历呗,时间和空间复杂度都是O(n),应该能把题目K掉。 问题代码: /*** Defini…

InCopy 2024 v19.0.0.151(文案撰写与编辑软件)

InCopy 2024 for Mac是一款文案撰写与编辑软件,适用于在Mac上进行文字编辑和排版工作。它是Adobe Creative Cloud套件中的一部分,主要面向文字编辑人员、作家和出版团队,提供了一系列强大的功能,以协同工作和改进工作流程。 InCo…

学校巡课工具,这次真包教包会!

随着教育领域的不断发展和科技的快速进步,在线巡课系统已经成为一种强大的工具,为学校、教育机构和教育者提供了全新的教学监督和评估方式。 客户案例 中学巡课项目 成都某中学面临着监督和改进教学质量的挑战。泛地缘科技推出的在线巡课系统&#xff0c…

欧拉图相关的生成与计数问题探究

最近学了一波国家集训队2018论文的最后一个专题。顺便带上了一些我的注解。 先放一波这个论文 1.基本概念 欧拉图问题是图论中的一类特殊的问题。在本文的介绍过程中,我们将会使用一些图 论术语。为了使本文叙述准确,本节将给出一些术语的定义。 定义…

嵌入式软件找实习需要掌握哪些知识?

嵌入式软件找实习需要掌握哪些知识? 嵌入式软件实习暂时不用刷力扣。具体需要掌握哪些知识每个公司的要求都不太一样,因为嵌入式也有很多细分领域,单片机、C语言是需要熟悉的,最近很多小伙伴找我,说想要一些嵌入式资料…

基于springboot实现篮球论坛管理系统项目【项目源码+论文说明】

基于springboot实现篮球论坛管理系统演示 摘要 首先,论文一开始便是清楚的论述了系统的研究内容。其次,剖析系统需求分析,弄明白“做什么”,分析包括业务分析和业务流程的分析以及用例分析,更进一步明确系统的需求。然后在明白了系统的需求基础上需要进一步地设计系统,主要包罗…

前端应用发布到nodejs server后浏览器刷新404问题

现象:Angular、Vue和React项目SPA应用如果路由模式为history,部署到服务器后,点击浏览器刷新按钮会出现404。 原因:当路由模式为history的时候,服务器端会根据浏览器中的请求地址去匹配资源,此时服务器端没…

安卓恶意应用识别(四)(特征处理与分类模型构建)——终结

前言 前面三章将数据初步整理出来: 1.安卓恶意应用识别(一)(Python批量爬取下载安卓应用) 2.安卓恶意应用识别(二)(安卓APK反编译) 3.安卓恶意应用识别(三&a…

一张Lambda大数据架构架构图

架构图 大数据的架构包括了Lambda架构和Kappa架构,Lambda架构分解为三层:即批处理层、加速层和服务层;Kappa架构不同于Lambda同时计算流计算和批计算并合并视图,Kappa只会通过流计算一条的数据链路计算并产生视图。 该系统的大数…

1024程序员节|是时候,展示真正的实力了!

“代码改变世界,开源创造未来” 一年一度的程序员节如期而至! 在由“0”和“1”构建的二进制世界里 运行程序的硬件进制基础是“1024” 因此,将每一年的10月24日定为“程序员节” 而作为码出每一个1024的程序员 则是这个“二进制世界”里…

*Django中的Ajax jq的书写样式1

导入插件,导入jquery,json是添加的json文件 Ajax的get请求与post请求 urls.py path(in3/,views.in3), views.py def in3(request):return render(request,07.html) 要返回数据的path没有写,html就是下面图片中控制台的内容,记得传递参数…

图片gif怎么做?这一招分分钟制作

在现代的通讯工具中,gif动态表情已经是人们日常必不可少的一种交流工具了。当我们想要自己制作gif动画图片的时候该怎么办呢?这时候,只需要使用GIF中文网的gif图片制作(https://www.gif.cn/)功能,上传多张静…

MySQL恢复不小心误删的数据记录(binlog)-生产实操

同事操作删除生产数据时,未及时备份导致误删除部份数据。现通过mysql的binlog日志回滚数据。 1.由于我们第一时间找到相应的操作时间的大概在12点~13点之间操作删除了product和product_item表的部份数据。因此我们只要找到这期间的binlog文件就行了, 脚…

【字符串】【字符串和字符数组的相互转化】Leetcode 541 反转字符串 II

【字符串】【toCharArray & String st new String(ch)】Leetcode 541 反转字符串 II 字符串和字符数组的相互转化:star:解法1 字符串和字符数组的相互转化⭐️ 解法1 时间复杂度O(N) 这个解法的时间复杂度是O(N),其中N是字符串的长度。…

美颜SDK集成指南:为应用添加视频美颜功能

随着社交媒体和直播应用的兴起,视频美颜功能已成为用户追求的一项热门特性。用户希望能够在拍摄照片或进行实时视频直播时,使用美颜功能来增强其外观。为了满足这一需求,开发者可以考虑集成美颜SDK,为其应用增加这一吸引人的功能。…

花生供应链中的沙门氏菌、大肠杆菌和肠杆菌: 从农场到餐桌

1.1Title:Salmonella, Escherichia coli and Enterobacteriaceae in the peanut supply chain: From farm to table 1.2 Author: Nascimento M. S. 1.3 机构:Campinas University 1.4 分区/影响因子:Q1/6.475 1.5 期刊:FOOD RESEARCH INT…

linux,windows命令行输出控制指令,带颜色的信息,多行刷新,进度条效果,golang

一、带颜色的信息 linux 颜色及模式编号 // 前景 背景 颜色 // --------------------------------------- // 30 40 黑色 // 31 41 红色 // 32 42 绿色 // 33 43 黄色 // 34 44 蓝色 // 35 45 紫红色 // 36 46 青蓝色 // 37 47 白色 // // 模式代码 意义 //…

接口测试到底怎么做,5分钟时间看完这篇文章彻底搞清楚!

01、通用的项目架构 02、什么是接口 接口:服务端程序对外提供的一种统一的访问方式,通常采用HTTP协议,通过不同的url,不同的请求类型(GET、POST),不同的参数,来执行不同的业务逻辑。…

SpringBoot中对Spring AOP的实现

文章目录 SpringBoot中对Spring AOP的实现AOP简介引入依赖AOP体系与概念编写AOP切面类启动SpringBoot项目然后访问controller控制器对环绕通知放行execution表达式的含义通过注解方式定义切点 SpringBoot中对Spring AOP的实现 AOP简介 AOP (Aspect Oriented Programming), 面…