【机器学习 - 2】:数据集的处理

news2024/12/26 21:38:29

文章目录

  • 训练集和数据集分离
  • 获取最优模型
    • 超参数
    • 寻找最优模型
  • 网格搜索的使用

训练集和数据集分离


训练集和数据集分离的原理:当我们获取一个数据集时,我们需要将其一小部分拿出来作为测试集,剩余的作为训练集。例如对于一个训练集,将其20%作为测试集,80%作为训练集,这20%的测试集是已经有目标值了的,将训练集进行拟合,获得模型,再通过测试集进行测试,获得最终结果,将最终结果和已知的目标值进行比对,可预测其训练模型的精确度。
在这里插入图片描述
以下使用sklearn中的knn算法进行预测,以识别鸢尾花为例。

  • 先获取数据集,观察下图中y的值,可将0,1,2分别看做鸢尾花的不同种类。
from sklearn.datasets import load_iris
# 获取数据集
iris = load_iris()
X = iris.data
y = iris.target

在这里插入图片描述

  • 由上图可看出数据集的目标值是有一定顺序的,我们需要将其打乱后再分出训练集和测试集,打乱用到的函数为np.random.permutation(),下图中shuffle_indexs里是0-150的随机索引
import numpy as np
shuffle_indexs = np.random.permutation(len(X))

在这里插入图片描述

  • 打乱数据后开始取训练集和测试集,训练集取80%,测试集取20%
test_ratio = 0.2 # 取20%做测试集
test_size = int(len(X) * test_ratio)

test_indexs = shuffle_indexs[:test_size] # 测试集索引
train_indexs = shuffle_indexs[test_size:] # 训练集索引

# 获得训练集
X_train = X[train_indexs]
y_train = y[train_indexs]
# 获得测试集
X_test = X[test_indexs]
y_test = y[test_indexs]
  • 调用sklearn中的knn算法,将训练集进行拟合,获得模型,测试集通过训练的模型,获得最终的预测结果,观察下图可看到y_predict和y_test(标准答案),大部分是相同的。
from sklearn.neighbors import KNeighborsClassifier
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_train)	# 拟合,获得模型

y_predict = knn_clf.predict(X_test)	# 获取测试集的最终结果

在这里插入图片描述

  • 根据y_predict和y_test,计算出模型的准确度,每次运行获得的准确度都不一样,但是准确率都在90%以上,说明模型的准确度较高。
np.sum(np.array(y_predict == y_test, dtype='int'))/len(y_test)

在这里插入图片描述

获取最优模型


参数的不同,会导致模型的不同,从而我们需要找到最合适的参数,从而训练出最优的模型。

我们可以自定义一个train_test_split函数,获取到训练集和测试集数据。根据以上的代码,编写的函数如下:

import numpy as np
def train_test_split(X, y, test_ratio=0.2, random_state=None):
    if random_state:
        np.random.seed(random_state) # 设置随机种子
    shuffle_indexs = np.random.permutation(len(X))
    test_ration = test_ration
    test_size = int(len(X) * test_ratio)
    
    test_indexs = shffle_indexs[:test_size]
    train_indexs = shuffle_indexs[test_size:]
    
    # 训练集
    X_train = X[train_indexs]
    y_train = y[train_indexs]
    
    # 测试集
    X_test = X[test_indexs]
    y_test = y[test_indexs]
    
    return X_train, X_test, y_train, y_test

我们也可以使用sklearn中封装好的train_test_split

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)

在这里插入图片描述

超参数

超参数:在执行程序之前需要确定的参数。

举个例子:在knn分类器中,即KNeighborsClassifier(n_neighbors=3),n_neighbors值的不同,会导致模型的准确率不同,我们需要不断调整参数,找到某个数更加拟合我们的数据,这就是超参数

权重问题:在【机器学习 - 1】:knn算法这一篇文章里,我们举了一个使用knn算法判断肿瘤为恶性肿瘤或良性肿瘤的例子,这个例子中我们主要以离待预测点周围最近的3个点进行判断。
而在如下图的情况中,待预测点(绿点)离红点(良性肿瘤)比较近,则它更可能为良性肿瘤,若以上篇文章中的思路来判断,因为它周围有2两个恶性肿瘤(蓝点),所以它很可能为恶性肿瘤。根据以上两种判断情况,我们需要把距离个数这两种判断特征都考虑进来。

即绿色的点离红色的点最近,我们可以给这些距离加一个权重,这样及时周围有两个蓝点,但红点最近的距离权重大于这两个蓝点的距离权重,绿色的点可能就为良性肿瘤
在这里插入图片描述
在KNeighborsClassifier()中可设置权重参数:weights
weights=uniform时,不考虑距离带来的权重问题
weights=distance时,距离作为计算的权重

我们先看一下KNeighborsClassifier()的源码(如下图2),weights默认为uniform,p=2这个p是距离方法(如下图1),当=1时为曼哈顿距离,p=2时为欧拉距离,p增大,计算距离的方法不同。在KNeighborsClassifier()中默认为欧拉距离
在这里插入图片描述
在这里插入图片描述

