【深度学习笔记】3_11 模型选择、欠拟合和过拟合

news2025/1/12 1:09:09

注:本文为《动手学深度学习》开源内容,做了部分个人理解标注,仅为个人学习记录,无抄袭搬运意图

3.11 模型选择、欠拟合和过拟合

在前几节基于Fashion-MNIST数据集的实验中,我们评价了机器学习模型在训练数据集和测试数据集上的表现。如果你改变过实验中的模型结构或者超参数,你也许发现了:当模型在训练数据集上更准确时,它在测试数据集上却不一定更准确。这是为什么呢?

3.11.1 训练误差和泛化误差

在解释上述现象之前,我们需要区分训练误差(training error)和泛化误差(generalization error)。通俗来讲,前者指模型在训练数据集上表现出的误差,后者指模型在任意一个测试数据样本上表现出的误差的期望,并常常通过测试数据集上的误差来近似。计算训练误差和泛化误差可以使用之前介绍过的损失函数,例如线性回归用到的平方损失函数和softmax回归用到的交叉熵损失函数。

让我们以高考为例来直观地解释训练误差和泛化误差这两个概念。训练误差可以认为是做往年高考试题(训练题)时的错误率,泛化误差则可以通过真正参加高考(测试题)时的答题错误率来近似。假设训练题和测试题都随机采样于一个未知的依照相同考纲的巨大试题库。如果让一名未学习中学知识的小学生去答题,那么测试题和训练题的答题错误率可能很相近。但如果换成一名反复练习训练题的高三备考生答题,即使在训练题上做到了错误率为0,也不代表真实的高考成绩会如此。

在机器学习里,我们通常假设训练数据集(训练题)和测试数据集(测试题)里的每一个样本都是从同一个概率分布中相互独立地生成的。基于该独立同分布假设,给定任意一个机器学习模型(含参数),它的训练误差的期望和泛化误差都是一样的。例如,如果我们将模型参数设成随机值(小学生),那么训练误差和泛化误差会非常相近。但我们从前面几节中已经了解到,模型的参数是通过在训练数据集上训练模型而学习出的,参数的选择依据了最小化训练误差(高三备考生)。所以,训练误差的期望小于或等于泛化误差。也就是说,一般情况下,由训练数据集学到的模型参数会使模型在训练数据集上的表现优于或等于在测试数据集上的表现。由于无法从训练误差估计泛化误差,一味地降低训练误差并不意味着泛化误差一定会降低。

机器学习模型应关注降低泛化误差。

3.11.2 模型选择

在机器学习中,通常需要评估若干候选模型的表现并从中选择模型。这一过程称为模型选择(model selection)。可供选择的候选模型可以是有着不同超参数的同类模型。以多层感知机为例,我们可以选择隐藏层的个数,以及每个隐藏层中隐藏单元个数和激活函数。为了得到有效的模型,我们通常要在模型选择上下一番功夫。下面,我们来描述模型选择中经常使用的验证数据集(validation data set)。

3.11.2.1 验证数据集

从严格意义上讲,测试集只能在所有超参数和模型参数选定后使用一次。不可以使用测试数据选择模型,如调参。由于无法从训练误差估计泛化误差,因此也不应只依赖训练数据选择模型。鉴于此,我们可以预留一部分在训练数据集和测试数据集以外的数据来进行模型选择。这部分数据被称为验证数据集,简称验证集(validation set)。例如,我们可以从给定的训练集中随机选取一小部分作为验证集,而将剩余部分作为真正的训练集。

然而在实际应用中,由于数据不容易获取,测试数据极少只使用一次就丢弃。因此,实践中验证数据集和测试数据集的界限可能比较模糊。从严格意义上讲,除非明确说明,否则本书中实验所使用的测试集应为验证集,实验报告的测试结果(如测试准确率)应为验证结果(如验证准确率)。

3.11.2.3 K K K折交叉验证

由于验证数据集不参与模型训练,当训练数据不够用时,预留大量的验证数据显得太奢侈。一种改善的方法是 K K K折交叉验证( K K K-fold cross-validation)。在 K K K折交叉验证中,我们把原始训练数据集分割成 K K K个不重合的子数据集,然后我们做 K K K次模型训练和验证。每一次,我们使用一个子数据集验证模型,并使用其他 K − 1 K-1 K1个子数据集来训练模型。在这 K K K次训练和验证中,每次用来验证模型的子数据集都不同。最后,我们对这 K K K次训练误差和验证误差分别求平均。这里没看明白的可以参考【人工智能概论】 K折交叉验证辅助理解。

