吴恩达机器学习课后作业-03多分类、神经网络前向传播

news2025/1/11 14:08:30

这里写目录标题

  • 逻辑回归解决多分类问题(逻辑回归的“一对多”(One-vs-All)策略。)
    • 绘制图像
    • 结果
  • 神经网络
    • 前向传播
    • 数字识别

逻辑回归解决多分类问题(逻辑回归的“一对多”(One-vs-All)策略。)

在这里插入图片描述
手写数字识别
在这里插入图片描述
在这里插入图片描述

绘制图像

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io as sio
def plot_an_image(x) :
    pick_one=np.random.randint(5000)
    image=x[pick_one,:]#取第一行
    fig,ax=plt.subplots(figsize=(1,1))
    ax.imshow(image.reshape(20,20).T,cmap="gray_r")
    plt.xticks([])
    plt.yticks([])
    plt.show()

def plot_100_image(x):
    sample_index=np.random.choice(len(x),100)
    image=x[sample_index,:]#随机取100行,也就是取一百张图片
    fig, ax = plt.subplots(figsize=(8,8),nrows=10,ncols=10,sharey=True,sharex=True)
    plt.xticks([])
    plt.yticks([])
    for r in range(10):
        for c in range(10):
            ax[r,c].imshow(image[10*r+c].reshape(20,20).T,cmap="gray_r")
    plt.show()

data=sio.loadmat("E:/学习/研究生阶段/python-learning/吴恩达机器学习课后作业/code/ex3-neural network/ex3data1.mat")

raw_x=data["X"]
raw_y=data["y"]

plot_100_image(raw_x)

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

结果

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy.optimize import minimize
def plot_an_image(x) :
    pick_one=np.random.randint(5000)
    image=x[pick_one,:]#取第一行
    fig,ax=plt.subplots(figsize=(1,1))
    ax.imshow(image.reshape(20,20).T,cmap="gray_r")
    plt.xticks([])
    plt.yticks([])
    plt.show()

def plot_100_image(x):
    sample_index=np.random.choice(len(x),100)
    image=x[sample_index,:]#随机取100行,也就是取一百张图片
    fig, ax = plt.subplots(figsize=(8,8),nrows=10,ncols=10,sharey=True,sharex=True)
    plt.xticks([])
    plt.yticks([])
    for r in range(10):
        for c in range(10):
            ax[r,c].imshow(image[10*r+c].reshape(20,20).T,cmap="gray_r")
    plt.show()



"""
代价函数
"""
def sigmoid(z):
    return 1/(1+np.exp(-z))


"""
回答一下:此时要求theta必须放在第一位,因为分类器那里
所要用到的函数theta是做为要优化的参数来的,其他的参数叫args

"""
def cost_function(theta,x,y,lamda):#
    y_=sigmoid(x @ theta)

    reg=theta[:1]@theta[:1]*(lamda/(2*len(x)))#正则化
    return np.sum(-(y*np.log(y_)+(1-y)*np.log(1-y_))/len(x))+reg

"""
梯度向量
"""


def gradient_reg(theta,x,y,lamda):
    reg = theta[1:] * (lamda / len(x))  #
    reg = np.insert(reg, 0, values=0, axis=0)  # 在第一个元素前0,为了与后面维数匹配
    first=x.T @ (sigmoid(x @ theta)-y)/len(x)
    return first+reg
"""
定义梯度下降函数
alpha:学习速率
inters:迭代次数
lamda
"""
def gradientDescent(x,y,theta,alpha,inters,lamda):
    costs = []
    for i in range(inters):

        reg = theta[1:] * (lamda / len(x))  #
        reg = np.insert(reg, 0, values=0, axis=0) #在第一个元素前插入0,为了与后面维数匹配
        theta=theta-alpha*x.T @ (sigmoid(x @ theta)-y)/len(x)


        cost=cost_function(x,y,theta,lamda)
        costs.append(cost)
        # if i%1000==0:
        #     print(cost)

    return theta,costs


