吴恩达《机器学习》——Logistic多分类与神经网络

news2024/9/21 4:38:50

Logistic多分类与神经网络

  • 1. MINIST数据集与Logistic多分类
    • MINIST简介
      • 数据集可视化
    • Logistic如何实现多分类?
      • One-Hot向量
    • Python实现
  • 2. 神经网络(Neural Network, NN)
    • 神经网络
    • 前馈传播
    • Python实现
  • 3. 基于PyTorch框架的网络搭建

数据集、源文件可以在Github项目中获得
链接: https://github.com/Raymond-Yang-2001/AndrewNg-Machine-Learing-Homework

1. MINIST数据集与Logistic多分类

在本文中,我们将实现通过Logistic回归对Minist数据集的均匀采样子数据集进行多分类的任务。

MINIST简介

MINST数据集(Modified National Institute of Standards and Technology),是一个大型手写数字数据库,通常用于训练各种图像处理系统,包含60,000个示例的训练集以及10,000个示例的测试集。
MINIST数据集最早在1998年被LeCun大佬使用,后来被包括线性分类,SVM,KNN,NN等在内的各种机器学习算法广泛使用,已经成为评价算法性能的通用数据库。

Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. “Gradient-based learning applied to document recognition.” Proceedings of the IEEE, 86(11):2278-2324, November 1998.

本文使用到数据集是MINIST的一个均匀采样的子集,包括5000个0~9九个数字,以MATLAB文件.mat格式给出。由于MATLAB不支持0索引,所以数字0的标签以“10”代替。随后我们会对这一部分进行详细的解释和操作。

数据集可视化

import scipy.io as scio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

由于数字"0"的标签为“10”,我们可以用"0"来替换原标签。

data = scio.loadmat('ex3data1.mat')
x = data.get('X')
y = data.get('y')
y = np.expand_dims(np.where(y[:,0]==10,0,y[:,0]), axis=1)

此时x和y的形状分别是 ( 5000 , 40 ) (5000,40) (5000,40) ( 5000 , 1 ) (5000,1) (5000,1)。接下来随机挑选25张图像进行可视化展示。

def plot_image(img):
    """ 
    assume the image is square
    img: (5000, 400)
    """
    sample_idx = np.random.choice(np.arange(img.shape[0]), 25)  # 100*400
    sample_images = img[sample_idx, :]

    fig, ax_array = plt.subplots(nrows=5, ncols=5, sharey=True, sharex=True, figsize=(5, 5))

    for r in range(5):
        for c in range(5):
            ax_array[r, c].matshow(sample_images[5 * r + c].reshape((20, 20)).T,
                                   cmap=matplotlib.cm.binary)
            plt.xticks(np.array([]))
            plt.yticks(np.array([]))


plot_image(x)

在这里插入图片描述

Logistic如何实现多分类?

回顾前文,我们知道在伯努利分布的理论基础上,Logistic回归只能实现对样本的二分类(即输出一个伯努利分布)。但是显然,对于多分类,我们的输出要符合一个多项式分布,即 p i ≥ 0 ( 1 ≤ i ≤ n ) , ∑ i = 1 n p i = 1 p_{i}\ge0\quad (1\le i \le n), \sum_{i=1}^{n}{p_{i}}=1 pi0(1in),i=1npi=1。其中 p i p_i pi代表了样本被分为第 i i i类的概率。

从另一个角度来思考,对一个样本进行多分类,相当于对其进行 n n n次二分类( n n n为类别数目),每一次二分类输出的positive的预测概率,都可以看做多项式分布的一项。当然,这里的说法并不严谨,因为多项式分布要求 ∑ i = 1 n p i = 1 \sum_{i=1}^{n}{p_{i}}=1 i=1npi=1。不过,对于预测分类来说,我们只关心概率最大的一项,和是否为1并没有很大的影响,在后面我们还将看到如何使得输出的分布符合一个多项式分布的要求。

按照这个思路,我们可以训练十个分类器,分别对应0~9的类别分类。对于每一个要分类的样本,我们同时将其送入十个分类器中,取positive分类概率最高的类为预测输出。

One-Hot向量

