【深度学习实战(33)】训练之model.train()和model.eval()

news2025/1/12 12:11:21

一、model.train(),model.eval()作用?

model.train() 和 model.eval() 是 PyTorch 中的两个方法,用于设置模型的训练模式和评估模式。

model.train() 方法将模型设置为训练模式。在训练模式下,模型会启用 dropout 和 batch normalization 等正则化方法,并且可以计算梯度以进行参数更新,同时还可以追踪梯度计算的图。训练时,均值、方差分别是该批次内数据相应维度的均值与方差

model.eval() 方法将模型设置为评估模式。在评估模式下,模型会禁用 dropout 和 batch normalization 等正则化方法,这样可以保证每次评估的结果是确定的。评估模式下的模型通常用于模型的测试、验证或推理阶段。推理时,均值、方差是基于所有批次的期望计算所得

区分训练模式和评估模式的目的在于保证模型在不同阶段的行为一致性。例如,在训练模式下,模型需要计算并追踪梯度以进行反向传播和参数更新;而在评估模式下,模型不需要计算梯度,只需要给出确定的预测结果。

二、model.train(),model.eval()对dropout产生的影响

使用model.train():有神经元被置零,且比例符合nn.Dropout(0.5)中的0.5设定

import torch
import torch.nn as nn

model = nn.Dropout(0.5)
model.train()
input = torch.rand([3, 4])

print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述
使用model.eval():没有神经元置零,nn.Dropout(0.5)被关闭

import torch
import torch.nn as nn

model = nn.Dropout(0.5)
#model.train()
model.eval()
input = torch.rand([3, 4])

print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述

不使用model.train()和model.eval():有神经元被置零,但是比例非常随机,不符合nn.Dropout(0.5)中的0.5设定
import torch
import torch.nn as nn

model = nn.Dropout(0.5)
#model.train()
#model.eval()
input = torch.rand([3, 4])

print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

三、model.train(),model.eval()对batch normalization产生的影响

使用model.eval():bn中的均值,方差,不发生改变

# 1.导入所需的库:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


# 2.定义数据集的转换方法。MNIST数据集是由28x28像素的手写数字组成的图像,将其转换为torch张量并进行标准化处理:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

# 3.下载MNIST数据集并进行转换:
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)

# 4.创建数据加载器:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=0)

# 5.现在你可以使用trainloader和testloader来获取训练集和测试集的批次数据了。例如,可以使用迭代器遍历数据集中的批次:
#dataiter = iter(trainloader)
#images, labels = dataiter.next()

# 上述代码将返回一个批次的图像和对应的标签。可以使用images和labels来进行模型的训练和评估。
# 这就是使用torch库自带的MNIST数据集的基本流程。根据需要,你还可以添加其他的数据处理和增强步骤。


# 定义模型
class Model(nn.Module):
    def __init__(self, hidden_num=32, out_num=10):
        super().__init__()
        self.fc1 = nn.Linear(28*28, hidden_num)
        self.bn  = nn.BatchNorm1d(hidden_num)
        self.fc2 = nn.Linear(hidden_num, out_num)
        self.softmax = nn.Softmax()
    def forward(self, inputs, **kwargs):
        x = inputs.flatten(1)
        x = self.fc1(x)
        
        print("========= bn之前存的数据: =========")
        print(self.bn.running_mean, self.bn.running_var)
        print()
        

        print("========= 当前 Batch 的数据: =========")
        x_mean = torch.mean(x,0)
        x_variance = torch.mean((x - x_mean)*(x - x_mean),0)
        print(x_mean, x_variance)
        print()
        

        print("========= torch官方计算之后的bn新数据: =========")
        x = self.bn(x)
        print(self.bn.running_mean, self.bn.running_var)
        print()
        
        # x = self.dropout(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x
    
torch.manual_seed(1)
model = Model()
#model.train()
model.eval()
for img, label in trainloader:
    label = nn.functional.one_hot(label.flatten(), 10)
    out = model(img)
    break

在这里插入图片描述
使用model.train():bn中的均值,方差,通过滑动平均地方式发生改变,

torch.manual_seed(1)
model = Model()
model.train()
#model.eval()
for img, label in trainloader:
    label = nn.functional.one_hot(label.flatten(), 10)
    out = model(img)
    break

在这里插入图片描述
不使用model.train()和model.eval():默认bn中的均值,方差,通过滑动平均地方式发生改变,
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1648871.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SinoDB SQL管理工具之-DBeaver安装使用说明

本文介绍如何使用DBeaver管理工具连接SinoDB数据库。 1. DBeaver下载 下载地址:Download | DBeaver Community 请根据需求选择对应自己操作系统的版本进行下载。本次示例使用Windows 64位操作系统进行安装配置。下载版本为:dbeaver-ce-23.0.2-x86_64-…

什么是SOL链跟单机器人与阻击机器人?