def one_vs_all(x,y,lamda,k):
    n=x.shape[1]
    theta_all=np.zeros((k,n))
    for i in range(1,k+1):
        theta_i=np.zeros(n,)
        res=minimize(fun=cost_function,
                     x0=theta_i,
                     args=(x,y==i,lamda),
                     method="TNC",
                     jac=gradient_reg)
        theta_all[i-1,:]=res.x
    return theta_all

"""

预测函数
"""

def predict(x,theta_finall):
    h=sigmoid(x@theta_finall.T)
    h_argmax=np.argmax(h,axis=1)
    return h_argmax+1



data=sio.loadmat("E:/学习/研究生阶段/python-learning/吴恩达机器学习课后作业/code/ex3-neural network/ex3data1.mat")

raw_x=data["X"]
raw_y=data["y"]
x=np.insert(raw_x,0,1,axis=1)
y=raw_y.flatten()#变为一维的

# plot_100_image(raw_x)

lamda=1
k=10
theta_finall=one_vs_all(x,y, lamda, k)
y_predict=predict(x,theta_finall)
acc=np.mean(y_predict==y)
print(acc)


神经网络

前向传播

在这里插入图片描述

数字识别

在这里插入图片描述

import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio

def sigmoid(z):
    return 1/(1+np.exp(-z))



data=sio.loadmat("E:/学习/研究生阶段/python-learning/吴恩达机器学习课后作业/code/ex3-neural network/ex3data1.mat")

raw_x=data["X"]
raw_y=data["y"]
x=np.insert(raw_x,0,1,axis=1)
y=raw_y.flatten()#变为一维的
theta=sio.loadmat("E:/学习/研究生阶段/python-learning/吴恩达机器学习课后作业/code/ex3-neural network/ex3weights.mat")
theta1=theta["Theta1"]#输入层到隐藏层传递参数
theta2=theta["Theta2"]#隐藏层到输出层传递参数


a1=x
z2=x@theta1.T
a2=sigmoid(z2)
a2=np.insert(a2,0,1,axis=1)
z3=a2 @ theta2.T
a3=sigmoid(z3)


y_pred=np.argmax(a3,axis=1)
y_pred=y_pred+1
acc=np.mean(y_pred==y)
print(acc)

代码看不太懂,下面是ai的解释

你的代码实现了使用已加载的权重(theta1 和 theta2)在一个简单的神经网络上进行预测,并计算了预测的准确率。这个神经网络有一个输入层、一个隐藏层和一个输出层。不过,代码中有些小地方可以优化或注意,以下是对你代码的详细解释和一些建议:

导入库和定义sigmoid函数:这部分是正确的,你已经成功导入了所需的库并定义了sigmoid函数。
加载数据:你使用scipy.io.loadmat来加载MATLAB的.mat文件,这是处理这类文件的标准方式。
处理输入和输出数据:
你将raw_x(输入数据)的每一行前面插入了一个1,这是为了加入偏置项(bias term),这个处理是正确的。
raw_y(目标输出)被转换为一维数组,这也是正确的,因为np.argmax需要一个一维数组来确定最大值的索引。
加载权重:你正确地加载了权重矩阵theta1和theta2。
前向传播:
z2的计算是正确的,它代表了输入层到隐藏层的线性组合。
a2通过sigmoid函数激活,然后你在其前面插入了1以添加隐藏层的偏置项。然而,这一步通常不是必要的,因为隐藏层的偏置项通常作为权重矩阵theta1的一部分被处理(即theta1的每一行都会有一个偏置项对应的权重)。如果你决定在a2中手动添加偏置项,那么你应该在theta2中相应地调整,但这通常不是标准做法。
z3和a3的计算是正确的,它们代表了隐藏层到输出层的计算。
预测和准确率计算:
使用np.argmax(a3, axis=1)来获取预测类别,这是正确的,因为a3的每一行都代表了一个样本属于各个类别的概率。
由于MATLAB的索引从1开始,而Python的索引从0开始,你通过y_pred=y_pred+1来调整了预测类别的索引,这是必要的,但请确保你的y(真实标签)也是以MATLAB的索引方式(即从1开始)给出的。如果不是,你可能需要调整y的值。
准确率计算acc=np.mean(y_pred==y)是正确的。
建议:

检查y的索引是否也是从1开始的,如果不是,你需要调整y的值以匹配y_pred的索引。
如果隐藏层不需要手动添加偏置项(即theta1已经包含了偏置项),则应该从a2中移除np.insert(a2, 0, 1, axis=1)这一行。
确保theta1和theta2的维度与你的网络架构相匹配(例如,theta1的列数应该等于x的行数(加上偏置项),theta2的列数应该等于输出类别的数量)。
最后,如果你的y数组是从0开始索引的,你应该在比较之前将y转换为从1开始的索引,或者修改y_pred的索引方式以匹配y的索引方式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2065040.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[Algorithm][综合训练][孩子们的游戏][大数加法][拼三角]详细讲解

目录 1.孩子们的游戏1.题目链接2.算法原理详解 && 代码实现 2.大数加法1.题目链接2.算法原理详解 && 代码实现 3.拼三角1.题目链接2.算法原理详解 && 代码实现 1.孩子们的游戏 1.题目链接 孩子们的游戏 2.算法原理详解 && 代码实现 问题抽象…

LongWriter——从长文本语言模型中释放出10,000+字的生成能力

概述 当前的长上下文大型语言模型 (LLM) 可以处理多达 100,000 个词的输入,但它们很难生成超过 2,000 个词的输出。受控实验表明,该模型的有效生成长度本质上受到监督微调(SFT) 期间看到的示例的限制。换句话说,这种输出限制源于现有 SFT 数…

三维模型单体化软件:地理信息与遥感领域的精细化革命

在地理信息与遥感科学日新月异的发展浪潮中,单体化软件作为一股强大的驱动力,正引领着我们迈向空间信息处理与应用的新纪元。本文旨在深度解析单体化软件的核心价值、技术前沿、实践应用及面临的挑战,共同探讨这一技术如何塑造行业的未来。 …

【手撕OJ题】——BM8 链表中倒数最后k个结点

目录 🕒 题目⌛ 方法① - 直接遍历⌛ 方法② - 快慢指针 🕒 题目 🔎 BM8 链表中倒数最后k个结点【难度:简单🟢】 输入一个长度为 n 的链表,设链表中的元素的值为 a i a_i ai​ ,返回该链表中倒…

一款MySQL数据库实时增量同步工具,能够监听MySQL二进制日志(Binlog)的变动(附源码)

背景 作为一名CURD的程序员,少不了跟MySQL打交道,在同步数据的时候,MySQL的Binlog显得重中之重,所以处理Binlog的工具尤为重要。 其中阿里巴巴开源的canal 更是耳闻目睹,但是今天小编给大家介绍另外一款MySQL数据库实…

【C++11】常用新语法②(类的新功能 || 可变参数模板 || lambda表达式 || 包装器)

🔥个人主页: Forcible Bug Maker 🔥专栏: C 目录 🌈前言🔥类的新功能新增默认成员函数强制生成默认函数的关键字default禁止生成默认函数的关键字delete 🔥可变参数模板递归函数方式展开参数包…

论文翻译:Benchmarking Large Language Models in Retrieval-Augmented Generation

https://ojs.aaai.org/index.php/AAAI/article/view/29728 检索增强型生成中的大型语言模型基准测试 文章目录 检索增强型生成中的大型语言模型基准测试摘要1 引言2 相关工作3 检索增强型生成基准RAG所需能力数据构建评估指标 4实验设置噪声鲁棒性结果负面拒绝测试平台结果信息…

算法5:位运算

文章目录 小试牛刀进入正题 没写代码的题,其链接点开都是有代码的。开始前请思考下图: 小试牛刀 位1的个数 class Solution { public:int hammingWeight(int n) {int res 0;while (n) {n & n - 1;res;}return res;} };比特位计数 class Solution…

计算机毕业设计选题推荐-猫眼电影数据可视化分析-Python爬虫-k-means算法

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

进程和文件痕迹排查——LINUX

