深度学习笔记:神经网络的学习(1)

news2025/1/4 19:15:55

机器学习的核心在于从数据中提取规律和特征,并用于分类或预测。对于识别手写数字,如果人工设计一个识别算法逻辑是十分困难的。一种方法是任务在数据中提取更重要的特征量,然后利用机器学习算法如SVM或KNN。而神经网络的方法则是完全由机器自主提取特征,中间没有任何人工干预

1 训练数据和测试数据

一般来说,机器学习的数据分为训练数据(或监督数据)和测试数据。首先使用训练数据寻找最优参数,然后使用测试数据评估模型泛化能力。

泛化能力指的是模型处理未被观察到数据的能力,是机器学习的目标。

如果模型只能处理特定数据集,泛化能力差,这种状态被称为过拟合(over fitting)

2 损失函数

神经网络利用损失函数作为指标表现现在的状态,再利用这一指标寻找更优权重参数。损失函数表现了网络和数据集不拟合程度,损失函数越低,网络更优

均方差法(mean squared error):
在这里插入图片描述
yk: 神经网络输出
tk:监督数据
(k代表数据维数)

import numpy as np

def mean_squared_error(y, t):
    return 0.5 * np.sum((y - t) ** 2)


t = np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
y = np.array([0.1, 0.05, 0.6, 0, 0.05, 0.1, 0, 0.1, 0, 0])
print(mean_squared_error(y, t)) # 0.0975

在本例中,数组t代表监督数据,里面2位置标为1,其他位置标为0代表结果为2.这一表示法叫one-hot表示。利用均方差得到的总损失为0.0975

交叉熵误差(cross entropy error):
在这里插入图片描述
由于tk的one-hot表示里只有正确解值为1,其他解值为0.交叉熵的值为-log(y),其中y为正确解的输出概率

def cross_entropy_error(y, t):
    delta = 1e-7
    return -np.sum(t * np.log(y + delta))

t = np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
y = np.array([0.1, 0.05, 0.6, 0, 0.05, 0.1, 0, 0.1, 0, 0])
print(cross_entropy_error(y, t)) # 0.511

这里我们加上一个很小的值delta,这使得在t为0时不会出现log(0)导致无限大

mini-batch学习

机器学习的过程即为对训练数据计算损失函数,并且找到使损失函数值最小的参数,以交叉熵为例:
在这里插入图片描述
这里我们把损失函数式子扩大到n个数据。最后除以n得到“平均损失函数”

一般来说,我们无法将全部数据作为训练样本,而是取全部数据里的一批样本(称为mini-batch)

import sys, os
sys.path.append(os.pardir)
import numpy as np
from dataset.mnist import load_mnist

(x_train, t_train), (x_test, t_test) = load_mnist(normalize = True, one_hot_label = True)

print(x_train.shape)
print(t_train.shape)

这一步从mnist库里导出数据集。x_train形状为(60000, 784),因为训练数据集个数60000,输入数据784维(28 X 28),t_train形状为(60000, 10),训练数据集个数60000,监督数据one-hot标签量为10

train_size = x_train.shape[0]
batch_size = 10
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]

这一步利用np.random.choice方法在数据集里随机抽取10个数据,并赋值给x_batch和t_batch

实现mini-batch交叉熵计算:

def cross_entropy_error(y, t):
    if y.ndim == 1:
        t = t.reshape(1, t.size)
        y = y.reshape(1, y.size)

    batch_size = y.shape[0]
    return -np.sum(t * np.log(y + 1e-7)) / batch_size

这里我们创建的cross_entropy_error可以处理单个数据和批数据。对于单个数据先将其变换为一个大小为1的数组,在代入交叉熵公式计算。将所以数据交叉熵之和除以batch_size相当于得到了数据集的平均交叉熵

使用损失函数的意义:
损失函数的作用是使用一个连续函数来代表模型的识别精度。神经网络的学习依靠对损失函数求导(准确来说是偏导)找到函数极小值处的参数

导数:
导数在数学上定义为函数切线斜率(lim(h -> 0) f(x + h) - f(x) / h )。但是在程序里实现求导,我们不可能使用一个无限趋于0的h值。h值的选取不得过小,否则会受到round-off误差的影响较大。

为了弥补这一误差,我们可以计算函数在(x + h) 和(x - h)间的差分,称为中心差分。计算(x + h)和x的差分叫前向差分

