深度学习入门(十) 模型选择、过拟合和欠拟合

news2025/1/27 6:39:59

深度学习入门(十) 模型选择、过拟合和欠拟合

  • 前言
  • 模型选择
    • 例子:预测谁会偿还贷款?
    • 训练误差和泛化误差
    • 验证数据集和测试数据集
    • K-则交叉验证
    • 总结
  • 过拟合和欠拟合
    • 模型容量
      • 模型容量的影响
      • 估计模型容量
    • VC维
      • 线性分类器的VC维
      • VC维的用处
    • 数据复杂度
    • 总结
  • 代码展示
  • QA:

前言

核心内容来自博客链接1博客连接2希望大家多多支持作者
本文记录用,防止遗忘

模型选择

例子:预测谁会偿还贷款?

银行雇你来调查谁会偿还贷款

  • 你得到了100个申请人的信息
  • 其中五个人在3年内违约了

发现:
你发现所有的5个人在面试的时候都穿了蓝色衬衫
你的模型也发现了这个强信号
这会有什么问题?

训练误差和泛化误差

训练误差:模型在训练数据上的误差
泛化误差:模型在新数据上的误差
例子:根据摸考成绩来预测未来考试分数

  • 在过去的考试中表现很好(训练误差)不代表未来考试—定会好(泛化误差)
  • 学生A通过背书在摸考中拿到很好成绩
  • 学生B知道答案后面的原因

验证数据集和测试数据集

验证数据集:一个用来评估模型好坏的数据集

  • 例如拿出50%的训练数据
  • 不要跟训练数据混在一起(常犯错误)
测试数据集:只用一次的数据集。例如
  • 未来的考试
  • 我出价的房子的实际成交价
  • 用在Kaggle私有排行榜中的数据集

K-则交叉验证

在没有足够多数据时使用(这是常态)
算法:

  • 将训练数据分割成K块
  • For i = 1,...,K
  • 使用第i块作为验证数据集,其余的作为训练数据集·报告K个验证集误差的平均
  • 常用:K=5或10

总结

  • 训练数据集:训练模型参数
  • 验证数据集:选择模型超参数
  • 非大数据集上通常使用k-则交叉验证

过拟合和欠拟合

在这里插入图片描述
在这里插入图片描述

模型容量

拟合各种函数的能力:

  • 低容量的模型难以拟合训练数据
  • 高容量的模型可以记住所有的训练数据

模型容量的影响

在这里插入图片描述

估计模型容量

难以在不同的种类算法之间比较

  • 例如数模型和神经网络

给定一个模型种类,将有两个主要因素

  • 参数的个数
  • 参数值的选择范围

VC维

  • 统计学习理论的一个核心思想
  • 对于一个分类模型,VC等于一个最大的数据集的大小,不管如何给定标号,都存在一个模型来对它进行完美分类

线性分类器的VC维

2维输入的感知机,VC维=3
能够分类任何三个点,但不是4个(xor)
在这里插入图片描述支持N维输入的感知机的VC维是N+1
一些多层感知机的VC维 O ( N l o g 2 N ) O(Nlog_2N) O(Nlog2N)

VC维的用处

提供为什么一个模型好的理论依据

  • 它可以衡量训练误差和泛化误差之间的间隔
但深度学习中很少使用
  • 衡量不是很准确
  • 计算深度学习模型的VC维很困难

数据复杂度

多个重要因素

  • 样本个数
  • 每个样本的元素个数
  • 时间、空间结构
  • 多样性

总结

  • 模型容量需要匹配数据复杂度,否则可能导致欠拟合和过拟合
  • 统计机器学习提供数学工具来衡量模型复杂度
  • 实际中一般靠观察训练误差和验证误差

代码展示

通过多项式拟合来交互地探索这些概念

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

使用以下三阶多项式来生成训练和测试数据的标签:

在这里插入图片描述

max_degree = 20
n_train, n_test = 100, 100
true_w = np.zeros(max_degree)
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])

features = np.random.normal(size=(n_train + n_test, 1))
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)

看一下前2个样本

true_w, features, poly_features, labels = [
    torch.tensor(x, dtype=torch.float32)
    for x in [true_w, features, poly_features, labels]]

features[:2], poly_features[:2, :], labels[:2]

输出:

(tensor([[-1.9729],
         [-0.1230]]),
 tensor([[ 1.0000e+00, -1.9729e+00,  1.9462e+00, -1.2799e+00,  6.3127e-01,
          -2.4909e-01,  8.1905e-02, -2.3084e-02,  5.6929e-03, -1.2480e-03,
           2.4621e-04, -4.4159e-05,  7.2602e-06, -1.1018e-06,  1.5527e-07,
          -2.0422e-08,  2.5182e-09, -2.9225e-10,  3.2032e-11, -3.3262e-12],
         [ 1.0000e+00, -1.2304e-01,  7.5698e-03, -3.1047e-04,  9.5502e-06,
          -2.3502e-07,  4.8195e-09, -8.4715e-11,  1.3030e-12, -1.7813e-14,
           2.1918e-16, -2.4517e-18,  2.5138e-20, -2.3793e-22,  2.0911e-24,
          -1.7153e-26,  1.3191e-28, -9.5474e-31,  6.5263e-33, -4.2264e-35]]),
 tensor([-11.1878,   4.9593]))

实现一个函数来评估模型在给定数据集上的损失

def evaluate_loss(net, data_iter, loss):  
    """评估给定数据集上模型的损失。"""
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]

定义训练函数

def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss()
    input_shape = train_features.shape[-1]
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])
    train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs], ylim=[1e-3, 1e2],
                            legend=['train', 'test'])
    for epoch in range(num_epochs):
        d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        if epoch == 0 or (epoch + 1) % 20 == 0:
            animator.add(epoch + 1, (evaluate_loss(
                net, train_iter, loss), evaluate_loss(net, test_iter, loss)))
    print('weight:', net[0].weight.data.numpy())

三阶多项式函数拟合(正态)

train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:])

输出:
在这里插入图片描述
线性函数拟合(欠拟合)

train(poly_features[:n_train, :2], poly_features[n_train:, :2],
      labels[:n_train], labels[n_train:])

输出:
在这里插入图片描述
高阶多项式函数拟合(过拟合)

train(poly_features[:n_train, :], poly_features[n_train:, :],
      labels[:n_train], labels[n_train:], num_epochs=1500)

在这里插入图片描述

QA:

1、SVM应用于分类想比神经网络的缺点?
SVM对于数据样本大的情况不好。SVM可以调的东西不多,比较平滑。
神经网络优点在于它是一个“语言”,神经网络很灵活,可编程线强。虽然SVM数学解释好,但是可解决的问题少。

2、K则交叉验证在大数据集的深度学习中应用不多,因为训练成本高。一般都是在数据量不够的情况下才用。

3、k则交叉验证中,k的确定主要取决的能够承受的计算成本。

4、模型参数≠超参数
模型参数:W,b
超参数:可选的模型参数之外的参数

5、
方案1:k则交叉验证确定超参数,再在整个数据集上训练一次
方案2:用k则交叉验证中的最好的参数
方案3:用k则交叉验证中的k个数据值在结果的均值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1509.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[云原生之k8s] Kubernetes原理

引言 单机容器编排:docker-compose 容器集群编排:docker swarm、mesosmarathon、kubernetes 应用编排:ansible 一、Kubernetes是什么? Kubernetes的缩写为:K8S,这个缩写是因为k和s之间有八个字符的关系…

线段树模板

好文分享:【数据结构】线段树(Segment Tree) - 小仙女本仙 - 博客园 线段树和树状数组的基本功能都是在某一满足结合律的操作(比如加法,乘法,最大值,最小值)下,O(logn)的时间复杂度内修改单个元…

Python回归预测建模实战-支持向量机预测房价(附源码和实现效果)

机器学习在预测方面的应用,根据预测值变量的类型可以分为分类问题(预测值是离散型)和回归问题(预测值是连续型),前面我们介绍了机器学习建模处理了分类问题(具体见之前的文章)&#…

x86 --- 任务隔离特权级保护

程序是记录在载体上的数据和指令。 程序正在执行时的一个副本叫做任务 所有段描述符都放在GDT --> 不做区分。 内核程序(任务)所占段在GDT中,用户程序(任务)所占段在LDT中 --> 做区分。 每个任务都有自己独立的…

【无标题】

第1章 概述 本章主要内容: 互联网的概念(标准化)、组成、发展历程;电路交换的基本概念、分组交换的原理;计算机网络的分类、性能指标及两种体系结构。 重点掌握: 在计算机网络分层模型中,网…

7、GC日志详解

目录如何分析GC日志参数配置程序运行GC日志打印解析GC日志数据分析指定其他垃圾收集器CMSG1GC分析工具JVM参数汇总查看命令如何分析GC日志 参数配置 对于java应用我们可以通过一些配置把程序运行过程中的gc日志全部打印出来,然后分析gc日志得到关键性指标&#xff…

目标检测算法——遥感影像数据集资源汇总(附下载链接)

关注”PandaCVer“公众号 深度学习资料,第一时间送达 目录 一、用于 2-5 分类问题 1.UCAS-AOD 遥感影像数据集 2.Inria Aerial Image Labeling Dataset 3.RSOD-Dataset 物体检测数据集 二、用于 5-10 分类问题 1.RSSCN7 DataSet 遥感图像数据集 2.NWPU…