3.11.3 欠拟合和过拟合

接下来,我们将探究模型训练中经常出现的两类典型问题:一类是模型无法得到较低的训练误差,我们将这一现象称作欠拟合(underfitting);另一类是模型的训练误差远小于它在测试数据集上的误差,我们称该现象为过拟合(overfitting)。在实践中,我们要尽可能同时应对欠拟合和过拟合。虽然有很多因素可能导致这两种拟合问题,在这里我们重点讨论两个因素:模型复杂度和训练数据集大小。这里建议参考机器学习知识总结 —— 6. 什么是过拟合和欠拟合辅助理解,个人感觉讲得很通俗易懂,适合小白
在这里插入图片描述

关于模型复杂度和训练集大小对学习的影响的详细理论分析可参见笔记原作者这篇博客。

3.11.3.1 模型复杂度

为了解释模型复杂度,我们以多项式函数拟合为例。给定一个由标量数据特征 x x x和对应的标量标签 y y y组成的训练数据集,多项式函数拟合的目标是找一个 K K K阶多项式函数

y ^ = b + ∑ k = 1 K x k w k \hat{y} = b + \sum_{k=1}^K x^k w_k y^=b+k=1Kxkwk

来近似 y y y。在上式中, w k w_k wk是模型的权重参数, b b b是偏差参数。与线性回归相同,多项式函数拟合也使用平方损失函数。特别地,一阶多项式函数拟合又叫线性函数拟合。

因为高阶多项式函数模型参数更多,模型函数的选择空间更大,所以高阶多项式函数比低阶多项式函数的复杂度更高。因此,高阶多项式函数比低阶多项式函数更容易在相同的训练数据集上得到更低的训练误差。给定训练数据集,模型复杂度和误差之间的关系通常如图3.4所示。给定训练数据集,如果模型的复杂度过低,很容易出现欠拟合;如果模型复杂度过高,很容易出现过拟合。应对欠拟合和过拟合的一个办法是针对数据集选择合适复杂度的模型。

在这里插入图片描述

3.11.3.2 训练数据集大小

影响欠拟合和过拟合的另一个重要因素是训练数据集的大小。一般来说,如果训练数据集中样本数过少,特别是比模型参数数量(按元素计)更少时,过拟合更容易发生。此外,泛化误差不会随训练数据集里样本数量增加而增大。因此,在计算资源允许的范围之内,我们通常希望训练数据集大一些,特别是在模型复杂度较高时,例如层数较多的深度学习模型。(关于模型复杂度的计算可参考理解深度学习模型复杂度评估辅助理解,建议先知道有这么些内容,后面学到的文中提到的内容能在此基础上快速消化)

3.11.4 多项式函数拟合实验

为了理解模型复杂度和训练数据集大小对欠拟合和过拟合的影响,下面我们以多项式函数拟合为例来实验。首先导入实验需要的包或模块。

%matplotlib inline
import torch
import numpy as np
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l

3.11.4.1 生成数据集

我们将生成一个人工数据集。在训练数据集和测试数据集中,给定样本特征 x x x,我们使用如下的三阶多项式函数来生成该样本的标签:

y = 1.2 x − 3.4 x 2 + 5.6 x 3 + 5 + ϵ , y = 1.2x - 3.4x^2 + 5.6x^3 + 5 + \epsilon, y=1.2x3.4x2+5.6x3+5+ϵ,

其中噪声项 ϵ \epsilon ϵ服从均值为0、标准差为0.01的正态分布。训练数据集和测试数据集的样本数都设为100。

n_train, n_test, true_w, true_b = 100, 100, [1.2, -3.4, 5.6], 5
features = torch.randn((n_train + n_test, 1))
poly_features = torch.cat((features, torch.pow(features, 2), torch.pow(features, 3)), 1) 
labels = (true_w[0] * poly_features[:, 0] + true_w[1] * poly_features[:, 1]
          + true_w[2] * poly_features[:, 2] + true_b)
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)

