Pytorch 多项式拟合

news2024/10/6 12:30:04

目录

1、训练误差和泛化误差

2、独立同分布假设

3、欠拟合和过拟合

4、多项式回归


1、训练误差和泛化误差

训练误差(training error)是指, 模型在训练数据集上计算得到的误差。 泛化误差(generalization error)是指, 模型应用在同样从原始样本的分布中抽取的无限多数据样本时,模型误差的期望。

一般来说泛化误差无法进行准确计算,实际上我们只能通过将模型应用于一个独立测试集来估计泛化误差,该测试集由随机选取的、未曾在训练集中出现的数据样本构成。

2、独立同分布假设

在深度学习应用过程中,我们假设训练数据和测试数据都是从相同的分布中独立提取的。这通常被称为独立同分布假设,这意味着对数据采样过程中没有进行记忆,也就是说,抽取第2个样本和第3个样本的相关性,并不比抽取第2个样本和第200万个样本的相关性强。

3、欠拟合和过拟合

训练误差和验证误差都很严重, 但它们之间仅有一点差距。 如果模型不能降低训练误差,这可能意味着模型过于简单(即表达能力不足), 无法捕获试图学习的模式。 此外,由于我们的训练和验证误差之间的泛化误差很小, 我们有理由相信可以用一个更复杂的模型降低训练误差。 这种现象被称为欠拟合(underfitting)。

当我们的训练误差明显低于验证误差时要小心, 这表明严重的过拟合(overfitting)

值得注意的是,过拟合并不总是一件坏事。特别是在深度学习领域,最好的预测模型在训练数据上的表现往往比在验证数据上好得多。最终我们更关心验证误差,而不是训练误差和验证误差之间的差距。

4、多项式回归

我们使用多项式回归来验证上述概念

导入所需的库

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

使用下述三阶多项式来生成训练和测试数据的标签:

 

噪声项∈服从均值为0且标准差为0.1的正态分布。 在优化的过程中,我们通常希望避免非常大的梯度值或损失值。  我们将为训练集和测试集各生成100个样本。

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6]) #多项式的前几项

features = np.random.normal(size=(n_train + n_test, 1)) # x (200,1)
np.random.shuffle(features)# 打乱x
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1)) #分别求x的0-19次方
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)! 伽玛函数
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w) # 进行乘法,每一行都是一个数据
labels += np.random.normal(scale=0.1, size=labels.shape) #加上噪声项

将ndarray转换为tensor

# NumPy ndarray转换为tensor
true_w, features, poly_features, labels = [torch.tensor(x, dtype=torch.float32) for x in [true_w, features, poly_features, labels]]

features[:2], poly_features[:2, :], labels[:2]

自定义函数评估模型在给定数据集上的损失

def evaluate_loss(net, data_iter, loss):  #@save
    """评估给定数据集上模型的损失"""
    metric = d2l.Accumulator(2)  # 损失的总和,样本数量
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]

定义模型训练函数

def train(train_features, test_features, train_labels, test_labels,num_epochs=400):
    loss = nn.MSELoss() # 均方误差损失函数
    input_shape = train_features.shape[-1]
    # 不设置偏置,因为我们已经在多项式中实现了它 第一项 5
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False)) #一个线性层
    batch_size = min(10, train_labels.shape[0])
    train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs], ylim=[1e-3, 1e2],
                            legend=['train', 'test'])
    for epoch in range(num_epochs):
        d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        if epoch == 0 or (epoch + 1) % 20 == 0:
            animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),
                                     evaluate_loss(net, test_iter, loss)))
    print('weight:', net[0].weight.data.numpy())

①正常情况

使用四个参数去拟合三阶多项式, 结果表明,该模型能有效降低训练损失和测试损失。 学习到的模型参数也接近真实值w=[5,1.2,−3.4,5.6]。

# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])

②欠拟合

 使用两个参数去拟合三阶多项式

# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])

经历过多轮迭代损失函数仍然很高,且结果并不好,模型欠拟合。

③过拟合

使用全部参数去拟合三阶多项式,虽然训练损失可以有效地降低,但测试损失仍然很高。 结果表明,复杂模型对数据造成了过拟合。

# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],
      labels[:n_train], labels[n_train:], num_epochs=3000)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/149461.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【OpenCV 例程 300篇】255.OpenCV 实现图像拼接

『youcans 的 OpenCV 例程300篇 - 总目录』 【youcans 的 OpenCV 例程 300篇】255.OpenCV 实现图像拼接 6.2 OpenCV 实现图像拼接 OpenCV中图像的数据结构是Numpy数组,使用切片方法可以实现图像的裁剪,使用数组堆叠方法可以实现图像的拼接。 Numpy 函数…

K8S Replication Controller 示例

K8S Replication Controller Replication Controller可确保在任何时间运行指定数量的pod副本。换句话说,ReplicationController确保一个pod或一组同类pod始终处于可用状态。 可以理解为Replication Controller (RC)是基于 K8S Pod对象之上的…

PR采购申请启用灵活工作流配置

目录 1. 检查系统基础工作流配置 2. 激活标准场景模版 WS02000458 3. 维护收件箱操作按钮文本 4. 维护任务代理及激活事件 5. 配置Inbox审批页面 6. PR灵活工作流启动配置 7. 为流程设置代理人 8. 配置工作流场景 9. 可扩展部分 官方Help文档路径: SAP He…

算法_杨氏矩阵_杨氏矩阵算法_剑指offer

目录 一、问题描述 二、问题分析 三、算法设计 ​四、代码实现 一、问题描述 有一个数字矩阵,矩阵的每行从左到右是递增的,矩阵从上到下是递增的,请编写程序在这样的矩阵中查找某个数字是否存在。 要求:时间复杂度小于O(N);…

