深度学习案例:带有一个隐藏层的平面数据分类

news2024/10/26 16:57:31

该案例来自吴恩达深度学习系列课程一《神经网络和深度学习》第三周编程作业,作业内容是设计带有一个隐藏层的平面数据分类。作业提供的资料包括测试实例(testCases.py)和任务功能包(planar_utils.py),下载请移步参考链接一。

文章目录

  • 1 介绍
    • 1.1 案例应用核心公式
    • 1.2 涉及库和主要接口
  • 2 编码
    • 2.1 查看数据集相关参数
    • 2.2 简单逻辑回归应用效果
    • 2.3 搭建神经网络
    • 2.4 隐藏层单元数的变换
  • 3 调试
    • 3.1 运行时警告
    • 3.2 弃用警告
    • 3.3 数据转换警告
  • 4 参考

1 介绍

1.1 案例应用核心公式

1.2 涉及库和主要接口

numpy:用Python进行科学计算的基本软件包。

numpy.round:均匀地四舍五入到给定的小数位数。

numpy.random.seed:设置随机数生成器的种子,可以使随机数的生成具有可重复性。

sklearn:为数据挖掘和数据分析提供的机器学习库。

sklearn.linear_model.LogisticRegressionCV:逻辑回归CV(又名logit,MaxEnt)分类器。

matplotlib:用于在Python中绘制图表的库。

matplotlib.pyplot.scatter:具有不同标记大小和/或颜色的y与x的散点图。


2 编码

2.1 查看数据集相关参数

check_data.py

import matplotlib.pyplot as plt
from planar_utils import load_planar_dataset

X, Y = load_planar_dataset()  # 加载数据

# 查看数据散点图
plt.scatter(X[0, :], X[1, :], c=Y, s=40, cmap=plt.cm.Spectral)
plt.show()

# 计算和打印相关参数
print("X的维度为:" + str(X.shape))
print("Y的维度为:" + str(Y.shape))
print("数据集里的数据个数为:" + str(Y.shape[1]))

在这里插入图片描述

在这里插入图片描述

2.2 简单逻辑回归应用效果

logic_nn.py

import numpy as np
import matplotlib.pyplot as plt
import sklearn.linear_model
from planar_utils import plot_decision_boundary, load_planar_dataset

X, Y = load_planar_dataset()

# 搭建模型并训练
clf = sklearn.linear_model.LogisticRegressionCV()
clf.fit(X.T, Y.T)

# 应用模型预测结果
predictions = clf.predict(X.T)
correct_predictions = ((np.dot(Y, predictions.T) + np.dot(1 - Y, 1 - predictions.T)) / float(Y.size)).reshape(1,)
print('准确率: %.2f' % (correct_predictions[0] * 100) + '%')

# 绘制颜色块边界
plot_decision_boundary(lambda x: clf.predict(x), X, Y)
plt.title("Logistic Regression")
plt.show()

在这里插入图片描述

经过测试,准确性只有47%,原因是数据集不是线性可分的,所以逻辑回归表现不佳。

线性可分:指在特征空间中,存在一个超平面能够将不同类别的数据点完全分开,在二维空间中超平面表现为一条直线。

2.3 搭建神经网络

double_layer_nn.py

import numpy as np
import matplotlib.pyplot as plt
from planar_utils import plot_decision_boundary, sigmoid, load_planar_dataset


# 层单元数量
def layer_sizes(X, Y):
    n_x = X.shape[0]
    n_h = 4
    n_y = Y.shape[0]
    return n_x, n_h, n_y


# 初始化模型参数
def initialize_parameters(n_x, n_h, n_y):
    W1 = np.random.rand(n_h, n_x) * 0.01
    b1 = np.random.rand(n_h, 1)
    W2 = np.random.rand(n_y, n_h) * 0.01
    b2 = np.random.rand(n_y, 1)

    parameters = {"W1": W1, "b1": b1, "W2": W2, "b2": b2}
    return parameters


# 前向传播
def forward_propagation(X, parameters):
    W1 = parameters["W1"]
    b1 = parameters["b1"]
    W2 = parameters["W2"]
    b2 = parameters["b2"]

    Z1 = np.dot(W1, X) + b1
    A1 = np.tanh(Z1)
    Z2 = np.dot(W2, A1) + b2
    A2 = sigmoid(Z2)

    cache = {"Z1": Z1, "A1": A1, "Z2": Z2, "A2": A2}
    return cache