在数字电路和机器学习中,one-hot向量(独热向量)是常常使用的一种概念。在one-hot向量中,只有一位为“高位”,即1;其余位均为“低位”,即0。其表示如下所示:
[ 0 ⋯ 1 ⋯ 0 ] \left[ \begin{array} {c c c} 0 & \cdots & 1 & \cdots & 0 \end{array} \right] [010]

显然,这非常适合作为多分类任务的标签。在二分类任务中,因为只存在两种状态,所以用1为的“0”或“1”就足以表示类别。在多分类任务中,one-hot向量将真实类的位置设为“1”,其余类为“0”,不仅能有效表示类别,同时这也符合我们之前提到的多项式分布的要求。

def onehot_encode(label):
    """
    Onehot编码
    :param label: (n,1)
    :return: onehot label (n, n_cls); cls (n,)
    """
    # cls (n_cls,)
    cls = np.unique(label)
    y_matrix = []
    for cls_idx in cls:
        y_matrix.append((label == cls_idx).astype(int))
    one_hot = np.array(y_matrix).T.squeeze()
    return one_hot, cls

y_onehot, cls = onehot_encode(y)

得到了one-hot编码的标签"y_onehot"和类别索引"cls"

Python实现

LogisticRegression类来自上一篇博文的代码实现 Logistic回归

classifier_list = []
for cls_idx in cls:
    classifier = LogisticRegression(x=train_x,y=train_y_ex[:,:,cls_idx], val_x=val_x, val_y=val_y_ex[:,:,cls_idx], epoch=epochs,lr=alpha,normalize=normalize, regularize="L2", scale=2, show=False)
    classifier.train()
    classifier_list.append(classifier)

print("Total Classifiers: {}".format(len(classifier_list)))

Total Classifiers: 10

划分训练集和验证集

from sklearn.model_selection import train_test_split
train_x, val_x, train_y, val_y = train_test_split(x, y_onehot, test_size=0.2)
for cls_idx in cls:
    train_sample_n = np.where(train_y[:,cls_idx]==1)[0].shape[0]
    val_sample_n = np.where(val_y[:,cls_idx]==1)[0].shape[0]
    print("Class {}:\t{} train samples\t{} val samples".format(cls_idx, train_sample_n, val_sample_n))
print("Total train samples: {}\n"
      "Total val samples: {}".format(train_y.shape[0],val_y.shape[0]))
Class 0:	406 train samples	94 val samples
Class 1:	401 train samples	99 val samples
Class 2:	389 train samples	111 val samples
Class 3:	393 train samples	107 val samples
Class 4:	395 train samples	105 val samples
Class 5:	416 train samples	84 val samples
Class 6:	406 train samples	94 val samples
Class 7:	394 train samples	106 val samples
Class 8:	397 train samples	103 val samples
Class 9:	403 train samples	97 val samples
Total train samples: 4000
Total val samples: 1000

查看分类性能,这里直接使用sklearn库的方法查看分类性能。

# (cls_n, sample_n, 1)
prob_list = [classifier_i.get_prob(val_x) for classifier_i in classifier_list]
# (sample_n, cls_n)
prob_arr = np.array(prob_list).squeeze().T
# (sample_n,)
multi_pred = np.argmax(prob_arr,axis=1)

from sklearn.metrics import classification_report
report = classification_report(multi_pred, np.argmax(val_y, axis=1), digits=4)
print(report)
          precision    recall  f1-score   support

       0     0.9574    0.9184    0.9375        98
       1     0.9495    0.8468    0.8952       111
       2     0.8378    0.9118    0.8732       102
       3     0.8505    0.8667    0.8585       105
       4     0.8952    0.8785    0.8868       107
       5     0.6905    0.9062    0.7838        64
       6     0.8830    0.8737    0.8783        95
       7     0.8962    0.9048    0.9005       105
       8     0.8738    0.7965    0.8333       113
       9     0.8351    0.8100    0.8223       100

    accuracy                         0.8690      1000
   macro avg     0.8669    0.8713    0.8669      1000
weighted avg     0.8742    0.8690    0.8699      1000

平均精度为86.99%。

2. 神经网络(Neural Network, NN)

接下来我们将使用一种全新的机器学习算法:神经网络,来进行图像分类。

神经网络

在这里插入图片描述
本文使用的神经网络的示意图如上所示。

