【动手学习深度学习--逐行代码解析合集】08模型选择、欠拟合和过拟合

news2024/11/25 7:43:51

【动手学习深度学习】逐行代码解析合集

08模型选择、欠拟合和过拟合


视频链接:动手学习深度学习–模型选择、欠拟合和过拟合
课程主页:https://courses.d2l.ai/zh-v2/
教材:https://zh-v2.d2l.ai/

1、生成数据集

在这里插入图片描述

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

"====================1、生成数据集===================="
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6]) # 多项式前四项系数

# 随机生成200个样本:均值为0,方差为0.01,形状(200,1)的特征样本
features = np.random.normal(size=(n_train + n_test, 1))
np.random.shuffle(features)  # 打乱

"计算出每个样本的所有输入特征(包括阶乘)"
# np.power(a,b),求a的b次方
# poly_features:200个数组,每组20个值
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    # poly_features:(200,20)
    poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!

"计算每个样本的真实标签,加上噪声"
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)   # 点乘运算,形成200个多项式
# 噪声项服从均值为0且标准差为0.1的正态分布
labels += np.random.normal(scale=0.1, size=labels.shape)

"NumPy ndarray转换为tensor"
true_w, features, poly_features, labels = [torch.tensor(x, dtype=
    torch.float32) for x in [true_w, features, poly_features, labels]]

print(features[:2], poly_features[:2, :], labels[:2])

运行结果

在这里插入图片描述

2、对模型进行训练和测试

"====================2、对模型进行训练和测试===================="
# 计算网络模型在训练集或数据集上的损失均值
# net:定义的网络模型
# data_iter:打乱的并且根据批量大小切割好的训练集或测试集
# loss:损失函数
def evaluate_loss(net, data_iter, loss):  #@save
    """评估给定数据集上模型的损失"""
    metric = d2l.Accumulator(2)  # 损失的总和,样本数量
    for X, y in data_iter:
        # 计算一个批量的预测值
        out = net(X)
        y = y.reshape(out.shape)
        # 计算一个批量的损失
        l = loss(out, y)
        # 将损失总和和样本数量  两个值累加
        metric.add(l.sum(), l.numel())
    # 返回损失均值
    return metric[0] / metric[1]

# 定义训练函数
def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss(reduction='none')  # 均方误差损失函数
    input_shape = train_features.shape[-1]
    # 不设置偏置,因为我们已经在多项式中实现了它
    # 网络模型:input_shape个输入,1个输出,没有偏置项
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])  # 设置批量大小为10
    # 按批量大小取出训练集(特征 + 对应的标签)
    train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),
                                batch_size)
    # 按批量大小取出测试集(特征 + 对应的标签)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),
                               batch_size, is_train=False)
    # 定义优化器
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)
    # 定义动画,显示训练结果
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs], ylim=[1e-3, 1e2],
                            legend=['train', 'test'])
    # 训练400轮
    for epoch in range(num_epochs):
        # 训练一轮
        d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        # 每隔20轮,将训练得到的模型在训练集和测试集上分别计算一次损失(训练损失、泛化损失)
        if epoch == 0 or (epoch + 1) % 20 == 0:
            animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),
                                     evaluate_loss(net, test_iter, loss)))
    # 输出训练得到的模型权重
    print('weight:', net[0].weight.data.numpy())

3、三阶多项式函数拟合(正常)

"====================3、三阶多项式函数拟合(正常)===================="
# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
# poly_features[:n_train, :4] :前100个样本,前4个维度  作为训练集
# poly_features[n_train:, :4] :后100个样本,前4个维度  作为训练集
# labels[:n_train]:将前100个样本的输出值作为训练集标签
# labels[n_train:]:将后100个样本的输出值作为测试集标签
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:])
d2l.plt.show()
"输出 weight: [[ 4.9942217  1.1960176 -3.4083142  5.5780005]]"

运行结果

在这里插入图片描述

4. 线性函数拟合(欠拟合)

"====================4、线性函数拟合(欠拟合)===================="
# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],
      labels[:n_train], labels[n_train:])
d2l.plt.show()
"输出 weight: [[3.214788  4.6012254]]"

