【从零开始学习深度学习】12. 什么是模型的训练误差?基于三阶多项式的欠拟合与过拟合训练过程演示

news2024/9/20 1:05:20

目录

  • 前言
    • 1.1 训练误差和泛化误差
    • 1.2 模型选择
    • 1.3 欠拟合和过拟合
      • 1.3.1 模型复杂度
      • 1.3.2 训练数据集大小
    • 1.4 多项式函数拟合示例
      • 1.4.1 生成数据集
      • 1.4.2 定义、训练和测试模型
      • 1.4.3 三阶多项式函数拟合(正常)
      • 1.4.4 线性函数拟合(欠拟合)
      • 1.4.5 训练样本不足(过拟合)
    • 总结

前言

前几篇文章基于Fashion-MNIST数据集的实验中,我们评价了机器学习模型在训练数据集和测试数据集上的表现。如果你改变过实验中的模型结构或者超参数,你也许发现了:当模型在训练数据集上更准确时,它在测试数据集上却不一定更准确。这是为什么呢?

1.1 训练误差和泛化误差

通俗来讲,训练误差(training error)指模型在训练数据集上表现出的误差和泛化误差(generalization error)指模型在任意一个测试数据样本上表现出的误差的期望,并常常通过测试数据集上的误差来近似。计算训练误差和泛化误差可以使用之前介绍过的损失函数,例如线性回归用到的平方损失函数和softmax回归用到的交叉熵损失函数。

在机器学习里,我们通常假设训练数据集(训练题)和测试数据集(测试题)里的每一个样本都是从同一个概率分布中相互独立地生成的。基于该独立同分布假设,给定任意一个机器学习模型(含参数),它的训练误差的期望和泛化误差都是一样的。例如,如果我们将模型参数设成随机值,那么训练误差和泛化误差会非常相近。但我们从前面几节中已经了解到,模型的参数是通过在训练数据集上训练模型而学习出的,参数的选择依据了最小化训练误差。所以,训练误差的期望小于或等于泛化误差。也就是说,一般情况下,由训练数据集学到的模型参数会使模型在训练数据集上的表现优于或等于在测试数据集上的表现。由于无法从训练误差估计泛化误差,一味地降低训练误差并不意味着泛化误差一定会降低

机器学习模型应关注降低泛化误差

1.2 模型选择

在机器学习中,通常需要评估若干候选模型的表现并从中选择模型。这一过程称为模型选择(model selection)。可供选择的候选模型可以是有着不同超参数的同类模型。以多层感知机为例,我们可以选择隐藏层的个数,以及每个隐藏层中隐藏单元个数和激活函数。为了得到有效的模型,我们通常要在模型选择上下一番功夫。下面,我们来描述模型选择中经常使用的验证数据集(validation data set)。

1.2.1 验证数据集

从严格意义上讲,测试集只能在所有超参数和模型参数选定后使用一次。不可以使用测试数据选择模型,如调参。由于无法从训练误差估计泛化误差,因此也不应只依赖训练数据选择模型。鉴于此,我们可以预留一部分在训练数据集和测试数据集以外的数据来进行模型选择。这部分数据被称为验证数据集,简称验证集(validation set)。例如,我们可以从给定的训练集中随机选取一小部分作为验证集,而将剩余部分作为真正的训练集。

1.2.3 K K K折交叉验证

由于验证数据集不参与模型训练,当训练数据不够用时,预留大量的验证数据显得太奢侈。一种改善的方法是 K K K折交叉验证( K K K-fold cross-validation)。在 K K K折交叉验证中,我们把原始训练数据集分割成 K K K个不重合的子数据集,然后我们做 K K K次模型训练和验证。每一次,我们使用一个子数据集验证模型,并使用其他 K − 1 K-1 K1个子数据集来训练模型。在这 K K K次训练和验证中,每次用来验证模型的子数据集都不同。最后,我们对这K次训练误差和验证误差分别求平均