# 计算代价函数
def compute_cost(A2, Y):
    m = Y.shape[1]
    cost = (-1 / m) * np.sum(Y * np.log(A2) + (1 - Y) * np.log(1 - A2))
    cost = float(np.squeeze(cost))
    return cost


# 反向传播
def backward_propagation(parameters, cache, X, Y):
    m = X.shape[1]
    W1 = parameters["W1"]
    W2 = parameters["W2"]
    A1 = cache["A1"]
    A2 = cache["A2"]

    dZ2 = A2 - Y
    dW2 = (1 / m) * np.dot(dZ2, A1.T)
    db2 = (1 / m) * np.sum(dZ2, axis=1, keepdims=True)
    dZ1 = np.dot(W2.T, dZ2) * (1 - np.power(A1, 2))
    dW1 = (1 / m) * np.dot(dZ1, X.T)
    db1 = (1 / m) * np.sum(dZ1, axis=1, keepdims=True)

    grads = {"dW1": dW1, "db1": db1, "dW2": dW2, "db2": db2}
    return grads


# 更新参数
def update_parameters(parameters, grads, learning_rate=1.2):
    W1 = parameters["W1"]
    b1 = parameters["b1"]
    W2 = parameters["W2"]
    b2 = parameters["b2"]

    dW1 = grads["dW1"]
    db1 = grads["db1"]
    dW2 = grads["dW2"]
    db2 = grads["db2"]

    W1 = W1 - learning_rate * dW1
    b1 = b1 - learning_rate * db1
    W2 = W2 - learning_rate * dW2
    b2 = b2 - learning_rate * db2

    parameters = {"W1": W1, "b1": b1, "W2": W2, "b2": b2}
    return parameters


# 构建神经网络
def nn_model(X, Y, n_h, num_iterations, learning_rate=0.5, print_cost=False):
    n_x = layer_sizes(X, Y)[0]
    n_y = layer_sizes(X, Y)[2]

    parameters = initialize_parameters(n_x, n_h, n_y)

    for i in range(num_iterations):
        cache = forward_propagation(X, parameters)
        cost = compute_cost(cache["A2"], Y)
        grads = backward_propagation(parameters, cache, X, Y)
        parameters = update_parameters(parameters, grads, learning_rate)
        if print_cost and (i % 1000 == 0):
            print("第 %i 次循环,成本为: %f" % (i, cost))

    return parameters


# 预测函数
def predict(parameters, X):
    cache = forward_propagation(X, parameters)
    predictions = np.round(cache["A2"])
    return predictions


# 进行深度学习
X, Y = load_planar_dataset()
n_h = 4
parameters = nn_model(X, Y, n_h, num_iterations=10000, learning_rate=0.5, print_cost=True)
predictions = predict(parameters, X)
correct_predictions = ((np.dot(Y, predictions.T) + np.dot(1 - Y, 1 - predictions.T)) / float(Y.size)).reshape(1,)
print('准确率: %.2f' % (correct_predictions[0] * 100) + '%')

# 绘制颜色块边界
plot_decision_boundary(lambda x: predict(parameters, x.T), X, Y)
plt.title("Decision Boundary for hidden layer size %i : %.2f " % (n_h, correct_predictions[0] * 100) + '%')
plt.show()

在这里插入图片描述

在这里插入图片描述

2.4 隐藏层单元数的变换

变换隐藏层的单元数量,观察对预测结果是否产生哪些影响。

在这里插入图片描述在这里插入图片描述
在这里插入图片描述在这里插入图片描述
在这里插入图片描述在这里插入图片描述
在这里插入图片描述在这里插入图片描述

3 调试

3.1 运行时警告

E:\pythonPrograming\DLHomework\course1-week3\double_layer_nn.py:53: RuntimeWarning: divide by zero encountered in log

logprobs= np.multiply(np.log(A2), Y) + np.multiply((1 - Y), np.log(1 - A2))

E:\pythonPrograming\DLHomework\course1-week3\planar_utils.py:25: RuntimeWarning: overflow encountered in exp

s = 1/(1+np.exp(-x))

