【PyTorch】模型选择、欠拟合和过拟合

news2025/1/12 20:38:18

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现
      • 2.2.1. 完整代码
      • 2.2.2. 输出结果

1. 理论介绍

  • 将模型在训练数据上拟合的比在潜在分布中更接近的现象称为过拟合, 用于对抗过拟合的技术称为正则化
  • 训练误差和验证误差都很严重, 但它们之间差距很小。 如果模型不能降低训练误差,这可能意味着模型过于简单(即表达能力不足),无法捕获试图学习的模式。 这种现象被称为欠拟合
  • 训练误差是指模型在训练数据集上计算得到的误差。
  • 泛化误差是指模型应用在同样从原始样本的分布中抽取的无限多数据样本时,模型误差的期望。我们永远不能准确地计算出泛化误差,在实际中,我们只能通过将模型应用于一个独立的测试集来估计泛化误差, 该测试集由随机选取的、未曾在训练集中出现的数据样本构成。
  • 影响模型泛化的因素
    • 可调整参数的数量。当可调整参数的数量(有时称为自由度)很大时,模型往往更容易过拟合。
    • 参数采用的值。当权重的取值范围较大时,模型可能更容易过拟合。
    • 训练样本的数量。即使模型很简单,也很容易过拟合只包含一两个样本的数据集,而过拟合一个有数百万个样本的数据集则需要一个极其灵活的模型。
  • 在机器学习中,我们通常在评估几个候选模型后选择最终的模型。 这个过程叫做模型选择。候选模型可能在本质上不同,也可能是不同的超参数设置下的同一类模型。
  • 为了确定候选模型中的最佳模型,我们通常会使用验证集。验证集与测试集十分相似,唯一的区别是验证集是用于确定最佳模型,测试集是用于评估最终模型的性能
  • K K K折交叉验证:当训练数据稀缺时,将原始训练数据分成 K K K个不重叠的子集。 然后执行 K K K次模型训练和验证,每次在 ( K − 1 ) (K-1) (K1)个子集上进行训练, 并在剩余的一个子集(在该轮中没有用于训练的子集)上进行验证。 最后,通过对 K K K次实验的结果取平均来估计训练和验证误差。
  • 引起过拟合的因素
    • 模型复杂度
      模型复杂度
    • 数据集大小
      • 训练数据集中的样本越少,我们就越有可能(且更严重地)过拟合。
      • 给出更多的数据,拟合更复杂的模型可能是有益的; 如果没有足够的数据,简单的模型可能更有用。

2. 实例解析

2.1. 实例描述

使用以下三阶多项式来生成训练和测试数据 y = 5 + 1.2 x − 3.4 x 2 2 ! + 5.6 x 3 3 ! + ϵ  where  ϵ ∼ N ( 0 , 0. 1 2 ) . y = 5 + 1.2x - 3.4\frac{x^2}{2!} + 5.6 \frac{x^3}{3!} + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.1^2). y=5+1.2x3.42!x2+5.63!x3+ϵ where ϵN(0,0.12).并用1阶(线性模型)、3阶、20阶多项式拟合。

2.2. 代码实现

2.2.1. 完整代码

import os
import numpy as np
import math, torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tensorboardX import SummaryWriter
from rich.progress import track

def evaluate_loss(dataloader, net, criterion):
    """评估模型在指定数据集上的损失"""
    num_examples = 0
    loss_sum = 0.0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.cuda(), y.cuda()
            loss = criterion(net(X), y)
            num_examples += y.shape[0]
            loss_sum += loss.sum()
        return loss_sum / num_examples

def load_dataset(*tensors):
    """加载数据集"""
    dataset = TensorDataset(*tensors)
    return DataLoader(dataset, batch_size, shuffle=True)


