机器学习之过拟合和欠拟合

news2024/9/21 16:18:49

文章目录

  • 前言
  • 什麽是过拟合和欠拟合?
  • 过拟合和欠拟合产生的原因:
    • 欠拟合(underfitting):
    • 过拟合(overfitting):
  • 解决欠拟合(高偏差)的方法
      • 1、模型复杂化
      • 2、增加更多的特征,使输入数据具有更强的表达能力
      • 3、调整参数和超参数
      • 4、增加训练数据往往没有用
      • 5、降低正则化约束
  • 解决过拟合(高方差)的方法:
      • 1、增加训练数据数
      • 2、使用正则化约束
      • 3、减少特征数
      • 4、调整参数和超参数
      • 5、降低模型的复杂度
      • 6、使用Dropout
      • 7、提前结束训练
  • 小例
  • 實戰
  • 總結


前言

随着人工智能的不断发展,机器学习这门技术也越来越重要,很多人都开启了学习机器学习,本文就介绍了机器学习的基础内容之过拟合和欠拟合。


什麽是过拟合和欠拟合?

  • 欠拟合是指模型在训练集、验证集和测试集上均表现不佳的情况;
  • 过拟合是指模型在训练集上表现很好,到了验证和测试阶段就很差,即模型的泛化能力很差。

过拟合和欠拟合产生的原因:

欠拟合(underfitting):

  • 模型复杂度过低
  • 特征量过少

过拟合(overfitting):

  • 建模样本选取有误,如样本数量太少,选样方法错误,样本标签错误等,导致选取的样本数据不足以代表预定的分类规则
  • 样本噪音干扰过大,使得机器将部分噪音认为是特征从而扰乱了预设的分类规则
  • 假设的模型无法合理存在,或者说是假设成立的条件实际并不成立
  • 参数太多,模型复杂度过高
  • 对于决策树模型,如果我们对于其生长没有合理的限制,其自由生长有可能使节点只包含单纯的事件数据(event)或非事件数据(no event),使其虽然可以完美匹配(拟合)训练数据,但是无法适应其他数据集
  • 对于神经网络模型:a)对样本数据可能存在分类决策面不唯一,随着学习的进行,,BP算法使权值可能-收敛过于复杂的决策面;b)权值学习迭代次数足够多(Overtraining),拟合了训练数据中的噪声和训练样例中没有代表性的特征

解决欠拟合(高偏差)的方法

1、模型复杂化

对同一个算法复杂化。例如回归模型添加更多的高次项,增加决策树的深度,增加神经网络的隐藏层数和隐藏单元数等
弃用原来的算法,使用一个更加复杂的算法或模型。例如用神经网络来替代线性回归,用随机森林来代替决策树等

2、增加更多的特征,使输入数据具有更强的表达能力

特征挖掘十分重要,尤其是具有强表达能力的特征,往往可以抵过大量的弱表达能力的特征。
特征的数量往往并非重点,质量才是,总之强特最重要。
能否挖掘出强特,还在于对数据本身以及具体应用场景的深刻理解,往往依赖于经验。

3、调整参数和超参数

超参数包括:
神经网络中:学习率、学习衰减率、隐藏层数、隐藏层的单元数、Adam优化算法中的β1和β2参数、batch_size数值等。
其他算法中:随机森林的树数量,k-means中的cluster数,正则化参数λ等。

4、增加训练数据往往没有用

欠拟合本来就是模型的学习能力不足,增加再多的数据给它训练它也没能力学习好。

5、降低正则化约束

正则化约束是为了防止模型过拟合,如果模型压根不存在过拟合而是欠拟合了,那么就考虑是否降低正则化参数λ或者直接去除正则化项

解决过拟合(高方差)的方法:

1、增加训练数据数

  • 发生过拟合最常见的现象就是数据量太少而模型太复杂
  • 过拟合是由于模型学习到了数据的一些噪声特征导致,增加训练数据的 量能够减少噪声的影响,让模型更多地学习数据的一般特征
  • 增加数据量有时可能不是那么容易,需要花费一定的时间和精力去搜集处理数据
  • 利用现有数据进行扩充或许也是一个好办法。例如在图像识别中,如果没有足够的图片训练,可以把已有的图片进行旋转,拉伸,镜像,对称等,这样就可以把数据量扩大好几倍而不需要额外补充数据
  • 注意保证训练数据的分布和测试数据的分布要保持一致,二者要是分布完全不同,那模型预测真可谓是对牛弹琴了。