SOL链作为一个快速增长的区块链生态系统,为各种应用程序提供了丰富的发展机会。在SOL链上,智能合约的应用已经开始蓬勃发展,其中包括了许多与加密货币交易相关的应用。在本文中,我们将介绍在SOL链上开发的阻击机器人(S…

42.乐理基础-拍号-看懂拍号的意义

到这必然是已经知道 X、Y的意思了: 然后带入数字: 然后念拍号的时候,在国内,百分之九十的地方是从下往上念,念作四二拍,还有百分之十的地方是和国外一样,从上往下念,念作二四拍&…

DigitalOcean 应用托管平台级更新:应用端到端运行时性能大幅改进

DigitalOcean 希望可以为企业提供所需的工具和基础设施,以帮助企业客户加速云端的开发,实现业务的指数级增长。为此 DigitalOcean 在 2020 年就推出了App Platform。 App Platform(应用托管) 是一个完全托管的 PaaS 解决方案&…

如何自己快速的制作流程图?6个软件教你快速进行流程图制作

如何自己快速的制作流程图?6个软件教你快速进行流程图制作 自己制作流程图可以是项目管理、流程设计或教学展示中的重要环节。以下是六款常用的流程图制作软件,它们都提供了快速、简单的方式来制作流程图: 迅捷画图:这是一款非…

Java基础(三):Java异常机制以及底层实现原理

🌷一、异常 ☘️1.1 什么是异常 Java异常是程序发生错误的一种处理机制,异常的顶级类是Throwable,Throwable字面意思就是可抛出的,该类是所有的错误和异常的超类,只有Throwable类或者Throwable子类的实例对象才可以被…

超详细——集成学习——Adaboost实现多分类——附代码

资料参考 1.【集成学习】boosting与bagging_哔哩哔哩_bilibili 集成学习——boosting与bagging 强学习器:效果好,模型复杂 弱学习器:效果不是很好,模型简单 优点 集成学习通过将多个学习器进行结合,常可获得比单一…

Xinstall广告效果监测,助力广告主优化投放策略

在移动互联网时代,APP推广已成为企业营销的重要手段。然而,如何衡量推广效果,了解用户来源,优化投放策略,一直是广告主和开发者面临的难题。这时,Xinstall作为国内专业的App全渠道统计服务商,以…

TCP四次挥手分析

TCP四次挥手分析 概念过程分析为什么连接的时候是三次握手,关闭的时候却是四次握手?为什么要等待2MSL? 概念 四次挥手即终止TCP连接,就是指断开一个TCP连接时,需要客户端和服务端总共发送4个包以确认连接的断开。 在…

有关string的部分接口

1.迭代器与反向迭代器(iterator-) 迭代器是可以用来访问string里面的内容的,这里来记录一下使用的方法。 里面用到了一个叫做begin函数和一个end函数,这两个都是针对string使用的函数。 s1.begin()函数是指向string内容的第一个元素 而s1.end()指向的则…

「新媒体营销必备」短链接生成,让你的内容更易传播!

在信息大爆炸的今天,无论是企业还是个人都需要有一个快速有效的方式让信息传播。而短链接生成的出现,为我们带来了极大的便利。 C1N短网址(c1n.cn)是一家致力于为用户提供快速、安全的短链接服务的公司。作为专注于短链接的品牌&…

Windows Server 2019虚拟机安装

目录 第一步、准备工作 第二步、部署虚拟机 第三步、 Windows Server 2019系统启动配置 第一步、准备工作 下载Windows Server 2019系统镜像 官网下载地址:Windows Server 2019 | Microsoft Evaluation Center VMware Workstation 17下载地址: 链…

阿里云国际服(alibabacloud)介绍、注册、购买教程?

一、什么是阿里云国际版? 阿里云分为国内版和国际版。国内版仅面向中国大陆客户,国际版面向全球客户。 二、国际版与国内版有何异同? 1)异:除了目标客户不同,运营主体不同,所需遵守的法律与政…

暗区突围pc端下载教程 暗区突围pc端怎么下载

暗区突围pc端下载教程 暗区突围pc端怎么下载 《暗区突围》是一款刺激的第一人称射击游戏。目前pc版本要上线了,即将在5月正式上线。在这款游戏里,我们会在随机的时间、地点,拿着不一定的装备,跟其他玩家拼个高低,还需…

(十六)Servlet教程——Servlet文件下载

Servlet文件下载 文件下载是将服务器上的资源下载到本地,可以通过两种方式来下载服务器上的资源。第一种是使用超链接来下载,第二种是通过代码来下载。 超链接下载 在HTML或者JSP页面中使用超链接时,可以实现页面之间的跳转,但是…

开发环境虚拟环境学习记录

1、VS Code搭建python环境 下载好Visual Studio Code后,首先需要进入Visual Studio Code并安装支持python开发的插件: 2、虚拟环境 2.1、初识虚拟环境 概述:①、在使用Python语言的时候我们使用pip来安装第三方包,但是由于pip的…

Leetcode—138. 随机链表的复制【中等】

2024每日刷题(129) Leetcode—138. 随机链表的复制 实现代码 /* // Definition for a Node. class Node { public:int val;Node* next;Node* random;Node(int _val) {val _val;next NULL;random NULL;} }; */class Solution { public:Node* copyRan…

Linux动态库与静态库解析

文章目录 一、引言二、C/C源文件的编译过程三、静态库1、静态库的定义和原理2、静态库的优缺点3、静态库的创建和使用a、创建静态库b、使用静态库 四、动态库1、动态库的定义和原理2、动态库的优缺点3、动态库的创建和使用示例a、创建动态库b、使用动态库 五、动静态库的比较 一…

【Python小技巧】matplotlib不显示图像竟是numpy惹的祸

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、问题:df.plot() 显示不出图像二、尝试各种解决办法1. 增加matplotlib.use,设定GUI2. 升级matplotlib版本 三、numpy是个重要的库1. …

详解MySQL常用的数据类型

前言 MySQL是一个流行的关系型数据库管理系统,它支持多种数据类型,以满足不同数据处理和存储的需求。理解并正确使用这些数据类型对于提高数据库性能、确保数据完整性和准确性至关重要。本文将详细介绍MySQL中的数据类型,包括数值类型、字符…