【机器学习 - 1】:knn算法

news2024/11/27 15:44:48

文章目录

  • 机器学习的概念和基础
  • knn算法的实现过程
  • 封装knn算法
  • 总结

机器学习的概念和基础


机器学习可以两类任务: 分类任务和回归任务

以机器学习本身来进行分类可分为: 监督学习 非监督学习 半监督学习 增强学习

监督学习:给机器的训练数据 有标记label 有答案

  • 猫狗识别
  • mnist手写识别数据集 0 - 9 既有x,又有y(x为矩阵,代表样本特征,y代表样本目标值) 学习的算法大部分是监督学习的算法,但是不代表非监督学习不重要

非监督学习: 我们给机器的训练数据没有标记 没有答案

  • 一般是用来辅助监督学习,比如对于没有标记的数据进行分类:聚类分析
  • 对数据进行降维:
  • 异常的样本

半监督学习:一部分有标记,一部分没有标记

增强学习:根据周围环境的情况,采取行动,然后根据行动的结果,学习行动的方式,非常适合机器人,例如alpha go zero

从其他维度对机器学习进行分类,可分为批量学习和在线学习
批量学习:模型固定,不会发生改变,面对新的数据,无法进行优化

  • 问题:无法适应环境变化,比如垃圾邮件
  • 解决方案:定时重新批量学习(成本巨大)

在线学习:不断的学习,应用模型 得到结果,同时不断的更新模型

  • 优点: 及时反映新的环境变化
  • 问题: 新的数据可能会带来不好的改变
  • 解决方案: 加强数据的监控

knn算法的实现过程


本次以预测肿瘤为良性肿瘤或恶性肿瘤为例,先给出训练数据进行训练,获取模型,然后输入测试数据x,计算x与所有点的距离,并获取离x点最近的3个点,根据这三个点的类型(良性肿瘤或恶性肿瘤),哪种类型的肿瘤较多,则x很可能为该肿瘤类型。以上算法也可称之为k近邻算法(如下图)

观察下图,绿色为待预测点,离它最近的3个点有一个良性肿瘤(红色),两个恶性肿瘤(蓝色),因为恶性肿瘤数量>良性肿瘤,所以绿色的点很有可能为恶性肿瘤。
在这里插入图片描述

以下是程序编写过程

  • 导入模块
import numpy as np
import matplotlib.pyplot as plt
  • 准备训练数据,并将训练数据转换成矩阵
raw_data_X = [[3.3935, 2.3312],
              [3.1101, 1.7815],
              [1.3438, 3.3684],
              [3.5823, 4.6792],
              [2.2804, 2.8670],
              [7.4234, 4.6965],
              [5.7451, 3.5340],
              [9.1722, 2.5111],
              [7.7928, 3.4241],
              [7.9398, 0.7916]]
raw_data_y =[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]  #  0是良性,1是恶性

# 训练数据,利用np.array将数据向量数据变为矩阵
X_train = np.array(raw_data_X)
y_train = np.array(raw_data_y)
  • 给定待预测数据x,并绘制散点图。X_train[y_train==0,0]表示在X_train中取出y_train等于0的点,并取出这些点的第一个索引0,即横坐标。
# 给定待预测数据,预测他的结果
x = np.array([7, 3])
plt.scatter(X_train[y_train==0,0], X_train[y_train==0,1], color='r')
plt.scatter(X_train[y_train==1,0], X_train[y_train==1,1], color='b')
plt.scatter(x[0], x[1], color='g')
plt.show()

在这里插入图片描述

  • 计算待预测点x与所有点的距离,并保存在distance列表中。计算所有点的距离使用到了欧拉距离公式,如下图,计算各向量的差值求平方,再相加,最后开根号。使用x_train-x,实现了各向量相减,减少了for循环的使用。
    在这里插入图片描述
from math import sqrt
# 保存和其他所有点的距离
distance = [sqrt(np.sum((x_train-x)**2)) for x_train in X_train]

在这里插入图片描述

  • 使用np.argsort(list),获取list中从小到大的值的索引,例如以下的np.argsort([1, 0, 5, 3]),结果得到的是[1, 0, 3, 2],其中1是list中0的索引,0为list中1的索引,3为list中3的索引,2为list中5的索引。
    在这里插入图片描述
# 找出离待预测点距离最近的k个点
k = 3
nearest = np.argsort(distance)
nearest = [i for i in nearest[:k]]

在这里插入图片描述

  • 找出最近的k个点下标值,即[8, 6, 5],依据其在y_train中找出这些样本对应的目标值,观察下图即可知离待遇测点x最近的三个点为恶性肿瘤。
top_K = [i for i in y_train[nearest]]

在这里插入图片描述

  • 根据获取的top_K,使用collections中的Counter函数,获取最终投票结果,观察最终结果可以看到1有3票
from collections import Counter
votes = Counter(top_K)

在这里插入图片描述

  • 使用votes.most_common获取最终的投票结果,votes.most_common(1)[0][0]获取到最终投票结果,由图可知x点很可能为恶性肿瘤。
    在这里插入图片描述