2、使用正则化约束

代价函数后面添加正则化项,可以避免训练出来的参数过大从而使模型过拟合。使用正则化缓解过拟合的手段广泛应用,不论是在线性回归还是在神经网络的梯度下降计算过程中,都应用到了正则化的方法。常用的正则化有l1正则和l2正则,具体使用哪个视具体情况而定,一般l2正则应用比较多。

3、减少特征数

欠拟合需要增加特征数,那么过拟合自然就要减少特征数。去除那些非共性特征,可以提高模型的泛化能力.

4、调整参数和超参数

不论什么情况,调参是必须的

5、降低模型的复杂度

欠拟合要增加模型的复杂度,那么过拟合正好反过来。

6、使用Dropout

这一方法只适用于神经网络中,即按一定的比例去除隐藏层的神经单元,使神经网络的结构简单化。
Dropout是在训练网络时用的一种技巧(trike),相当于在隐藏单元增加了噪声。Dropout 指的是在训练过程中每次按一定的概率(比如50%)随机地“删除”一部分隐藏单元(神经元)。所谓的“删除”不是真正意义上的删除,其实就是将该部分神经元的激活函数设为0(激活函数的输出为0),让这些神经元不计算而已。

7、提前结束训练

即early stopping,在模型迭代训练时候记录训练精度(或损失)和验证精度(或损失),如果模型训练的效果不再提高,比如训练误差一直在降低但是验证误差却不再降低甚至上升,这时候便可以结束模型训练了。

小例

我们知道酶的活性会随温度的变化而变化,在最适温度达到最高,高于或低于最适温度都会有所降低,那么我们若想预测某温度下酶的活性时,该怎么判断模型是否欠拟合或过拟合了呢?来看几幅图

在这里插入图片描述
可以看到大概在60度左右酶的活性最高,在这之前和之后都会有所降低,我们预想的模型大概是这样的。它的决策边界可能如图所示。

在这里插入图片描述

										r = θ₀ + θ₁t + θ₂t²

如果说模型过于简单了,就可能就会得到一条直线。这种训练好后的模型既不能满足训练数据的预期,也不能满足新数据的预期,此时它就是属于欠拟合。(训练数据和新数据的预测结果都不准确)

在这里插入图片描述

						                 r = θ₀ + θ₁t

当然,如果模型过于复杂了也不是一件好事。虽然它可以很好的拟合我们的训练数据,但在新数据的预测上就不尽人意了。如下图,对应训练数据,该模型的预测结果是非常准确的,但是我们知道在温度超过最适温度后,酶的活性就会降低,图示结果明显错误,这种情况就是过拟合。(对训练数据的预测结果非常准确,但对新数据的预测结果不准确)
在这里插入图片描述
在这里插入图片描述

實戰

酶活性預測實戰task:

  1. 基於T-R-train.csv數據,建立綫性回歸模型,計算其在T-R-test.csv數據上的r2分數,可視化模型預測結果
  2. 加入多項式特徵(2次、3次),建立回歸模型
  3. 計算多項式回歸模型對測試數據進行預測的r2分數,判斷哪個模型預測更準確
  4. 可視化多項式回歸模型數據預測結果,判斷哪個模型預測更准確
#load the data
import pandas as pd
import numpy as np
data_train = pd.read_csv('T-R-train.csv')
data_train

在这里插入图片描述

#define X_train and y_train
X_train = data_train.loc[:,'T']
y_train = data_train.loc[:,'rate']
#visualize the data
from matplotlib import pyplot as plt
fig1 = plt.figure(figsize=(5,5))
plt.scatter(X_train,y_train)
plt.title('raw data')
plt.xlabel('temperature')
plt.ylabel('rate')
plt.show()

在这里插入图片描述