3.2 弃用警告

E:\pythonPrograming\DLHomework\course1-week3\double_layer_nn.py:135: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)

print ('准确率: %d' % float((np.dot(Y, predictions.T) + np.dot(1 - Y, 1 - predictions.T)) / float(Y.size) * 100) + '%')

3.3 数据转换警告

E:\pythonPrograming\DLHomework\course1-week3.venv\Lib\site-packages\sklearn\utils\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
y = column_or_1d(y, warn=True)

clf.fit(X.T, Y.T)

此处问题出现在逻辑回归效果 logic_nn.py 中,由于提供的 Y 值是一个二维数组,与函数所需的一维数组存在数据转换问题,可以人工使用ravel()flatten() 将 Y 转换为一维数组。


4 参考

【中文】【吴恩达课后编程作业】Course 1 - 神经网络和深度学习 - 第三周作业-CSDN博客

NumPy reference — NumPy v2.1 Manual

scikit-learn 1.5.2 documentation

API Reference — Matplotlib 3.9.2 documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2224082.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Jetpack架构组件_LiveData组件

1.LiveData初识 LiveData:ViewModel管理要展示的数据(VM层类似于原MVP中的P层),处理业务逻辑,比如调用服务器的登陆接口业务。通过LiveData观察者模式,只要数据的值发生了改变,就会自动通知VIEW层&#xf…

Flutter TextField和Button组件开发登录页面案例

In this section, we’ll go through building a basic login screen using the Button and TextField widgets. We’ll follow a step-bystep approach, allowing you to code along and understand each part of the process. Let’s get started! 在本节中,我们…

【Python爬虫系列】_031.Scrapy_模拟登陆中间件

课 程 推 荐我 的 个 人 主 页:👉👉 失心疯的个人主页 👈👈入 门 教 程 推 荐 :👉👉 Python零基础入门教程合集 👈👈虚 拟 环 境 搭 建 :👉👉 Python项目虚拟环境(超详细讲解) 👈👈PyQt5 系 列 教 程:👉👉 Python GUI(PyQt5)教程合集 👈👈…

ArcGIS001:ArcGIS10.2安装教程

摘要:本文详细介绍arcgis10.2的安装、破解、汉化过程。 一、软件下载 安装包链接:https://pan.baidu.com/s/1T3UJ7t_ELZ73TH2wGOcfpg?pwd08zk 提取码:08zk 二、安装NET Framework 3.5 双击打开控制面板,点击【卸载程序】&…

World of Warcraft [CLASSIC][80][the Ulduar]

Ulduar 奥杜尔副本介绍 奥杜尔共计14个BOSS,通常说的10H就是10个苦难模式就是全通,9H就是除了【观察者奥尔加隆】,特别说明开启【观察者奥尔加隆】,是需要打掉困难模式4个守护者的。 所以人们经常说的类似“10H 观察者”、“10H…

Python开发日记 -- 实现bin文件的签名

目录 1.数据的不同表现形式签名值不一样? 2.Binascii模块简介 3.问题定位 4.问题总结 1.数据的不同表现形式签名值不一样? Happy Muscle试运行了一段时间,组内同事再一次提出了新的需求:需要对bin文件签名。 PS:服…

react18中的函数组件底层渲染原理分析

react 中的函数组件底层渲染原理 react组件没有局部与全局之分,它是一个整体。这点跟vue的组件化是不同的。要实现 react 中的全局组件,可以将组件挂在react上,这样只要引入了react,就可以直接使用该组件。 函数式组件的创建 …

Kafka之消费者客户端

1、历史上的二个版本 与生产者客户端一样,在Kafka的发展过程当中,消费者客户端主要有两个大的版本: 旧消费者客户端(Old Consumer):基于Scala语言开发的版本,又称为Scala消费者客户端。新消费…

【力扣】GO解决子序列相关问题

文章目录 一、引言二、动态规划方法论深度提炼子序列问题的通用解法模式 三、通用方法论应用示例:最长递增子序列(LeetCode题目300)Go 语言代码实现 四、最长连续递增序列(LeetCode题目674)Go 语言代码实现 五、最长重…

ffmpeg视频滤镜:定向模糊-dblur