寻找最优模型

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)

%%time
best_k = 0
best_score = 0.0
best_clf = None
best_method = None
best_p = 0
for p in range(1, 6):
    for weight in ['uniform', 'distance']:
        for k in range(1, 21):
            knn_clf = KNeighborsClassifier(n_neighbors=k, weights=weight, p=p)
            knn_clf.fit(X_train, y_train)
            score = knn_clf.score(X_test, y_test)
            if score>best_score:
                best_score = score
                best_k = k
                best_clf = knn_clf
                best_method = weight
                best_p = p
print(best_k)
print(best_score)
print(best_clf)
print(best_method)
print(best_p)

在这里插入图片描述

网格搜索的使用


本次网格搜索的数据集以手写识别数据集为例。

  1. 获取数据,可以打印描述信息进行查看。
from sklearn.datasets import load_digits # 导入手写识别数据集
import numpy as np
from matplotlib import pyplot as plt

digits = load_digits()

X = digits.data
y = digits.target

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y)

在这里插入图片描述

  1. 绘制出手写数字
x = X_train[1000].reshape(8, -1)
plt.imshow(x, cmap=plt.cm.binary)
plt.show()

在这里插入图片描述

  1. 使用sklearn中的grid search
# 创建网格参数,每一组参数放在一个字典中
param_grid = [
    {'weights':['uniform'],
     'n_neighbors':[i for i in range(1,21)]
    },
    {
        'weights':['distance'],
        'n_neighbors': [i for i in range(1,21)],
        'p':[i for i in range(1,6)]
    }
] 

from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier

knn_clf = KNeighborsClassifier()

%%time
# 尝试寻找最佳参数
grid_search = GridSearchCV(knn_clf, param_grid, verbose=2, n_jobs=-1) # verbose越大越详细,n_jobs调用几个cpu进行计算,当n_jobs=-1时表示调用所有cpu进行计算
grid_search.fit(X_train, y_train)

grid_search.best_estimator_
grid_search.best_score_
grid_search.best_params_

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/157621.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RocketChip RISC-V生成RTL到仿真全流程

一、Scala配置项修改和RTL代码生成可以通过对scala中的配置项修改,来达到定制化配置RISC-V的目的,这里总结几个比较常用的配置项、配置项含义和所在的scala中的位置:1.$rocket-chip/src/main/scala/system/Config.scala1)new With…

机器学习-2-安装Python 3.6和Pytorch 1.1.0

0. 说明: 之前根据GPU版本安装了CUDA 9.0,因此现安装与CUDA 9.0相对应的Pytorch版本,但在安装Pytorch之前要先确认一下Python的版本。 1. 查看 CUDA 9.0 对应的 Pytorch 从https://pytorch.org/get-started/previous-versions/中查找CUDA …

程序的机器级表示part1——程序编码与数据格式