X_train = np.array(X_train).reshape(-1,1)

轉一下類型不然後面會報下面錯誤

ValueError: Expected 2D array, got 1D array instead:
array=[46.53 48.14 50.15 51.36 52.57 54.18 56.19 58.58 61.37 63.34 65.31 66.47
 68.03 69.97 71.13 71.89 73.05 74.21].
Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

在这里插入图片描述

#linear regression model prediction
from sklearn.linear_model import LinearRegression
lr1 = LinearRegression()
lr1.fit(X_train,y_train)

訓練綫性回歸模型
在这里插入图片描述

#load the test data
data_test = pd.read_csv('T-R-test.csv')
data_test

在这里插入图片描述

#define X_test and y_test
X_test = data_test.loc[:,'T']
y_test = data_test.loc[:,'rate']
X_test = np.array(X_test).reshape(-1,1)
#make prediction on the training and testing data
y_train_predict = lr1.predict(X_train)
y_test_predict = lr1.predict(X_test)
from sklearn.metrics import r2_score
r2_train = r2_score(y_train,y_train_predict)
r2_test = r2_score(y_test,y_test_predict)
print('training r2:',r2_train)
print('test r2:',r2_test)

可以看出r2_score值很低也就是説明模型很差

在这里插入图片描述

#generate new data
X_range = np.linspace(40,90,300).reshape(-1,1)
y_range_predict = lr1.predict(X_range)
fig2 = plt.figure(figsize=(5,5))
plt.plot(X_range,y_range_predict)
plt.scatter(X_train,y_train)

plt.title('prediction data')
plt.xlabel('temperature')
plt.ylabel('rate')
plt.show()

可視化看一下,可以明顯看出不是一個好的訓練模型,這是個明顯的欠擬合的bad fit

在这里插入图片描述

加入多項式特徵(2次、3次),建立回歸模型

#多項式模式
#generate new features
from sklearn.preprocessing import PolynomialFeatures
poly2 = PolynomialFeatures(degree=2) # 二階
X_2_train = poly2.fit_transform(X_train)
X_2_test = poly2.transform(X_test)

poly5 = PolynomialFeatures(degree=5) # 五階
X_5_train = poly5.fit_transform(X_train)
X_5_test = poly5.transform(X_test)
print(X_5_train.shape)

看一下五階的維度
在这里插入图片描述

看一下二階的數據
在这里插入图片描述

lr2 = LinearRegression()
lr2.fit(X_2_train,y_train)

y_2_train_predict = lr2.predict(X_2_train)
y_2_test_predict = lr2.predict(X_2_test)
r2_2_train = r2_score(y_train,y_2_train_predict)
r2_2_test = r2_score(y_test,y_2_test_predict)

lr5 = LinearRegression()
lr5.fit(X_5_train,y_train)

y_5_train_predict = lr5.predict(X_5_train)
y_5_test_predict = lr5.predict(X_5_test)
r2_5_train = r2_score(y_train,y_5_train_predict)
r2_5_test = r2_score(y_test,y_5_test_predict)


print('training r2_2:',r2_2_train)
print('test r2_2:',r2_2_test)
print('training r2_5:',r2_5_train)
print('test r2_5:',r2_5_test)

看一下r2_score越接近1越好,可以看出多項式的比綫性回歸的效果好的多。同時也可以看出五階對於訓練數據r2分數高(預測準確),但對於預測數據r2分數低(預測不準確)

在这里插入图片描述

X_2_range = np.linspace(40,90,300).reshape(-1,1)
X_2_range = poly2.transform(X_2_range)
y_2_range_predict = lr2.predict(X_2_range)

X_5_range = np.linspace(40,90,300).reshape(-1,1)
X_5_range = poly5.transform(X_5_range)
y_5_range_predict = lr5.predict(X_5_range)
fig3 = plt.figure(figsize=(5,5))
plt.plot(X_range,y_2_range_predict)
plt.scatter(X_train,y_train)
plt.scatter(X_test,y_test)

plt.title('polynomial prediction result (2)')
plt.xlabel('temperature')
plt.ylabel('rate')
plt.show()