y_predict = votes.most_common(1)[0][0]

在这里插入图片描述

封装knn算法

将knn算法进行封装,并放在jupter notebook中运行,本次传入的待遇测数据为二维数据。

以下是封装浩的knn算法——knn.py

import numpy as np
from math import sqrt
from collections import Counter

# 在sklearn中,对于数据的拟合,创建模型,是放在fit方法中

class Knn:
    def __init__(self, n_neighbor=3):
        self.X_train = None
        self.y_train = None
        self.n_neighbor = n_neighbor

    def fit(self, X_train, y_train):
        # 给定X_train和y_train,得到训练模型
        assert X_train.shape[0] == len(y_train)  # 校验X_train的行数是否等于y_train的列数
        self.X_train = X_train
        self.y_train = y_train
        # 返回对象的描述
        return self

    def predict(self, X):
        # 对于给定的待预测数据,返回预测结果
        assert self.X_train is not None
        assert self.y_train is not None
        assert self.X_train.shape[1] == X.shape[1]  # 校验X_train的列数是否等于X的列数
        return np.array([self._predict(x) for x in X])

    def _predict(self, x):
        # 给定一个样本求出一个结果
        distance = [sqrt(np.sum((x_train - x)**2)) for x_train in self.X_train]
        nearest = np.argsort(distance)
        nearest = [i for i in nearest[:self.n_neighbor]]
        top_K = [i for i in self.y_train[nearest]]
        votes = Counter(top_K)
        y_predict = votes.most_common(1)[0][0]
        return y_predict

    def __repr__(self):
        # 返回对象的描述
        return "KnnClassifier(n_neighbor=3)"

将封装好的knn算法放在项目同目录下,导入封装好的算法,进行计算。
在这里插入图片描述

总结

根据以上学习,我们可以总结出机器学习的算法实现过程如下图所示,先得到训练数据集,通过机器学习算法,使用fit方法对数据进行拟合,获取模型,最后通过输入样例(待预测数据),获得输出结果。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/153420.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

android架构拆分方案-结构相关方案与技术

很纯、很生硬的架构技术归纳blog上上文https://blog.csdn.net/dongyi1988/article/details/128617738接上文https://blog.csdn.net/dongyi1988/article/details/128629011android架构官网地址https://source.android.google.cn/docs/core/architecture?hlzh-cnGKI(…

VBO、VAO、EBO学习记录

在这里要先了解一下OpenGL的一个幕后大致运作流程,可以直接阅读OPENGL CN 我自己大概总结了一下就是,OpenGL本身就是一个巨大的状态机,我们通过更改状态变量(上下文)来告诉OpenGL如何去绘制图像。一般通过设置选项,修改缓冲来更改…

【网络与系统安全】国科大《网络与系统安全》复习大纲整理 + 考试记忆版

国科大《网络与系统安全》复习整理笔记 重在理解概念考试不算太难 文章目录一、新形势安全面临挑战和安全保障能力提升二、网络与系统安全的需求与目标三、自主与强制访问控制1.访问控制的基本概念2.访问控制的要素3.访问控制3种基本类型4.访问控制矩阵、访问控制列表、访问控制…

【Linux修炼】13.缓冲区

每一个不曾起舞的日子,都是对生命的辜负。 缓冲区的理解一. C接口打印两次的现象二. 理解缓冲区问题为什么要有缓冲区缓冲区刷新策略的问题所说的缓冲区在哪里?指的是什么缓冲区?三. 解释打印两次的现象四. 模拟实现五. 缓冲区与OS的关系一. …

ThinkPHP 表单验证使用

对前端或表单请求的数据,一定要做校验,而使用ThinkPHP 验证器则可以事半功倍。 可以使用validate助手函数(或者封装验证方法)进行验证。TP版本6.1。 目录 验证场景 验证器 创建验证器 定义规则和提示 数据验证 独立验证&…

Arbotix使用

内容学自赵虚左的视频及资料 需求描述: 控制机器人模型在 rviz 中做圆周运动 1.安装 Arbotix 方式1:命令行调用 sudo apt-get install ros-<<VersionName()>>-arbotix <<VsersionName()>> 替换成当前 ROS 版本名称 添加 arbotix 所需配置文件 # …

Web原型设计规范

上篇文章为大家介绍了app端在进行原型设计时的设计规范&#xff0c;本篇将继续为大家介绍一下Web端&#xff08;这里主要指网页端&#xff09;的设计规范。其实web端的设计规范并没有像app端那样多&#xff0c;因为展示的空间比较大&#xff0c;所有要求也就没有那么严苛。 电脑…

Spring_事务

事务的主要内容 事务定义 特性&#xff1a;ACID 并发时产生的问题 事务的隔离级别 锁 事务的传播特性 异常处理 超时 只读事务 TransactionDefinition 并发时产生的问题 一个数据库可以允许多个客户端同时访问&#xff0c;即并发的方式访问数据库。数据库中的同一个数据可能同…

