【PyTorch】多项式回归

news2025/2/24 17:47:41

文章目录

  • 1. 模型与代码实现
    • 1.1. 模型
    • 1.2. 代码实现
      • 1.2.1. 完整代码
      • 1.2.2. 输出结果
  • 2. Q&A
    • 2.1. 欠拟合与过拟合

1. 模型与代码实现

1.1. 模型

  • 将多项式特征值预处理为线性模型的特征值。即
    y = w 0 + w 1 x + w 2 x 2 + ⋯ + w n x n y = w_0+w_1x+w_2x^2+\dots+w_nx^n y=w0+w1x+w2x2++wnxn变换为 y = w 0 + w 1 z 1 + w 2 z 2 + ⋯ + w n z n y=w_0+w_1z_1+w_2z_2+\dots+w_nz_n y=w0+w1z1+w2z2++wnzn
  • 为了避免指数值过大,可以将 x i x^i xi调整为 x i i ! \frac{x^i}{i!} i!xi,即 y = w 0 + w 1 x 1 ! + w 2 x 2 2 ! + ⋯ + w n x n n ! y = w_0+w_1\frac{x}{1!}+w_2\frac{x^2}{2!}+\dots+w_n\frac{x^n}{n!} y=w0+w11!x+w22!x2++wnn!xn

1.2. 代码实现

1.2.1. 完整代码

import os
import numpy as np
import math, torch
from d2l import torch as d2l
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tensorboardX import SummaryWriter
from rich.progress import track

def evaluate_loss(dataloader):
    """评估给定数据集上模型的损失"""
    metric.reset()
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
            loss = criterion(net(X), y)
            metric.add(loss.sum(), loss.numel())
        return metric[0] / metric[1]

def load_dataset(data_arrays):
    """加载数据集"""
    dataset = TensorDataset(*data_arrays)
    return DataLoader(dataset, batch_size, shuffle=True, pin_memory=True,
        num_workers=num_workers, prefetch_factor=prefetch_factor)


if __name__ == '__main__':
    # 全局参数设置
    learning_rate = 0.01
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_epochs = 400
    batch_size = 10
    num_workers = 0
    prefetch_factor = 2

    max_degree = 20             # 多项式最高阶数
    model_degree = 1           # 多项式模型阶数
    n_train, n_test = 100, 100  # 训练集和测试集大小

    true_w = np.zeros(max_degree+1)
    true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])

    # 创建记录器
    def get_logdir():
        root = 'runs'
        if not os.path.exists(root):
            os.mkdir(root)
        order = len(os.listdir(root)) + 1
        return f'runs/exp{order}'
    writer = SummaryWriter(get_logdir())

    # 生成数据集
    features = np.random.normal(size=(n_train + n_test, 1))
    np.random.shuffle(features)
    poly_features = np.power(features, np.arange(max_degree+1).reshape(1, -1))
    for i in range(max_degree+1):
        poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
    labels = np.dot(poly_features, true_w)
    labels += np.random.normal(scale=0.1, size=labels.shape)    # 加高斯噪声服从N(0, 0.01)

    poly_features, labels = [
        torch.as_tensor(x, dtype=torch.float32) for x in [
            poly_features, labels]]
    
    # 创建模型
    net = nn.Sequential(nn.Linear(model_degree+1, 1, bias=False)).to(device, non_blocking=True)
    def init_weights(m):
        if type(m) == nn.Linear:
            nn.init.normal_(m.weight, mean=0, std=0.01)
    net.apply(init_weights)
    criterion = nn.MSELoss(reduction='none')
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

    # 加载数据集
    features_train, labels_train = poly_features[:n_train, :model_degree+1], labels[:n_train].reshape(-1, 1)
    features_test, labels_test = poly_features[n_train:, :model_degree+1], labels[n_train:].reshape(-1, 1)
    dataloader_train = load_dataset((features_train, labels_train))
    dataloader_test = load_dataset((features_test, labels_test))
    
    # 训练循环
    metric = d2l.Accumulator(2)  # 损失的总和, 样本数量
    for epoch in track(range(num_epochs)):
        for X, y in dataloader_train:
            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
            loss = criterion(net(X), y)
            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()

        writer.add_scalars(f"{model_degree}-degree", {
            "train_loss": evaluate_loss(dataloader_train),
            "test_loss": evaluate_loss(dataloader_test),
        }, epoch)
    print("weights =", net[0].weight.data.cpu().numpy())

    writer.close()