注:使用微小差分求导被称为数值微分(numerical differentiation),基于数学式求导称为解析求导(analytic differentiation)

def numerical_diff(f, x):
    h = 1e-4
    return ((f(x + h) - f(x - h)) / (2 * h))

经测试,数值差分结果和真实导数0.2差距可以忽略不计
在这里插入图片描述
偏导数:
对带有多个自变量的函数求导时,需要指定关于其某一个变量求导,而将其他变量当做常数,这样得到的导数被称为偏导

如对于函数f(x1, x2) = x1 ^ 2 + x2 ^ 2, ∂/∂x1 = 2x1 ∂/∂x2 = 2x2

由关于f(x)所有变量偏导汇总起来的向量被称为梯度,实现代码如下:

def numerical_gradient(f, x):
    h = 1e-4
    grad = np.zeros_like(x)
    
    for i in range(x.size):
        temp = x[i]
        x[i] = temp + h
        fxh1 = f(x)
        x[i] = temp - h
        fxh2 = f(x)
        grad[i] = (fxh1 - fxh2) / (2 * h)
        x[i] = temp
        
    return grad

注:np.zeros_like(x)会生成一个形状和x相同,所有值为0的数组

该函数遍历所有自变量x逐个求偏导,最后汇总起来的数组即为梯度

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/179898.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ISIS的3级别(level-1、level-2、level-1-2)4大类(IIH、LSP、CSNP、PSNP)9小类与邻接关系建立LSP交互过程介绍

2.2.0 ISIS 4种报文类型IIH、LSP、CSNP、PSNP、邻居建立过程、交互LSP过程 ISIS的3级别4大类9小类 ISIS拥有3种级别的路由器,分别是level-1、level-2、level-1-2。 不同级别之间进行交互的报文也是有所区别的,常规的ISIS报文分有4大类:IIH、…

cubeIDE开发, stm32人工智能开发应用实践(Cube.AI).篇一

一、cube.AI简介及cubeIDE集成 1.1 cube.AI介绍 cube.AI准确来说是STM32Cube.AI,它是ST公司的打造的STM32Cube生态体系的扩展包X-CUBE-AI,专用于帮助开发者实现人工智能开发。确切地说,是将基于各种人工智能开发框架训练出来的算法模型&#…

Vue3商店后台管理系统设计文稿篇(六)

记录使用vscode构建Vue3商店后台管理系统,这是第六篇,从这一篇章开始,所有的预备工作结束,正式进入商店后台管理系统的开发 文章目录一、创建后台管理系统的标题栏二、安装Icon 图标三、创建Menu菜单正文内容: 一、创…

PowerShell 学习笔记:操作JSON文件

