【深度学习】3-2 神经网络的学习- mini-batch学习

news2025/1/10 20:49:40

机器学习使用训练数据进行学习。使用训练数据进行学习,就是针对训练数据计算损失函数的值,也就是说,训练数据有100个的话,就要把这 100个损失函数的总和作为学习的指标。

求多个数据的损失函数,要求所有训练数据的损失函数的综合,可以写成如下式子:
在这里插入图片描述
其实就是把求单个数据的损失函数的式子扩大到了N份数据,不过最后还要除以N进行正规化。通过除以N,可以求单个数据的“平均损失函数”。通过这样的均化,可以获得和训练数据的数量无关的统一指标。

在以大数据为对象求损失函数的和,需要花费较长的时间,因此,我们从全部数据中选出一部分,作为全部数据的“近似”。神经网络的学习也是从训练数据中选出一批数据(称为mini-batch,小批量),然后对每个mini-batch进行学习。

下面来编写从训练数据中随机选择指定个数的数据的代码,以进行mini-batch学习。

import sys,os
sys.path.append(os .pardir)
import numpy as np
from dataset.mnist import load_mnist
# 读人MNIST数据集
(x_train, t train), (x test, t test) =load_mnist(normalize=True, one_hot_label=True)
print(x_train.shape) #(60000784)
print(t_train.shape) # (6000010)

train_size =x_train.shape[0]
batch size =10

batch_mask = np.random.choice(train_size,batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]

使用np.random.choice()可以从指定的数字中随机选择想要的数字
np.random.choice(60000,10)会从0到59999之间随机选择10个数子,可以得到一个包含被选数据的索引的数组

>>>np.random.choice(6000010)
array([ 8013,14666, 58210, 23832, 52091, 10153, 8107, 19410, 27262, 14111])

之后,只需指定这些随机选出的索引,取出mini-batch
用随机量数据( mini-batch)作为全体训练数据的近似值。

mini-batch版交叉熵误差的实现
要实现对应mini-batch的交叉误差,需要改良之前实现的单个数据的交叉熵误差,让它可以同时处理单个数据和批量数据(数据作为batch集中输人)

def cross_entropy_error(y,t):
	if y.ndim == 1:
		t = t.reshape(1, t.size)
		y = y.reshape(1, y.size)
	batch_size = y.shape[0]
	return -np.sum(t * np.log(y + 1e-7)) / batch_size

这里,y是神经网络的输出,t是监督数据。y的维度为1时,即求单个数据的交叉熵误差时,需要改变数据的形状。并且,当输人为mini-batch时,要用batch的个数进行正规化,计算单个数据的平均交叉熵误差

此外,当监督数据是标签形式(非one-hot)表示,而是像“2”“7”这样的标签时,交叉熵误差可通过如下代码实现:

def cross_entropy_error(y, t):
	if y.ndim == 1:
		t = t.reshape(1, t.size)
		y = y.reshape(1, y.size)
		
	batch_size = y.shape[0]
	return -np.sum(np.log(y[np.arange(batch_size), t] + le-7)) / batch_size

实现的要点是,由于one-hot表示中t为0的元素的交叉嫡误差也为0,因此针对这些元素的计算可以忽略。换言之,如果可以获得神经网络在正确解标签处的输出,就可以计算交叉熵误差。
此外关于 np.log(y[np.arange(batch_size), t] + 1e-7)
np.arange(batch_size)会生成一个从0到batch_size-1的数组。因为t中标签是以[2,7,0,9,4]的形式存储的,所以y[np.arange(batch_size), t] 能抽出各个数据的正确解标签对应的神经网络的输出在这个例子中y[np.arange(batch_size), t] 会输出NumPy数组[y[0,2], y[1,7]],y[2,0],y[3,9],y[4,4]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/663191.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

INTERSPEECH2023|达摩院语音实验室入选论文全况速览

近日,语音技术领域旗舰会议INTERSPEECH 2023公布了本届论文审稿结果,阿里巴巴达摩院语音实验室有17篇论文被大会收录。 01 论文题目:FunASR: A Fundamental End-to-End Speech Recognition Toolkit 论文作者:高志付,…

基于 AntV G2Plot 来实现一个 堆叠柱状图 加 折线图 的多图层案例

前言 最近研究了一下antv/g2的组合图例,并尝试做了一个不算太难的组合图,下面介绍一下整个图里的实现过程。 最终效果图 先来看一下最终的效果图 该图表有两部分组成,一部分是柱状图,准确说是堆叠的柱状图,一个柱…

【TA100】图形 3.5 Early-z和Z-prepass

一、深度测试:Depth Test 1.回顾深度测试的内容 深度测试位于渲染管线哪个位置 ○ 深度测试位于逐片元操作中、模板测试后、透明度混合前 为什么做深度测试 ● 深度测试可以解决:物体的可见遮挡性问题 ○ 我们可以用一个例子说明 ■ 图的解释&…

windows应急整理

windows应急整理 Virustotal 网站分析恶意样本 BrowingHistoryView 查看浏览器所有历史记录,可能会请求攻击者的恶意网站或者下载东西 启动项检查 开机启动项文件夹 msconfig 注册表run 键值查看 启动项 临时文件检查,temp 目录权限特殊,容易成为被利用对象 %temp%查看 tem…

华为HCIP第一天---------RSTP

一、介绍 1、以太网交换网络中为了进行链路备份,提高网络可靠性,通常会使用冗余链路,但是这也带来了网络环路的问题。网络环路会引发广播风暴和MAC地址表震荡等问题,导致用户通信质量差,甚至通信中断。为了解决交换网…