虽然现在有很多流行的说法阐释神经网络与人脑系统之间的关系,但是两者的关联微乎其微。除去示意图的计算单元与人脑的神经元有些相似以外,神经网络是基于数学和信息论的严谨的数学模型。

在Logistic回归中,我们有以下计算公式:
h ( x ; θ ) = σ ( θ x ⊤ ) h(\boldsymbol{x};\boldsymbol{\theta})=\sigma(\boldsymbol{\theta x^{\top}}) h(x;θ)=σ(θx)
其中, σ \sigma σ代表了sigmoid运算,也可以叫做激励函数。事实上,除去sigmoid以外,我们还可以使用很多激励函数,最常见的如ReLU,SiLU,tanh函数等。

前馈传播

在神经网络中,我们简单的将模型分为输入层、隐藏层、输出层三个部分。

在输入层, a ( 1 ) = x a^{(1)}=x a(1)=x,添加 a 0 ( 1 ) = 0 a^{(1)}_{0}=0 a0(1)=0,使得偏置项的计算更加方便。
在隐藏层, z ( 2 ) = θ ( 1 ) a ( 1 ) z^{(2)}=\theta^{(1)}a^{(1)} z(2)=θ(1)a(1) a ( 2 ) = σ ( z ( 2 ) ) a^{(2)}=\sigma(z^{(2)}) a(2)=σ(z(2)),这相当于做了一次线性运算,并使用激励函数进行非线性的变换。(如果没有激励函数的话,无论多么复杂的神经网络,都相当于对输入进行线性运算,网络容量会变得极其有限。)同时,也在这里添加 a 0 ( 2 ) = 0 a^{(2)}_{0}=0 a0(2)=0
在输出层, z ( 3 ) = θ ( 2 ) a ( 2 ) z^{(3)}=\theta^{(2)}a^{(2)} z(3)=θ(2)a(2) a ( 3 ) = σ ( z ( 3 ) ) = h ( θ ; x ) a^{(3)}=\sigma(z^{(3)})=h(\theta;x) a(3)=σ(z(3))=h(θ;x)

Python实现

这里使用训练好的参数直接进行前馈传播,关于神经网络如何训练和优化,将在之后的文章中进行讲解。

class ForwardModel:
    def __init__(self):
        self.theta1 = None
        self.theta2 = None

    def load_parameters(self, parameters):
        self.theta1 = parameters[0]
        self.theta2 = parameters[1]

    def __call__(self, x, *args, **kwargs):
        # x (n,d)
        t = np.ones(shape=(x.shape[0], 1))
        # x (n,d+1)
        a1 = np.concatenate((t, x), axis=1)
        # a2 (n, hidden_size)
        a2 = sigmoid(np.matmul(a1, self.theta1))
        # a2 (n, hidden_size + 1)
        a2 = np.concatenate((t, a2), axis=1)
        # a3 (n, cls_n)
        a3 = sigmoid(np.matmul(a2, self.theta2))
        return a3
model = ForwardModel()
model.load_parameters([theta1, theta2])
pred_prob = model(x)
# pred_prob = np.concatenate([np.expand_dims(pred_prob[:,-1],axis=1), pred_prob[:,:-1]], axis=1)

pred = np.argmax(pred_prob,axis=1) + 1
from sklearn.metrics import classification_report
report = classification_report(pred, y, digits=4)
print(report)

查看分类性能

              precision    recall  f1-score   support

           1     0.9820    0.9684    0.9752       507
           2     0.9700    0.9818    0.9759       494
           3     0.9600    0.9776    0.9687       491
           4     0.9680    0.9699    0.9690       499
           5     0.9840    0.9723    0.9781       506
           6     0.9860    0.9782    0.9821       504
           7     0.9700    0.9778    0.9739       496
           8     0.9820    0.9781    0.9800       502
           9     0.9580    0.9657    0.9618       496
          10     0.9920    0.9822    0.9871       505

    accuracy                         0.9752      5000
   macro avg     0.9752    0.9752    0.9752      5000
weighted avg     0.9753    0.9752    0.9752      5000

平均精度为97.52。

3. 基于PyTorch框架的网络搭建