JSON文件(字符串)是有一定格式要求的文本文件。百度百科JSON(JavaScriptObject Notation, JS对象简谱)是一种轻量级的数据交换格式。它基于 ECMAScript(European Computer Manufacturers Association, 欧洲计算机协会制…

初识Linux常见指令汇总

文章目录前言1.对文件或目录的常用指令1.查看当前路径下的文件或目录相关信息2.进入指定路径3.创建删除文件或者目录4.使用nano简单编辑文件查看文件属性5.复制移动重命名文件或目录6.输入输出重定(查看文件内容)向和搜索查找1.输入输出重定向2.搜索查找7.打包压缩文件2.时间相…

如何使用Maven构建Java项目?Maven的使用详细解读

文章目录1. 前言2. Maven 快速入门2.1 Maven 项目模型2.2 Maven 仓库3. Maven的安装配置3.1 安装3.2 配置环境变量3.4 Maven 配置4. Maven 的常用命令4.1 编译4.2 清理4.3 打包4.4 测试4.5 安装5. Maven生命周期6. 总结Java编程基础教程系列:1. 前言 在 Java 开发中…

C++初阶:list类

文章目录1 list介绍2 list的模拟实现2.1 类的定义2.2 默认成员函数2.2.1 构造函数2.2.2 析构函数2.2.3 拷贝构造2.2.4 赋值重载2.3 迭代器2.3.1 正向迭代器2.3.2 反向迭代器2.4 修改接口2.4.1 任意位置插入2.4.2 任意位置删除2.5 其他接口2.5.1 尾插2.5.2 头插2.5.3 尾删2.5.3 …

3.7-2动态规划--图像压缩(举例子和写代码)

3.7动态规划--图像压缩_昵称什么的不存在的博客-CSDN博客 问题描述(再写一遍) 这篇文章是接着上面这一篇写的,就是写一个例子方便理解,模拟填写数组的过程 l: l[i]存放第i段长度, 表中各项均为8位长,限制了相同位数的…

CGAL 点云精配准之ICP算法

文章目录 一、简介二、相关参数三、实现过程三、举个栗子四、实现效果参考资料一、简介 ICP算法总共分为6个阶段,如下图所示: (1)挑选发生重叠的点云子集,这一步如果原始点云数据量比较巨大,一般会对原始点云进行下采样操作。 (2)匹配特征点。通常是距离最近的两个点,…

如何批量增加视频的音量(ffmpeg)

问题背景 由于之前爷爷的唱戏机充不进去电,过年时给爷爷买了个新的。但这个新买的机子,它的曲目(视频)在U盘里,声音普遍较低,我爷爷的耳朵不好,声音需要比正常的声音调大一些。 在Videolouder这…

【数据结构和算法】认识线性表中的链表,并实现单向链表

本文接着上文,上文我们认识了线性表的概念,并实现了静态、动态顺序表。接下来我们认识一个新概念链表。并实现单向链表的各种操作。顺序表还有不明白的看这一篇文章 (13条消息) 【数据结构和算法】实现线性表中的静态、动态顺序表_小王学代码的博客-CSDN…

leetcode--链表

链表1.链表的基本操作(1)反转链表(206)(2) 合并两个有序链表(21)(3)两两交换链表中的节点(24)2.其它链表技巧(1)相交链表(160)(2)回文链表(234)3.练习&#x…

力扣 2293. 极大极小游戏

题目 给你一个下标从 0 开始的整数数组 nums ,其长度是 2 的幂。 对 nums 执行下述算法: 设 n 等于 nums 的长度,如果 n 1 ,终止 算法过程。否则,创建 一个新的整数数组 newNums ,新数组长度为 n / 2 &…

手把手带初学者快速入门 JAVA Web SSM 框架

博主也是刚开始学习SSM,为了帮大家节省时间,写下SSM快速入门博客 有什么不对的地方还请 私信 或者 评论区 指出 ​只是一个简单的整合项目,让初学者了解一下SSM的大致结构 项目先把框架写好,之后在填写内容 项目压缩包 完整的蓝奏…

浅谈phar反序列化漏洞

目录 基础知识 前言 Phar基础 Phar文件结构 受影响的函数 漏洞实验 实验一 实验二 过滤绕过 补充 基础知识 前言 PHP反序列化常见的是使用unserilize()进行反序列化,除此之外还有其它的反序列化方法,不需要用到unserilize()。就是用到了本文…

C 语言零基础入门教程(十一)

C 数组 C语言支持数组数据结构,它可以存储一个固定大小的相同类型元素的顺序集合。数组是用来存储一系列数据,但它往往被认为是一系列相同类型的变量。 数组的声明并不是声明一个个单独的变量,比如 runoob0、runoob1、…、runoob99&#xf…

【Linux】调试器 - gdb 的使用

目录 一、背景知识 二、debug 与 release 1、生成两种版本的可执行程序 2、debug 与 release 的区别 三、gdb 的使用 1、调试指令与指令集 2、源代码显示、运行与退出调试 3、断点操作 4、逐语句与逐过程 5、调试过程中的数据监视 6、调试过程中快速定位问题 一、背…

吴恩达机器学习笔记(三)逻辑回归

机器学习(三) 学习机器学习过程中的心得体会以及知识点的整理,方便我自己查找,也希望可以和大家一起交流。 —— 吴恩达机器学习第五章 —— 四、逻辑回归 线性回归局限性 线性回归对于分类问题的局限性:由于离群点…

LeetCode动态规划经典题目(九):入门

学习目标: 了解动态规划 学习内容: 1. LeetCode509. 斐波那契数https://leetcode.cn/problems/fibonacci-number/ 2. LeetCode70. 爬楼梯https://leetcode.cn/problems/climbing-stairs/ 3. LeetCode746. 使用最小花费爬楼梯https://leetcode.cn/proble…

ice规则引擎==启动流程和源码分析

启动 git clone代码 创建数据库ice,执行ice server里的sql,修改ice server的配置文件中的数据库信息 启动ice server 和ice test 访问ice server localhost:8121 新增一个app,默认给了个id为1,这个1可以看到在ice test的配置文件中指定…