《统计学习方法:李航》笔记 从原理到实现(基于python)-- 第 2章感知机

news2024/11/16 7:50:29

文章目录

  • 第 2章感知机
    • 2.1 感知机模型
    • 2.2 感知机学习策略
      • 2.2.1 数据集的线性可分性
      • 2.2.2 感知机学习策略
    • 2.3 感知机学习算法
      • 2.3.1 感知机学习算法的原始形式
      • 2.3.2 算法的收敛性
      • 2.3.3 感知机学习算法的对偶形式
    • 实践:二分类模型(iris数据集)
      • 数据集可视化:
      • Perceptron
      • scikit-learn实例

《统计学习方法:李航》笔记 从原理到实现(基于python)-- 第 2章感知机
《统计学习方法:李航》笔记 从原理到实现(基于python)-- 第1章 统计学习方法概论

我算是有点基础的(有过深度学习和机器学的项目经验),但也是半路出家,无论是学Python还是深度学习,都是从问题出发,边查边做,没有系统的学过相关的知识,这样的好处是入门快(如果想快速入门,大家也可以试试,直接上手项目,从小项目开始),但也存在一个严重的问题就是,很多东西一知半解,容易走进死胡同出不来(感觉有点像陷入局部最优解,找不到出路),所以打算系统的学习几本口碑比较不错的书籍。
  书籍选择: 当然,机器学习相关的书籍有很多,很多英文版的神书,据说读英文版的书会更好,奈何英文不太好,比较难啃。国内也有很多书,周志华老师的“西瓜书”我也有了解过,看了前几章,个人感觉他肯能对初学者更友好一点,讲述的非常清楚,有很多描述性的内容。对比下来,更喜欢《统计学习方法》,毕竟能坚持看完才最重要。
  笔记内容: 笔记内容尽量省去了公式推导的部分,一方面latex编辑太费时间了,另一方面,我觉得公式一定要自己推到一边才有用(最好是手写)。尽量保留所有标题,但内容会有删减,通过标黑和列表的形式突出重点内容,要特意说一下,标灰的部分大家最好读一下(这部分是我觉得比较繁琐,但又不想删掉的部分)。
  代码实现: 最后是本章内容的实践,如果想要对应的.ipynb文件,可以留言

第 2章感知机

  感知机 (perceptron) 是二类分类的线性分类模型,其输入为实例的特征向量,输 出为实例的类别,取 +1 和-1 二值。

  感知机对应于输入空间(特征空间)中将实例划 分为正负两类的分离超平面,属于判别模型

  感知机学习旨在求出将训练数据进行线性划分的分离超平面,为此:

  • 导入基于误分类的损失函数,
  • 利用梯度下降法对损失函 数进行极小化,求得感知机模型。

  感知机学习算法具有简单而易于实现的优点,分为 原始形式对偶形式。

