训练/测试、过拟合问题

news2024/12/23 6:37:31

在机器学习中,我们创建模型来预测某些事件的结果,比如之前使用重量和发动机排量,预测了汽车的二氧化碳排放量

要衡量模型是否足够好,我们可以使用一种称为训练/测试的方法

训练/测试是一种测量模型准确性的方法

之所以称为训练/测试,是因为我们将数据集分为两组:训练集和测试集

80% 用于训练,20% 用于测试

使用训练集来训练模型、

使用测试集来测试模型

训练模型意味着创建模型

测试模型意味着测试模型的准确

下面是模拟的数据:我们的数据集展示了商店中的 100 位顾客及其购物习惯

import numpy
import matplotlib.pyplot as plt

# 使用 `numpy.random.seed()` 函数设定种子可以确保每次生成的随机数序列是相同的
# 从而保证算法的可重复性和稳定性
numpy.random.seed(2)

x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / x

plt.scatter(x, y)
plt.show()

散点图如下

x 轴表示购买前的分钟数

y 轴表示在购买上花费的金额

训练集应该是原始数据的 80% 的随机选择

测试集应该是剩余的 20%

train_x = x[:80]
train_y = y[:80]

test_x = x[80:]
test_y = y[80:]

显示与训练集相同的散点图

plt.scatter(train_x, train_y)
plt.show()

如下所示

import numpy
import matplotlib.pyplot as plt
numpy.random.seed(2)

x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / x

train_x = x[:80]
train_y = y[:80]

test_x = x[80:]
test_y = y[80:]

plt.scatter(train_x, train_y)
plt.show()

为了确保测试集不是完全不同,我们还要看一下测试集

plt.scatter(test_x, test_y)
plt.show()

 

 进行拟合数据集,通过数据点画一条线,我们使用 matplotlib 模块的 plott() 方法

绘制穿过数据点的多项式回归线

import numpy
import matplotlib.pyplot as plt
numpy.random.seed(2)

x = numpy.random.normal(3, 1, 100)

# 对应位置逐个元素相除,可以用来进行归一化、标准化等数据预处理操作
y = numpy.random.normal(150, 40, 100) / x

train_x = x[:80]
train_y = y[:80]

test_x = x[80:]
test_y = y[80:]

mymodel = numpy.poly1d(numpy.polyfit(train_x, train_y, 4))

# 生成 0 ~ 6 之间的100个 等差数列用于拟合曲线
myline = numpy.linspace(0, 6, 100)

plt.scatter(train_x, train_y)
plt.plot(myline, mymodel(myline))
plt.show()

此结果可以支持我们对数据集拟合多项式回归的建议,即使如果我们尝试预测数据集之外的值会给我们带来一些奇怪的结果。例如:该行表明某位顾客在商店购物 6 分钟,会完成一笔价值 200 的购物。这可能是过拟合的迹象

但是 R-squared 分数呢? R-squared score很好地指示了我的数据集对模型的拟合程度

 R2,也称为 R平方(R-squared),它测量 x 轴和 y 轴之间的关系,取值范围从 0 到 1,其中 0 表示没有关系,而 1 表示完全相关

sklearn 模块有一个名为 rs_score() 的方法,该方法将帮助我们找到这种关系

在这里,我们要衡量顾客在商店停留的时间与他们花费多少钱之间的关系

import numpy
from sklearn.metrics import r2_score
numpy.random.seed(2)

x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / x

train_x = x[:80]
train_y = y[:80]

test_x = x[80:]
test_y = y[80:]

mymodel = numpy.poly1d(numpy.polyfit(train_x, train_y, 4))

r2 = r2_score(train_y, mymodel(train_x))

print(r2)

 因此,从上面的情况来看,在训练数据方面,我们已经建立了一个不错的模型

然后,我们要使用测试数据来测试模型,以检验是否给出相同的结果

import numpy
from sklearn.metrics import r2_score
numpy.random.seed(2)

x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / x

train_x = x[:80]
train_y = y[:80]