1.3 欠拟合和过拟合

模型训练中经常出现的两类典型问题:一类是模型无法得到较低的训练误差,我们将这一现象称作欠拟合(underfitting);另一类是模型的训练误差远小于它在测试数据集上的误差,我们称该现象为过拟合(overfitting)。在实践中,我们要尽可能同时应对欠拟合和过拟合。虽然有很多因素可能导致这两种拟合问题,在这里我们重点讨论两个因素:模型复杂度和训练数据集大小。

1.3.1 模型复杂度

为了解释模型复杂度,我们以多项式函数拟合为例。给定一个由标量数据特征 x x x和对应的标量标签 y y y组成的训练数据集,多项式函数拟合的目标是找一个 K K K阶多项式函数

y ^ = b + ∑ k = 1 K x k w k \hat{y} = b + \sum_{k=1}^K x^k w_k y^=b+k=1Kxkwk

来近似 y y y。在上式中, w k w_k wk是模型的权重参数, b b b是偏差参数。与线性回归相同,多项式函数拟合也使用平方损失函数。特别地,一阶多项式函数拟合又叫线性函数拟合。

因为高阶多项式函数模型参数更多,模型函数的选择空间更大,所以高阶多项式函数比低阶多项式函数的复杂度更高。因此,高阶多项式函数比低阶多项式函数更容易在相同的训练数据集上得到更低的训练误差。给定训练数据集,模型复杂度和误差之间的关系通常如下图所示。给定训练数据集,如果模型的复杂度过低,很容易出现欠拟合;如果模型复杂度过高,很容易出现过拟合。应对欠拟合和过拟合的一个办法是针对数据集选择合适复杂度的模型

在这里插入图片描述

1.3.2 训练数据集大小

影响欠拟合和过拟合的另一个重要因素是训练数据集的大小。一般来说,如果训练数据集中样本数过少,特别是比模型参数数量(按元素计)更少时,过拟合更容易发生。此外,泛化误差不会随训练数据集里样本数量增加而增大。因此,在计算资源允许的范围之内,我们通常希望训练数据集大一些,特别是在模型复杂度较高时,例如层数较多的深度学习模型。

1.4 多项式函数拟合示例

为了理解模型复杂度和训练数据集大小对欠拟合和过拟合的影响,下面我们以多项式函数拟合为例来实验。首先导入实验需要的包或模块。

%matplotlib inline
import torch
import numpy as np
import sys
import d2lzh_pytorch as d2l

1.4.1 生成数据集

我们将生成一个人工数据集。在训练数据集和测试数据集中,给定样本特征 x x x,我们使用如下的三阶多项式函数来生成该样本的标签:

y = 1.2 x − 3.4 x 2 + 5.6 x 3 + 5 + ϵ , y = 1.2x - 3.4x^2 + 5.6x^3 + 5 + \epsilon, y=1.2x3.4x2+5.6x3+5+ϵ,

其中噪声项 ϵ \epsilon ϵ服从均值为0、标准差为0.01的正态分布。训练数据集和测试数据集的样本数都设为100。

n_train, n_test, true_w, true_b = 100, 100, [1.2, -3.4, 5.6], 5
# torch.randn:用来生成随机数字的tensor,这些随机数字满足标准正态分布(0~1)
features = torch.randn((n_train + n_test, 1))
# torch.cat()把多个tensor进行拼接,poly_features.shape=200*3
poly_features = torch.cat((features, torch.pow(features, 2), torch.pow(features, 3)), 1) 
labels = (true_w[0] * poly_features[:, 0] + true_w[1] * poly_features[:, 1]
          + true_w[2] * poly_features[:, 2] + true_b)
# np.random.normal()的意思是一个正态分布,噪声
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)

看一看生成的数据集的前两个样本。

features[:2], poly_features[:2], labels[:2]

输出:

(tensor([[-1.0613],
         [-0.8386]]), tensor([[-1.0613,  1.1264, -1.1954],
         [-0.8386,  0.7032, -0.5897]]), tensor([-6.8037, -1.7054]))