2.1 感知机模型

  感知机是一种线性分类模型,属于判别模型。

  感知机模型的假设空间是定义在特征空间中的所有线性分类模型(linear classification modeD 或线性分类器 (linear classifier) ,即函数集合

f ∣ f ( x ) = ω • x + b {{f|f(x) = ω • x + b}} ff(x)=ωx+b

几何解释:线性方程

ω ⋅ x + b = 0 ω\cdot x+b=0 ωx+b=0

  对应于特征空间 R n R^n Rn 中的一个超平面 S , 其中 ω超平面的法向量b超平面的截距

  这个超平面将特征空间划分为两个部分。位于两部分的点(特征向量)分别被分为 正、负两类。因此,超平面 S称为分离超平面 (separating hyperplane) ,如图 2.1 所示。

2.2 感知机学习策略

2.2.1 数据集的线性可分性

  给定一个数据集T:

T = ( x l , y 1 ) , ( x 2 , y 2 ) , … , ( x n , y n ) T = {(x_l ,y_1), (x_2 ,y_2) ,… , (x_n,y_n)} T=(xl,y1),(x2,y2),,(xn,yn)

  其中 , x i ∈ X = R n , y i ∈ Y = ( + 1 , − 1 ) , i = 1 , 2 , … , n x_i \in X = R^n, y_i \in Y=(+1 ,-1) , i= 1 , 2,… , n xiX=Rn,yiY=(+1,1)i=12n

  如果存在某个超乎面 S

ω ⋅ x + b = 0 ω\cdot x+b=0 ωx+b=0

  能够将数据集的正实例点和负实例点完全正确地划分到超平面的两侧,则称数据集 T 为线性可分数据集( linearly separable data set ) ;否则,称数据集 T 线性不可分

2.2.2 感知机学习策略

  假设训练数据集是线性可分的,感知机学习的目标是求得一个能够将训练集正实例点和负实例点完全正确分开分离超平面
在这里插入图片描述

  损失函数的一个自然选择是误分类点的总数。但是,这样的损失函数不是连续可导函数,不易优化。损失函数的另一个选择是误分类点到超平面 S 的总距离

  • 输入空间 R n R^n Rn 中任一 x o x_o xo 到超平面 S S S 的 距离:

1 ∣ ∣ w ∣ ∣ ∣ w ⋅ x 0 + b ∣ \frac{1}{||w||}|w \cdot x_0+b| ∣∣w∣∣1wx0+b

  • 对于误分类的数据 ( x i , x i ) (x_i,x_i) (xi,xi) 来说,

− y i ( ω ⋅ x i + b ) > O -y_i(ω \cdot x_i+b)>O yi(ωxi+b)>O

  • ω • x i + b > 0 ω • x_i + b > 0 ωxi+b>0 时 , y i = − 1 y_i = -1 yi=1
  • ω • x i + b < 0 ω • x_i + b < 0 ωxi+b<0 时, x i = + 1 x_i = +1 xi=+1
  • 所有误分类点超平面 S总 距离

− 1 ∣ ∣ w ∣ ∣ ∑ x i ∈ M y i ( w ⋅ x 0 + b ) -\frac{1}{||w||}\sum_{x_i\in M}y_i(w \cdot x_0+b) ∣∣w∣∣1xiMyi(wx0+b)

  感知机 s i g n ( w • x + b ) sign(w • x + b) sign(wx+b) 学习的损失函数定义为:

L ( w , b ) = − ∑ x i ∈ M y i ( w ⋅ x 0 + b ) L(w,b)=-\sum_{x_i\in M}y_i(w \cdot x_0+b) L(w,b)=xiMyi(wx0+b)

  其中 M 为误分类点的集合。

  这个损失函数就是感知机学习的经验风险函数。

2.3 感知机学习算法

  感知机学习问题转化为求解损失函数式的最优化问题,最优化的方法是随 机梯度下降法。

2.3.1 感知机学习算法的原始形式

求参数 w , b w, b wb , 使其为以下损失函数极小化问题的解:

m i n w , b L ( w , b ) = − ∑ x i ∈ M y i ( w ⋅ x i + b ) min_{w,b}L(w,b)=-\sum_{x_i\in M}y_i(w \cdot x_i+b) minw,bL(w,b)=xiMyi(wxi+b)

  其中 M 为误分类点的集合。

求解思路:

  • 感知机学习算法是误分类驱动的,具体采用随机梯度下降法 (stochastic gradient descent)。
  • 首先,任意选取一个超平面 w 0 , b 0 w_0,b_0 w0,b0 , 然后用梯度下降法不断地极小化目标函数(损失函数)
  • 极小化过程中不是一次使M 中所有误分类点的梯度下降,而是一次随机 选取一个误分类点使其梯度下降
  • 假设误分类点集合 M 是固定的,那么损失函数 L ( w , b ) L(w,b) L(w,b)梯度由下式给出:

∇ w L ( w , b ) = − ∑ x i ∈ M y i x i \nabla_w L(w,b)=-\sum_{x_i\in M}{y_ix_i} wL(w,b)=xiMyixi

∇ b L ( w , b ) = − ∑ x i ∈ M y i \nabla _b L(w,b)=-\sum_{x_i\in M}{y_i} bL(w,b)=xiMyi

  • 随机选取一个误分类点 ( x i , y i ) (x_i,y_i) xi,yi ω , b ω, b ωb 进行更新:

w ← w + η y i x i w\leftarrow w+ηy_ix_i ww+ηyixi

b ← b + η y i b \leftarrow b+ηy_i bb+ηyi

  式中 η ( 0 < η ≤ 1 ) η(0 <η\leq1) η(0<η1) 是步长,在统计学习中又称为学习率(learning rate) 。

在这里插入图片描述

  这种学习算法直观上有如下解释:

  当一个实例点被误分类,即位于分离超平面的 错误一侧时,则调整 ω, b 的值,使分离超平面向该误分类点的一侧移动,以减少该误分类点与超平面间的距离,直至超平面越过该误分类点使其被正确分类。

2.3.2 算法的收敛性

在这里插入图片描述

  定理表明,误分类的次数 k 是有上界的,经过有限次搜索可以找到将训练数据完 全正确分开的分离超平面。也就是说,当训练数据集线性可分时,感知机学习算法原 始形式迭代是收敛的。

2.3.3 感知机学习算法的对偶形式

  对偶形式的基本想法是,将 ω ω ω b b b 表示为实例 x i x_i xi标记 y i y_i yi线性组合的形式, 通过求解其系数而求得 ω ω ω b b b

在这里插入图片描述

  对偶形式中训练实例仅以内积的形式出现。

  为了方便,可以预先将训练集中实例间的内积计算出来并以矩阵的形式存储,这个矩阵就是所谓的 Gram 矩阵 (Gram matrix):
G = [ x i ⋅ x i ] N × N G=[x_i \cdot x_i]_{N \times N} G=[xixi]N×N

实践:二分类模型(iris数据集)

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
%matplotlib inline
#load data
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = [
							'sepal length', 'sepal width', 'petal length', 'petal width', 'label'
							]
df.label.value_counts()
=========================
2    50
1    50
0    50
Name: label, dtype: int64

数据集可视化:

plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()

在这里插入图片描述

data = np.array(df.iloc[:100, [0, 1, -1]])
X, y = data[:,:-1], data[:,-1]
y = np.array([1 if i == 1 else -1 for i in y])

Perceptron

# 数据线性可分,二分类数据
# 此处为一元一次线性方程
class Model:
    def __init__(self):
        self.w = np.ones(len(data[0]) - 1, dtype=np.float32)
        self.b = 0
        self.l_rate = 0.1
        # self.data = data

    def sign(self, x, w, b):
        y = np.dot(x, w) + b
        return y

    # 随机梯度下降法
    def fit(self, X_train, y_train):
        is_wrong = False
        while not is_wrong:
            wrong_count = 0
            for d in range(len(X_train)):
                X = X_train[d]
                y = y_train[d]
                if y * self.sign(X, self.w, self.b) <= 0:
                    self.w = self.w + self.l_rate * np.dot(y, X)
                    self.b = self.b + self.l_rate * y
                    wrong_count += 1
            if wrong_count == 0:
                is_wrong = True
        return 'Perceptron Model!'

    def score(self):
        pass

训练

perceptron = Model()
perceptron.fit(X, y)
===============================
'Perceptron Model!'

分类&可视化

x_points = np.linspace(4, 7, 10)
y_ = -(perceptron.w[0] * x_points + perceptron.b) / perceptron.w[1]
plt.plot(x_points, y_)

plt.plot(data[:50, 0], data[:50, 1], 'bo', color='blue', label='0')
plt.plot(data[50:100, 0], data[50:100, 1], 'bo', color='orange', label='1')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()

在这里插入图片描述

scikit-learn实例

import sklearn
from sklearn.linear_model import Perceptron
===============
sklearn.__version__
'0.21.2'
clf = Perceptron(fit_intercept=True, 
                 max_iter=1000, 
                 shuffle=True)
clf.fit(X, y)
=================================
Perceptron(alpha=0.0001, class_weight=None, early_stopping=False, eta0=1.0,
           fit_intercept=True, max_iter=1000, n_iter_no_change=5, n_jobs=None,
           penalty=None, random_state=0, shuffle=True, tol=0.001,
           validation_fraction=0.1, verbose=0, warm_start=False)
# Weights assigned to the features.
print(clf.coef_)
===============================
[[ 23.2 -38.7]]
# 截距 Constants in decision function.
print(clf.intercept_)
================================
[-5.]

可视化

# 画布大小
plt.figure(figsize=(10,10))

# 中文标题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.title('鸢尾花线性数据示例')

plt.scatter(data[:50, 0], data[:50, 1], c='b', label='Iris-setosa',)
plt.scatter(data[50:100, 0], data[50:100, 1], c='orange', label='Iris-versicolor')

# 画感知机的线
x_ponits = np.arange(4, 8)
y_ = -(clf.coef_[0][0]*x_ponits + clf.intercept_)/clf.coef_[0][1]
plt.plot(x_ponits, y_)

# 其他部分
plt.legend()  # 显示图例
plt.grid(False)  # 不显示网格
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()

在这里插入图片描述

注意 !

在上图中,有一个位于左下角的蓝点没有被正确分类,这是因为 SKlearn 的 Perceptron 实例中有一个tol参数。

tol 参数规定了如果本次迭代的损失和上次迭代的损失之差小于一个特定值时,停止迭代。所以我们需要设置 tol=None 使之可以继续迭代:

clf = Perceptron(fit_intercept=True, 
                 max_iter=1000,
                 tol=None,
                 shuffle=True)
clf.fit(X, y)

# 画布大小
plt.figure(figsize=(10,10))

# 中文标题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.title('鸢尾花线性数据示例')

plt.scatter(data[:50, 0], data[:50, 1], c='b', label='Iris-setosa',)
plt.scatter(data[50:100, 0], data[50:100, 1], c='orange', label='Iris-versicolor')

# 画感知机的线
x_ponits = np.arange(4, 8)
y_ = -(clf.coef_[0][0]*x_ponits + clf.intercept_)/clf.coef_[0][1]
plt.plot(x_ponits, y_)

# 其他部分
plt.legend()  # 显示图例
plt.grid(False)  # 不显示网格
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()

在这里插入图片描述

现在可以看到,所有的两种鸢尾花都被正确分类了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1409902.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

爬虫正则+bs4+xpath+综合实战详解

Day3 - 1.数据解析概述_哔哩哔哩_bilibili 聚焦爬虫&#xff1a;爬取页面中指定的页面内容 编码流程&#xff1a;指定url -> 发起请求 -> 获取响应数据 -> 数据解析 -> 持久化存储 数据解析分类&#xff1a;正则、bs4、xpath(本教程的重点) 数据解析原理概述&am…

2024群硕荣誉首响,第十三届公益节斩获企业大奖

2024年1月23日至24日&#xff0c;第十三届公益节在北京顺利举行。 历经多年的探索和实践&#xff0c;公益节已经成为中国公益慈善领域颇具影响力的年度盛事。本届公益节全面恢复线下活动&#xff0c;各大企业齐聚现场&#xff0c;展现社会责任的力量&#xff0c;现场气氛热烈而…

什么是5G RedCap?5G RedCap有什么优势?

5G RedCap&#xff08;Reduced Capability&#xff09;是指5G轻量化技术&#xff0c;即通过对5G技术进行一定程度的“功能裁剪”&#xff0c;来降低终端和模组的复杂度、成本、尺寸和功耗等指标&#xff0c;从而“量体裁衣”适配不同的物联需求&#xff0c;实现兼顾物联网系统的…

【送书活动八期】docker容器中登陆并操作postgresql

这里的背景比较简单&#xff0c;因为区块链浏览器使用的是blockscout&#xff0c;blockscout的数据库选择的是postgresql&#xff0c;这些服务组件都是使用的docker容器来管理&#xff0c;今天进行区块链上交易查询的时候&#xff0c;发现数据存在部分问题&#xff0c;因此需要…

大数据信用查询系统能查到什么呢?

在金融助贷行业&#xff0c;大数据有叫大数据信用或者网贷大数据&#xff0c;在申贷的时候&#xff0c;想必大多数人都有听说过&#xff0c;很多人因为大数据不良的原因申贷被拒过&#xff0c;那大数据信用查询系统能查到什么呢?本文就简单为大家总结几点大数据信用查询的内容…

freeRTOS总结(十)消息 队列

1&#xff0c;队列简介&#xff08;了解&#xff09; 队列是任务到任务、任务到中断、中断到任务数据交流的一种机制&#xff08;消息传递&#xff09; 与全局变量的区别 类似全局变量&#xff1f;假设有一个全局变量a 0&#xff0c;现有两个任务都在写这个变量a 假如 当任务…

三层架构-pc通外网小实验

要求:pc端能上外网(isp) 效果图:pc1(VLAN2)和pc3(vlan3)都能ping通2.2.2.2(R2环回) 代码:#先配置好r1,r2,端口ip # [R1] ip route-static 0.0.0.0 0.0.0.0 10.1.1.2 acl 2000 rule permit source any interface GigabitEthernet0/0/2 nat outbound 2000 …

Android消息推送 SSE(Server-Sent Events)方案实践

转载请注明出处&#xff1a;https://blog.csdn.net/kong_gu_you_lan/article/details/135777170 本文出自 容华谢后的博客 0.写在前面 最近公司项目用到了消息推送功能&#xff0c;在技术选型的时候想要找一个轻量级的方案&#xff0c;偶然看到一篇文章讲ChatGPT的对话机制是基…

Bank_Code_FullName_2020.06.16.xlsx

Bank_Code_FullName_2020.06.16.xlsx 银行联行号和全称 https://download.csdn.net/download/spencer_tseng/88780566 144692条记录&#xff0c;没法子贴上去

抖音VR直播:沉浸式体验一键打通360度精彩

随着5G技术的发展&#xff0c;VR直播近年来也逐步进入到大众的视野中&#xff0c;相比于传统直播&#xff0c;VR直播能够提供更加丰富的内容和多样化的互动方式&#xff0c;让观众更有沉浸感和参与感。现如今&#xff0c;抖音平台也上线了VR直播&#xff0c;凭借沉浸式体验和有…

基于 pytorch-openpose 实现 “多目标” 人体姿态估计

前言 还记得上次通过 MediaPipe 估计人体姿态关键点驱动 3D 角色模型&#xff0c;虽然节省了动作 K 帧时间&#xff0c;但是网上还有一种似乎更方便的方法。MagicAnimate 就是其一&#xff0c;说是只要提供一张人物图片和一段动作视频 (舞蹈武术等)&#xff0c;就可以完成图片…

【Kubernetes】深入了解Kubernetes(K8s):现代容器编排的引领者

欢迎来到英杰社区&#xff1a; https://bbs.csdn.net/topics/617804998 欢迎来到阿Q社区&#xff1a; https://bbs.csdn.net/topics/617897397 作者简介&#xff1a; 辭七七&#xff0c;目前大二&#xff0c;正在学习C/C&#xff0c;Java&#xff0c;Python等 作者主页&#xf…

JVM-初始JVM

什么是JVM JVM 全称是 Java Virtual Machine&#xff0c;中文译名 Java虚拟机。JVM 本质上是一个运行在计算机上的程序&#xff0c;他的职责是运行Java字节码文件。 Java源代码执行流程如下&#xff1a; JVM的功能 1 - 解释和运行 2 - 内存管理 3 - 即时编译 解释和运行 解释…

LeetCode.2865. 美丽塔 I

题目 题目链接 分析 闲谈&#xff1a;每次读 LeetCode 的题目描述都要费老大劲&#xff0c;o(╥﹏╥)o 题意&#xff1a;这个其实意思就是以数组的每一位作为最高点&#xff0c;这个点&#xff08;数字&#xff09;左右两边的数字都不能大于这个数字(可以等于)&#xff0c;…

Qt基础-屏蔽qDebug()、qWarning()调试和警告消息

本文讲解Qt如何-屏蔽qDebug()、qWarning()调试和警告消息 在工程文件.pro里面添加 DEFINES QT_NO_WARNING_OUTPUT\ QT_NO_DEBUG_OUTPUT 如果只想Release版本的时候不打印&#xff1a; Release:DEFINES QT_NO_WARNING_OUTPUT\ QT_NO_DEBUG_OUTPUT 这样只是在Release版本…

想要透明拼接屏展现更加效果,视频源是技术活,尤其作为直播背景

随着科技的飞速发展&#xff0c;视频制作和显示技术也在不断进步。透明拼接屏视频作为一种新型的视频形式&#xff0c;在许多场合都得到了广泛的应用。尼伽小编将深入探讨透明拼接屏视频的制作过程、要求、清晰度&#xff0c;以及目前常作为直播背景的优势。 一、透明拼接屏视频…

Make.com的发送邮件功能已经登峰造极

make.com的发送邮件功能已经做到了登峰造极。 我给你个任务&#xff0c;让你发送个新邮件给谁谁&#xff0c;你一定想到SMTP服务器不就行了。 我给你第二个任务&#xff0c;我让你自动回复一个邮件&#xff0c;注意是回复。 做不到了吧&#xff5e;&#xff5e;&#xff01;…

【3万字】modbus简易不简单的教程

&#x1f396;️Modbus简易不简单的教程 文章目录 &#x1f396;️Modbus简易不简单的教程&#x1f3ab;一、简介1.1 Modbus&#xff1a;工业通信的革命1.2 理解标准化通信1.3 Modbus协议的变体 &#x1f380;二、例程引入2.1 示例&#xff1a;使用01功能码读取灯的开关状态2.2…

电商一年挣100w的赚钱模型

现在有多少人还不知道电商具体应该怎么干&#xff0c;有多少人还是看了身边的朋友做电商挣钱了也跟着做了。然后做半天没做起来&#xff0c;然后就找各种原因&#xff0c;你看别人每天上架你也上架&#xff0c;别人开车你也开车&#xff0c;别人亏钱你也亏钱&#xff0c;别人赚…

dns被劫持怎么修复?6种常用修复方法解读

当遇到DNS被劫持的情况时&#xff0c;通常表现出来的症状是无法正常访问某些网站&#xff0c;或者访问被重定向到不正确的地址。DNS劫持可能是由于恶意软件、黑客活动或者ISP&#xff08;Internet服务提供商&#xff09;的问题导致的。 以下是修复DNS劫持的六种方法&#xff1…