运行结果

在这里插入图片描述

5. 高阶多项式函数拟合(过拟合)

"====================5、高阶多项式函数拟合(过拟合)===================="
# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],
      labels[:n_train], labels[n_train:], num_epochs=1500)
d2l.plt.show()
'''
输出:
weight: [[ 4.95916700e+00  1.27137506e+00 -3.25717926e+00  5.24726105e+00
  -3.11582983e-01  1.13641846e+00  2.20295087e-01 -8.79566371e-02
   4.93251672e-03 -1.06145725e-01  7.90703818e-02  1.64333731e-02
   8.57480839e-02 -1.81607231e-01  1.93943262e-01 -1.26601264e-01
   2.00300813e-01 -1.24204971e-01  1.35094225e-01 -3.30150127e-03]]
'''

运行结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/725011.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ROS:参数名称设置

目录 一、前言二、rosrun设置参数三、launch文件设置参数四、编码设置参数4.1C实现4.1.1ros::param设置参数4.1.2ros::NodeHandle设置参数 4.2python实现 一、前言 在ROS中节点名称话题名称可能出现重名的情况,参数名称也可能重名。 关于参数重名的处理&#xff0c…

Css:浮动相关

1.为什么需要浮动? 多个块级元素纵向排列找 标准流,多个块级元素横线排列找 浮动 2.浮动的特性 浮动会脱离标准流(脱标) 浮动的盒子不再保留原来的位置 3.清除浮动

BM68-矩阵的最小路径和

题目 给定一个 n * m 的矩阵 a,从左上角开始每次只能向右或者向下走,最后到达右下角的位置,路径上所有的数字累加起来就是路径和,输出所有的路径中最小的路径和。 数据范围: 1≤n,m≤500,矩阵中任意值都满足 0≤ai,j…

T100新程序的开发【完整步骤】

简易程序的开发 记录T100中一个简易程序的开发完整步骤。 一、程序基本数据设置作业 打开作业 azzi900,弹出作业详情。 新增一个程序编号。 一些属性概念 程序编号:手动输入你建立的新程序。程序名称:手动输入你建立的名称。归属模块:取决于你程序编号的第一个字母。归属…

C语言判断当前目录下是否存在某一个文件

要判断当前目录下是否存在文件A&#xff0c;可以使用C语言中的标准库函数access来实现。access函数用于检查指定文件是否存在及是否具有指定的访问权限。 #include <stdio.h> #include <unistd.h>int main() {const char* filename "fileName";// 检查…

MongoDB【Springboot访问MongoDB、MongoDB安全认证、MongoDB内置角色 】(五)-全面详解(学习总结---从入门到深化)

目录 Springboot访问MongoDB MongoDB安全认证 MongoDB内置角色 Springboot访问MongoDB MongoTemplate方式 引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-mongodb</artifactId>…

【裸机开发】SPI 通信接口(一)—— SPI 通信流程及四种工作模式

目录 一、SPI 简介 二、SPI 的基本通信流程 三、SPI 的四种工作模式 1、极性和相位 2、四种工作模式 一、SPI 简介 SPI 采用主从的方式工作&#xff0c;可以一个主设备对应一个从设备&#xff0c;也可以一个主设备对应多个从设备。虽然是一个主设备对多个从设备的关系&am…

如何建立一套完整的人事管理制度?

一、什么是人事管理制度 人事管理制度是企业为有效管理和运营人力资源而建立的一系列规章制度、流程和政策。它是人力资源管理的基础&#xff0c;旨在确保企业拥有合适的员工队伍&#xff0c;并通过有效的管理和激励机制&#xff0c;使员工能够充分发挥自己的潜力&#xff0c;…

计算机网络 day2 物理层-数据链路层-帧-MAC地址 交换机的工作原理

目录 物理层&#xff08;physical layer&#xff09; 数据链路层&#xff08;Data link layer&#xff09; MAC地址&#xff1a; 网络地址&#xff1a; 帧的格式&#xff1a; MTU&#xff1a;最大传输单元 max transfer unit 1500 &#xff08;ip add可以查看&#xf…

【动态规划算法练习】day16