1.2.2. 输出结果

  • 采用1阶多项式(线性模型)拟合:
    1degree

  • 采用3阶多项式拟合
    3degree

  • 采用20阶多项式拟合
    20degree

2. Q&A

2.1. 欠拟合与过拟合

数据集是按照3阶多项式生成的。使用1阶多项式去拟合,发现最后损失始终降不下去,这种情况称为欠拟合,说明模型复杂度不够;使用20阶多项式去拟合,发现测试损失最后还增长了,训练和测试损失总体也比3阶多项式模型的值高,这种情况称为过拟合,说明模型太复杂了,训练过程受到了噪声的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1290127.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

自动生成实体类,mapper类和mapper.xml文件(解放双手,定义好数据库表就不要手写啦)

背景 建的表有四十多个字段,建好了已经很累了,映射成Javabean还要再写一次!! 吐槽 在建立好了sql表之后,我们已经写了一次建表了,难道还要我们自己再一次手写模Java模型吗,我的表有几十个字段…

10_企业架构NOSQL数据库之MongoDB

企业架构NOSQL数据库之MongoDB 学习目标和内容 1、能够简单描述MongoDB的使用特点 2、能够安装配置启动MongoDB 3、能够使用命令行客户端简单操作MongoDB 4、能够实现基本的数据操作 5、能够实现MongoDB基本安全设置 6、能够操作安装php的MongoDB扩展 一、背景描述及其方案设计…

Linux访问NFS存储及自动挂载

本章主要介绍NFS客户端的使用 创建NFS服务器并通过NFS共享一个目录在客户端上访问NFS共享的目录自动挂载的配置和使用 1.1 访问NFS存储 前面那篇介绍了本地存储,本章就来介绍如何使用网络上上的存储设备。NFS即网络文件系统,所实现的是Linux和Linux之…

IDEA快速生成lambda表达式的方法

IDEA快速生成lambda表达式的方法-CSDN博客 建议修改成 shift/

Redis对象

Redis根据基本数据结构构建了自己的一套对象系统。主要包括字符串对象、列表对象、哈希对象、集合对象和有序集合对象 同时不同的对象都有属于自己的一些特定的redis指令集,而且每种对象也包括多种编码类型,和实现方式。 Redis对象结构 struct redisOb…

使用TouchSocket适配一个c++的自定义协议

这里写目录标题 说明一、新建项目二、创建适配器三、创建服务器和客户端3.1 服务器3.2 客户端3.3 客户端发送3.4 客户端接收3.5 服务器接收与发送 四、关于同步Send 说明 今天有小伙伴咨询我,他和同事(c端)协商了一个协议,如果使…

二叉树的右视图[中等]

优质博文:IT-BLOG-CN 一、题目 给定一个二叉树的 根节点root,想象自己站在它的右侧,按照从顶部到底部的顺序,返回从右侧所能看到的节点值。 示例 1: 输入: [1,2,3,null,5,null,4] 输出: [1,3,4] 示例 2: 输入: [1,null,3] 输出…

强化学习第1天:强化学习概述

☁️主页 Nowl 🔥专栏《机器学习实战》 《机器学习》 📑君子坐而论道,少年起而行之 ​​ 文章目录 介绍 强化学习要素 强化学习任务示例 环境搭建:gym 基本用法 环境信息查看 创建智能体 过程可视化 完整代码 结语…

0基础学java-day15