可視化二階,看出擬合和預測效果都還不錯,是個good fit

在这里插入图片描述

fig4 = plt.figure(figsize=(5,5))
plt.plot(X_range,y_5_range_predict)
plt.scatter(X_train,y_train)
plt.scatter(X_test,y_test)

plt.title('polynomial prediction result (5)')
plt.xlabel('temperature')
plt.ylabel('rate')
plt.show()

可視化五階,看出擬合很完美但是預測效果不佳,是個過擬合bad fit

在这里插入图片描述

總結

酶活性預測實戰summary:

  1. 通過建立二階多項式回歸模型,對酶活性實現一個較好的預測,無論針對訓練或測試數據都得到一個高的r2分數
  2. 通過建立綫性回歸、五階多項式回歸模型,發現存在欠擬合或過擬合情況。過擬合情況下,對於訓練數據r2分數高(預測準確),但對於預測數據r2分數低(預測不準確)
  3. 無論是通過r2分數,或是可視化模型結果,都可以發現二階多項式回歸模型效果最好

這就是本次學習过拟合和欠拟合的筆記
附上本次實戰的數據集和源碼:
鏈接:https://github.com/fbozhang/python/tree/master/jupyter

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/61691.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java项目:SSM游戏点评网站

作者主页:源码空间站2022 简介:Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 项目介绍 本项目分为前后台,前台为普通用户登录,后台为管理员登录; 管理员角色包含以下功能: 管理员登录…

jenkins-pipeline语法总结(最全)

1、jenkins总结之pipeline语法 jenkins总结之pipeline语法1、jenkins总结之pipeline语法1.1必要的Groovy知识1.2pipeline的组成1.2.1pipeline最简结构1.3post部分1.4pipeline支持的指令• environment:• tools:• input:• options&#xff…

大学网课查题接口

