【深度学习笔记】08 欠拟合和过拟合

news2024/11/28 14:48:25

08 欠拟合和过拟合

      • 生成数据集
      • 对模型进行训练和测试
      • 三阶多项式函数拟合(正常)
      • 线性函数拟合(欠拟合)
      • 高阶多项式函数拟合(过拟合)

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

生成数据集

给定 x x x,我们将使用以下三阶多项式来生成训练和测试数据的标签:

y = 5 + 1.2 x − 3.4 x 2 2 ! + 5.6 x 3 3 ! + ϵ  where  ϵ ∼ N ( 0 , 0. 1 2 ) . y = 5 + 1.2x - 3.4\frac{x^2}{2!} + 5.6 \frac{x^3}{3!} + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.1^2). y=5+1.2x3.42!x2+5.63!x3+ϵ where ϵN(0,0.12).

噪声项 ϵ \epsilon ϵ服从均值为0且标准差为0.1的正态分布。
在优化的过程中,我们通常希望避免非常大的梯度值或损失值。
这就是我们将特征从 x i x^i xi调整为 x i i ! \frac{x^i}{i!} i!xi的原因,
这样可以避免很大的 i i i带来的特别大的指数值。
我们将为训练集和测试集各生成100个样本。

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集的大小
true_w = np.zeros(max_degree)  # 分配大量空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])

features = np.random.normal(size = (n_train + n_test, 1))
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train + n_test)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale = 0.1, size = labels.shape)

存储在poly_features中的单项式由gamma函数重新缩放,
其中 Γ ( n ) = ( n − 1 ) ! \Gamma(n)=(n-1)! Γ(n)=(n1)!
从生成的数据集中查看一下前2个样本,
第一个值是与偏置相对应的常量特征。

# NumPy ndarray转换为tensor
true_w, features, poly_features, labels = [torch.tensor(x, dtype=
    torch.float32) for x in [true_w, features, poly_features, labels]]
features[:2], poly_features[:2, :], labels[:2]
(tensor([[-1.2565],
         [-2.2676]]),
 tensor([[ 1.0000e+00, -1.2565e+00,  7.8936e-01, -3.3060e-01,  1.0385e-01,
          -2.6096e-02,  5.4648e-03, -9.8091e-04,  1.5406e-04, -2.1508e-05,
           2.7024e-06, -3.0868e-07,  3.2321e-08, -3.1238e-09,  2.8036e-10,
          -2.3484e-11,  1.8442e-12, -1.3630e-13,  9.5145e-15, -6.2919e-16],
         [ 1.0000e+00, -2.2676e+00,  2.5709e+00, -1.9433e+00,  1.1016e+00,
          -4.9960e-01,  1.8881e-01, -6.1164e-02,  1.7337e-02, -4.3681e-03,
           9.9049e-04, -2.0418e-04,  3.8583e-05, -6.7300e-06,  1.0901e-06,
          -1.6479e-07,  2.3354e-08, -3.1151e-09,  3.9243e-10, -4.6835e-11]]),
 tensor([ -1.1486, -17.5782]))

对模型进行训练和测试

首先实现一个函数来评估模型在给定数据集上的损失

def evaluate_loss(net, data_iter, loss):  #@save
    """评估给定数据集上模型的损失"""
    metric = d2l.Accumulator(2)  # 损失的总和,样本数量
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]

定义训练函数

def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss()
    input_shape = train_features.shape[-1]
    # 不设置偏置,因为我们已经在多项式中实现了它
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])
    train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs], ylim=[1e-3, 1e2],
                            legend=['train', 'test'])
    for epoch in range(num_epochs):
        d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        if epoch == 0 or (epoch + 1) % 20 == 0:
            animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),
                                     evaluate_loss(net, test_iter, loss)))
    print('weight:', net[0].weight.data.numpy())

三阶多项式函数拟合(正常)

首先使用三阶多项式函数,它与数据生成函数的阶数相同。
结果表明,该模型能有效降低训练损失和测试损失。
学习到的模型参数也接近真实值 w = [ 5 , 1.2 , − 3.4 , 5.6 ] w = [5, 1.2, -3.4, 5.6] w=[5,1.2,3.4,5.6]

# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:])
weight: [[ 5.016911   1.2140008 -3.4260864  5.581189 ]]

在这里插入图片描述

线性函数拟合(欠拟合)

当用于拟合非线性模式(如这里的三阶多项式函数)时,线性模型容易欠拟合。

# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],
      labels[:n_train], labels[n_train:])
weight: [[3.6582568 3.967709 ]]

在这里插入图片描述

高阶多项式函数拟合(过拟合)

尝试使用一个阶数过高的多项式来训练模型。在这种情况下,没有足够的数据用于学习高阶系数应该具有接近于零的值。因此,这个过于复杂的模型会轻易受到训练数据中噪声的影响。虽然训练损失可以有效地降低,但测试损失仍然很高。结果表明,复杂模型对数据造成了过拟合。

# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],
      labels[:n_train], labels[n_train:], num_epochs=1500)
weight: [[ 5.0074177   1.2689897  -3.3731365   5.1744385  -0.28380862  1.3880066
   0.28433597  0.23886798  0.10456859 -0.01008706 -0.13444044 -0.00965116
  -0.09757714 -0.09527045  0.21342376  0.13767722 -0.08204057 -0.10666441
   0.12475184  0.21017507]]

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1286533.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

动态类型语言与静态类型语言的对比与比较

编程语言可以根据类型系统和类型检查时机分为动态编程语言和静态编程语言两大类,它们在运行时的代码检查方式、变量类型的使用方式等方面有很大的区别。这一块你知道吗? 本文将为您详细讲解两种编程语言的优缺点,以及它们的应用场景。 动态编…

洗地机好用吗?口碑好的洗地机有哪些?

自从洗地机开始引入市场以来,它一直受到人们的关注。它在解放家庭清洁劳动力和提供快速方便的清洁方面表现出色,超越了多年来传统的拖把清洁方式。越来越多的人选择使用洗地机来完成家庭清洁任务。如果你也对洗地机产生了浓厚的兴趣,并想购买…

【精选】ATKCK红队评估实战靶场二 (超详细过程思路)

🍬 博主介绍👨‍🎓 博主介绍:大家好,我是 hacker-routing ,很高兴认识大家~ ✨主攻领域:【渗透领域】【应急响应】 【ATK&CK红队评估实战靶场】 【VulnHub靶场复现】【面试分析】 &#x1f…

智能优化算法应用:基于人工电场算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于人工电场算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于人工电场算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.人工电场算法4.实验参数设定5.算法结果6.参考…

IO / day03 作业

1. 使用文件IO完成对图像的读写操作 代码 #include<myhead.h>int main(int argc, const char *argv[]) {int fd-1;if((fd open("./bird.bmp", O_RDWR)) -1){perror("fopen error");return -1;}//读取该图片的大小&#xff0c;需要将光标向后偏移…

写论文焦虑?No,免费AI写作大师来帮你

先来看1分钟的视频&#xff0c;对于要写论文的你来说&#xff0c;绝对有所值&#xff01; 还在为写论文焦虑&#xff1f;免费AI写作大师来帮你三步搞定 第一步&#xff1a;输入关键信息 第二步&#xff1a;生成大纲 稍等片刻后&#xff0c;专业大纲生成&#xff08;由于举例&am…

python 堆与栈

【一】堆与栈 【 1 】简介 栈&#xff08;stack&#xff09;&#xff0c;有些地方称为堆栈&#xff0c;是一种容器&#xff0c;可存入数据元素、访问元素、删除元素&#xff0c;它的特点在于只能允许在容器的一端&#xff08;称为栈顶端指标&#xff0c;英语&#xff1a;top&a…

智能无人售货奶柜,引领便捷时代

智能无人售货奶柜&#xff0c;引领便捷时代 在当今快节奏的生活中&#xff0c;人们对购物的需求变得越来越迫切。同时&#xff0c;随着科技的进步&#xff0c;无人售货柜作为一种创新的销售模式逐渐受到人们的关注和喜爱。其中&#xff0c;智能无人售货奶柜以其便捷高效的特点成…

OGG实现Oracle19C到postgreSQL14的实时同步

&#x1f4e2;&#x1f4e2;&#x1f4e2;&#x1f4e3;&#x1f4e3;&#x1f4e3; 哈喽&#xff01;大家好&#xff0c;我是【IT邦德】&#xff0c;江湖人称jeames007&#xff0c;10余年DBA及大数据工作经验 一位上进心十足的【大数据领域博主】&#xff01;&#x1f61c;&am…

TP5使用Composer安装phpoffice/phpspreadsheet,导出Excel文件

