机器学习第5天:多项式回归与学习曲线

news2024/11/17 23:36:02

文章目录

多项式回归介绍

方法与代码

方法描述

分离多项式

学习曲线的作用

场景

学习曲线介绍

欠拟合曲线

示例

结论

过拟合曲线

示例

​结论 


多项式回归介绍

当数据不是线性时我们该如何处理呢,考虑如下数据

import matplotlib.pyplot as plt
import numpy as np


np.random.seed(42)

x = 8 * np.random.rand(100, 1) - 4
y = 2*x**2+3*x+np.random.randn(100, 1)

plt.scatter(x, y)
plt.show()


方法与代码

方法描述

先讲思路,以这个二元函数为例

y=3*x^{2}+2*x+c

将多项式化为多个单项的,也就是将x的平方和x两个项分离开,然后单独给线性模型处理,求出参数,最后再组合在一起,很好理解,让我们来看一下代码


分离多项式

我们使用机器学习库的PolynomialFeatures来分离多项式

from sklearn.preprocessing import PolynomialFeatures


poly_features = PolynomialFeatures(degree=2, include_bias=False)
x_poly = poly_features.fit_transform(x)
print(x[0])
print(x_poly[0])

运行结果

可以看到,4, 5行代码将原始x和x平方挑选了出来,这时我们再把这个数据进行线性回归

model = LinearRegression()
model.fit(x_poly, y)
print(model.coef_)

 这段代码使用处理后的x拟合y,再打印模型拟合的参数,可以看到模型的两个参数分别是2.9和2左右,而我们的方程的一次参数和二次参数分别是3和2,可见效果还是很好的

把预测的结果绘制出来

model = LinearRegression()
model.fit(x_poly, y)
pre_y = model.predict(x_poly)

# 这里是为了让x升序的排序算法, 可以尝试不加这段代码图会变成什么样
sorted_indices = sorted(range(len(x)), key=lambda k: x[k])
x_sorted = [x[i] for i in sorted_indices]
y_sorted = [pre_y[i] for i in sorted_indices]

plt.plot(x_sorted, y_sorted, "r-")
plt.scatter(x, y)
plt.show()


学习曲线的作用

场景

设想一下,当你需要预测房价,你也有多组数据,包括离学校距离,交通状况等,但是问题来了,你只知道这些特征可能与房价有关,但并不知道这些特征与房价之间的方程关系,这时我们进行回归任务时,就可能导致欠拟合或者过拟合,幸运的是,我们可以通过学习曲线来判断


学习曲线介绍

学习曲线图就是以损失函数为纵坐标,数据集大小为横坐标,然后在图上画出训练集和验证集两条曲线的图,训练集就是我们用来训练模型的数据,验证集就是我们用来验证模型性能的数据集,我们往往将数据集分成训练集与验证集

我们先定义一个学习曲线绘制函数

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression


def plot_learning_curves(model, x, y):
    x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)
    train_errors, val_errors = [], []
    for m in range(1, len(x_train)):
        model.fit(x_train[:m], y_train[:m])
        y_train_predict = model.predict(x_train[:m])
        y_val_predict = model.predict(x_val)
        train_errors.append(mean_squared_error(y_train[:m], y_train_predict))
        val_errors.append(mean_squared_error(y_val, y_val_predict))
    plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")
    plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")
    plt.legend()
    plt.show()

 简单介绍一下,这个函数接收模型参数,x,y参数,然后在for循环中,取不同数据集大小来计算RMSE损失(就是\sqrt{MSE}),然后把曲线绘制出来


欠拟合曲线

我们知道欠拟合就是模拟效果不好的情况,可以想象的到,无论在训练集还是验证集上,他的损失都会比较高

示例

我们将线性模型的学习曲线绘制出来

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression


def plot_learning_curves(model, x, y):
    x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)
    train_errors, val_errors = [], []
    for m in range(1, len(x_train)):
        model.fit(x_train[:m], y_train[:m])
        y_train_predict = model.predict(x_train[:m])
        y_val_predict = model.predict(x_val)
        train_errors.append(mean_squared_error(y_train[:m], y_train_predict))
        val_errors.append(mean_squared_error(y_val, y_val_predict))
    plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")
    plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")
    plt.legend()
    plt.show()


x = np.random.rand(100, 1)
y = 2 * x + np.random.rand(100, 1)

model = LinearRegression()
plot_learning_curves(model, x, y)

 