PyTorch是Facebook开源的深度学习框架,除去手工搭建神经网络,我们还可以使用PyTorch框架进行神经网络的搭建。关于PyTorch安装的相关知识,可以参考B站up主小土堆的教程。

from torch import nn

class PytorchForward(nn.Module):
    def __init__(self):
        super(PytorchForward, self).__init__()
        self.layer1 = nn.Linear(in_features=400, out_features=25)
        self.layer2 = nn.Linear(in_features=25, out_features=10)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

model = PytorchForward()
for name, parameters in model.named_parameters():
    print(name,':',parameters.size())

可以看到参数的大小:

layer1.weight : torch.Size([25, 400])
layer1.bias : torch.Size([25])
layer2.weight : torch.Size([10, 25])
layer2.bias : torch.Size([10])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/133224.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用CMake构建静态库和动态库

使用CMake构建静态库和动态库一、准备工作二、动态库的构建2.1 工程改造2.2 编译动态库2.3 更多的说明三、静态库的构建3.1 错误的尝试3.2 新的构建指令四、动态库的版本号五、安装动态库和头文件一、准备工作 本机演示环境为: 主机windows11 vscode 虚拟机安装的…

人工智能在药物研发和生物技术中的应用:回顾与展望

人工智能(Artificial intelligence, AI)的出现正在重新塑造整个制药和生物技术行业的发展。几乎所有大大小小的生命科学和药物发现机构,都对采用人工智能驱动的发现平台表现出浓厚的兴趣,希望通过AI来简化研发工作,减少发现时间和成本&#x…

【C++基础】10:STL(二)

CppSTL(二) OVERVIEWCppSTL(二)一、函数对象1.函数对象:(1)概述:(2)简单使用:2.谓词:(1)一元谓词:…

LVGL学习笔记10 - 按钮Button

目录 1. Check型按钮 2. 修改样式 2.1 设置背景 2.1.1 颜色 2.1.2 透明度 2.1.3 渐变色 2.1.4 渐变色起始位置设置 2.2 修改边界 2.2.1 宽度 2.2.2 颜色 2.2.3 透明度 2.2.4 指定边 2.3 修改边框 2.4 修改阴影 2.4.1 宽度 2.4.2 透明度 2.4.3 偏移坐标 2.3.…

PHP代码审计系列(五)

PHP代码审计系列&#xff08;五&#xff09; 本系列将收集多个PHP代码安全审计项目从易到难&#xff0c;并加入个人详细的源码解读。此系列将进行持续更新。 数字验证正则绕过 源码如下 <?phperror_reporting(0); $flag flag{test}; if ("POST" $_SERVER[…

数值优化之凸集

本文ppt来自深蓝学院《机器人中的数值优化》 目录 1 凸集的定义 2 凸集的运算 1 凸集的定义 集合中任意两点连线形成的线段属于这个集合&#xff0c;这个集合是凸集。 注意&#xff1a;是否是凸集&#xff0c;集合的边界是否属于这个集合很重要 这涉及到构造最小凸包的问题…

行锁功过:怎么减少行锁对性能的影响?

在上一篇文章中,我跟你介绍了 MySQL 的全局锁和表级锁,今天我们就来讲讲 MySQL 的行锁。 MySQL 的行锁是在引擎层由各个引擎自己实现的。但并不是所有的引擎都支持行锁,比如 MyISAM 引擎就不支持行锁。不支持行锁意味着并发控制只能使用表锁,对于这种引擎的表,同一张表上…

Java中的四种引用类型

1.对象引用介绍 从 JDK1.2 版本开始&#xff0c;把对象的引用分为四种级别&#xff0c;从而使程序更加灵活的控制对象的生命周期。这四种级别由高到低依次为&#xff1a;强引用、软引用、弱引用和虚引用。 用表格整理之后&#xff0c;各个引用类型的区别如下&#xff1a; 2.强…

BOM编程:location对象

document 对象和window对象的location的区别 document对象的位置获取步骤是返回这个相关的全局对象的位置对象&#xff0c;如果这是完全活动的&#xff0c;否则为空。 Window对象的位置获取步骤是返回它的位置对象。每个Window对象都与创建Window对象时分配的Location对象的唯…

Spark01:Spark工作原理