1.4.2 定义、训练和测试模型

我们先定义作图函数semilogy,其中 y y y 轴使用了对数尺度。

def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None,
             legend=None, figsize=(3.5, 2.5)):
    d2l.set_figsize(figsize)
    d2l.plt.xlabel(x_label)
    d2l.plt.ylabel(y_label)
    d2l.plt.semilogy(x_vals, y_vals)
    if x2_vals and y2_vals:
        d2l.plt.semilogy(x2_vals, y2_vals, linestyle=':')
        d2l.plt.legend(legend)

和线性回归一样,多项式函数拟合也使用平方损失函数。因为我们将尝试使用不同复杂度的模型来拟合生成的数据集,所以我们把模型定义部分放在fit_and_plot函数中。

num_epochs, loss = 100, torch.nn.MSELoss()

def fit_and_plot(train_features, test_features, train_labels, test_labels):
    net = torch.nn.Linear(train_features.shape[-1], 1)
    # 通过Linear文档可知,pytorch已经将参数初始化了,所以我们这里就不手动初始化了
    
    batch_size = min(10, train_labels.shape[0])    
    
    dataset = torch.utils.data.TensorDataset(train_features, train_labels)
    train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)
    optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
    
    train_ls, test_ls = [], []
    for _ in range(num_epochs):
        for X, y in train_iter:
            l = loss(net(X), y.view(-1, 1))
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
        train_labels = train_labels.view(-1, 1)
        test_labels = test_labels.view(-1, 1)
        train_ls.append(loss(net(train_features), train_labels).item())
        test_ls.append(loss(net(test_features), test_labels).item())
    
    print('final epoch: train loss', train_ls[-1], 'test loss', test_ls[-1])
    semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
             range(1, num_epochs + 1), test_ls, ['train', 'test'])
    print('weight:', net.weight.data,
          '\nbias:', net.bias.data)

1.4.3 三阶多项式函数拟合(正常)

我们先使用与数据生成函数同阶的三阶多项式函数拟合。实验表明,这个模型的训练误差和在测试数据集的误差都较低。训练出的模型参数也接近真实值: w 1 = 1.2 , w 2 = − 3.4 , w 3 = 5.6 , b = 5 w_1 = 1.2, w_2=-3.4, w_3=5.6, b = 5 w1=1.2,w2=3.4,w3=5.6,b=5

fit_and_plot(poly_features[:n_train, :], poly_features[n_train:, :], 
            labels[:n_train], labels[n_train:])

输出:

final epoch: train loss 0.00010175639908993617 test loss 9.790256444830447e-05
weight: tensor([[ 1.1982, -3.3992,  5.6002]]) 
bias: tensor([5.0014])

在这里插入图片描述

1.4.4 线性函数拟合(欠拟合)

我们再试试线性函数拟合。很明显,该模型的训练误差在迭代早期下降后便很难继续降低。在完成最后一次迭代周期后,训练误差依旧很高。线性模型在非线性模型(如三阶多项式函数)生成的数据集上容易欠拟合。

fit_and_plot(features[:n_train, :], features[n_train:, :], labels[:n_train],
             labels[n_train:])

输出:

final epoch: train loss 249.35157775878906 test loss 168.37705993652344
weight: tensor([[19.4123]]) 
bias: tensor([0.5805])

在这里插入图片描述

1.4.5 训练样本不足(过拟合)

事实上,即便使用与数据生成模型同阶的三阶多项式函数模型,如果训练样本不足,该模型依然容易过拟合。让我们只使用两个样本来训练模型。显然,训练样本过少了,甚至少于模型参数的数量。这使模型显得过于复杂,以至于容易被训练数据中的噪声影响。在迭代过程中,尽管训练误差较低,但是测试数据集上的误差却很高。这是典型的过拟合现象。