test_x = x[80:]
test_y = y[80:]

mymodel = numpy.poly1d(numpy.polyfit(train_x, train_y, 4))

r2 = r2_score(test_y, mymodel(test_x))

print(r2)

 结果 0.809 表明该模型也适合测试集,我们确信可以使用该模型预测未来值

如果购买客户在商店中停留 5 分钟,他/她将花费多少钱?

import numpy
from sklearn.metrics import r2_score
numpy.random.seed(2)

x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / x

train_x = x[:80]
train_y = y[:80]

test_x = x[80:]
test_y = y[80:]

mymodel = numpy.poly1d(numpy.polyfit(train_x, train_y, 4))

print(mymodel(5))

 该例预测客户花费了 22.88 美元,似乎与图表相对应

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/560465.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

springmvc升级到springboot2踩的坑

声明:删除springmvc的jar配置改成springboot的,若别的组件依赖springboot该升级就升级,该删掉就删掉,此文章只记录升级后的坑,升级springboot所需的jar请自行百度。 一.Hibernate的坑 概念:jpa和Hibernate的关系,jpa…

【JAVAEE】网络编程的简单介绍及其实现

目录 1.什么是网络编程 网络编程中的基本概念 常见的客户端服务端模型 2.Socket套接字 Socket套接字分类 举例对比TCP和UDP 3.UDP数据报套接字编程 DatagramSocket API DatagramPacket API InetSocketAddress API 4.实现一个简单的UDP回显服务器与客户端 服务端与客…

当前最新免费使用GPT-4方法汇总

目录 前言 温馨提示 Ora AI 使用方式 使用测试 Forefont chat 使用方式 使用测试 Perplexity AI 使用方式 使用测试 Poe 总结 前言 目前GPT-4的收费对于大多数人而言都还是不便宜,且付费方式复杂,使用上还有每3小时25个问题的限制&#xff…

Aspose.OCR For NET 23.5 Crack

使用几行代码将光学字符识别 (OCR) 添加到您的 .NET 应用程序。 适用于 .NET 的 Aspose.OCRAspose.OCR 文档 Aspose.OCR for .NET 是一个功能强大但易于使用且具有成本效益的光学字符识别 API。有了它,您可以用不到 5 行代码将 OCR 功能添加到您的 .NET 应用程序…

【Linux】初识优雅的Linux编辑器——Vim

❤️前言 大家好!今天给大家带来的博客内容是关于Linux操作系统下的一款多模式文本编辑器Vim。本文将和大家一起来了解Vim编辑器的一些基础知识。 正文 Vim是一个多模式的文本编辑器(一共有十二种模式),其中我们当我们初学Vim时主要了解如下三种工作模式…

Linux——多线程(线程概念|进程与线程|线程控制)

目录 地址空间和页表 如何看待地址空间和页表 虚拟地址如何转化到物理地址的 线程与进程的关系 什么叫进程? 什么叫线程? 如何看待我们之前学习进程时,对应的进程概念呢?和今天的冲突吗? windows线程与linux线…

Leetcode665. 非递减数列

Every day a Leetcode 题目来源:665. 非递减数列 解法1:贪心 本题是要维持一个非递减的数列,所以遇到递减的情况时(nums[i] > nums[i 1]),要么将前面的元素缩小,要么将后面的元素放大。 …

K8s in Action 阅读笔记——【2】First steps with Docker and Kubernetes

K8s in Action 阅读笔记——【2】First steps with Docker and Kubernetes 2.1 Creating, running, and sharing a container image 2.1.1 Installing Docker and running a Hello World container 在电脑上安装好Docker环境后,执行如下命令, $ dock…

真会玩:莫言用ChatGPT为余华写了一篇获奖词

5月16日,《收获》杂志65周年庆典暨新书发布活动在上海舞蹈中心举行。 典礼现场,余华凭借《文城》获得收获文学榜2021年长篇小说榜榜首。 作为老友,莫言在颁奖时故意卖了个关子:“这次获奖的是一个了不起的人物,当然了&…

