9 从0开始学PyTorch | 过拟合欠拟合、训练集验证集、关闭自动求导

news2025/1/7 6:33:37

这一小节在开始搞神经网络之前,我们先熟悉几个概念,主要还是把模型训练的流程打通。

过拟合和欠拟合

我们在日常的工作中,训练好的模型往往是要去评价它的准确率的,通过此来判断我们的模型是否符合我的要求。
几个可能的方案是,对我们训练使用的数据再输入到训练好的模型中,查看输出的结果是否跟预期的结果是一致的,当然这个在我们的线性模型上跟训练过程没有区别。另外一个比较靠谱的方案是把一部分在训练的时候没有用过的数据放进模型里,看预测结果是否和预期结果一致。

过拟合(overfitting):对于上述两个方案获得的结果,一种情况是在训练用的数据上表现良好,但是对于新数据预测的结果比较差,这时候就是过拟合了,模型学到了训练数据上太多的细节,导致模型的泛化能力变差。
欠拟合(underfitting):另外一个可能的情况是,不光在新数据上表现不好,就在训练数据上表现也不好,这种情况就是欠拟合,连训练数据的特点都没学好。
如下图中画的,左边的模型算是比较好的,中间的模型就是欠拟合,只学到了上半部分数据的特征,而右边那副图就是过拟合。对于处理过拟合和欠拟合问题,有很多解决方案,比如说增加数据,增加迭代轮次,调整参数,增加噪声,随机丢弃等等,这里我们先不纠缠这个问题。

image.png

训练集和验证集

关于上面提到的两份数据,我们就可以称为训练集和验证集,当然有些时候还有一个叫测试集,有时候认为测试集介于训练集和验证集之间,也就是拿训练集去训练模型,使用测试集测试并进行调整,最后用验证集确定最终的效果。在这本书上只写了训练集和验证集,所以我们这里也先按照这个思路来介绍。

image.png

正如上图绘制的那样,在原始数据到来的时候,把它分成两份,一份是训练集,一份是验证集。训练集用来训练模型,当模型迭代到一定程度的时候,我们使用验证集输入到训练好的模型里,评估模型的表现。

torch.randperm方法:将0~n-1(包括0和n-1)随机打乱后获得的数字序列,函数名是random permutation缩写
下面用代码来实现一下

n_samples = t_u.shape[0] #获取样本数量
n_val = int(0.2 * n_samples) #验证集的数量,取全集的20%,这里是2个

shuffled_indices = torch.randperm(n_samples) #打乱顺序

train_indices = shuffled_indices[:-n_val] #训练集位置信息
val_indices = shuffled_indices[-n_val:] #验证集位置信息

train_indices, val_indices  
outs:(tensor([2, 5, 9, 8, 6, 1, 4, 3, 7]), tensor([10,  0]))

紧接着是获取训练数据和验证数据

train_t_u = t_u[train_indices]
train_t_c = t_c[train_indices]

val_t_u = t_u[val_indices]
val_t_c = t_c[val_indices]

train_t_un = 0.1 * train_t_u
val_t_un = 0.1 * val_t_u

定义训练方法,这些跟之前的都差不多

def training_loop(n_epochs, optimizer, params, train_t_u, val_t_u,
                  train_t_c, val_t_c):
    for epoch in range(1, n_epochs + 1):
        train_t_p = model(train_t_u, *params) # <1>
        train_loss = loss_fn(train_t_p, train_t_c)
                             
        val_t_p = model(val_t_u, *params) # <1>
        val_loss = loss_fn(val_t_p, val_t_c)
        
        optimizer.zero_grad()
        train_loss.backward() 
        optimizer.step()

        if epoch <= 3 or epoch % 500 == 0:
            print(f"Epoch {epoch}, Training loss {train_loss.item():.4f},"
                  f" Validation loss {val_loss.item():.4f}")
            
    return params

params = torch.tensor([1.0, 0.0], requires_grad=True)
learning_rate = 1e-2
optimizer = optim.SGD([params], lr=learning_rate)

training_loop(
    n_epochs = 3000, 
    optimizer = optimizer,
    params = params,
    train_t_u = train_t_un, # <1> 
    val_t_u = val_t_un, # <1> 
    train_t_c = train_t_c,
    val_t_c = val_t_c)
out:
Epoch 1, Training loss 91.7660, Validation loss 29.0568
Epoch 2, Training loss 43.7766, Validation loss 2.3025
Epoch 3, Training loss 36.0900, Validation loss 3.5195
Epoch 500, Training loss 7.0920, Validation loss 4.6118
Epoch 1000, Training loss 3.4116, Validation loss 4.0901
Epoch 1500, Training loss 2.9273, Validation loss 3.9970
Epoch 2000, Training loss 2.8636, Validation loss 3.9759
Epoch 2500, Training loss 2.8552, Validation loss 3.9700
Epoch 3000, Training loss 2.8541, Validation loss 3.9680
tensor([  5.4240, -17.2490], requires_grad=True)