一、泛型 1 泛型的理解和好处 1.1 看一个需求 【不小心加入其它类型,会导致出现类型转换异常】 package com.hspedu.generic;import java.util.ArrayList;/*** author 林然* version 1.0*/ public class Generic01 {SuppressWarnings("all")public st…

企业数字档案馆室建设指南

数字化时代,企业数字化转型已经成为当下各行业发展的必然趋势。企业数字化转型不仅仅是IT系统的升级,也包括企业内部各种文件、档案、合同等信息的数字化管理。因此,建设数字档案馆室也变得尤为重要。本篇文章将为您介绍企业数字档案馆室建设…

SpringMVC修炼之旅(2)基础入门

一、第一个程序 1.1环境配置 略 1.2代码实现 package com.itheima.controller;import org.springframework.stereotype.Controller; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.ResponseBody;//定义…

初识MQ——消息队列技术选型

文章目录 同步和异步通讯同步通讯异步通讯 技术对比 同步和异步通讯 微服务间通讯有同步和异步两种方式: 同步通讯:就像打电话,需要实时响应。 异步通讯:就像发邮件,不需要马上回复。 两种方式各有优劣&#xff0c…

CCF编程能力等级认证GESP—C++1级—20230318

CCF编程能力等级认证GESP—C1级—20230318 单选题(每题 2 分,共 30 分)判断题(每题 2 分,共 20 分)编程题 (每题 25 分,共 50 分)每月天数长方形面积 答案及解析单选题判断题编程题1编程题2 单选…

SQL手工注入漏洞测试(Sql Server数据库)-墨者

———靶场专栏——— 声明:文章由作者weoptions学习或练习过程中的步骤及思路,非正式答案,仅供学习和参考。 靶场背景: 来源: 墨者学院 简介: 安全工程师"墨者"最近在练习SQL手工注入漏洞&#…

国内AI翘楚,看看有没有你心动的offer?

科技创新争占高地,AI领域各显神通。从一战成名的阿尔法狗到引起轩然大波的ChatGPT,我们早已卷入了一场没有硝烟的革命。前方世人看到的科技日新日异、岁月静好,后方是各大企业的绞尽脑汁、争先恐后。人工智能时代,AI是挡不住的时代…

Lebesgue积分及应用

Lebesgue积分及应用 文章目录 Lebesgue积分及应用一、Lebesgue测度和可测函数1.1 Riemann积分和Lebesgue积分1.2 直线上的Lebesgue测度【定义】外测度(Outer Measure)【定理】外测度的性质【定义】内测度【定义】可测、Lebesgue测度【定理】卡氏条件&…

Java注册并监听全局快捷键

背景 之前在博客中分享了SWT托盘功能, 随之带来一个问题, 当程序最小化后无法快速唤醒, 按照平时使用软件的思路, 自然想到了注册全局快捷键, 本文介绍使用java方式实现全局快捷键的注册. 方案 通过google,搜到一个现成的库: jintellitype, 使用maven可以直接引用, 非常方便…

C语言易错知识点八(整形与浮点型在内存中存储的实质)

整形与浮点型在内存中存储的实质 当我们在刷抖音或者其他短视频平台时,可能会时不时(总是,我相信大家肯定是不会被外表骗到的那一类人ヾ(●゜ⅴ゜)ノ)刷到各种帅哥美女的视频,或者我们在学校里看到帅哥美女时,如果我们只…

NFC和蓝牙在物联网中有什么意义?如何选择?

#NFC物联网# #蓝牙物联网# 在物联网中,NFC和蓝牙有什么意义? NFC在物联网中代表近场通信技术。它是一种短距离、高频的无线通信技术,可以在近距离内实现设备间的数据传输和识别。NFC技术主要用于移动支付、电子票务、门禁、移动身份识别、防…

Vue2中v-html引发的安全问题

前言:v-html指令 1.作用:向指定节点中渲染包含html结构的内容。 2.与插值语法的区别: (1).v-html会替换掉节点中所有的内容,{{xx}}则不会。 (2).v-html可以识别html结构。 3.严重注意:v-html有安全性问题&#xff0…