孙宇晨接受韩国媒体专访:熊市受宏观经济的不确定性影响

10月27日至10月29日,韩国釜山备受关注的大型区块链活动 2022 釜山区块链周(BWB 2022)在釜山会展中心(BEXCO)举行。韩国区块链媒体TokenPost 对出席活动的波场TRON创始人孙宇晨进行了专访。10月28日,该媒体发…

Nginx快速入门部署前端项目

目录 一,Nginx简介 1.1 负载均衡 演示 1.1.2 安装nginx 再复制一份一样的tomcat并修改端口号 打开两个tomcat的服务 打开防火墙中的8081端口 修改Nginx配置 重启Nginx服务,让配置生效 1.2 反向代理 Nginx项目部署 1.确保前端项目能用 2.将前台项目…

看过来,Windows 11 Insider Preview 25231.1000推送啦!

微软于近日凌晨发布新的Windows 11内部预览版系统,版本号为25231.1000,该系统对平板任务栏体验进行了改进,修复了系统托盘、设置等问题。下面一起来看看完整的更新内容。 更新日志 TL;速度三角形定位法(dead reckoning…

【ASM】字节码操作 转换已有的类 清空方法体

1.概述 在文章:【ASM】字节码操作 转换已有的类 移除Instruction 移除NOP 中我们学会了如何移除NOP。 本章我们将学习如何清空方法体。 1.1 如何清空方法体 在有些情况下,我们可能想清空整个方法体的内容,那该怎么做呢?其实,有两个思路。 ●第一种思路,就是将instructi…

Spring中事务的传播机制以及REQUIRED、REQUIRES_NEW、NESTED区别以及代码演示

​📒个人主页:热爱生活的李📒 ​❤️感谢大家阅读本文,同时欢迎访问本人主页查看更多文章​❤️ 🙏本人也在学习阶段,如若发现问题,请告知,非常感谢🙏 事务隔离级别demo理…

[计算机网络]第一章 概述 -- 1.1 计算机网络在信息时代中的作用 1.2 互联网概述

文章目录1.1 计算机网络在信息时代中的作用1.2 互联网概述1.2.1 网络的网络1.2.2 互联网基础结构发展的三个阶段第一阶段第二阶段第三阶段1.2.3 互联网标准化工作1.1 计算机网络在信息时代中的作用 21世纪是以网络为核心的信息时代,21世纪的重要重要特征&#xff1a…

小侃设计模式(二)-单例模式

1.概述 设计模式在粒度和抽象层次上各不相同,因此从不同的角度,分类形式也不同,目前存在两种较为经典的划分方式,即根据模式作用的范围、模式的目的来划分。根据模式主要是用于类还是用于对象,可将其划分为类模式和对…

【JavaWeb】Tomcat

1.JavaWeb是指所有通过java语言编写可以通过浏览器访问的程序的总称 请求是指客户端给服务器发送数据 响应是指服务器给客户端回传数据 2.Web资源按实现的技术和呈现的效果的不同,又分为静态资源和动态资源两种. 静态资源:html css js txt mp4视频 jpg图片 动态资源:jsp页面 se…

前端工程化基建探索:从内部机制和核心原理了解npm

大厂技术 坚持周更 精选好文 前言 本文【前端工程化基建探索】的第2篇,上一篇 前端工程化基建探索(1)前端大佬,你好! 当我们拉取一个前端工程化项目,都会通过npm/Yarn/pnpm 管理工具来安装项目的依赖&am…

大学解惑06 - 要求输入框内只能输入2位以内小数,怎么做?

请听题:有一个输入框,准备用于计算使用,要求点击“校验”按钮的时候进行验证,必须输入数字,并且只能是2位以内的小数,如果输入不合法,请给出提示,如果输入合法通过验证,则…

又是一篇教你摸鱼的文章,用Python实现自动发送周报给老板

前言 有没有哪个同志跟我一样,每周都要写工作周报 像我这种记性不好的,一个月四周忘记三次 索性就用Python写个小工具,让它每周帮我给老板发周报~ Github: Weekday 小工具 提出目标 源码.资料.素材.点击领取即可 想有一个工具能发邮件 目…

ARM 汇编基础

一、ARM架构 ARM芯片属于精简指令集计算机(RISC:Reduced Instruction Set Computing),它所用的指令比较简单,有如下特点: 对内存只有读、写指令对于数据的运算是在CPU内部实现使用RISC指令的CPU复杂度小一点,易于设计…

WebShell箱子简介与原理

今天继续给大家介绍渗透测试相关知识,本文主要内容是WebShell箱子简介与原理。 免责声明: 本文所介绍的内容仅做学习交流使用,严禁利用文中技术进行非法行为,否则造成一切严重后果自负! 再次强调:严禁对未授…