深度学习笔记:神经网络的学习(2)

news2025/2/24 16:29:00

本章上一篇可见链接

https://blog.csdn.net/Raine_Yang/article/details/128682091?spm=1001.2014.3001.5501

梯度下降法(gradient descend)

神经网络学习的目标是找到使损失函数最小的参数(权重和偏置)。通过求得损失函数(总损失关于权重和偏置的函数)梯度,寻找梯度下降的发现,即可找到函数最小值。

注意利用梯度下降发得到的不一定是最小值,而仅仅为一个极小值,及梯度为0.另外,当函数呈扁平状,学习可能会进入一个平坦区域,难以进展,被称为学习高原

梯度法即为从当前函数取值沿梯度方向前进一定距离,然后重新求梯度,再继续迭代。其中每一步前进步幅被称为学习率(learning rate)

用公式表示如下:
在这里插入图片描述
梯度下降法程序实现:

import numpy as np

def gradient_descent(f, init_x, lr = 0.01, step_num = 100):
    x = init_x;
    
    for i in range(step_num):
        grad = numerical_gradient(f, x)
        x -= lr * grad
        
    return x

注:f 要优化的函数,init_x初始值,lr学习率,step_num函数重复次数
numerical_gradient(f, x)为求梯度函数,代码实现在上一篇文章

学习率这样的参数被称为超参数。超参数不由神经网络自己训练获得,而必须人工设定。超参数一般要尝试多个值才能找到一个合适的设定

学习率过大会导致训练不精确,而过小会导致学习速度慢

神经网络的梯度

神经网络的梯度值损失函数关于权重参数的梯度,如对于一个2 * 3的神经网络权重W,损失函数为L,梯度即为∂L/∂W
在这里插入图片描述
梯度中每一个值的意义即为当该权重值变化时,损失函数的变化率。梯度的形状和W形状相同

使用梯度下降法处理神经网络的输出

# coding: utf-8
import sys, os
sys.path.append('D:\AI learning source code')  # 为了导入父目录中的文件而进行的设定
import numpy as np
from common.functions import softmax, cross_entropy_error
from common.gradient import numerical_gradient


class simpleNet:
    def __init__(self):
        self.W = np.random.randn(2,3)

    def predict(self, x):
        return np.dot(x, self.W)

    def loss(self, x, t):
        z = self.predict(x)
        y = softmax(z)
        loss = cross_entropy_error(y, t)

        return loss

x = np.array([0.6, 0.9])
t = np.array([0, 0, 1])

net = simpleNet()

f = lambda w: net.loss(x, t)
dW = numerical_gradient(f, net.W)

print(dW)

1

def __init__(self):
        self.W = np.random.randn(2,3)

初始化隐藏层,权重设为随机值

2

def predict(self, x):
        return np.dot(x, self.W)

将输入值x和权重W相乘,得到神经网络第一层输出

3

def loss(self, x, t):
    z = self.predict(x)
    y = softmax(z)
    loss = cross_entropy_error(y, t)

    return loss

使用softmax函数求得分类结果,使用交叉熵得到损失函数

4

f = lambda w: net.loss(x, t)

这里使用了lambda表达式。python里的lambda表达式可以便捷创建简单函数。该式子的意思为 f(W) = net.loss(x, t) 这里网络权重W为一个伪参数。我们将net.loss(x, t)的值定义为 f 关于 W 的函数。

注:神经网络损失函数是关于网络输出输出的函数,而网络输出又是关于W的函数,所有损失函数也为关于W的函数

5

dW = numerical_gradient(f, net.W)

