吴恩达机器学习-C2W3-多类分类

news2025/1/12 3:49:46

目标

在本实验中,您将探索一个使用神经网络进行多类分类的示例。
在这里插入图片描述

工具

您将使用一些绘图例程。它们存储在这个目录下的lab_utils_multiclass_TF.py中。

import numpy as np
import matplotlib.pyplot as plt
%matplotlib widget
from sklearn.datasets import make_blobs
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
np.set_printoptions(precision=2)
from lab_utils_multiclass_TF import *
import logging
logging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)

2.0多类分类

神经网络经常被用来对数据进行分类。例如神经网络:

  • 拍摄照片,并将照片中的主题分类为{狗,猫,马,其他};
  • 拍摄一个句子,并将其元素的“词性”分类:{名词,动词,形容词等。}
    这种类型的网络将在其最后一层有多个单元。每个输出都与一个类别相关联。当一个输入示例应用于网络时,具有最高值的输出是预测的类别。如果输出应用于softmax函数,则softmax的输出将提供输入在每个类别中的概率。
    在本实验中,您将看到一个在Tensorflow中构建多类网络的示例。然后我们将看看神经网络是如何进行预测的。
    让我们从创建一个四类数据集开始。

2.1准备和可视化我们的数据

我们将使用Scikit-Learn make_blobs函数制作一个包含4个类别的训练数据集,如下图所示。

# make 4-class dataset for classification
classes = 4
m = 100
centers = [[-5, 2], [-2, -2], [1, 2], [5, -2]]
std = 1.0
X_train, y_train = make_blobs(n_samples=m, centers=centers, cluster_std=std,random_state=30)
plt_mc(X_train,y_train,classes, centers, std=std)

每个点代表一个训练样例。轴(x0,x1)是输入,颜色表示与示例相关联的类。一旦训练完毕,模型将呈现一个新的例子(x0,x1),并将预测该类。
在生成时,该数据集代表了许多现实世界的分类问题。有几个输入特征(x0,…,xn)和几个输出类别。该模型被训练成使用输入特征来预测正确的输出类别。

# show classes in data set
print(f"unique classes {np.unique(y_train)}")
# show how classes are represented
print(f"class representation {y_train[:10]}")
# show shapes of our dataset
print(f"shape of X_train: {X_train.shape}, shape of y_train: {y_train.shape}")

2.2模型

本实验将使用如图所示的2层网络。与二元分类网络不同,该网络有四个输出,每个类一个。给定一个输入示例,具有最大值的输出是该输入的预测类。
下面是一个如何在Tensorflow中构建该网络的示例。注意,输出层使用线性激活而不是softmax激活。虽然可以在输出层中包含softmax,但如果在训练期间将线性输出传递给损失函数,则在数值上更稳定。如果该模型用于预测概率,那么softmax可以
在这里插入图片描述

tf.random.set_seed(1234)  # applied to achieve consistent results
model = Sequential(
    [
        Dense(2, activation = 'relu',   name = "L1"),
        Dense(4, activation = 'linear', name = "L2")
    ]
)

下面的语句编译并训练网络。将from_logits=True设置为损失函数的参数,指定输出激活是线性的,而不是softmax。

model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(0.01),
)

model.fit(
    X_train,y_train,
    epochs=200
)

通过训练模型,我们可以看到模型是如何对训练数据进行分类的。

plt_cat_mc(X_train, y_train, model, classes)

上面的决策边界显示了模型如何划分输入空间。这个非常简单的模型对训练数据进行分类没有问题。它是如何做到的呢?让我们更详细地看看这个网络。
下面,我们将从模型中提取训练好的权重,并使用它来绘制每个网络单元的函数。再往下看,对结果有更详细的解释。要成功地使用神经网络,您不需要知道这些细节,但它可能有助于获得更多关于这些层如何组合以解决分类问题的直觉。

# gather the trained parameters from the first layer
l1 = model.get_layer("L1")
W1,b1 = l1.get_weights()
# plot the function of the first layer
plt_layer_relu(X_train, y_train.reshape(-1,), W1, b1, classes)
# gather the trained parameters from the output layer
l2 = model.get_layer("L2")
W2, b2 = l2.get_weights()
# create the 'new features', the training examples after L1 transformation
Xl2 = np.zeros_like(X_train)
Xl2 = np.maximum(0, np.dot(X_train,W1) + b1)

plt_output_layer_linear(Xl2, y_train.reshape(-1,), W2, b2, classes,
                        x0_rng = (-0.25,np.amax(Xl2[:,0])), x1_rng = (-0.25,np.amax(Xl2[:,1])))

解释

这些图显示了网络第一层中单元0和1的函数。输入在轴上是( x 0 , x 1 x_0,x_1 x0,x1)该单元的输出由背景的颜色表示。每个图表右侧的颜色条表示了这一点。请注意,由于这些单元使用的是ReLu,因此输出不一定落在0到1之间,在本例中,输出的峰值大于20。
在这里插入图片描述