大学网课查题接口 本平台优点: 多题库查题、独立后台、响应速度快、全网平台可查、功能最全! 1.想要给自己的公众号获得查题接口,只需要两步! 2.题库: 查题校园题库:查题校园题库后台(点击跳…

项目管理逻辑:老板为什么赔钱的项目也做?为什么害怕你闲着?

目录 1.波士顿矩阵 2.为什么企业还要做没有市场占有率,也没有销售增长率的产品? 2.1项目层级划分 2.2项目集 2.3组合管理 2.4赔钱也做的项目案例 1.波士顿矩阵 项目经理没有资源, 公司不给足够的支持 在任何一个企业老板的脑子里,都会有这样一个矩阵, 纵向表示销售增长…

数据结构与算法,MySQL数据库面试专题及答案

文章目录数据结构面试题及答案数组问题字符串相关问题链表问题二叉树问题编程面试问题之杂项答案数据结构与算法时间复杂度 并不是计算程序具体运行的时间,而是算法执行语句的次数 O(2^n) 表示对 n 数据处理需要进行 2^n 次计算 多项式的时间复杂度 数据 n 在表达式…

Docker安装部署Redis集群

目录 概述 一、创建文件和目录 1.1 创建需要挂载的文件和目录 1.2 同步操作 二、随机从节点模式 2.1 创建master节点的redis容器 2.2 在同一台机器上创建另外2个节点 2.3 其他2台机器同步操作 2.4 配置主从集群 2.4.1 进入任意一个 Redis 实例 2.4.2 配置集群 2.4…

《未来简史:从智人到智神》笔记一——人类的新议题

目录 一、人类的旧议题演变 二、人类的新议题 1、长生不死 2、追求幸福快乐 3、努力把自己升级为神 三、研究历史的意义——不是为了重复过去,而是为了摆脱过去并从中获得解放 四、生命的意义 1、主观体验有两个基本特征 2、生命的意义? 一、人类…

C语言第十三课:初阶指针

目录 前言: 一、指针是什么: 1.那么指针到底是什么呢? 2.内存中的数据存储原理: 3.数据存储与指针使用实例: 4.存储编址原理: 二、指针和指针类型: 1.决定了指针的步长: 2.决定了…

【VSCode + Anaconda】VSCode [WinError 126]找不到指定模块

【VSCode Anaconda】VSCode [WinError 126]找不到指定模块问题解决一解决二问题 在 Anaconda Prompt 中的 python 环境测试,可以使用 import torch 命令 现在在 VSCode 中测试,发现相关异常 图中,已经选择了相应的 conda 环境的 python.exe…

分片集群中的分片集合

分片集群中的分片集合 MongoDB 中 分片集群有专门推荐的模式,例如 分片集合 它是一种基于分片键的逻辑对文档进行分组,分片键的选择对分片是非常重要的,分片键一旦确定,MongoDB 对数据的分片对应用是透明的 mongodb 分片中&#…

MySQL高级语句(三)

一、正则表达式(REGEXP) 1、正则表达式匹配符 字符解释举列^匹配文本的开始字符’ ^aa ’ 匹配以 aa 开头的字符串$匹配文本的结束字符’ aa$ ’ 匹配以aa结尾的字符串.匹配任何单个字符’ a.b 匹配任何a和b之间有一个字符的字符串*匹配零个或多个在它…

数据结构—树、有序二叉树

文章目录树的概述树的分类二叉树的遍历有序二叉树代码通过链表方式构建有序二叉树通过递归方式实现有序二叉树递归遍历有序二叉树中序遍历:先序遍历:后序遍历:删除节点1、删除叶子节点删除叶子节点总结图示2、删除只有一个子树的节点删除只有…

毕业设计-基于深度学习火灾烟雾检测识别系统-yolo

前言 📅大四是整个大学期间最忙碌的时光,一边要忙着准备考研,考公,考教资或者实习为毕业后面临的就业升学做准备,一边要为毕业设计耗费大量精力。近几年各个学校要求的毕设项目越来越难,有不少课题是研究生级别难度的,对本科同学来说是充满挑战。为帮助大家顺利通过…

Spring循环依赖源码解析(深度理解)

文章目录前言本章目标一、什么是循环依赖?1、那么循环依赖是个问题吗?2、但是在Spring中循环依赖就是一个问题了,为什么?二、Bean的生命周期2.1、在Spring中,Bean是如何生成的?2.2、那么这个注入过程是怎样…

GitLab CI/CD系列教程(一)

来自:GitLab CI/CD系列教程(一):Docker安装GitLab_哔哩哔哩_bilibili 1. 创建虚拟机并连接Xterm 创建一个4G内存的虚拟机,否则很容易启动不了,报502 虚拟机的创建看这篇: VMware16的安装及VM…

基于java+ssm+vue+mysql的网上书店

项目介绍 本网上系统是针对目前网上的实际需求,从实际工作出发,对过去的网上系统存在的问题进行分析,结合计算机系统的结构、概念、模型、原理、方法,在计算机各种优势的情况下,采用目前最流行的B/S结构和java中流行的…

从0开始搭建vue2管理后台基础模板

网站主要完成:侧边菜单栏、页面标签卡、内容栏 源代码gitee地址:https://gitee.com/zhao_liangliang1997/navigation-bar 一、起步 1、创建vue项目 vue create 项目名2、引入element 3、其他安装 1、首先需要安装如下 cnpm install vuex cnpm install…

DockerCompose安装、使用 及 微服务部署实操

1 什么是DockerCompose DockerCompose是基于Compose文件帮助我们快速的部署分布式应用。 解决容器需手动一个个创建和运行的问题! DockerCompose本质上也是一个文本文件,其通过指令定义集群中的每个容器如何运行。我们可以将其看做是将多个docker run…

Ansible 自动化运维工具的使用

目录 一、Ansible简介 二、Ansible 的安装和使用 1.下载 2.使用 三、Ansible命令和模块 1.命令格式 2.命令行模块 (1)command 模块 (2)shell 模块 (3)cron 模块 (4)user …

多线程 3

多线程 3 : 文章目录1.线程安全2. 产生线程安全的原因3. synchronized - 加锁操作4.可重入5.死锁问题6. volatile 关键字7.wait 和 notify1.线程安全 为啥会出现线程安全 ?   罪魁祸首,还是多线程的抢占式执行, 正因为抢占式执行&#xff0c…