这段代码生成了一个具有 200 个数据点(100 个训练点和 100 个测试点)的数据集,其中每个数据点都包含一个特征值以及对应的三次多项式特征(原始特征、平方特征、立方特征),以及一个由这些特征和真实权重、偏置计算得出的标签值(加上一些随机噪声)。这样的数据集可以用于训练一个模型来拟合这个三次多项式函数。

看一看生成的数据集的前两个样本。

features[:2], poly_features[:2], labels[:2]

输出:

(tensor([[-1.0613],
         [-0.8386]]), tensor([[-1.0613,  1.1264, -1.1954],
         [-0.8386,  0.7032, -0.5897]]), tensor([-6.8037, -1.7054]))

3.11.4.2 定义、训练和测试模型

我们先定义作图函数semilogy,其中 y y y 轴使用了对数尺度。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None,
             legend=None, figsize=(3.5, 2.5)):
    d2l.set_figsize(figsize)
    d2l.plt.xlabel(x_label)
    d2l.plt.ylabel(y_label)
    d2l.plt.semilogy(x_vals, y_vals)
    if x2_vals and y2_vals:
        d2l.plt.semilogy(x2_vals, y2_vals, linestyle=':')
        d2l.plt.legend(legend)

和线性回归一样,多项式函数拟合也使用平方损失函数。因为我们将尝试使用不同复杂度的模型来拟合生成的数据集,所以我们把模型定义部分放在fit_and_plot函数中。多项式函数拟合的训练和测试步骤与3.6节(softmax回归的从零开始实现)介绍的softmax回归中的相关步骤类似。

num_epochs, loss = 100, torch.nn.MSELoss()

def fit_and_plot(train_features, test_features, train_labels, test_labels):
    net = torch.nn.Linear(train_features.shape[-1], 1)
    # 通过Linear文档可知,pytorch已经将参数初始化了,所以我们这里就不手动初始化了
    
    batch_size = min(10, train_labels.shape[0])    
    dataset = torch.utils.data.TensorDataset(train_features, train_labels)
    train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)
    
    optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
    train_ls, test_ls = [], []
    for _ in range(num_epochs):
        for X, y in train_iter:
            l = loss(net(X), y.view(-1, 1))
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
        train_labels = train_labels.view(-1, 1)
        test_labels = test_labels.view(-1, 1)
        train_ls.append(loss(net(train_features), train_labels).item())
        test_ls.append(loss(net(test_features), test_labels).item())
    print('final epoch: train loss', train_ls[-1], 'test loss', test_ls[-1])
    semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
             range(1, num_epochs + 1), test_ls, ['train', 'test'])
    print('weight:', net.weight.data,
          '\nbias:', net.bias.data)

3.11.4.3 三阶多项式函数拟合(正常)

我们先使用与数据生成函数同阶的三阶多项式函数拟合。实验表明,这个模型的训练误差和在测试数据集的误差都较低。训练出的模型参数也接近真实值: w 1 = 1.2 , w 2 = − 3.4 , w 3 = 5.6 , b = 5 w_1 = 1.2, w_2=-3.4, w_3=5.6, b = 5 w1=1.2,w2=3.4,w3=5.6,b=5

fit_and_plot(poly_features[:n_train, :], poly_features[n_train:, :], 
            labels[:n_train], labels[n_train:])

输出:

final epoch: train loss 0.00010175639908993617 test loss 9.790256444830447e-05
weight: tensor([[ 1.1982, -3.3992,  5.6002]]) 
bias: tensor([5.0014])

在这里插入图片描述

3.11.4.4 线性函数拟合(欠拟合)

我们再试试线性函数拟合。很明显,该模型的训练误差在迭代早期下降后便很难继续降低。在完成最后一次迭代周期后,训练误差依旧很高。线性模型在非线性模型(如三阶多项式函数)生成的数据集上容易欠拟合。

fit_and_plot(features[:n_train, :], features[n_train:, :], labels[:n_train],
             labels[n_train:])

输出:

final epoch: train loss 249.35157775878906 test loss 168.37705993652344
weight: tensor([[19.4123]]) 
bias: tensor([0.5805])

在这里插入图片描述

3.11.4.5 训练样本不足(过拟合)

