李沐深度学习记录3:11模型选择、欠拟合和过拟合

news2025/1/12 19:00:18

通过多项式拟合探索欠拟合与过拟合

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

#生成数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])

features = np.random.normal(size=(n_train + n_test, 1))#生成均值为0,标准差为1的正态分布概率密度随机数
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)  #加上噪声

#分别对应常数系数、x值、x的次方再除以阶乘、算得的y值
true_w,features,poly_features,labels=[torch.tensor(x,dtype=torch.float32) for x in [true_w,features,poly_features,labels]]
true_w[:2],features[:2],poly_features[:2,:],labels[:2]

在这里插入图片描述

#实现函数评估模型损失
def evaluate_loss(net,data_iter,loss):
    '''评估给定数据集上模型的损失'''
    metric=d2l.Accumulator(2) #记录 损失的总和,样本数量
    for X,y in data_iter:
        out=net(X)
        y=y.reshape(out.shape)
        l=loss(out,y)
        metric.add(l.sum(),l.numel())
    return metric[0]/metric[1]

#from torch.utils import data

#定义训练函数
def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss(reduction='none')
    input_shape = train_features.shape[-1]
    # 不设置偏置,因为我们已经在多项式中实现了它
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])
    #print(train_labels.shape)   [100]
    #print(train_labels.shape[0])   100
    #print(batch_size)   10
    train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),
                                batch_size)
    #等价于以下两行代码
#     train_dataset=data.TensorDataset(train_features,train_labels.reshape(-1,1))
#     train_iter=data.DataLoader(train_dataset,batch_size,shuffle=True)
    
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs], ylim=[1e-3, 1e2],
                            legend=['train', 'test'])
    for epoch in range(num_epochs):
        d2l.train_epoch_ch3(net, train_iter, loss, trainer) #训练模型一个迭代周期
        if epoch == 0 or (epoch + 1) % 20 == 0:   #每20个epoch训练迭代周期,评估一次模型损失(包括训练集和测试集),画一次数据点
            animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),  #评估训练数据集上模型损失
                                     evaluate_loss(net, test_iter, loss)))  #评估测试数据集上模型损失
    print('weight:', net[0].weight.data.numpy())  #输出最终模型的参数权重
#使用三阶多项式拟合,与数据生成函数的阶数相同(正常)
# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:])

在这里插入图片描述

#使用线性函数拟合非线性函数(这里是三阶多项式函数),线性模型很容易欠拟合
# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],
      labels[:n_train], labels[n_train:])

![在这里插入图片描述](https://img-blog.csdnimg.cn/10ad87a962cf40ccb39909b0fa84c7ea.png

#使用一个阶数过高的复杂多项式模型来训练会造成过拟合。在这种情况下,没有足够的数据用于学到高阶系数应该具有接近于零的值。 因此,这个过于复杂的模型会轻易受到训练数据中噪声的影响。 虽然训练损失可以有效地降低,但测试损失仍然很高。 
# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],
      labels[:n_train], labels[n_train:], num_epochs=1500)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1065667.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VMware 17pro安装流程附带密钥手把手教