OMA通道-2

1 简介 本文档中指定的 API 使移动应用程序能够访问移动设备中的不同 SE,例如 SIM 或嵌入式 SE。 本规范提供了接口定义和 UML 图,以允许在各种移动平台和不同的编程语言中实现。 如果编程语言支持命名空间,则它应为 org.simalliance.openmob…

Foxit PDF SDK OCR Add-on Library (C++, Windows)-Crk

OCR文档扫描--Crack version 使用Foxit PDF SDK OCR Add-on的光学字符识别(OCR)软件将扫描的文档转换为可搜索的文本PDF。专为扫描、归档和数字化而设计,我们的插件输出13种不同的文件格式,包括PDF和PDF/A。 在不投资数据输入人员…

Linux 禁用23端口

禁用23端口 前言 23端口是用于Telnet服务的默认端口。Telnet是一种早期的网络协议,允许用户使用一个远程终端连接到远程计算机上,以便在远程计算机上执行命令和操作。通过输入用于Telnet服务器的IP地址和端口号,用户可以在本地计算机上打开一…

【Java|golang】1090. 受标签影响的最大值---关联数组排序问题以及切片排序失败

我们有一个 n 项的集合。给出两个整数数组 values 和 labels ,第 i 个元素的值和标签分别是 values[i] 和 labels[i]。还会给出两个整数 numWanted 和 useLimit 。 从 n 个元素中选择一个子集 s : 子集 s 的大小 小于或等于 numWanted 。 s 中 最多 有相同标签的 …

数据结构初阶--栈和队列OJ题

目录 前言有效的括号思路分析代码实现 用队列实现栈思路分析代码实现 用栈实现队列思路分析代码实现 设计循环队列思路分析代码实现 前言 本篇文章将对部分栈和队列综合运用题进行讲解,以对栈和队列有一个更深层次的理解。 有效的括号 先来看题 思路分析 这里…

优秀的流程图应该怎么设计呢?

优秀的流程图应该怎么绘制呢? 本文将带大家学习优秀流程图的绘制要点和技巧,以及讲解流程图与UML活动图、BPMN图之间的关系和区别。 1、认识流程图 流程图简单讲就是用图描述流程,这种流程可以是一种有先后顺序的操作组成,可以…

2024王道数据结构考研丨第六篇:查找、排序

到此,2024王道数据结构考研笔记专栏“基础知识”部分已更新完毕,其他内容持续更新中,欢迎 点此 订阅,共同交流学习… 文章目录 第七章 查找7.1查找表相关概念 第八章 排序8.1排序的基本概念8.2 插入排序8.2.1直接插入排序8.2.2折半…

使用Maven管理项目、导入依赖、测试打包项目、常用依赖

使用Maven管理项目 文章目录 使用Maven管理项目Maven项目结构Maven依赖导入Maven依赖作用域Maven可选依赖Maven排除依赖Maven继承关系Maven常用命令Maven测试项目Maven打包项目 Maven 翻译为"专家"、“内行”,是 Apache 下的一个纯 Java 开发的开源项目。…

hive函数03

自定义函数 Hive 自带了一些函数,比如:max/min等,但是数量有限,自己可以通过自定义UDF来方便的扩展。 在企业中处理数据的时候,对于敏感数据往往需要进行脱敏处理。比如手机号。我们常见的处理方式是将手机号中间4位…

MySQL表设计原则

前言 这里简单整理一些常用的数据库表设计原则以及常用字段的使用范围。 表的设计准则 1、命名规范 表名、字段名必须使用小写字母或者数字,禁止使用数字开头,禁止使用拼音,并且一般不使用英文缩写。主键索引名为 pk_字段名;唯…

SSL/TLS认证握手过程

一: SSL/TLS介绍 什么是SSL,什么是TLS呢?官话说SSL是安全套接层(secure sockets layer),TLS是SSL的继任者,叫传输层安全(transport layer security)。说白点,就是在明文的上层和TCP层之间加上一层加密,这样就保证上层信…