事实上,即便使用与数据生成模型同阶的三阶多项式函数模型,如果训练样本不足,该模型依然容易过拟合。让我们只使用两个样本来训练模型。显然,训练样本过少了,甚至少于模型参数的数量。这使模型显得过于复杂,以至于容易被训练数据中的噪声影响。在迭代过程中,尽管训练误差较低,但是测试数据集上的误差却很高。这是典型的过拟合现象。

fit_and_plot(poly_features[0:2, :], poly_features[n_train:, :], labels[0:2],
             labels[n_train:])

输出:

final epoch: train loss 1.198514699935913 test loss 166.037109375
weight: tensor([[1.4741, 2.1198, 2.5674]]) 
bias: tensor([3.1207])

在这里插入图片描述

我们将在接下来的两个小节继续讨论过拟合问题以及应对过拟合的方法。

小结

  • 由于无法从训练误差估计泛化误差,一味地降低训练误差并不意味着泛化误差一定会降低。机器学习模型应关注降低泛化误差。
  • 可以使用验证数据集来进行模型选择。
  • 欠拟合指模型无法得到较低的训练误差,过拟合指模型的训练误差远小于它在测试数据集上的误差。
  • 应选择复杂度合适的模型并避免使用过少的训练样本。

注:本节除了代码之外与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1471409.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OpenAI-Sora:最新文生视频教程-Sora怎么用(新手小白)

Sora 是什么? Open AI 宣布推出全新的生成式人工智能模型“Sora”。据了解,通过文本指令,Sora 可以直接输出长达 60 秒的视频,并且包含高度细致的背景、复杂的多角度镜头,以及富有情感的多个角色。 - 继文本、图像之后…

荣获国家高新技术企业认证,苹芯科技领航AI芯片产业发展

北京苹芯科技有限公司(以下简称“苹芯科技”)凭借在存算一体芯片领域的卓越表现,荣获国家高新技术企业认证。这一荣誉不仅是对苹芯科技多年来在科技创新、产品研发等方面所取得成果的肯定,更是对其未来发展潜力的认可。 苹芯科技…

解析OOM的三大场景,原因及实战解决方案

目录 一、什么是OOM 二、堆内存溢出(Heap OOM) 三、方法区内存溢出(Metaspace OOM) 四、栈内存溢出(Stack OOM) 一、什么是OOM OOM 是 Out Of Memory 的缩写,意思是内存耗尽。在计算机领域…

力扣思路题:丑数

此题的思路非常奇妙,可以借鉴一下 bool isUgly(int num){if(num0)return false;while(num%20)num/2;while(num%30)num/3;while(num%50)num/5;return num1; }

树-王道-复试

树 1.度: 树中孩子节点个数,所有结点的度最大值为 树的度 2.有序树: 逻辑上看,树中结点的各子树从左至右是有次序的,不能互换。 **3.**树的根节点没有前驱,其他节点只有一个前驱 **4.**所有节点可有零个或…

[Linux]文件基础-如何管理文件

回顾C语言之 - 文件如何被写入 fopen fwrite fread fclose fseek … 这一系列函数都是C语言中对文件进行的操作: int main() {FILE* fpfopen("text","w");char str[20]"write into text";fputs(str,fp);fclose(fp);return 0; }而上…

全域增长方法论:帮助品牌实现科学经营,助力长效生意增长

前两年由于疫情反复、供给需求收缩等条件制约,品牌业务均受到不同程度的影响。以双十一和618电商大促为例,就相比往年颇显“惨淡”,大多品牌营销都无法达到理想预期。 随着市场环境不断开放,2023年营销行业开始从低迷期走上了高速…

最新IE跳转Edge浏览器解决办法(2024.2.26)

最新IE跳转Edge浏览器解决办法(2024.2.26) 1. IE跳转原因1.1. 原先解决办法1.2. 最新解决办法1.3. 最后 1. IE跳转原因 关于IE跳转问题是由于在2023年2月14日,微软正式告别IE浏览器,导致很多使用Windows10系统的电脑在打开IE浏览…

【计算机科学引论 Computing Essentials 2021】【名词术语】【第7章】

Computing Essentials Chapter 7: Secondary Storage 二级存储 MATCHING Match each numbered item with the most closely related lettered item. Write your answers in the spaces provided. Choices a. DVD (Digital Versatile Disc) b. file compression c. hi-def…

pytorch -- torch.nn下的常用损失函数