目录 1. 汇编语言和机器级语言 1.1 不同的编程语言 1.2 Linux下的汇编语言 2. 程序编码 1.1 机器级代码 1.2 代码示例 3. 数据格式 本文基于CSAPP第三章撰写,主要介绍部分x86-64汇编的相关知识,后续会将该部分内容慢慢完善(PS&a…

Web Spider XHR断点 千千XX 歌曲下载(三)

Web Spider XHR断点 千千XX 歌曲下载 首先声明: 此次案例只为学习交流使用,切勿用于其他非法用途 注:网站url、接口url请使用base64.b64decode自行解码 文章目录Web Spider XHR断点 千千XX 歌曲下载前言一、资源推荐二、任务说明三、网站分析四、XHR断点…

knife4j使用与步骤

1、导入依赖<dependency><groupId>com.github.xiaoymin</groupId><artifactId>knife4j-spring-boot-starter</artifactId><version>3.0.3</version> </dependency>2、knife4j的配置类&#xff08;可有可无&#xff09;package…

为什么程序员的工资比其他行业高这么多?

不止一次听到有人说程序员工资高&#xff0c;更有甚者喊着“把IT工资打下来”。 拜托大哥大姐们&#xff01;看事情要客观好吧&#xff01;&#xff01; 虽然看起来程序员工资是不少&#xff0c;对比其他行业确实会高一些&#xff0c;但并不代表程序员这个岗位工资就要压到三千…

Elasticsearch:如何在 Elasticsearch 中正确使用同义词功能

同义词用于提高搜索质量并扩大匹配范围。 例如&#xff0c;搜索 England 的用户可能希望找到包含 British 或 UK 的文档&#xff0c;尽管这三个词完全不同。 Elasticsearch 中的同义词功能非常强大&#xff0c;如果实施得当&#xff0c;可以使你的搜索引擎更加健壮和强大。 在…

详解结构体内存对齐

目录 前言 一、结构体内存对齐规则 二、 offsetof 宏 三、结构体内存对齐的原因 四、 修改默认对齐数 前言 引入问题&#xff1a; #include <stdio.h>struct S {char c1;int i;char c2; };int main() {printf("%zd\n", sizeof(struct S));return 0; } 程…

干货 | 人脸识别技术的风险及应对方案

以下内容整理自清华大学《数智安全与标准化》课程大作业期末报告同学的汇报内容。第一部分&#xff1a;人脸识别技术概述人脸识别的发展阶段&#xff0c;主要分为三个阶段&#xff1a;起步阶段&#xff08;1950s-1980s&#xff09;&#xff0c;这一阶段的人脸识别只是作为一般性…

房产管理系统---系统安全性需求分析

数图互通高校房产管理系统是基于公司自主研发的FMCenterV5.0平台&#xff0c;是针对中国高校房产的管理特点和管理要求&#xff0c;研发的一套标准产品&#xff1b;通过在中国100多所高校的成功实施和迭代&#xff0c;形成了一套成熟、完善、全生命周期的房屋资源管理解决方案。…

Linux学习笔记——HBase集群安装部署

5.11、大数据NoSQL数据库HBase集群部署 5.11.1、简介 HBase是一种分布式、可扩展、支持海量数据存储的NoSQL数据库。 和Redis一样&#xff0c;HBase是一款KeyValue型存储的数据库。 不过和Redis设计方向不同&#xff1a; Redis设计为少量数据&#xff0c;超快检索HBase设计…

【部署】Docker容器

Docker 使用 Google 公司推出的 Go 语言进行开发实现&#xff0c;基于 Linux 内核的 cgroup、namespace 以及 OverlayFS 类的 Union FS 等技术&#xff0c;对进程进行封装隔离&#xff0c;属于操作系统层面的虚拟化技术。由于隔离的进程独立于宿主和其它的隔离的进程&#xff0…

算法刷题打卡第63天:对称二叉树

对称二叉树 难度&#xff1a;简单 给你一个二叉树的根节点 root &#xff0c; 检查它是否轴对称。 示例 1&#xff1a; 输入&#xff1a;root [1,2,2,3,4,4,3] 输出&#xff1a;true示例 2&#xff1a; 输入&#xff1a;root [1,2,2,null,3,null,3] 输出&#xff1a;false…

BOM浏览器对象模型

文章目录一、BOM概述1、什么是BOM2、BOM的构成二、window 对象的常见事件1、窗口加载事件&#xff08;1&#xff09;window.onload&#xff08;3&#xff09;DOMContentLoaded2、调整窗口大小事件三、定时器1、两种定时器2、setTimeout()定时器3、停止 setTimeout() 定时器4、s…

如何使用CMD修复硬盘命令来解决硬盘问题?

随着计算机的越来越普及&#xff0c;现在在我们的日常生活中都会使用到计算机电脑。硬盘作为计算机电脑的主要存储设备&#xff0c;里面存储着我们平时使用的软件文件、文档资料、照片等重要的数据文件。一旦硬盘损坏会给我们带来许多不必要的麻烦&#xff0c;那硬盘损坏有哪些…

图解卡尔曼滤波(Kalman Filter)

背景关于滤波首先援引来自知乎大神的解释。“一位专业课的教授给我们上课的时候&#xff0c;曾谈到&#xff1a;filtering is weighting&#xff08;滤波即加权&#xff09;。滤波的作用就是给不同的信号分量不同的权重。最简单的loss pass filter&#xff0c; 就是直接把低频的…

【Linux操作系统】1. Linux操作系统简介、安装

前言 本系列是Linux操作系统的一些知识以及实践内容&#xff0c;Linux操作系统作为开发最常使用的操作系统&#xff0c;是必备的一门求职、提升技术。本文先介绍Linux操作系统&#xff0c;并安装一个Linux操作系统。 Linux操作系统简介 Linux&#xff0c;全称GNU/Linux&#…

Javadoc

Javadoc 在学习JavaSE时&#xff0c;我们知道Java支持三种注释方式&#xff1a; 单行注释多行注释文档注释 Javadoc是文档注释&#xff0c;用来对类或方法进行标准的注释&#xff0c;在开发中写好JavaDoc非常重要。 在调用方法时&#xff0c;你可能会看到这样的情景 这种注…

Unity - 搬砖日志 - 如何设置AssetDatabase.Create(“xxx.asset“, mesh) 的Read/Write=false

最近很忙&#xff0c;想写的 BLOG 都遗漏编写了 踩坑的时间比较多&#xff0c;充电的时间少了很多 为了减少以后自己填坑时间&#xff0c;随便简单的记录一下 搬砖日志 环境 unity : 2020.3.37f1 pipeline : brp 问题 因为之前搜索、购买、使用了各式各样的 LOD 插件、工具…

机器学习100天(三十一):031 K近邻回归算法

机器学习100天,今天讲的是:K 近邻回归算法! 《机器学习100天》完整目录:目录 一、理论介绍 我们之前讲了 K 近邻分类算法,用来处理分类问题。其实 K 近邻也可以用来处理回归问题。 如左图所示,K 近邻分类算法的思路是选取与测试样本距离最近的前 k 个训练样本。然后对…