该图中的等高线表示输出 a j [ 1 ] a^{[1]}_j aj[1]为零和非零之间的过渡点。回想一下ReLu的图
在这里插入图片描述

图中的等高线是ReLu中的拐点。

单元0将类别0和1与类别2和3分开。行左边的点(类0和1)将输出0,而右边的点将输出大于0的值。
单元1将0、2班与1、3班分开。线以上的点(类0和2)将输出零,而线以下的点将输出大于零的值。让我们看看下一层是如何实现的!

-----------------这是一条手动分割线------------------
在这里插入图片描述

这些图中的点是由第一层翻译的训练样本。考虑这个问题的一种方法是,第一层创建了一组新的特性,供第二层评估。这些图中的轴是前一层 a 0 [ 1 ] a^{[1]}_0 a0[1] a 1 [ 1 ] a^{[1]}_1 a1[1]的输出。如上所述,类0和1(绿色和蓝色)有 a 0 [ 1 ] = 0 a^{[1]}_0 = 0 a0[1]=0,而类1和2(蓝色和绿色)有 a 1 [ 1 ] = 0 a^{[1]}_1 = 0 a1[1]=0
同样,背景颜色的强度表示最高值。
单元0将产生接近(0,0)的值的最大值,其中类0(蓝色)已被映射
单元1在左上角选择类别1(绿色)产生最大值。
第二单元的目标是右下角3班(橙色)所在的地方。
第3单元在右上方选择我们的最终类(紫色)时产生最大值。
图表中不明显的另一个方面是,单位之间的值是协调的。对于一个单位来说,为它所选择的类别产生最大值是不够的,它还必须是该类别中所有点数单位的最大值。这是通过隐含的softmax函数完成的,该函数是损失函数(’ SparseCategoricalCrossEntropy ')的一部分。与其他激活函数不同,softmax适用于所有输出。
你可以成功地使用神经网络,而不需要知道每个单元的细节。希望这个例子

恭喜

您已经学习了如何构建和操作用于多类分类的神经网络。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2049623.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

centos7.9最小化安装之后的配置与下载

一. 配置yum源 1.备份系统自带源文件 cd /etc/yum.repos.d/ mkdir bak mv *.repo bak 2. 配置阿里云yum源 若有wget wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-7.repo若没有wget,先直接把 http://mirrors.aliyun.com/re…

11.怎么做好一个动态标签页

效果 步骤 1.在Elementui找一个标签页组件 复制粘贴到代码 2.将他写活 将很多页面需要用的方法和变量写入store editableTabsValue: 2,editableTabs: [{title: 首页,name: index,},],addTab(state, tab) {if (state.editableTabs.findIndex(item > item.title tab.titl…

mysql5.7主从同步失败原因总结-windows

1,主库data文件复制到从库,之后主库要同步的实例data一定不要在修改; 1.1,修改之后就要重新覆盖一遍修改过的data 2,如果状态不对:一定要查看日志;比如slave_io_state是空时,需要查…

KEEPALIVED 全csdn最详细----理论+实验(干货扎实,包教会的)

环境准备 主机名IP虚拟IP(VIP)功能ka1172.25.254.10172.25.254.100keepalived服务ka2172.25.254.20172.25.254.100keepalived服务realserver1172.25.254.110web服务realserver2172.25.254.120web服务 注意一定要关闭selinux,和防火墙,不然在…

zabbix7.0 设置中文语言( Debian GNU/Linux 12)

本例为安装zabbix7.0 zabbix_server (Zabbix) 6.4.17 Revision c12261f00b4 15 July 2024, compilation time: Jul 15 2024 11:05:06 系统版本信息为 lsb_release -a No LSB modules are available. Distributor ID: Debian Description: Debian GNU/Linux 12 (bookworm) Rele…

防疫物资管理信息系统pf

TOC springboot379防疫物资管理信息系统pf 第1章 绪论 1.1选题动因 当前的网络技术,软件技术等都具备成熟的理论基础,市场上也出现各种技术开发的软件,这些软件都被用于各个领域,包括生活和工作的领域。随着电脑和笔记本的广泛…

【Django开发】前后端分离django美多商城项目第2篇:展示用户注册页面,1. 创建用户模块子应用【附代码文档】

全套笔记资料代码移步: 前往gitee仓库查看 感兴趣的小伙伴可以自取哦~ 本教程的知识点为: 项目准备 项目准备 配置 1. 修改settings/dev.py 文件中的路径信息 2. INSTALLED_APPS 3. 数据库 用户部分 图片 1. 后端接口设计: 视图原型 2. 具体…

【网络安全】SSO登录过程实现账户接管

未经许可,不得转载。 文章目录 正文正文 登录页面展示了“使用 SSO 登录”功能: 经分析,单点登录(SSO)系统的身份验证过程如下: 1、启动SSO流程:当用户点击按钮时,浏览器会发送一个GET请求到指定的URL: /idp/auth/mid-oidc?req=[UNIQUE_ID]&redirect_uri=[REDI…

在 Mac 上更改 24小时制时间显示

使用“日期与时间”设置设定或更改 Mac 上的日期和时间。如果日期和时间正确,那么电子邮件、信息和文件上的时间戳也是准确的。了解如何设定日期和时间。 若要更改这些设置,请选取苹果菜单 >“系统设置”,点按边栏中的“通用” &#x…

[星瞳科技]OpenMV使用时有哪些常见错误和解决办法?

常见代码错误 ImportError:no module named xxx 这个错误是Import错误,没有stepper这个模块。 原因: 你没有把stepper.py这个文件拖到你的板子里。见:模块的使用 拖过去之后,需要重启,使模块生效 MemoryError:FB …

Class字节码文件结构

class字节码文件结构 类型名称说明长度数量u4magic魔数,识别Class文件格式4个字节1u2minor_version副版本号(小版本)2个字节1u2major_version主版本号(大版本)2个字节1u2constant_pool_count常量池计数器2个字节1cp_infoconstant_pool常量池表n个字节constant_pool_count-1u2a…

马头拧紧驱动器维修 拧紧控制器故障

马头拧紧控制器作为工业自动化领域不可或缺的重要设备,其稳定运行对于生产线的效率与安全性至关重要。然而,在实际应用中,难免会遇到各种Desoutter拧紧工具控制器故障,影响生产进度和设备性能。 拧紧轴控制器维修 拧紧装置 马头…

ubuntu设置共享文件夹,非虚拟机,服务器版

在Ubuntu中共享文件夹通常可以通过几种不同的方式来实现,比如使用Samba服务、NFS(Network File System)或者通过虚拟机软件如VirtualBox或VMware的内置共享文件夹功能。这里我假设您是在询问如何在Ubuntu主机上设置一个简单的文件共享服务&am…

MongoDB Redis 快速上手:NoSQL数据库操作精要

先言之 ☘️随着大数据时代的到来,非关系型数据库因其灵活性和扩展性逐渐受到开发者的青睐。MongoDB 和 Redis 作为两种非常流行的 NoSQL 数据库,各自拥有独特的特性和应用场景。MongoDB 是一款面向文档的数据库,适用于需要存储复杂数据结构…

ESP32-C3在MQTT访问时出现“transport_base: Poll timeout or error”问题的分析(2)

接前一篇文章:ESP32-C3在MQTT访问时出现“transport_base: Poll timeout or error”问题的分析(1) 前一篇文章在分析定位笔者所遇MQTT(MQTTs)传输时问题的时候,定位到了问题是出自于components\components\tcp_transport\transport_ssl.c的ssl_write函数。本回开始,就围…

【四】阿伟开始学Kafka

阿伟开始学Kafka 概述 人生若只如初见,阿伟心里回想起了第一次和Kafka见面的场景,记忆虽然已经有些模糊,但是感觉初次见面是美好的。积累了一些实战经验之后,阿伟感觉不能再是面对百度开发了,于是决心系统的学习一下Ka…

liblzma库Android平台编译

1.下载源码: git clone https://github.com/tukaani-project/xz.git --recursive 2.配置交叉编译环境: 生成Android平台makefile export ANDROID_API=25 export ANDROID_NDK=/opt/aarch64-darwin-android export ANDROID_NDK_REVISION=r25b export AR=/opt/aarch64-darwin-a…

李宏毅 机器学习与深度学习【2022版】 01

文章目录 一、基本概念二、深度学习内容总览三、预测YouTube播放量的模型1、假设一个含有未知参数的函数式2、根据Training Data定义一个 Loss3、最优化Optimization4、测试集验证模型性能5、线性模型特征维度提升6、非线性模型7、ReLU 四、深度学习概述1、Fully Connect Feedf…

基于改进YOLOv8的景区行人检测算法

贵向泉, 刘世清, 李立, 秦庆松, 李唐艳. 基于改进YOLOv8的景区行人检测算法[J]. 计算机工程, 2024, 50(7): 342-351. DOI: 10.19678/j.issn.10 原文链接如下:基于改进YOLOv8的景区行人检测算法https://www.ecice06.com/CN/rich_html/10.19678/j.issn.1000-3428.006…

DOM Clobbring个人理解

目录 toString One Level Two Level Three Level More Dom Clobbering:就是⼀种将 HTML 代码注⼊⻚⾯中以操纵 DOM 并最终更改⻚⾯上 JavaScript ⾏为的技术 DOM Clobbering中的操作也是根据JavaScript行为的层级来分为一层、两层、三层和更多 toString 我们…