从上面的结果可以看到,训练集损失持续下降,验证集损失前期波动比较大,这可能是因为我们的验证集数量太少导致的,不过在500代以后训练损失和验证损失都趋于稳定。
这里作者给出了几个对比训练损失和验证损失的图片,很有意思。其中蓝色实线是训练损失,红色虚线是验证损失。对于图A,训练损失和验证损失随着训练轮次的增长都没啥变化,表明数据并没有提供什么有价值的信息;图B中,随着训练轮次增加,训练损失逐步下降,而验证损失逐步上升,这说明出现了过拟合现象;C图中验证损失和训练损失同步下降,是一种比较理想化的模型效果;D图中验证损失和训练损失也是同步下降,但是训练损失下降幅度更大一些,这种情况显示存在一定的过拟合,但是仍在可以接受的范围内。

image.png

关闭自动求导

在上面的过程中,我们涉及到一个问题,就是对于验证损失计算完以后,我们并没有调用backward(),那是因为我们只想用验证集数据来检查模型效果,而不希望验证集数据影响我们的模型训练,不然的话就相当于验证集数据也加入了训练,那就很难判断模型是否存在过拟合了。就像下图所写的,使用模型预测和计算损失的步骤是一样的,但是只对train_loss进行反向传播。

image.png

因此在验证过程中,我们实际不需要进行自动求导,但是如果我们前面都设置了自动求导怎么办呢,这会带来大量不必要的运算开销。于是PyTorch提供了关闭自动求导的方法,就是使用torch.no_grad()。

def training_loop(n_epochs, optimizer, params, train_t_u, val_t_u,
                  train_t_c, val_t_c):
    for epoch in range(1, n_epochs + 1):
        train_t_p = model(train_t_u, *params)
        train_loss = loss_fn(train_t_p, train_t_c)

        with torch.no_grad():  # 上下文管理器,关闭自动求导
            val_t_p = model(val_t_u, *params)
            val_loss = loss_fn(val_t_p, val_t_c)
            assert val_loss.requires_grad == False
            
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

这里还有另外一个方式,就是使用set_grad_enabled(),这个方法接收一个bool类型的参数,来设置是否自动求导。

def calc_forward(t_u, t_c, is_train): 
    with torch.set_grad_enabled(is_train): #这里传入了是否是训练这样一个bool类型来显示当前的前向传播是训练还是验证
        t_p = model(t_u, *params)
        loss = loss_fn(t_p, t_c)
    return loss

今天写的比较短,感觉轻松多了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/693051.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

国外学位论文去哪里查找下载

查找下载国外博士论文最合适的文献数据库就是ProQuest学位论文全文数据库。 ProQuest学位论文全文数据库覆盖了大部分北美地区高等院校以及世界其他地区数千个高等院校每年获得通过的博硕士论文。是将ProQuest公司PQDD文摘库&#xff08;现名PQDT&#xff09;中适合中国科研人…

Windows 11 22H2 中文版、英文版 (x64、ARM64) 下载 (updated Jun 2023)

Windows 11 绕过 TPM 方法总结&#xff0c;通用免 TPM 镜像下载 (2023 年 6 月更新) 在虚拟机、Mac 电脑和 TPM 不符合要求的旧电脑上安装 Windows 11 的通用方法总结 请访问原文链接&#xff1a;https://sysin.org/blog/windows-11-no-tpm/&#xff0c;查看最新版。原创作品…

nuxt3 多级动态路由

需求&#xff1a; 写法&#xff1a; 对应 文件目录 pages\product\[class]\[brand]\[SPU].vue pages/ --| product/ ----| [class] ------| [brand] --------| [SPU].vue script 内跳转方法 const router useRouter() const nuxtApp useNuxtApp()const jumpSPU () >…

caffeine和google-guava cache缓存使用详解和源码介绍

google-guava cache 1.pom引入其依赖 <dependency><groupId>com.google.guava</groupId><artifactId>guava</artifactId><version>20.0</version></dependency> 2.具体使用 com.google.common.cache.LoadingCache<Strin…

【selenium】问题记录

1、驱动和浏览器版本不一致 报错&#xff1a;selenium.common.exceptions.SessionNotCreatedException: Message: session not created: This version of ChromeDriver only supports Chrome version 106 问题原因&#xff1a; chrome版本114&#xff0c;Chromedriver版本106 …

机器学习之深度神经网络

目录 卷积神经网络与全连接神经网络 前向后向传播推导 通用手写体识别模型 人脸识别模型 电影评论情感分析模型 卷积神经网络与全连接神经网络 卷积神经网络&#xff08;Convolutional Neural Network&#xff0c;CNN&#xff09;和全连接神经网络&#xff08;Fully Conn…

Django学习笔记-用户名密码登录

笔记内容转载自 AcWing 的 Django 框架课讲义&#xff0c;课程链接&#xff1a;AcWing Django 框架课。 CONTENTS 1. 扩充Django数据库2. 实现获取用户信息3. 渲染登录与注册界面4. 实现登录与登出功能5. 实现注册功能6. 修改获取用户信息 1. 扩充Django数据库 首先我们先在 s…

JavaWeb学习路线(7)——文件上传