fit_and_plot(poly_features[0:2, :], poly_features[n_train:, :], labels[0:2],
             labels[n_train:])

输出:

final epoch: train loss 1.198514699935913 test loss 166.037109375
weight: tensor([[1.4741, 2.1198, 2.5674]]) 
bias: tensor([3.1207])

在这里插入图片描述

总结

  • 由于无法从训练误差估计泛化误差,一味地降低训练误差并不意味着泛化误差一定会降低。机器学习模型应关注降低泛化误差。
  • 可以使用验证数据集来进行模型选择。
  • 欠拟合指模型无法得到较低的训练误差,过拟合指模型的训练误差远小于它在测试数据集上的误差。
  • 应选择复杂度合适的模型并避免使用过少的训练样本。

如果内容对你有帮助,感谢点赞+关注哦!

关注下方GZH:阿旭算法与机器学习,可获取更多干货内容~欢迎共同学习交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/68636.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Docker教程系列】Docker学习5-Docker镜像理解

通过前面几篇文章的学习,我们已经安装好了Docker,也学会使用一些常用的命令。比如启动命令、镜像命令、容器命令。常用命令分类后的第二个就是镜像命令。那么镜像是什么?拉取镜像的时候为什么是一层一层的?镜像加载的原理是什么&a…

LeetCode #2.两数相加

力扣 | 2.两数相加 题目截图 官方方法 总共有三个链表,l1,l2和最后的结果链表。l1和l2各自滑动对应链表各个位数支出相应的值,最后相加,将结果添加到新链表上。 唯一需要考虑的是数值相加后的进位。 先设置进位carry为0 结果相加的时候利…

[附源码]计算机毕业设计考试系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

如何系统全面的成为一个网络工程师?看完这个你就懂了

如何系统全面的成为一个网络工程师?总有粉丝这样问我,今天就来给大家推荐一个方法。 想要真正系统全面地成为一个网络工程师,掌握系统全面的网络技术的话,我这边建议的话还是完整的学一下华为认证HCIA、HCIP、HCIE,这三…

JavaScript -- Date对象及常用方法介绍

文章目录Date1 Date介绍2 常用方法介绍3 创建指定日期的Date对象4 日期的格式化Date Date - JavaScript 1 Date介绍 在JS中所有的和时间相关的数据都由Date对象来表示 创建Date对象 let d new Date() // 直接通过new Date()创建时间对象时,它创建的是当前的时…

双12哪些数码好物实用不吃灰?2022最值得入手的数码好物推荐

年末的双十二来了,相信不少人像我一样已经将自己想买的一些数码产品加入购物车啦。实用性强的数码产品才是最好的,才能最大程度地体现它的价值。那么,有哪些数码好物实用性强且不吃灰呢?下面,我来给大家推荐几款2022最…

电力系统的延时功率流 (CPF)的计算【 IEEE-14节点】(Matlab代码实现)

💥💥💥💞💞💞欢迎来到本博客❤️❤️❤️💥💥💥 📝目前更新:🌟🌟🌟电力系统相关知识,期刊论文&…

计算机网络-排队时延和分组丢失

有志者,事竟成 文章目录一、描述1、排队时延和分组丢失二、总结一、描述 1、排队时延和分组丢失 每台分组交换机有多条链路与之相连。对于每条相连的链路,该分组交换机具有一个输出缓存也称为输出队列,它用于存储路由器准备发往那条链路的分…

数字图像处理实验(四)|图像压缩与编码实验{JPGE编码、离散余弦变换DCT、图像分块dctmtx|blkproc}(附matlab实验代码和截图)