求得损失函数梯度,本例中结果为:
[[ 0.05613215 0.3069528 -0.36308495]
[ 0.08419822 0.4604292 -0.54462742]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/198057.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JavaEE】快速了解什么是Maven?

✨哈喽,大家好,我是辰柒!✨ 🛰️🛰️系列专栏:【JavaEE】 ✈️✈️本篇内容:学习如何使用maven! 🚀🚀代码存放仓库github:JavaEE代码! ⛵⛵作者简介&#xff…

C++STL——stack与queue

stack与queuestack与queuepriority_queue容器适配器vector与list的反向迭代器模拟实现仿函数deque(了解)stack与queue模拟实现priority_queue模拟实现stack与queue 这两个就是之前数据结构学过的栈和队列,只不过多了几个接口。 stack: queue&#xff…

【MySQL】MySQL 8 的 JSON 新特性详解(1)JSON 数据类型

目录一、概述二、MySQL 8 的环境搭建三、创建数据库、数据表并插入默认数据四、JSON格式数据的增加和查询1. 增加一条带JOSN格式的数据2.查询JSON内数据3.带筛选条件的查询五、总结一、概述 你好,我是小雨青年,一名使用MySQL 8 的程序员。 MySQL 8 引入…

Hadoop安装(二) --- Hadoop安装

目录 Hadoop安装(一)---JDK安装 修改hadoop313的权限 更改配置文件 配置core-site.xml 配置hadoop-env.sh 配置hdfs-site.xml 配置mapred-site.xml 配置yarn-site.xml 配置环境 刷新当前的shell环境 初始化 启动所有 SH 修改hadoop31…

Android Studio 从安装到第一个Android 应用Demo

安装Android Studio 安装需要 上网 ,我这挺顺利的,就是在官网下载安装包,一路 Next,大概连下载总共半个小时。 第一个应用 参考官方教程:https://developer.android.com/codelabs/basic-android-kotlin-compose-firs…

Redis最佳实践

一、Redis键值设计 1.1、优雅的key结构 Redis的key,最佳实践约定: 遵循基本格式:【业务名称】:【数据名】:【id】长度不超过44字节不包含特殊字符 好处 可读性强避免key冲突方便管理更节省内存 1.2、拒绝BigKey BigKey通常以Key的大小和…

SOLIDWORKS PDM的智能报表自动生成工具

一、SOLIDWORKS企业高级报表软件介绍: SolidKits.Reports(企业高级报表)是一款无缝集成于SOLIDWORKS PDM的智能报表自动生成工具,可以自动生成企业所需的各类报表数据,涵盖结构数据报表、离散数据报表、变更数据报表、…

rocketmq源码-consumer负载均衡逻辑

前言 这篇笔记主要记录consumer在启动过程中,负载均衡的逻辑,多个消费者组成一个消费者组,对于集群模式,同一个消费者组中的多个消费者共同消费一个topic下的所有消息,所以每个consumer可能会处理N个messageQueue&…

【4】KVM管理 | 虚拟机的管理 | 克隆 | 快照

目录 1、虚机基本管理 2、虚机的克隆 3、增量镜像 4、虚机快照 1、虚机基本管理 查看正在运行的虚机 [rootlocalhost ~]# virsh list Id Name State ----------------------------------------------------查看所有的虚机 [rootlocalhost ~…

Apache Oozie(1):Apache Oozie简介

1 Oozie 概述 Oozie 是一个用来管理 Hadoop 生态圈 job 的工作流调度系统。由Cloudera 公司贡献给 Apache。Oozie 是运行于 Java servlet 容器上的一个 java web 应用。Oozie 的目的是按照 DAG(有向无环图)调度一系列的 Map/Reduce或者Hive 等任务。Ooz…

Java SE 进阶(二)之 HashSet底层原理

文章目录前言HashSet底层原理1.哈希表2.哈希值3.底层原理4.回答三个问题前言 关于Set和HashSet的API使用可参见 集合基础入门(Collection,ArrayList,HashSet,HashMap) HashSet底层原理 1.哈希表 HashSet集合底层采…

Vue组件 —— 单文件组件

追溯vue组件问题 在未讲项目之前,在 这一篇内容当中就讲到了组件引入使用,有内置的组件和动态组件以及封装一个swiper组件,组件也分为全局组件和局部组件,在讲在项目当中去使用组件之前先简单的回顾一下组件的编写: &…

89.【SpringBoot-02】

SpringBoot聊一聊如何构建一个网站(十四)、.SpringBoot整合数据库操作1.整合JDBC(1).SpringData简介(2).整合JDBC(3).JdbcTemplate ⭐2. 整合Druid数据源 (德鲁伊)(1).Druid简介(2).配置数据源(3).配置Druid数据源监控(4).配置Druid数据源过滤器(5).注解…

Echarts的Y轴添加定值横线的示例

第010个点击查看专栏目录Echarts折线图的y轴要画一条横线,主要是在series中设置markLine的图表标线参数,具体的参考源代码。文章目录示例效果示例源代码(共142行)相关资料参考专栏介绍示例效果 示例源代码(共142行&…

怎么在Windows电脑更新 DirectX ?

玩游戏的人应该都对DirectX不陌生,它可以提高游戏或多媒体程序的运行效率,增强3d图形和声音效果。但很多人都不知道DirectX该如何更新,这篇文章将以Win10为例,教大家怎么在电脑上更新DirectX。 一、检查当前DirectX版本 如果你不…

简单聊一聊组件封装

封装一个思维导图组件 最近封装了一个简单的思维导图组件,在此简单记录一下心里历程 组件样式 组件结构设计 节点之间的线分成三部分,分别是竖线左边的横线A、竖线B、竖线右边的横线C,所以一个节点可以包含以下几个元素: 横线…

VBA提高篇_18 VBA代码录制优化Select(tion)及表格合并Merge(cells()/Rows()/Columns()

文章目录1. Cells(1,1)2. Rows(Str)和Columns(Str)3. VBA合并单元格3.1 Range.MergeCells属性:3.2 Range.Merge/UnMerage属性:3.3 Range.Merge(参数True/False)3.4 操作合并/取消合并单元格的两种方法4. Select / Selection 和 录制宏的代码优化4.1 Select / Selection4.2 录制…

anconda的pip下载包出现的问题

问题一: 在anconda里面如何创建新的python环境(也就是更换新的python版本) 1.先打开anconda软件,创建需要的环境 2. 环境创建好之后,去pycharm里面进行配置解释器 3. 这样就可以用了 问题二:pip的安装软件时出现包找不到的问题? 注意:因为我们刚刚创建了一个python环境,等…

Python基于已知的分幅条带号筛选出对应遥感影像文件的方法

本文介绍基于Python语言,结合已知研究区域中所覆盖的全部遥感影像的分幅条带号,从大量的遥感影像文件中筛选落在这一研究区域中的遥感影像文件的方法。 首先,先来明确一下本文所需实现的需求。现已知一个研究区域(四川省&#xff…

【C++】C++入门 函数重载

前言 自然语言中,一个词可以有多重含义,人们可以通过上下文来判断该词真实的含义,即该词被重载了。 函数重载一、函数重载定义二、函数重载的条件1. 参数类型不同2. 参数个数不同3. 参数类型顺序不同三、函数重载的原理--名字修饰(name Mangl…