目录 介绍步骤 介绍 进程(Process)是计算机中的程序关于某数据集合上的一次运行活动,是系统进行资源分配和调度的基本单位,是操作系统结构的基础。 在早期面向进程设计的计算机结构中,进程是程序的基本执行实体&…

fastadmin 安装

环境要求,大家可以参考官方文档的,我这里使用的是phpstudy,很多已经集成了。 注意一点,PHP 版本:PHP 7.4 。 第二步:下载 下载地址:https://www.fastadmin.net/download.html 进入下载地址后…

IDEA:Terminal找不到npm

Terminal的命令失效通过修改cmd.exe的方式还是不生效的话,考虑是windwos11 默认idea不是通过管理员启动的,如下图修改就可以了。

前端vue 3中使用 顶象 vue3 版本

顶象 验证 的插件 不知道大家使用过没有 顶象-业务安全引领者&#xff0c;让数字世界无风险 可以防止 机器人刷接口 等 可以在任何 加密操作中使用 下面我直接 贴代码 解释 <script src"https://cdn.dingxiang-inc.com/ctu-group/captcha-ui/v5/index.js" cro…

第12章 网络 (2)

目录 12.5 网络命名空间 12.6 套接字缓冲区 12.6.1 使用 sk_buff 管理数据 12.6.2 管理套接字缓冲区数据 本专栏文章将有70篇左右&#xff0c;欢迎关注&#xff0c;查看后续文章。 12.5 网络命名空间 一个网卡可能只在某个特定命名空间可见。 struct net&#xff1a; 表…

C语言贪吃蛇之BUG满天飞

C语言贪吃蛇之BUG满天飞 今天无意间翻到了大一用C语言写的贪吃蛇&#xff0c;竟然还标注着BUG满天飞&#xff0c;留存一下做个纪念&#xff0c;可能以后就找不到了 /* 此程序 --> 贪吃蛇3.0 Sur_流沐 当前版本&#xff1a; Bug满天飞 */ #include<stdio.h> #includ…

Linux C、C++编程之线程同步

【图书推荐】《Linux C与C一线开发实践&#xff08;第2版&#xff09;》_linux c与c一线开发实践pdf-CSDN博客《Linux C与C一线开发实践&#xff08;第2版&#xff09;&#xff08;Linux技术丛书&#xff09;》(朱文伟&#xff0c;李建英)【摘要 书评 试读】- 京东图书 (jd.com…

qt处理表格,Qtxlsx库文件的安装以及导入

qt想要处理excel表格的&#xff0c;这个过程中避免不了使用Qtxlsx这个库文件。这几天花了几天时间&#xff0c;终于本地调通了。记录一下。 关于Qtxlsx的使用&#xff0c;大致分为2中方法。 方法一&#xff1a;直接下载对应的xlsx文件&#xff0c;然后在.pro文件中 这种方法是…

使用Java往Geoserver发布tif图层和shp图层

1. Maven依赖 栅格文件对应Tif文件 (即: 栅格就是tif) 矢量文件对应shp文件(即: 矢量就是shp) 注: 有的依赖可能在中央仓库及一些镜像仓库找不到需要手动指定仓库, 在依赖最下方 <!-- 中文转拼音工具类 --><dependency><groupId>com.belerweb</groupId&g…

指针的学习和理解

初级 1、指针的概念 在64位操作系统中&#xff0c;不管什么类型的指针都占8个字节 int a1; int* p&a;//p就是一个整型的指针&#xff0c;保存了a的地址2、指针和变量 int* p&a;* p100; // 等价于a100p //p&a*有两种定义&#xff1a; 定义的时候&#xff08;前…

【工具类】Java优雅的将XML转为JSON格式、XML转JSON

Java优雅的将XML转为JSON格式、XML转JSON 1. 导入依赖1.1 Maven使用1.2 Gradle使用 2. 代码编写3.运行示例 1. 导入依赖 1.1 Maven使用 <dependency><groupId>org.dom4j</groupId><artifactId>dom4j</artifactId><version>2.1.3</vers…