C# WebSocketSharp 框架的用法

效果: 一、概述 WebSocketSharp 是一个 C# 实现 websocket 协议客户端和服务端,WebSocketSharp 支持RFC 6455;WebSocket客户端和服务器;消息压缩扩展;安全连接;HTTP身份验证;查询字符串,起始标题和Cookie;通过HTTP代理服务器连接;.NET Framework 3.5或更高版本(包括…

腾讯云服务器云监控是什么?

腾讯云服务器云监控是什么?云监控用于监控云服务器性能资源指标如CPU利用率、内存使用量、内网外网出入带宽、TCP连接数、硬盘IOPS、硬盘IO等性能指标,云服务器吧建议免费开通云监控功能。 什么是云监控? 腾讯云服务器CVM云监控是什么&…

从小白到大神之路之学习运维第43天---第三阶段----LVS-----keepalived+LVS(DR)搭建部署

第三阶段基础 时 间:2023年6月19日 参加人:全班人员 内 容: keepalivedLVS(DR)搭建部署 目录 一、作用 技术特点: 与nginx的区别: 安全性: 配置文件: 二、环境简介 三、操作步骤 …

SPEC 2006 gcc version 8.3.0 (Uos 8.3.0.3-3+rebuild) x86_64 源码编译tools 错误处理笔记

编译tools 拷贝tools到安装目录 cp /mnt/iso/tools /opt/speccpu2006/ -r 执行编译 su rootcd /opt/speccpu2006/tools/src sh -x buildtools 错误 undefined reference to __alloca 编辑./make-3.82/glob/glob.c,注释掉以下宏判断 you should not run config…

unittest教程__测试报告(6)

用例执行完成后,执行结果默认是输出在屏幕上,其实我们可以把结果输出到一个文件中,形成测试报告。 unittest自带的测试报告是文本形式的,如下代码: import unittestif __name__ __main__:# 识别指定目录下所有以tes…

springcloud 中RestTemplate 是怎么和 ribbon整合,实现负载均衡的?源码分析

一、RestTemplate 拦截器了解 RestTemplate 内置了一个 ClientHttpRequestInterceptor,这个是一个拦截器操作,我们可以在请求的前后做一些事情。然后我们看一下这个类,这个类里面 有一个 intercept方法。我们看下这个实现类,里面有一个 LoadBalancerInterceptor实现类。 …

pm2详解

对于后台进程的管理,常用的工具是crontab,可用于两种场景:定时任务和常驻脚本。关于常驻脚本,今天介绍一款更好用的工具:pm2,基于nodejs开发的进程管理器,适用于后台常驻脚本管理,同…

whisper语音识别部署及WER评价

1.whisper部署 详细过程可以参照:🏠 创建项目文件夹 mkdir whisper cd whisper conda创建虚拟环境 conda create -n py310 python3.10 -c conda-forge -y 安装pytorch pip install --pre torch torchvision torchaudio --extra-index-url 下载whisper p…

STM32单片机LED显示屏驱动原理与实现

STM32单片机驱动LED显示屏的原理与实现方法与Arduino类似,但涉及到的具体硬件资源和库函数可能会有所不同。下面是一个详细的介绍: 原理: STM32单片机驱动LED显示屏的原理是通过控制GPIO引脚的电平状态来控制LED的亮灭。通过设置引脚的输出电…

Jetpack Compose中的附带效应及效应处理器

Jetpack Compose中的附带效应及效应处理器 将在任何可组合函数范围之外运行的代码称为附带效应。 为什么要编写在任何可组合函数范围之外的代码? 这是因为可组合项的生命周期和属性(例如不可预测的重组)会执行可组合项的重组。 让我们通过一…

软考高级系统架构设计师(一) 考什么

目录 一、背景 二、软考(高级)的用途 三、考什么 第一科:综合知识 第二科:案例分析 第三科:论文 四、系统架构设计师常见的考试内容 五、模拟与训练 一、背景 系统架构设计师,属于软考高级考试中的一种。 二、软考(高级)…

Node搭建前端服务Mysql数据库交互一篇搞定

目录 介绍 安装环境及数据准备 代码示例 mysql连接工具类 测试方法文件 单表总量查询 单表条件查询 新增数据 修改 删除 ​编辑 ​编辑 联表查询 联表过滤 搭配express服务搭建api使用 介绍 在前端开发中,可以使用纯node前端进行服务搭建与mysql进行数据库的交互,这样…

Bun vs. Node.js

Bun vs. Node.js 你知道 Bun 吗?Bun 是新的 JavaScript 运行时,最近在技术领域引起轰动,它声称比 Node.js 更好。本文将展示如何使用基准分数对其进行测试。 在本文中,我们将介绍最近在技术领域引起轰动的新的 Bun 运行时。我们…

螯合剂试剂:DOTA-CH2-Ph-azide(HCl salt),分子式:C21H34Cl3N7O6,的相关参数信息

文章关键词:双功能螯合剂,azide叠氮 为大家介绍(CAS:N/A),试剂仅用于科学研究,不可用于人类,非药用,非食用。 分子式:C21H34Cl3N7O6 分子量:586.9 英文名称&a…

限定国家及时间|心理学老师如期赴意大利访学

S老师由于个人情况变化需要办理CSC改派,并限定了国家且要求年底出国。我们最终用意大利巴里大学的邀请函,助其成功申请了CSC改派并如期出国。 S老师背景: 申请类型: CSC访问学者 工作背景: 高校教师 教育背景&#…