滤镜简述 dblur 官网链接 > https://ffmpeg.org/ffmpeg-filters.html#dblur 有一个模糊滤镜&#xff0c;我试了一下&#xff0c;没有感觉到它的特殊之处, 这里简单介绍一下。 滤镜使用 滤镜的参数 angle <float> ..FV.....T. set angle (from 0 t…

找不到包的老版本???scikit-learn,numpy,scipy等等!!

废话不多说 直接上链接了&#xff1a; https://pypi.tuna.tsinghua.edu.cn/simple/https://pypi.tuna.tsinghua.edu.cn/simple/https://pypi.tuna.tsinghua.edu.cn/simple/xxx/ 后面的这个xxx就是包的名字 大家需要什么包的版本&#xff0c;直接输进去就可以啦 举个栗子&#…

零基础Java第十期:类和对象(一)

目录 一、拜访对象村 1.1. 什么是面向对象 1.2. 面向对象与面向过程 二、类定义和使用 2.1. 类的定义格式 2.2. 类的定义练习 三、类的实例化 3.1. 什么是实例化 3.2. 类和对象的说明 四、this引用 4.1. 什么是this引用 4.2. this引用的特性 一、拜访对象村 在…

<项目代码>YOLOv8路面病害识别<目标检测>

YOLOv8是一种单阶段&#xff08;one-stage&#xff09;检测算法&#xff0c;它将目标检测问题转化为一个回归问题&#xff0c;能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法&#xff08;如Faster R-CNN&#xff09;&#xff0c;YOLOv8具有更高的…

STMicroelectronics意法半导体车规芯片系列--亿配芯城(ICgoodFind)

在汽车电子领域&#xff0c;意法半导体的车规级芯片系列一直备受瞩目。亿配芯城作为电子元器件领域的可靠供应商&#xff0c;为大家介绍意法半导体车规级芯片系列的卓越之处。 意法半导体在车规级芯片领域拥有深厚的技术积累和丰富的经验。 其车规级芯片涵盖了多个关键领域&am…

8.three.js相机详解

8.three.js相机详解 1、 认识相机 在Threejs中相机的表示是THREE.Camera&#xff0c;它是相机的抽象基类&#xff0c;其子类有两种相机&#xff0c;分别是正投影相机THREE.OrthographicCamera和透视投影相机THREE.PerspectiveCamera&#xff1a; 正投影和透视投影的区别是&am…

【Java】常用方法合集

以 DemoVo 为实体 import lombok.Data; import com.alibaba.excel.annotation.ExcelProperty; import com.alibaba.excel.annotation.ExcelIgnoreUnannotated;Data ExcelIgnoreUnannotated public class ExportPromoteUnitResult {private String id;ExcelProperty(value &qu…

贪心算法记录 - 下

135. 分发糖果 困难 n 个孩子站成一排。给你一个整数数组 ratings 表示每个孩子的评分。 你需要按照以下要求&#xff0c;给这些孩子分发糖果&#xff1a; 每个孩子至少分配到 1 个糖果。相邻两个孩子评分更高的孩子会获得更多的糖果。 请你给每个孩子分发糖果&#xff0c…

一文搞懂指令周期,机器周期和时钟周期

如图&#xff1a; 指令周期 > 机器周期 > 时钟周期 指令周期&#xff1a;一个指令&#xff0c;从取值到执行的全部周期。一个指令执行过程包括取值&#xff0c;译码和执行阶段。 机器周期&#xff1a;,取指、间址、执行和中断等 时钟周期&#xff1a;时钟频率的倒数&am…

什么样的JSON编辑器才好用

简介 JSON&#xff08;JavaScript Object Notation&#xff09;是一种轻量级的数据交换格式&#xff0c;易于人阅读和编写&#xff0c;同时也便于机器解析和生成。随着互联网和应用程序的快速发展&#xff0c;JSON已经成为数据传输和存储的主要格式之一。在处理和编辑JSON数据…

python查询并安装项目所依赖的所有包

引言 如果需要进行代码的移植&#xff0c;肯定少不了在另一台pc或者服务器上进行环境的搭建&#xff0c;那么首先是要知道在已有的工程的代码中用到了哪些包&#xff0c;此时&#xff0c;如果是用人工去一个一个的代码文件中去查看调用了哪些包&#xff0c;这个工作甚是繁琐。…