结论

可以看到,在只有一点数据时,模型在训练集上效果很好(因为就是开始这一些数据训练出来的),而在验证集上效果不好,但随着训练集增加(模型学习到的越多),验证集上的误差逐渐减小,训练集上的误差增加(因为是学到了一个趋势,不会完全和训练集一样了)

这个图的特征是两条曲线非常接近,且误差都较大(差不多在0.3) ,这是欠拟合的表现(模型效果不好)


过拟合曲线

过拟合就是完全以数据集来模拟曲线,泛化能力很差

示例

我们来试试将一次函数模拟成三次函数,再来看看学习曲线(毫无疑问过拟合了)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline


def plot_learning_curves(model, x, y):
    x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)
    train_errors, val_errors = [], []
    for m in range(1, len(x_train)):
        model.fit(x_train[:m], y_train[:m])
        y_train_predict = model.predict(x_train[:m])
        y_val_predict = model.predict(x_val)
        train_errors.append(mean_squared_error(y_train[:m], y_train_predict))
        val_errors.append(mean_squared_error(y_val, y_val_predict))
    plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")
    plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")
    plt.legend()
    plt.show()


np.random.seed(10)
x = np.random.rand(200, 1)
y = 2 * x + np.random.rand(200, 1)

poly_regression = Pipeline([
    ("Poly", PolynomialFeatures(degree=3, include_bias=False)),
    ("Line", LinearRegression())
])

plot_learning_curves(poly_regression, x, y)

结论 

这条曲线的特征是训练集的效果比验证集好(两条线之间有一定间距),这往往是过拟合的表现(在训练集上效果好,验证集差,表面泛化能力差) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1222330.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

皮具生产ERP都有哪些功能?企业应当如选型

箱包皮具型号多、款式多、营销渠道广泛,各个销售平台的管理模式各不相同,而皮具生产企业经营策略和管理模式的差异,也导致企业在市场竞争力等方面参差不齐。 此外,不同的客户有不同的需求,再加上皮具行业竞争的加剧&a…

【鸿蒙应用ArkTS开发系列】- 云开发入门简介

目录 概述开发流程工程概览工程模板工程结构 工程创建与配置 概述 HarmonyOS云开发是DevEco Studio新推出的功能,可以让您在一个项目工程中,使用一种语言完成端侧和云侧功能的开发。 基于AppGallery Connect Serverless构建的云侧能力,开发…

直线插补-逐点比较法

直线插补-逐点比较法 逐点比较法四个节拍的工作流程如图所示举例1 逐点比较法 逐点比较法逐点比较法是通过逐点比较刀具与所需插补曲线之间的相对位置,确定刀具的进给方向,进而加工出工件轮廓的插补方法。刀具从加工起点开始,按照“靠近曲线…

ssrf学习笔记总结

SSRF概述 ​ 服务器会根据用户提交的URL 发送一个HTTP 请求。使用用户指定的URL,Web 应用可以获取图片或者文件资源等。典型的例子是百度识图功能。 ​ 如果没有对用户提交URL 和远端服务器所返回的信息做合适的验证或过滤,就有可能存在“请求伪造”的…

mount /dev/mapper/centos-root on sysroot failed处理

今天发现centos7重启开不进去系统 通过查看日志主要告警如下 修复挂载目录 xfs_repair /dev/mapper/centos-root不行加-L参数 xfs_repair -L /dev/mapper/centos-root重启 reboot

软件测试基础-01

目标: 1.软件测试的定义; 2.7种测试分类的区别; 3.质量模型的重点5项; 4.测试流程的6个步骤; 5.测试模板8要素; 就业方向: 1.功能测试接口测试 2.功能测试性能测试 3.功能测试web自动化…

DAY57 739. 每日温度 + 496.下一个更大元素 I

739. 每日温度 题目要求: 请根据每日 气温 列表,重新生成一个列表。对应位置的输出为:要想观测到更高的气温,至少需要等待的天数。如果气温在这之后都不会升高,请在该位置用 0 来代替。 例如,给定一个列…

ChatGPT暂时停止开通puls,可能迎来封号高峰期

前言: 前两日,chat gpt的创始人 San Altman在网上发表了,由于注册的使用量超过了他们的承受能力,为了确保每个人的良好使用体验,chat gpt将暂时停止开通gpt plus。 情况: 前段时间好像出现了官网崩溃的情况,就连api key都受到了影响,所以现在就开始了暂时停止puls的注…