1. Spark执行数据计算的整个流程 首先通过Spark客户端提交任务到Spark集群&#xff0c;然后Spark任务在执行的时候会读取数据源HDFS中的数据&#xff0c;将数据加载到内存中&#xff0c;转化为RDD&#xff0c;然后针对RDD调用一些高阶函数对数据进行处理&#xff0c;中间可以调…

ElementUI——案例1用户管理(基于SpringBoot)

1.前期准备 备注&#xff1a;主要涉及组件container组件&#xff0c;导航菜单组件&#xff0c;router路由组件&#xff0c;carousel 走马灯组件&#xff0c;Image组件&#xff0c;Table表格组件 #1.在项目开发目录使用脚手架新建vue项目&#xff08;需要提前安装好node和webp…

无字母数字webshell提高

前言 元旦快乐 -- 转眼就到了2023年 新的一年继续努力 在p神博客中看到一个 通过上传临时文件进行rce&#xff0c;便想着写一篇文章&#xff0c;记录一下这个小trick。太强了 比如给你下面这么一串代码。正如文章标题 无字母数字&#xff0c;如果匹配到字母和数字&#xf…

【Vuex快速入门】vuex基础知识与安装配置

vuex快速入门——什么是vuex&#xff1f;创作背景vuex基础知识一、vuex是什么&#xff1f;二、vuex的组成三、为什么使用vuex&#xff1f;四、什么时候使用vuex&#xff1f;vuex的安装配置一、直接下载 / CDN引用二、npm安装vuex三、yarn安装四、自己构建更多内容可参考Vuex官方…

[从零开始]用python制作识图翻译器·二

AlsoEasy-RecognitionTranslator需求分析系统分析功能拆解工程语言选择技术可行性分析具体实现需求分析 见上篇[从零开始]用python制作识图翻译器一 上篇分析了该产品的需求以及市场上的可行性&#xff08;没有被吊打的竞品&#xff09;。而本篇将着重于分析如何实现。 系统分析…

gateway基本配置

目录 1、gateway简介 2、gateway核心概念 3、路由 4、断言 5、过滤器 5.1、过滤器介绍 5.2、内置局部过滤器与使用 5.3、内置全局过滤器 5.4、自定义全局过滤器 5.4.1、黑名单校验 5.4.2、模拟登录校验 6、一个简单的gateway配置实例 1、gateway简介 路由转发 执行…

Linear Regression with PyTorch 用PyTorch实现线性回归

文章目录4、Linear Regression with PyTorch 用PyTorch实现线性回归4.1 Prepare dataset 准备数据集4.2 Design Model 设计模型4.2.1 __call__() 作用4.3 Construct Loss and Optimizer 构造损失和优化器4.4 Training Cycle 训练周期4.5 Test Model 测试模型4.6 Different Opti…

redis缓存淘汰策略

定时删除 Redis不可能时时刻刻遍历所有被设置了生存时间的key&#xff0c;来检测数据是否已经到达过期时间&#xff0c;然后对它进行删除。 立即删除能保证内存中数据的最大新鲜度&#xff0c;因为它保证过期键值会在过期后马上被删除&#xff0c;其所占用的内存也会随之释放。…

zookeeper学习笔记2(小D课堂)

zookeeper数据模型&#xff1a; 我们的zookeeper是以节点的形式存在的&#xff0c;这样的形式和数据结构中的树的形式很像。同时也很像我们的linux的结构&#xff0c;例如linux的/user/local目录下可以有我们的/usr/local/tomcat目录。这样的节点形式。 我们的zookeeper中的每…

算法练习-常用查找算法复现

一个不知名大学生&#xff0c;江湖人称菜狗 original author: jacky Li Email : 3435673055qq.com Time of completion&#xff1a;2023.1.1 Last edited: 2023.1.1 目录 算法练习-常用查找算法复现&#xff08;PS&#xff1a;1 -- 3自己写的&#xff0c;4、5懒得写了&#xf…

PHP开发者之路

我们经常会发现&#xff0c;历时四年软件专业的大学生毕业居然找不到工作&#xff0c;即便找到了工作也只能是做一些简单的辅助性工作。 那么我们不禁要问&#xff0c;究竟是什么原因让我们可爱的大学生们学而无用&#xff0c;或者用而不学呢&#xff1f; 我认为主要是因为现…