文章目录一、 实验目的二、 实验原理1.图像压缩基本概念及原理(1)无损压缩编码种类(2)有损压缩编码种类(3)混合编码2. JPEG压缩编码原理(1)使用正向离散余弦变换(forward discrete cosine transform&#x…

压缩与打包

压缩文件 准备工作 命令:gzip -v anaconda-ks.cfg 压缩之后,原文件删除,只有压缩文件 解压文件 解压缩之后,压缩文件删除,生成原文件 tar命令 生成打包文件 清空/tmp/part1目录 将/etc目录拷贝到/tmp/part1…

展望未来 | Google Play 与时俱进,奔赴下一个十年

作者 / Google Play 产品总监 Alex Musil对于 Google Play 来说,今年是具有里程碑意义的重要年份。不久前,我们迎来了十周年纪念,而现在,我们将展望未来十年,思考我们的平台需要怎样发展完善,与时俱进。为此…

浏览器格式化文件代码以及浏览器调试小技巧

前端开发之浏览器格式化文件代码前言效果图1、在谷歌浏览器中查看源文件2、在谷歌浏览器中的源文件点击下面的{}就会生成相应的格式话文件3、中断调试:添加断点后,当JS代码运行到断点时会中断(对于添加了中断条件的断点在符合条件时中断&…

CSS媒体查询简介及案例

文章目录一、简介二、使用方式2.1 目标2.2 import2.3 link、source等2.4 media2.5 其他三、媒体类型(Media Types)3.1 简介3.2 常见的媒体类型值3.3 已废弃媒体类型四、媒体特性(Media Features)4.1 简介4.2 常见媒体特性简介五、媒体查询 - 逻辑操作符(Logical Operators)5.1 …

金色传说:SAP-QM-周期性检验:MSC1N/MSC2N/MSC3N下一次检验日期逻辑问题

文章目录前言一、实现方案二、收货测试1.建立生产订单2.触发周期检验前言 物料在存储过程中,通常会有定期检验,也叫周期性检验的需求,这里给大家分享下周期性检验的实现过程。 一、实现方案 1.首先要在TCode:MM02-质量管理视图中…

数据结构的起航,用C语言实现一个简约却不简单的顺序表!(零基础也能看懂)

目录 0.前言 1.线性表 2.顺序表 2.1什么是顺序表 2.2 顺序表结构体设计 2.3 顺序表的初始化 2.4 顺序表的销毁 2.5* 顺序表的尾插 2.6* 顺序表的尾删 2.7 顺序表的头插 2.8 顺序表的头删 2.9 顺序表的查找 2.10 顺序表的插入(任意位置) 2.11…

载药磷酸钙纳米粒子;载药阿奇霉素纳米粒子;载药酪蛋白纳米粒子

产品名称:载药磷酸钙纳米粒子;载药阿奇霉素纳米粒子;载药酪蛋白纳米粒子 英文名称:Drug loaded calcium phosphate nanoparticles; Azithromycin loaded nanoparticles; Drug loaded casein nanoparticles 描述: 磷…

C# --- 结构体

1.结构体与类的第一个区别是创建的时候使用的关键字不同 --- 结构体 struct ; 类是 class 2.结构体中不能够自定义无参构造函数,因为结构体中已经隐式包含了一个无参构造函数,这个不像类中自带无参构造函数可以被我们替代,它是无法替代的 3…

OSWE 考试总结

背景 今天,2022年12月6日,我收到了 OSWE 考试通过的邮件。至此,Offsensive Security 之旅告一段落。我拿到了 KLCP, OSCP, OSEP, OSED, OSWE, OSCE3。圆满完成三月份立下的 FLAG,没有打脸。 这里,小小总结一下 OSWE …

Apache服务深入学习篇(详细介绍)

什么是Apache? Apache HTTP Server(简称Apache)是Apache软件基金会的一个开放源码的网页服务器,可以在大多数计算机操作系统中运行,由于其跨平台和安全性被广泛使用,是最流行的Web服务器端软件之一。它快速…

汽车电子之功能安全介绍

功能安全介绍 1.什么是功能安全FS? 2.为什么需要功能安全? 3.认识标准《ISO26262》。 4.怎么评估ASIL 等级? 5.功能安全怎么做(措施)? 6.参考资料 1.什么是功能安全FS? (1)功能安全的发展过…