OpenAI的Whisper蒸馏:蒸馏后的Distil-Whisper速度提升6倍

1 Distil-Whisper诞生 Whisper 是 OpenAI 研发并开源的一个自动语音识别(ASR,Automatic Speech Recognition)模型,他们通过从网络上收集了 68 万小时的多语言(98 种语言)和多任务(multitask&am…

零代码编程:用ChatGPT自动合并多个Word文件

一个文件夹中有多个docx格式的word文档: 想要把它们都合并成一个文件,然后打印,可以在ChatGPT中输入提示词: 你是一个Python编程专家,要完成一个处理word内容的任务,具体步骤如下: 打开文件夹…

lvgl 画圆弧时进入 HardFault

目录 一、现象描述 lvgl 版本 二、问题分析 lvgl 需要的资源新建mcu 工程时默认分配的资源问题解决 一、现象描述 移植完lvgl 之后,能正常显示label,但是button arc 等复杂的控件都不能正常显示。调用官方的画圆弧demo 时,在多次调用 _lv…

机器学习第4天:模型优化方法—梯度下降

文章目录 前言 梯度下降原理简述 介绍 可能的问题 批量梯度下降 随机梯度下降 基本算法 存在的问题 退火算法 代码演示 小批量梯度下降 前言 若没有机器学习基础,建议先阅读同一系列以下文章 机器学习第1天:概念与体系漫游-CSDN博客 机器学习…

C语言第入门——第十六课

目录 一、分治策略与递归 二、递归 1.求解n的阶乘 2.输入整数、倒序输出 3.输入整数、正序输出 4.计算第n位Fibonacci数列 ​编辑5.无序整数数组打印 6.找到对应数组下标 一、分治策略与递归 在我们遇到大问题的时候,我们的正确做法是将它分解成小问题&a…

3DMAX平铺插件MaxTiles教程

MaxTiles 结合了一组材质和地图插件,任何建筑师或 3D 可视化艺术家都会喜欢。与静态位图纹理不同,MaxTiles 材质可以更改键合图案、替换和混合砖块、更改边缘、随机化颜色、位置、表面等等。MaxTiles 结合了以下功能: 墙壁和瓷砖 – 用于创建…

【计算思维】蓝桥杯STEMA 科技素养考试真题及解析 5

1、要把下面4张图片重新排列成蜗牛的画像,该如何排列这些图片 A、 B、 C、 D、 答案:A 2、将下图的绳子沿虚线剪开后,绳子被分成了()部分 A、6 B、7 C、8 D、9 答案:C 3、下面的立体图形,沿箭头方向看去&#…

“世亚智博会,世亚软博会”双展联动,3月上海,4月杭州,6月北京

2024世亚智博会与世亚软博会双展联动,3月上海,4月杭州,6月北京,历经多年沉淀与打磨,随着扩张速度的不断加快,参展企业的数量也水涨船高,引领行业前沿趋势,已成为智能产业和软件行业的…

【uniapp】Google Maps

话不多说 直接上干货 提前申请谷歌地图账号一、新建地图 使用h5获取当前定位或者使用三方uniapp插件 var coords ""navigator.geolocation.getCurrentPosition(function(position) {coords {lat: position.coords.latitude,lng: position.coords.longitude};lats …

windows安装wsl2以及ubuntu

查看自己系统的版本 必须运行 Windows 10 版本 2004 及更高版本(内部版本 19041 及更高版本)或 Windows 11 才能使用以下命令 在设置,系统里面就能看到 开启windows功能 直接winQ搜 开启hyber-V、使用于Linux的Windows子系统、虚拟机平…

让资产权利归于建设者:Kiosk使过程变得更简单

区块链凭借着其将人的权利地位置于平台之上的能力,可以重塑互联网,而自托管为个人提供了控制和管理其资产和数据的能力。链上交易支持建设者和客户之间的点对点交易。这些特质联合起来,可以将数字世界从基于价值提取的模式转变为基于价值创造…

机器学习第7天:逻辑回归

文章目录 介绍 概率计算 逻辑回归的损失函数 单个实例的成本函数 整个训练集的成本函数 鸢尾花数据集上的逻辑回归 Softmax回归 Softmax回归数学公式 Softmax回归损失函数 调用代码 参数说明 结语 介绍 作用:使用回归算法进行分类任务 思想:…