if __name__ == '__main__':
    # 全局参数设置
    num_epochs = 400
    batch_size = 10
    learning_rate = 0.01

    # 创建记录器
    def log_dir():
        root = "runs"
        if not os.path.exists(root):
            os.mkdir(root)
        order = len(os.listdir(root)) + 1
        return f'{root}/exp{order}'
    writer = SummaryWriter(log_dir())

    # 生成数据集
    max_degree = 20             # 多项式最高阶数
    n_train, n_test = 100, 100  # 训练集和测试集大小

    true_w = np.zeros(max_degree+1)
    true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])

    features = np.random.normal(size=(n_train + n_test, 1))
    np.random.shuffle(features)
    poly_features = np.power(features, np.arange(max_degree+1).reshape(1, -1))
    for i in range(max_degree+1):
        poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
    labels = np.dot(poly_features, true_w)
    labels += np.random.normal(scale=0.1, size=labels.shape)    # 加高斯噪声服从N(0, 0.01)

    poly_features, labels = [
        torch.as_tensor(x, dtype=torch.float32) for x in [
            poly_features, labels.reshape(-1, 1)]]
    
    def loop(model_degree):
        # 创建模型
        net = nn.Linear(model_degree+1, 1, bias=False).cuda()
        nn.init.normal_(net.weight, mean=0, std=0.01)
        criterion = nn.MSELoss(reduction='none')
        optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

        # 加载数据集
        dataloader_train = load_dataset(poly_features[:n_train, :model_degree+1], labels[:n_train])
        dataloader_test = load_dataset(poly_features[n_train:, :model_degree+1], labels[n_train:])
        
        # 训练循环
        for epoch in track(range(num_epochs), description=f'{model_degree}-degree'):
            for X, y in dataloader_train:
                X, y = X.cuda(), y.cuda()
                loss = criterion(net(X), y)
                optimizer.zero_grad()
                loss.mean().backward()
                optimizer.step()

            writer.add_scalars(f"{model_degree}-degree", {
                "train_loss": evaluate_loss(dataloader_train, net, criterion),
                "test_loss": evaluate_loss(dataloader_test, net, criterion),
            }, epoch)
        print(f"{model_degree}-degree: weights =", net.weight.data.cpu().numpy())

    for model_degree in [1, 3, 20]:
        loop(model_degree)
    writer.close()

2.2.2. 输出结果

权重

  • 采用1阶多项式(线性模型)拟合
    1
  • 采用3阶多项式拟合
    3
  • 采用20阶多项式拟合
    20

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1291893.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DELL EMC unity 存储系统日志收集方法

对于一些非简单的硬件故障,解决故障最有效、最快速的方法就是收集日志,而不是瞎搞。常见的乱搞方法就是 1. reimage系统‘ 2. 更换控制器;3, 重启。 本文详细介绍了图形界面GUI和命令行CLI下如何收集DELL EMC Unity日志的方法和常…

0007Java程序设计-ssm基于微信小程序的在线考试系统

文章目录 **摘要**目 录系统实现开发环境 编程技术交流、源码分享、模板分享、网课分享 企鹅🐧裙:776871563 摘要 网络技术的快速发展给各行各业带来了很大的突破,也给各行各业提供了一种新的管理技术,基于微信小程序的在线考试…

功能测试,接口测试,自动化测试,压力测试,性能测试,渗透测试,安全测试,具体是干嘛的?

软件测试是一个广义的概念,他包括了多领域的测试内容,比如,很多新手可能都听说:功能测试,接口测试,自动化测试,压力测试,性能测试,渗透测试,安全测试等&#…

Goby 漏洞发布| Apache OFBiz webtools/control/xmlrpc 远程代码执行漏洞(CVE-2023-49070)

漏洞名称: Apache OFBiz webtools/control/xmlrpc 远程代码执行漏洞(CVE-2023-49070) English Name:Apache OFBiz webtools/control/xmlrpc Remote Code Execution Vulnerability (CVE-2023-49070) CVSS core: 9.8 影响资产数&…

金蝶云星空使用webapi查询单据附件的主键

文章目录 金蝶云星空使用webapi查询单据附件的主键业务需求详细操作查询单据附件查看账套单据附件查询采购价目表的单据内码和单据体内码查询单据头附件明细webapi查询json返回结果 查询单据明细附件查看账套单据明细附件查询采购价目表的单据内码和单据体内码查询单据体附件明…

phpStudy本地快速搭建网站,实现无公网IP固定地址远程访问

文章目录 [toc]使用工具1. 本地搭建web网站1.1 下载phpstudy后解压并安装1.2 打开默认站点,测试1.3 下载静态演示站点1.4 打开站点根目录1.5 复制演示站点到站网根目录1.6 在浏览器中,查看演示效果。 2. 将本地web网站发布到公网2.1 安装cpolar内网穿透2…

leetcode刷题:611.有效三角形的个数(双指针实现)

题目地址:有效三角形的个数 解决此题时,首先需要知道的是如何判断三个数字是否能够构成三角形。 我们知道,三角形任意两边之和都大于第三边。所以判断三个数字是否能构成三角形需要进行三次比较(最基础的思路) 方法一…

OLED材料市场研究:预计2029年将达到1447亿元

由于技术优势突出,近年来OLED 率先在智能手机、可穿戴等中小尺寸领域的渗透率持续提升。OLED就是有机发光显示技术,其最大特点是每个像素独立自发光,具有非常完美的黑色显示能力,在亮度、色彩、响应速度等方面远胜LCD屏幕&#xf…