2023年12306购票平台自动化购票终|解决乘客选择与车票提交(附自动化购票完整源代码与演示视频)

目录 一、说明 1.1、背景 1.2、说明 二、步骤 2.1、切换视角检索乘车乘客 2.2、选择乘客 2.3、关闭学生票选择界面 2.4、提交订单 2.5、选择座位并确认 三、完整代码与视频演示 3.1、完整源代码如下 3.2、视频演示代码运行 四、结果 4.1、代码运行结果 五、总结…

windows获取iOS设备信息

依赖环境&#xff1a; 1.python3.6以上版本&#xff0c; 2.配置python的系统环境变量。 3.python已经安装pip。 安装tidevice: 1.打开cmd&#xff0c;输入命令pip3 install -U "tidevice[openssl]"如图所示&#xff0c;安装成功。 2.查看tidevice版本号&#xff0c…

网络超火的音效素材、BGM,全在这里了。

推荐几个超好用的音效素材网站&#xff0c;全网火爆的音效、BGM这里都能找到&#xff0c;自媒体、视频剪辑小伙伴必备&#xff01;建议收藏&#xff01; 1、菜鸟图库 https://www.sucai999.com/audio.html?vNTYwNDUx 菜鸟图库是一个综合性素材网站&#xff0c;站内涵盖设计、…

vector模拟实现之迭代器失效及深浅拷贝的问题

vector模拟实现 Tips&#xff1a;new申请空间不用判断&#xff0c;因为失败的话会抛异常。 STL源代码中vector的私有成员变量如下&#xff1a; private:iterator _start;//首元素iterator _finish;//最后一个有效数据的下一个&#xff0c;-_start为sizeiterator _endofstora…

6-3分布散度的9个梯度

( A, B )---1*30*2---( 1, 0 )( 0, 1 ) 让网络的输入只有1个节点&#xff0c;AB各由9张二值化的图片组成&#xff0c;排列组合A和B的所有可能性&#xff0c;固定收敛误差为7e-4&#xff0c;统计收敛迭代次数&#xff0c;并比较迭代次数的变化规律。 差值结构 A-B 迭代次数 …

Huawei Matebook X Pro 2018 Space Gray电脑 Hackintosh 黑苹果efi引导文件

硬件型号驱动情况主板Huawei Matebook X Pro 2018 Space Gray处理器Intel Core i7-8550U已驱动内存16 GB LPDDR4 2133 MHz已驱动硬盘LiteON SSD PCIe NVMe 512 GB [CA3-8D512]已驱动显卡NVIDIA GeForce MX150 (Disabled) / Intel(R) UHD Graphics 620已驱动声卡瑞昱ALC256 英特…

微积分——导数和切线问题

目录 1. 切线(Tangent Line)问题 2. 函数的导数(derivative) 3. 函数的可微性(differentiability)与连续性(Continuity) 1. 切线(Tangent Line)问题 微积分的出现源于17世纪欧洲数学家们正在研究解决的四个主要的问题&#xff1a; (1) 切线(tangent line)问题&#xf…

使用Alexnet实现CIFAR10数据集的训练

如果对你有用的话&#xff0c;希望能够点赞支持一下&#xff0c;这样我就能有更多的动力更新更多的学习笔记了。&#x1f604;&#x1f604; 使用Alexnet进行CIFAR-10数据集进行测试&#xff0c;这里使用的是将CIFAR-10数据集的分辨率扩大到224X224&#xff0c;因为在测试…

第03讲:Docker 容器的数据卷

一、什么是数据卷 数据卷是宿主机中的一个目录或文件&#xff0c;当容器目录或者文件和数据卷目录或者文件绑定后&#xff0c;对方的修改会立即同步&#xff0c;一个数据卷可以被多个容器同时挂载&#xff0c;一个容器也可以被挂载多个数据卷&#xff0c;数据卷的作用:容器数据…

基于遥感卫星影像水体提取方法综述

水体提取分类依据及基础 水体提取分类依据 水体提取的方法很多,很多学者也进行了分类,大体上有一个分类框架,主要是基于光学影像的分类,比如王航等[7]将水体提取分成3类,分别是基于阈值法、分类器法和自动化法; 李丹等[8]更深一步进行总结,引入近些年发展火热的基于雷达影像数…

Redisson自定义序列化

配置RedissonClientBean public RedissonClient redissonClient() {Config config new Config();// 单节点模式SingleServerConfig singleServerConfig config.useSingleServer();singleServerConfig.setAddress("redis://127.0.0.1:6379");singleServerConfig.set…

LeetCode二叉树经典题目(六):二叉搜索树

目录 28. LeetCode617. 合并二叉树 29. LeetCode700. 二叉搜索树中的搜索 30. LeetCode98. 验证二叉搜索树 31. LeetCode530. 二叉搜索树的最小绝对差 32. LeetCode501. 二叉搜索树中的众数 33. LeetCode236. 二叉树的最近公共祖先​ 28. LeetCode617. 合并二叉树 递归&…