1.基础 loss function损失函数:预测输出与实际输出 差距 越小越好 - 计算实际输出和目标之间的差距 - 为我们更新输出提供依据(反向传播) 1. L1 torch.nn.L1Loss(size_averageNone, reduceNone, reduction‘mean’) 2. 平方差(…

蓝桥杯-成绩分析

许久不敲代码&#xff0c;库名也忘了&#xff0c;精度设置还有求最大最小值都是常规题了。 #include <iostream> #include <iomanip> using namespace std; int main() { //一种不用开数组的方法 int n; cin>>n; int top0; int low100;//确定最大…

【MySQL面试复习】详细说下事务的特性

系列文章目录 在MySQL中&#xff0c;如何定位慢查询&#xff1f; 发现了某个SQL语句执行很慢&#xff0c;如何进行分析&#xff1f; 了解过索引吗&#xff1f;(索引的底层原理)/B 树和B树的区别是什么&#xff1f; 什么是聚簇索引&#xff08;聚集索引&#xff09;和非聚簇索引…

外贸获客必读!2个月内成交2单,EDM如何避免垃圾箱成功拓展新客户?

近期&#xff0c;不少伙伴反馈在使用EDM邮件开发时遇到了一些问题&#xff0c;例如邮件发不出去&#xff0c;或者直接进入垃圾箱。在这篇文章中&#xff0c;我们将分享某用户是如何解决这些问题&#xff0c;并通过实战案例成功成交了2个客户。 搜索&#xff1a;focussend.com &…

[云原生] 二进制安装K8S(中)部署网络插件和DNS

书接上文&#xff0c;我们继续部署剩余的插件 一、K8s的CNI网络插件模式 2.1 k8s的三种网络模式 K8S 中 Pod 网络通信&#xff1a; &#xff08;1&#xff09;Pod 内容器与容器之间的通信 在同一个 Pod 内的容器&#xff08;Pod 内的容器是不会跨宿主机的&#xff09;共享…

API保障——电子商务安全性与稳定性设计

在这次深入探讨中&#xff0c;我们将深入了解API设计&#xff0c;从基础知识开始&#xff0c;逐步进阶到定义出色API的最佳实践。 作为开发者&#xff0c;你可能对许多这些概念很熟悉&#xff0c;但我将提供详细的解释&#xff0c;以加深你的理解。 ​ API设计&#xff1a;电…

MySQL数据库应用与开发(全)

一、 MySQL数据库基本操作 1、创建表 #创建数据库 CREATE DATABASE teaching CHARSET utf8; #贴换teaching为当前数据库 USE teaching;二、数据表的创建 1、创建学生表 CREATE TABLE IF NOT EXISTS student (studentno CHAR(11) NOT NULL COMMENT 学号,sname CHAR(8) NOT N…

TextCNN:文本分类卷积神经网络

模型原理 1、前言2、模型结构3、示例3.1、词向量层3.2、卷积层3.3、最大池化层3.4、Fully Connected层 4、总结 1、前言 TextCNN 来源于《Convolutional Neural Networks for Sentence Classification》发表于2014年&#xff0c;是一个经典的模型&#xff0c;Yoon Kim将卷积神…

Java 中常用的数据结构类 API

目录 常用数据结构API 对应的线程安全的api 高可用衡量标准 常用数据结构API ArrayList: 实现了动态数组&#xff0c;允许快速随机访问元素。 import java.util.ArrayList; LinkedList: 实现了双向链表&#xff0c;适用于频繁插入和删除操作。 import java.util.LinkedLis…

「连载」边缘计算(十九)02-22:边缘部分源码(源码分析篇)

&#xff08;接上篇&#xff09; 从启动函数Start(&#xff09;中可以看到&#xff0c;其以go routine的方式启动很多后台处理服务&#xff0c;具体如下。 1&#xff09;初始化edged的kubeClient&#xff0c;具体如下所示。 // use self defined client to replace fake kube…

深度神经网络中的计算和内存带宽

深度神经网络中的计算和内存带宽 文章目录 深度神经网络中的计算和内存带宽来源原理介绍分析1&#xff1a;线性层分析2&#xff1a;卷积层分析3&#xff1a;循环层总结 来源 相关知识来源于这里。 原理介绍 Memory bandwidth and data re-use in deep neural network computat…