视频监控管理平台/智能监测/检测系统EasyCVR智能地铁监控方案,助力地铁高效运营

近日,关于全国44座城市开通地铁,却只有5座城市赚钱的新闻冲上热搜。地铁作为城市交通的重要枢纽,是人们出行必不可少的一种方式,但随着此篇新闻的爆出,大家也逐渐了解到城市运营的不易,那么,如何…

安装Kuboard管理K8S集群

目录 第一章.安装Kuboard管理K8S集群 1.安装kuboard 2.绑定K8S集群,完成信息设定 3.内网安装 第二章.kuboard-spray安装K8S 2.1.先拉镜像下来 2.2.之后打开后,先熟悉功能,注意版本 2.3.打开资源包管理,选择符合自己服务器…

C# WPF上位机开发(数据库sqlite编程)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面我们写过一个会员管理的软件,上面数据保存的方法是用的json保存的。如果数据量比较少,或者是数据类型也不多的时候&…

基于SpringBoot+Thymeleaf+Mybatis实现大学生创新创业管理系统(源码+数据库+项目运行指导文档)

一、项目简介 本项目是一套基于SpringBoot实现大学生创新创业管理系统,主要针对计算机相关专业的正在做bishe的学生和需要项目实战练习的Java学习者。 包含:项目源码、数据库脚本等,该项目可以直接作为bishe使用。 项目都经过严格调试&#…

想转行IT,有前途嘛?30个详细理由中会得到你想要的答案

目录 前言: 一、转行IT的前景 二、IT行业的情况 三、技能需求 四、如何准备转行IT 如果你想转行IT,以下是一些建议: 前言: 转行IT是一个颇具吸引力的选择,尤其在当前社会,IT行业的需求非常广泛。然而…

记录一下本地源码安装部署ThingsBoard可能踩到的坑

使用git下载源码后, 必须运行 mvn clean install -DskipTests这一步很重要, 有很多文件需要初始化, 如果直接放入idea可能存在各种问题, 最好是用命令行执行 初始化时, 可能报错停止, 这个一般是网络问题, 可以尝试修改maven镜像, 这是我成功构建的镜像 <!--阿里云仓库--…

[HITCON 2017]SSRFme perl语言的 GET open file 造成rce

这里记录学习一下 perl的open缺陷 这里首先本地测试一下 发现这里使用open打开 的时候 如果通过管道符 就会实现命令执行 然后这里注意的是 perl 中的get 调用了 open的参数 所以其实我们可以通过管道符实现命令执行 然后这里如果file可控那么就继续可以实现命令执行 这里就…

如何使用Net2FTP轻松部署本地Web文件管理器并远程访问管理内网资源?

文章目录 1.前言2. Net2FTP网站搭建2.1. Net2FTP下载和安装2.2. Net2FTP网页测试 3. cpolar内网穿透3.1.Cpolar云端设置3.2.Cpolar本地设置 4.公网访问测试5.结语 1.前言 文件传输可以说是互联网最主要的应用之一&#xff0c;特别是智能设备的大面积使用&#xff0c;无论是个人…

解锁全球潜力:IT外包解决跨国企业海外分支的IT需求

在全球化的浪潮中&#xff0c;跨国企业为了拓展业务辐射面&#xff0c;经常在世界各地设立海外分支。然而&#xff0c;这些分支机构面临着独特的挑战&#xff0c;其中包括解决复杂的IT需求。为了更高效地应对这些挑战&#xff0c;越来越多的企业正在转向IT外包&#xff0c;以便…

【java】Java程序员,你掌握了多线程吗?

摘要&#xff1a;互联网的每一个角落&#xff0c;无论是大型电商平台的秒杀活动&#xff0c;社交平台的实时消息推送&#xff0c;还是在线视频平台的流量洪峰&#xff0c;背后都离不开多线程技术的支持。在数字化转型的过程中&#xff0c;高并发、高性能是衡量系统性能的核心指…

uniapp实战 —— 自适配高度的可滚动区域(scroll-view的使用技巧)

自定义的顶部导航栏&#xff0c;可参考博文 https://blog.csdn.net/weixin_41192489/article/details/134852124 如图可见&#xff0c;在页面滚动过程中&#xff0c;顶部导航栏和底栏未动&#xff0c;仅中间的内容区域可滚动。 整个页面的高度设置为 100%&#xff0c;并采用 …

echarts双折线图

引用 //反应时长 durationCharts categoryCommonChart(studyBehavior.durationCharts, durationCharts) function categoryCommonChart(odata, dom){var myChart echarts.init(document.getElementById(dom));let oarr []oarr odata.series.map(function(item){let color…