文章目录 一、完全背包1.题目简介2.解题思路3.代码4.运行结果 二、322. 零钱兑换1.题目简介2.解题思路3.代码4.运行结果 三、518. 零钱兑换 II1.题目简介2.解题思路3.代码4.运行结果 四、279. 完全平方数1.题目简介2.解题思路3.代码4.运行结果 总结 一、完全背包 1.题目简介 …

【百日冲大厂】第二十篇,牛客网选择题+编程题 字符串反转+公共子串计算(dp问题)

前言&#xff1a; 大家好&#xff0c;我是良辰丫&#xff0c;第二十篇,牛客网选择题编程题 字符串反转公共子串计算(dp问题).&#x1f49e;&#x1f49e;&#x1f49e;生活就像一只盲盒&#xff0c;藏着意想不到的辛苦&#xff0c;当然也有万般惊喜的可能。不管是次次都如愿以偿…

初学者一步步学习python 学习提纲

当学习Python时&#xff0c;可以按照以下提纲逐步学习&#xff1a; 入门基础 了解Python的历史和应用领域安装Python解释器和开发环境&#xff08;如Anaconda、IDLE等&#xff09;学习使用Python的交互式解释器或集成开发环境&#xff08;IDE&#xff09;进行简单的代码编写和…

浅谈C++下观察者模式的实现

为什么要有观察者模式 想象一个场景&#xff0c;有一只猫和一群老鼠&#xff0c;当猫出现的时候&#xff0c;每一只老鼠都要逃跑 用最简单的方法实现一个去模拟这一个过程 #include<iostream>class Mouse_1 {public:void CatCome(){std::cout<<"Mouse_1 Ru…

【MyBatis-Plus】DQL编程控制

1&#xff0c;DQL编程控制 增删改查四个操作中&#xff0c;查询是非常重要的也是非常复杂的操作&#xff0c;这块需要我们重点学习下&#xff0c;这节我们主要学习的内容有: 条件查询方式查询投影查询条件设定字段映射与表名映射 1. 条件查询 1. 条件查询的类 MyBatisPlus…

PHP 训练成绩管理系统mysql数据库web结构apache计算机软件工程网页wamp

一、源码特点 PHP 训练成绩管理系统 是一套完善的web设计系统&#xff0c;对理解php编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为PHP APACHE&#xff0c;数据库 为mysql5.0&#xff0c;使用php语言开发。 …

css设计表格圆角最简单的方法

代码如下&#xff1a; table {width: 100%;/* border-collapse: collapse; */background-color: #FBFBFB; /* 背景颜色; */border-collapse: separate; /* 让border-radius有效 */border-spacing: 0; /*表格中每个格边距设为0*/border: 1px solid #DFDFDF;/*边框*/border-radi…

软件测试面试简历,三年测试项目经验怎么写?

作为三年左右的测试工程师&#xff0c;简历上有五六个项目经历很正常&#xff0c;那如何设计这几个项目&#xff0c;其实设计好三两个就行&#xff0c;面试官能关注到的也只有最新的三两个&#xff0c;两年前的项目也没有关注的必要啦&#xff0c;所以在这两三个项目中一定要体…

mysql8.0 navicat mysql 2059报错

进入mysql安装目录&#xff1a; 输入用户名密码连接mysql 设置密码 刷新 测试连接&#xff0c;连接成功

如何保障业务稳定性?一文详解蚂蚁业务智能可观测平台BOS

随着业务规模的不断扩大以及AI、云计算、大数据等技术的不断发展&#xff0c;大量的企业希望利用上云来加速其数字化转型&#xff0c;全面提升可靠性、安全性和灵活性&#xff0c;并且降低运营成本。 不过对于大多数企业来说&#xff0c;全面上云是一项颇具难度的挑战。这里面…

阿里图标库中图标的下载使用

一 iconfont-阿里巴巴矢量图标库 进去找到你想要的图标 二 点这个 三 点这个 点这个 新建自己的项目 选择这个点下载 解压出来&#xff0c;除了两个demo不要都添加到你的代码中的文件夹保存 四 main.js中全局导入 import ./xxxx/xxxx/iconfont.css 五 页面使用 <…