Composer安装 如果你尚未安装Composer&#xff0c;请先安装 Composer。Composer是PHP的依赖管理工具&#xff0c;它可以方便地安装和管理项目中的第三方库。 安装phpoffice/phpspreadsheet&#xff1a; 触发控制器里面的方法 wdjzdc() 在控制中引入 use PhpOffice\PhpSpread…

从零开始的Spring Cloud Gateway指南:构建强大微服务架构

目录 一、 什么是Gateway&#xff1f;1. 网关的由来2. 网关的作用3. 网关的技术实现 二、如何搭建一个简易网关服务1. 引入依赖2. 配置yml文件 三、进阶话题&#xff1a;过滤器和路由配置1. gateway的执行原理2. 路由断言工厂: Predicate Factory3. 网关过滤器&#xff1a;Gate…

Cannot resolve symbol ‘ActivityResultLauncher‘ 报错处理方法

修改 app/build.gradle implementation ‘androidx.appcompat:appcompat:1.2.0’ 为 implementation ‘androidx.appcompat:appcompat:1.4.0’

【模型报错记录】‘PromptForGeneration‘ object has no attribute ‘can_generate‘

通过这个连接中的方法解决&#xff1a; “PromptForGeneration”对象没有属性“can_generate” 期刊 #277 thunlp/OpenPrompt GitHub的 问题描述&#xff1a;在使用model.generate() 的时候报错&#xff1a;PromptForGeneration object has no attribute can_generate 解决方法…

在Spring Cloud中使用组件Zuul网关,并注册到Eureka中去

在上一篇中&#xff0c;我们搭建了Spring Cloud的父子模块&#xff0c;并实现了一个Eureka子模块的启动&#xff0c;可以通过浏览器地址去访问Eureka主页了&#xff0c;相信了解过的童鞋应该看到&#xff0c;主页上并未有任何服务去注册&#xff0c;那么我们就在这篇&#xff0…

【1】一文读懂PyQt简介和环境搭建

目录 1. PyQt简介 1.1. Qt 1.2. PyQt 1.3. 关于PyQt和PySide 2. 通过pip安装PyQt5 3. 无法运行处理 4. VSCode配置PYQT插件 PyQt官网:Riverbank Computing | Introduction 1. PyQt简介 PyQt是一套Python的GUI开发框架,即图形用户界面开发框架. Python中经常使用的GUI…

NXT : 十年源头代码的回溯与展望

11月28日,区块链领域的开山鼻祖之一NXT上线十周年。十年弹指一挥间,从2013年最初基于密码学证明思想构建的NXT,到今日仍在不断拓展与优化的Ardor(NXT 2.0)生态。我们不妨利用这个契机,回溯一下这个源头级公链的创新之路。 2013年,在区块链意味着复制比特币代码并调整挖矿算法的…

Verilog基础:$time、$stime和$realtime系统函数的使用

相关阅读 Verilog基础https://blog.csdn.net/weixin_45791458/category_12263729.html $time、 $stime和$realtime这三个系统函数提供了返回当前仿真时间方法。注意&#xff0c;这里的仿真时间的最小分辨能力是由仿真时间精度决定的&#xff0c;简单来说&#xff0c;可以理解为…

在线测量大尺寸管材的测径仪有哪些?

工业高速发展的背后&#xff0c;离不开与之匹配的高端设备作为科研的支撑。品质检测仪器也在随着现代科技的发展而不断变化&#xff0c;随着科技的进步&#xff0c;各种大口径的管材、管道被生产制造出来&#xff0c;而对其外径尺寸的检测则因口径范围大而使得很少有仪器能进行…

js选中起始时间使用标准时间毫秒值计算一年后的当前少一天的日期(并考虑闰年)

js选中起始时间使用标准时间毫秒值计算一年后的当前少一天的日期 实际代码里面带入默认日期’20230301’这个特殊日期&#xff0c;因为下一年的当前日期少一天为闰年的2月会有29天&#xff0c;使用特殊值校验代码效果图 HTML部分代码 <el-button click"chengTime()&q…

多线程--11--ConcurrentHashMap

ConcurrentHashMap与HashMap等的区别 HashMap线程不安全 我们知道HashMap是线程不安全的&#xff0c;在多线程环境下&#xff0c;使用Hashmap进行put操作会引起死循环&#xff0c;导致CPU利用率接近100%&#xff0c;所以在并发情况下不能使用HashMap。 ConcurrentHashMap 主…