第4章 神经网络【1】——损失函数

news2025/1/30 12:54:39

4.1.从数据中学习

        实际的神经网络中,参数的数量成千上万,因此,需要由数据自动决定权重参数的值。

        4.1.1.数据驱动

                数据是机器学习的核心。

                我们的目标是要提取出特征量,特征量指的是从输入数据/图像中提取出的本质的数                       据,特征量通常表示为向量的形式。                

                有两种方法:a. 使用人想到的特征量将图像数据转换为向量,然后对转换后的向量使用机器学习中的SVM、KNN等分类器进行学习【关于这一点,我的想法是,如果使用传统算法来提取特征,就根据经验针对不同的问题选取合适的特征量】;b.直接使用神经网络来实现端到端【从原始数据直接获得输出结果】的学习。 这两个方法目的一样,都是为了从原始数据中提取出本质的数据或信息。

        4.1.2.训练数据和测试数据

        获得泛化能力是机器学习的最终目标。       

        仅仅用一个数据集去学习和评价参数,是不客观的,可能会导致可以顺利地处理某个数据集,但无法处理其他数据集的情况,即过拟合。

        为了避免过拟合,追求模型的泛化能力【指处理未被观察过的数据】【举例来说,识别手写数字的问题,泛化能力可能会被用在自动读取明信片的邮政编码的系统上,此时,手写识别的就是“任何一个人写的任意文字”,而不是“特定某个人写的特定的文字”】,需要划分训练集和测试集。使用训练数据进行学习,寻找最优的参数,然后,利用测试数据评价训练得到的模型的实际能力。

4.2.损失函数

        神经网络的学习中使用损失函数来寻找最优权重参数,这里的损失函数可以用任意函数,一般用均方误差和交叉熵误差。                

        4.2.1.均方误差

        【one-hot表示:正确解标签表示为1,其他标签表示为0】 

def mean_squared_error(y, t):
    return 0.5 * np.sum((y-t)**2)

        4.2.2.交叉熵误差

        

        这里的tk是正确解标签,并且,只有正确解标签的索引为1,其他的索引均为0(one-hot表示),因此,式子4.2实际上只计算对应正确解标签的输出的自然对数。

def cross_entropy_error(y, t): 
    delta = 1e-7
    return -np.sum(t * np.log(y + delta))

        这里在log里加了一个很小的delta的值,为了防止y为0时,log值为-inf,这样会导致后续计算无法进行,即相当于一个保护性对策。

        4.2.3.mini-batch学习

        MNIST 数据集的训练数据有 60000 个,一些大的数据,数据量页会有几百万、几千万之多,这种情况下以全部数据为对象计算平均损失函数是不现实的。因此,从全部数据中选出一部分,作为全部数据的“近似”。神经网络的学习也是从训练数据中选出一批数据,然后对每个mini-batch进行学习。这种学习方式称为mini-batch学习。

        以交叉熵误差为例,求所有训练数据的损失函数的总和,把单个数据的“平均损失函数”的式扩大到了N份数据,最后除以N进行正规化,即得出单个数据的“平均损失函数”:【通过这样的平均化,可以获得和训练数据的数量无关的统一指标】

       举例介绍一下mini-batch学习的编码过程:

        a.读入 MNIST 数据集

import sys, os sys.path.append(os.pardir)
import numpy as np
from dataset.mnist import load_mnist
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
print(x_train.shape) # (60000, 784) print(t_train.shape) # (60000, 10)

        one_hot_label设置为True,表示正确解标签为1,其余为0。

        b.从训练数据中随机选取10笔数据

        使用NumPy的np.random.choice(),可以从指定的数字中随机选取想要的数字,即

train_size = x_train.shape[0]
batch_size = 10
batch_mask = np.random.choice(train_size, batch_size) 
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]

         之后,指定这些随机选取的索引,取出mini-batch,然后使用mini-batch计算损失函数即可。

        4.2.4.mini-batch版交叉熵误差的实现

        当监督数据t是one-hot形式时,可实现一个同时处理单个数据和批量数据batch两种情况的函数:

def cross_entropy_error(y, t):
 if y.ndim == 1:
     t = t.reshape(1, t.size)
     y = y.reshape(1, y.size)
 batch_size = y.shape[0]
 return -np.sum(t * np.log(y + 1e-7)) / batch_size

        当监督数据t是标签形式时(非 one-hot 表示,而是像“2”“7”这样的 标签),可通过如下代码实现:

def cross_entropy_error(y, t): 
    if y.ndim == 1:
        t = t.reshape(1, t.size) 
        y = y.reshape(1, y.size)
    batch_size = y.shape[0]
    return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size

        介绍一下代码实现中的np.log(y[np.arange(batch_size), t] + 1e-7):np.arange(batch_size)会生成一个从0到batch_size-1的数组。例如当batch_size为5时,np.arange(batch_size)会生成一个NumPy数组[0,1,2,3,4]。由于t中标签是以[2,7,0,9,4]的形式存储的,所以y[np.arange(batch_size), t]能抽出各个数据的正确解标签对应的神经网络的输出(在这个例子中,y[np.arange(batch_size), t]会生成NumPy数组[y[0,2], y[1,7], y[2,0], y[3,9], y[4,4]]。

        4.2.5.为什么要设定损失函数

        以数字识别任务为例,目的既然是能提高识别精度的参数,那特意导入一个损失函数不是有些重复劳动吗?为什么不直接把识别精度作为指标?

        对于这个疑问,我们来关注一下神经网络的某一个权重参数,对该权重参数的损失函数求导,如果导数值为正,则该权重参数向负方向改变可减小损失函数的值,反之,权重参数向正方向改变可减小损失函数的值。若导数为0,则无论权重参数向哪个方向变化,损失函数的值都不会变,即权重参数的更新会停留在此处。【而之所以不用识别精度作为指标,是因为绝大多数地方的导数都会变为0,导致参数无法更新,而且识别精度的值也不像损失函数作为指标时那样连续变化,即识别精度对微小的参数变化基本上没有什么反应】

       

                

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2286432.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Go的内存逃逸

Go的内存逃逸 内存逃逸是 Go 语言中一个重要的概念,指的是本应分配在栈上的变量被分配到了堆上。栈上的变量在函数结束后会自动回收,而堆上的变量需要通过垃圾回收(GC)来管理,因此内存逃逸会增加 GC 的压力&#xff0…

StarRocks BE源码编译、CLion高亮跳转方法

阅读SR BE源码时,很多类的引用位置爆红找不到,或无法跳转过去,而自己的Linux机器往往缺乏各种C依赖库,配置安装比较麻烦,因此总体的思路是通过CLion远程连接SR社区已经安装完各种依赖库的Docker容器,进行编…

Vue 响应式渲染 - 待办事项简单实现

Vue 渐进式JavaScript 框架 基于Vue2的学习笔记 - Vue 响应式渲染 - 待办事项简单实现 目录 待办事项简单实现 页面初始化 双向绑定的指令 增加留言列表设置 增加删除按钮 最后优化 总结 待办事项简单实现 页面初始化 对页面进行vue的引入、创建输入框和按钮及实例化V…

SpringBoot基础概念介绍-数据源与数据库连接池

🙋大家好!我是毛毛张! 🌈个人首页: 神马都会亿点点的毛毛张 毛毛张今天介绍的SpringBoot中的基础概念-数据源与数据库连接池,同时介绍SpringBoot整合两种连接池的教程 文章目录 1 数据库与数据库管理系统2 JDBC与数…

Microsoft Visual Studio 2022 主题修改(补充)

Microsoft Visual Studio 2022 透明背景修改这方面已经有很多佬介绍过了,今天闲来无事就补充几点细节。 具体的修改可以参考:Microsoft Visual Studio 2022 透明背景修改(快捷方法)_material studio怎么把背景弄成透明-CSDN博客文…

(done) ABI 相关知识补充:内核线程切换、用户线程切换、用户内核切换需要保存哪些寄存器?

由于操作系统和编译器约定了 ABI,如下: 编译器在对 C 语言编译时,会自动 caller 标注的寄存器进行保存恢复。保存的步骤通常发生在进入函数的时候,恢复的步骤通常发生在从函数返回的时候。 内核线程切换需要保存的寄存器&#…

Linux 多路转接select

Linux 多路转接select 1. select select() 是一种较老的多路转接IO接口,它有一定的缺陷导致它不是实现多路转接IO的最优选择,但 poll() 和 epoll() 都是较新版的Linux系统提供的,一些小型嵌入式设备的存储很小,只能使用老版本的…

【实践案例】使用Dify构建文章生成工作流【在线搜索+封面图片生成+内容标题生成】

文章目录 概述开始节点图片封面生成关键词实时搜索主题参考生成文章详情和生成文章标题测试完整工作流运行测试结果 概述 使用Dify构建文章生成工作流,使用工具包括:使用 Tavily 执行的搜索查询,使用Flux生成封面图片,使用Stable…

Web3 如何赋能元宇宙,实现虚实融合的无缝对接

随着技术的飞速发展,元宇宙作为一个未来数字世界的概念,正在吸引全球范围内的关注。而 Web3 技术的兴起,为元宇宙的实现提供了强大的支撑。Web3 是基于区块链技术的去中心化网络,它在改变互联网的同时,也推动着虚拟世界…

LangChain的开发流程

文章目录 LangChain的开发流程开发密钥指南3种使用密钥的方法编写一个取名程序 LangChain表达式 LangChain的开发流程 为了更深人地理解LangChain的开发流程,本文将以构建聊天机器人为实际案例进行详细演示。下图展示了一个设计聊天机器人的LLM应用程序。 除了Wb服务…

电商系统-用户认证(四)Oauth2授权模式和资源服务授权

本文章介绍:Oauth2.0 常见授权模式,资源服务授权 。 准备工作 搭建认证服务器之前,先在用户系统表结构中增加如下表结构: CREATE TABLE oauth_client_details (client_id varchar(48) NOT NULL COMMENT 客户端ID,主…

[答疑]DDD伪创新哪有资格和仿制药比

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 远航 2025-1-24 10:40 最近的热门话题仿制药,想到您经常批评的伪创新,这两者是不是很像? UMLChina潘加宇 伪创新哪有资格和仿制药比。 仿制药的…

图漾相机——Sample_V1示例程序

文章目录 1.SDK支持的平台类型1.1 Windows 平台1.2 Linux平台 2.SDK基本知识2.1 SDK目录结构2.2 设备组件简介2.3 设备组件属性2.4 设备的帧数据管理机制2.5 SDK中的坐标系变换 3.Sample_V1示例程序3.1 DeviceStorage3.2 DumpCalibInfo3.3 NetStatistic3.4 SimpleView_SaveLoad…

系统架构设计师教材:信息系统及信息安全

信息系统 信息系统的5个基本功能:输入、存储、处理、输出和控制。信息系统的生命周期分为4个阶段,即产生阶段、开发阶段、运行阶段和消亡阶段。 信息系统建设原则 1. 高层管理人员介入原则:只有高层管理人员才能知道企业究竟需要什么样的信…

Kafka 深入客户端 — 事务

Kafka 事务确保了数据在写入Kafka时的原子性和一致性。 1 幂等 幂等就是对接口的多次调用所产生的结果和调用一次是一致的。 Kafka 生产者在进行重试的时候可能会写入重复的消息,开启幂等性功能后就可以避免这种情况。将生产者客户端参数enable.idempotence设置为…

ZZNUOJ(C/C++)基础练习1011——1020(详解版)

1011 : 圆柱体表面积 题目描述 输入圆柱体的底面半径r和高h,计算圆柱体的表面积并输出到屏幕上。要求定义圆周率为如下宏常量 #define PI 3.14159 输入 输入两个实数,表示圆柱体的底面半径r和高h。 输出 输出一个实数,即圆柱体的表面积&…

Baklib探索内容中台的核心价值与实施策略

内容概要 在数字化转型的背景下,内容中台逐渐成为企业数字化策略中的关键组成部分。内容中台是一个集成的内容管理体系,旨在打破信息孤岛,使内容能够在各个业务部门和平台之间高效流通。这种管理体系不仅能够提升内容的生产效率,…

网络安全攻防实战:从基础防护到高级对抗

📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 引言 在信息化时代,网络安全已经成为企业、政府和个人必须重视的问题。从数据泄露到勒索软件攻击,每一次…

论文阅读(十三):复杂表型关联的贝叶斯、基于系统的多层次分析:从解释到决策

1.论文链接:Bayesian, Systems-based, Multilevel Analysis of Associations for Complex Phenotypes: from Interpretation to Decision 摘要: 遗传关联研究(GAS)报告的结果相对稀缺,促使许多研究方向。尽管关联概念…

“““【运用 R 语言里的“predict”函数针对 Cox 模型展开新数据的预测以及推理。】“““

主题与背景 本文主要介绍了如何在R语言中使用predict函数对已拟合的Cox比例风险模型进行新数据的预测和推理。Cox模型是一种常用的生存分析方法,用于评估多个因素对事件发生时间的影响。文章通过具体的代码示例展示了如何使用predict函数的不同参数来获取生存概率和…