一、概念 &#xff08;一&#xff09;文件上传概念&#xff1a; 指将本地的图片、视频、音频等文件上传到服务器&#xff0c;供其他用户浏览或下载的过程。 &#xff08;二&#xff09;前端文件上传三元素 method“post”&#xff08;form&#xff09;enctype“multipart/for…

四、Bean 的作用域,Bean 的自动装配以及通过注解实现 Bean 的自动装配

文章目录 一、Bean 的作用域二、Bean 的自动装配三、通过注解实现 Bean 的自动装配 一、Bean 的作用域 Spring 官网 Bean 的作用域讲解 单例(Singleton)作用域&#xff1a;在这种作用域下&#xff0c;容器只会创建一个Bean实例对象&#xff0c;无论该Bean被注入到多少个其它B…

Unity使用MySQL

效果&#xff1a; 问题记录&#xff1a; unity mysql “The given key ‘utf8mb4‘ was not present in the dictionary” – 我这里数据库字符集没有utf8&#xff0c;改选utf8mb4 – 这个改了&#xff0c;那么MySQL配置文件也得改了。如下&#xff1a; – 然后还报错&…

字符、字符集、编码

一、基本概念 在计算机中&#xff0c;所有的内容都是以二进制数据存储的&#xff0c;而我们在屏幕上看到的字和符号以及看不到的字符都是二进制数据转换后的结果。将字符按照某种规则转成对应的二进制数据&#xff0c;这个过程称为编码&#xff1b;而相对应的&#xff0c;将二…

Azure获取linux服务器磁盘和控制台disk的对应关系

从Azure控制台上删除/卸载服务器上不用的磁盘时&#xff0c;需要确定服务器上磁盘和控制台上显示的磁盘的对应关系。以免当有多块磁盘时&#xff0c;卸载了错误的磁盘&#xff0c;引起生产事故。 通过LUN确定磁盘对应关系 什么是LUN&#xff1f; 逻辑单元号 (LUN) 是用于标识…

Vue之事件处理(v-on)

文章目录 前言一、v-on基本使用二、使用举例1.传参和不传参使用2.$event占位代表事件对象3.函数用箭头函数时this作用域4.正常未用箭头函数的this指向&#xff08;与未用箭头函数作比较&#xff09; 总结 前言 v-on&#xff1a;事件绑定 一、v-on基本使用 格式&#xff1a;&l…

Linux安装ElasticSearch和Kibana

es官网下载地址&#xff1a;https://www.elastic.co/cn/downloads/past-releases#elasticsearch 可以去官网下载包然后放到服务器 也可以使用wget进行下载安装 如果使用wget方式下载的话需要先安装 安装wget yum install -y wgetwget下载es&#xff1a;wget https://artifacts…

B+树的设计步骤

1.节点的结构&#xff08;如下图&#xff09; &#xff08;1&#xff09;键值对--key是标识&#xff1b;value是存储的具体数据 &#xff08;2&#xff09;节点的子节点--存储的是具体的子节点 &#xff08;3&#xff09;节点的后节点--标记后一个节点 &#xff08;4&#xff0…

JSP实现自定义标签【上】

目录 一、基础概念 1、标签语言的形式或结构 2、分类 二、自定义标签的开发及步骤 三、标签生命周期 1、返回值 四、案例 1、if 2、out 一、基础概念 JSP自定义标签是一种扩展JSP标记语言的方法。通过自定义标签&#xff0c;我们可以将自定义功能封装在一个独立的标签…

# rust abc(6): 字符串的简单使用

文章目录 1. 目的2. 数据类型2.1 str 类型2.2 标准库 String 类型 3. 常用 API3.1 len() 方法3.2 is_empty() 方法3.3 starts_with() 方法3.4 find() 方法 4. References 1. 目的 学习 Rust 语言中的字符串&#xff0c; 包括数据类型&#xff0c; 常用 API。 2. 数据类型 Ru…

新手入门:从零搭建vue3+webpack实战项目模板

搭建一个 vue3 webpack5 element-plus 基本模板 &#xff08;vue3 webpack5 从零配置项目&#xff09;。 本项目结构可以作为实战项目的基本结构搭建学习&#xff0c;作为刚学习完vue还没有实战项目经验的小伙伴练习比较合适。 项目地址&#xff1a; GitHub&#xff1a;ht…

如何将手写笔记转换成电子版格式?

记笔记是一种非常有效的学习方法。它不仅可以帮助我们加深对所学内容的理解&#xff0c;还能让我们收集更多有用的信息&#xff0c;以方便后续的查看和复习。不过&#xff0c;用传统的纸质笔记本记录笔记存在一定的弊端&#xff0c;比如说不易保存、不易携带等等。所以&#xf…

Mac下的java.io.FileNotFoundException: ~/Desktop/a.sql (No such file or directory)

【问题】&#xff1a; 今天在运行一个文件读取的Demo时&#xff0c;报如下错误: java.io.FileNotFoundException: ~/Desktop/a.sql (No such file or directory)如下图所示 &#xff1a; 可是这个文件命名可以通过终端窗口访问到啊&#xff1f; 【解决方案】&#xff…