【目标检测】Cascade RCNN中的两个常见问题

一、关于mismatch问题 在training阶段和inference阶段使用不同的阈值很容易导致mismatch,什么意思呢? 在training阶段,由于给定了GT,所以可以把与GT的IoU大于阈值的proposals作为正样本,这些正样本参与之后的bbox回归…

【C语言】深入浅出讲解函数栈帧(超详解 | 建议收藏)

🚀write in front🚀 📝个人主页:认真写博客的夏目浅石. 🎁欢迎各位→点赞👍 收藏⭐️ 留言📝 📣系列专栏:凡人修C传 💬总结:希望你看完之后&…

基于机器学习与协同过滤的图书管理推荐系统

基于机器学习与协同过滤的图书推荐系统 一、系统结构图 二、Demo示例 完整源码可联系博主微信【1257309054】 点我跳转 三、K-means聚类机器学习推荐算法 1、原理 从数据库中 1、首先获取书籍类别 2、获取用户注册时勾选喜欢的类别,勾选的为1,否则…

让 Java Agent 在 Dragonwell 上更好用

本文是《容器中的Java》系列文章之 3/n,欢迎关注后续连载 😃 。1 背景 随着越来越多的云原生微服务应用的大规模部署,大家对微服务治理的能力需求越来越强。 Java Agent技术能够让业务专注于业务逻辑,与此同时,中间…

kali渗透测试到底该如何学?

1、渗透测试是什么? 渗透测试,是为了证明网络防御按照预期计划正常运行而提供的一种机制。渗透测试是通过各种手段对目标进行一次渗透(攻击),通过渗透来测试目标的安全防护能力和安全防护意识。打个比方:比…

手写spring12(把aop动态代理整合到spring生命周期)

文章目录目标设计项目结构四、实现1、定义Advice拦截器链2、定义Advisor访问者3、方法前置拦截器——MethodBeforeAdviceInterceptor4、代理工厂——ProxyFactory5、融入Bean生命周期的自动代理创建者——InstantiationAwareBeanPostProcessor 、DefaultAdvisorAutoProxyCreato…

什么是3dMax中的“块”?如何在3dMax中使用“块”?

3dMax 块简介 3dMax 是一款用于设计 3d 模型的软件,在 3d 建模图形专业人士中最受欢迎。我们可以在此软件中导入不同类型的预设计 3d 模型,以简化我们的工作。块是不同类型的 3d 模型,您可以从互联网上下载,然后将它们导入到 3dMax 软件的项目工作中,以节省您的时间。在本…

Python和OpenCV创建超快的“for”像素循环

这篇博客将介绍如何使用Python和OpenCV创建超快的“for”像素循环(逐像素循环),即Cython快速优化for循环; 使用Python和OpenCV逐像素循环图像是一个非常缓慢的操作,即使图像在内部由NumPy数组表示。 为什么会这样&am…

C语言-指针进阶-qsort函数的学习与模拟实现(9.3)

目录 思维导图: 回调函数 qsort函数介绍 模拟实现qsort 写在最后: 思维导图: 回调函数 什么是回调函数? 回调函数是一个通过函数指针调用的函数。 将一个函数指针作为参数传递给一个函数,当这个指针被用来调用…

57、JDBC和连接池

目录 一、JDBC基本介绍 二、JDBC快速入门 三、JDBC API 1、ResultSet [结果集] 2、Statement 3、PreparedStatement 4、DriverManager 四、封闭JDBCUtils 五、事务 六、批处理 七、数据库连接池 4、数据库连接池种类 (1) c3p0数据库连接池&…

MacBookPro 遇到pip: command not found问题的解决

学习Pyhton的时候,需要安装第三方插件pyecharts,执行以下命令: pip install pyecharts总是报错 pip: command not found 我很郁闷,于是上网搜索尝试各种命令, 命令1:curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py 命…

Codeforces Round 507(div. 1) C(分类讨论,并查集)

题目链接: Problem - C - Codeforceshttps://codeforces.com/contest/1039/problem/C 题意: 计算机网络由个服务器组成,每个服务器有到范围内的加密秘钥。设是分配给第i台服务器的加密密钥。对服务器通过数据通信通道直接连接。由于加密算…

高性能分布式缓存Redis-第二篇章

高性能分布式缓存Redis-第二篇章一、持久化原理1.1、持久化流程(落盘)1.2、RDB详解1.2.1、介绍1.2.2、触发&原理1.2.3、实现1.2.4、RDB总结1.3、AOF详解1.3.1、概念1.3.2、AOF 持久化的实现1.3.2、开启1.3.4、命令追加1.3.5、文件写入和同步&#xf…

SQL 别名

通过使用 SQL,可以为表名称或列名称指定别名。 SQL 别名 通过使用 SQL,可以为表名称或列名称指定别名。 基本上,创建别名是为了让列名称的可读性更强。 列的 SQL 别名语法 SELECT column_name AS alias_name FROM table_name; 表的 SQL …

dubbo学习笔记4(小d课堂)

dubbo高级特性 服务分组及其配置 我们再来创建一个实现类: 接下来我们在xml中去进行配置: 现在我们去运行看是否会有错误呢? 我们有两个服务实现类,那运行的时候到底执行哪个呢? 但是我们想的是可以指定执行哪个实现…

设计模式——访问者模式

访问者模式一、基本思想二、结构图一、基本思想 将作用于某种数据结构中的各元素的操作分离出来封装成独立的类,使其在不改变数据结构的前提下可以添加作用于这些元素的新的操作,为数据结构中的每个元素提供多种访问方式。它将对数据的操作与数据结构进…