VMware 17pro centos-8.5.2111-isos-x86_64安装包下载_开源镜像站-阿里云 安装VMware 17pro 下一步 勾选我接营许可协议中的条款点击下一步 更改路径后点击下一步 注意两个都要取消勾选不然会自动更新 下一步即可 最后一步为安装就行(我电脑上有VMware 16pro所以我的…

3.绘制一个点(鼠标点击)

愿你出走半生,归来仍是少年! 通过鼠标点击交互,实现在gl中绘制点。 1.知识点 1.1.点击坐标转换为Gl坐标 通过canvas的点击事件将会获得鼠标在浏览器客户区中的坐标。通过移除canvas自身位置后可获取鼠标在canvas中的点击位置。同时通过canvas的长宽将坐…

Mini-dashboard 和meilisearch配合使用

下载的meilisearch一般是development模式,内置客户端,修改客户端后需要重要全部编译,花时间太长了。前后端分离才是正道,客户端修改不用重新编译后端。 方法如下: 1、修改配置文件/etc/meilisearch.toml,…

磁盘io使用率高问题排查

Linux系统出现了性能问题,一般我们可以通过top、iostat、free、vmstat等命令来查看初步定位问题。其中iostat可以给我们提供丰富的IO状态数据。 1.小文件读写的磁盘性能瓶颈是寻址(随机读写性能更差)评估标准:TPS 2.大文件读写的磁盘性能瓶颈…

cereal:支持C++11的开源序列化库

cereal:支持C11的开源序列化库 文章目录 一:引言二、cereal简介三、cereal的下载和使用 一:引言 序列化 (Serialization) 程序员在编写应用程序的时候往往需要将程序的某些数据存储在内存中,然后将其写入某个文件或是将它传输到网络中的另…

认识计算机主板

目录 定义主要部件简单图示 主要功能 定义 计算机主板(Motherboard)是计算机系统中的核心组件之一,也被称为系统板、主板或母板。它是一个电子电路板,用于连接和支持计算机的各种硬件组件,包括中央处理器(…

零售业:别让数据安全成为业务的绊脚石!(附文件下载)

零售业上榜! 截至2023年8月31日,南都大数据研究院通过各地行政执法公示平台、媒体报道等公开渠道,收集到146起依据《数据安全法》作出行政处罚决定的案例。梳理发现,零售业以10.27%的占比位居行业第三,成为数据安全行…

尚硅谷Nginx教程由浅入深--笔记

尚硅谷Nginx教程由浅入深--笔记 Nginx简介Nginx相关概念反向代理负载均衡动静分离 Nginx安装Nginx命令Nginx配置Nginx配置实例反向代理1反向代理2负载均衡动静分离 Nginx简介 Nginx是一个高性能的HTTP和反向代理服务器,特点是内存占用少,并发能力强。 …

高铁站高速稳定用网秘籍,赶紧收藏

中秋国庆黄金周将至,销售旺季即将来临。车来车往、人潮涌动,稳定可靠的网络连接,成为了各大小商户抢占市场、掌握流量密码的关键。 在湖南省郴州市,某食品连锁商店负责人正在为店铺网络问题发愁。该连锁店部分销售网点位于繁忙的高…

Pikachu靶场——暴力破解

文章目录 1. 暴力破解1.1 基于表单的暴力破解1.2 验证码绕过(on server)1.3 验证码绕过(on client)1.4 token防爆破?1.5 漏洞防御 1. 暴力破解 暴力破解漏洞是指攻击者通过尝试各种组合的用户名和密码,以暴力方式进入系统或应用程序的方法。它利用了系统或应用程序…

睿趣科技:新手无货源怎么开抖音小店

抖音小店的开设对于许多商家来说是一个有吸引力的选择,尤其是对于那些喜欢短视频和社交媒体的年轻人。然而,对于没有货源的新手来说,这可能是一个令人头疼的问题。这篇文章将为你提供一些解决方案。 首先,你可以考虑从批发市场购买…

掌握SKILL语言:数字IC设计师必备的技能之一

去年在各个平台更新了一篇《数字IC必学之《Skill入门教程》建议收藏!》,阅读量在每个平台都很客观,且这半年以来,不断有粉丝留言想要获取这份资料。看来大家对于SKILL的需求是很大的,想要掌握SKILL语言:数字…

95、Spring Data Redis 之使用RedisTemplate 实现自定义查询 及 Spring Data Redis 的样本查询

Spring Data Redis 之使用RedisTemplate 实现自定义查询 Book实体类 原本的接口,再继承我们自定义的接口 自定义查询接口----CustomBookDao 实现类:CustomBookDaoImpl 1、自定义添加hash对象的方法 2、自定义查询价格高于某个点的Book对象 测试&a…

微信小程序wxs标签 在wxml文件中编写JavaScript逻辑

PC端开发 可以在界面中编写JavaScript脚本 vue/react这些框架更是形成了一种常态 因为模板引擎和jsx语法本身就都是在js中的 我们小程序中其实也有类似的奇妙写法 不过先声明 这东西不是很强大 我们可以先写一个案例代码 wxml代码参考 <view><wxs module"wordSt…

如何与瑞诺司Rhenus 建立EDI连接?

Rhenus Automotive 是德国百年家族企业Rethmann Group的子公司&#xff0c;提供从零部件的有序供应、即装即用模块的组装&#xff0c;一直到整车的组装。主要在全球范围内为劳斯莱斯&#xff0c;宝马&#xff0c;奔驰&#xff0c;奥迪等汽车企业提供智能制造解决方案。 项目挑战…

阿拉伯文排版是如何实现的

背景&#xff1a; 今天开工了&#xff0c;无意间看到多语言的页面&#xff0c;毕竟我们网站也是有多语言的 但是并没有阿拉伯语。但是我很好奇&#xff0c;分析阿拉伯语言的css并没有发现什么猫腻&#xff01; 到底是怎么实现的呢&#xff1f; 解&#xff1a; html dir 属性 …

聊聊MySQL的聚簇索引和非聚簇索引

文章目录 1. 索引的分类1. 存储结构维度2. 功能维度3. 列数维度4. 存储方式维度5. 更新方式维度 2. 聚簇索引2.1 什么是聚簇索引2.2 聚簇索引的工作原理 3. 非聚簇索引&#xff08;MySQL官方文档称为Secondary Indexes&#xff09;3.1 什么是非聚簇索引3.2 非聚簇索引的工作原理…

win7怎么录屏视频?小白也能轻松学会

“win7怎么录屏视频呀&#xff1f;在学校机房上课&#xff0c;电脑都是win7系统的&#xff0c;每次需要录屏的时候都找不到方法&#xff0c;问了老师也解决不了&#xff0c;有人知道win7怎么录屏吗&#xff1f;” Windows 7系统已经逐渐淡出了主流操作系统的行列&#xff0c;但…

JVM上篇之虚拟机与java虚拟机介绍

目录 虚拟机 java虚拟机 简介 特点 作用 位置 整体结构 类装载子系统 运行时数据区 java执行引擎 Java代码执行流程 jvm架构模型 基于栈式架构 基于寄存器架构 总结 jvm的生命周期 1.启动 2.执行 3.退出 JVM的发展历程 虚拟机 所谓虚拟机&#xff0c;指的…

要体验 AI 编程助手吗?

能不能用 AI 编程辅助写代码&#xff1f; 亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、活动与竞赛等。帮助中国开发者对接世界最前沿技术&#xff0c;观点&#xff0c;和项目&#xff0c;并将中国优秀开发者或技术…