【深度学习】3-4 神经网络的学习- 学习算法的实现

news2024/11/24 7:24:50

神经网络的学习步骤如下所示:

步骤1(mini-batch)
从训练数据中随机选出一部分数据,目标是减小mini-batch的损失函数的值

步骤2(计算梯度)
为了减小mini-batch的损失函数的值,需要求出各个权重参数的梯度

步骤3(更新参数)
将权重参数沿梯度方向进行微小更新.

步骤4(重复)
重复步骤1、步骤2、步骤3。

这里因为使用的数据是随机选择的mini batch数据,所以称为随机梯度下降法(stochastic gradient descent)。深度学习的很多框架中,随机梯度下降法一般由一个名为SGD的函数来实现, SGD来源于随机梯度下降法的英文名称的首字母。

下面,来实现手写数字识别的神经网络。这里以2层神经网(隐藏层为1层的网络)为对象,使用MNIST数据集进行学习。
首先,下面看这个名为TwoLayerNet的类

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
from common.functions import *
from common.gradient import numerical_gradient


class TwoLayerNet:

    def __init__(self, input_size, hidden_size, output_size, weight_init_std=0.01):
        # 初始化权重
        # 保存神经网络的参数的字典型变量(实例变量)
        self.params = {}
        self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)
        self.params['b1'] = np.zeros(hidden_size)
        self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size)
        self.params['b2'] = np.zeros(output_size)

	# 进行识别(推理)
    def predict(self, x):
        W1, W2 = self.params['W1'], self.params['W2']
        b1, b2 = self.params['b1'], self.params['b2']
    
        a1 = np.dot(x, W1) + b1
        z1 = sigmoid(a1)
        a2 = np.dot(z1, W2) + b2
        y = softmax(a2)
        
        return y
        
    # x:输入数据, t:监督数据
    def loss(self, x, t):
        y = self.predict(x)
        
        return cross_entropy_error(y, t)
    
    # 计算识别精度
    def accuracy(self, x, t):
        y = self.predict(x)
        y = np.argmax(y, axis=1)
        t = np.argmax(t, axis=1)
        
        accuracy = np.sum(y == t) / float(x.shape[0])
        return accuracy
        
    # x:输入数据, t:监督数据
    def numerical_gradient(self, x, t):
        loss_W = lambda W: self.loss(x, t)
        
        # 保存梯度的字典型变量
        grads = {}
        grads['W1'] = numerical_gradient(loss_W, self.params['W1'])
        grads['b1'] = numerical_gradient(loss_W, self.params['b1'])
        grads['W2'] = numerical_gradient(loss_W, self.params['W2'])
        grads['b2'] = numerical_gradient(loss_W, self.params['b2'])
        
        return grads
        
    def gradient(self, x, t):
        W1, W2 = self.params['W1'], self.params['W2']
        b1, b2 = self.params['b1'], self.params['b2']
        grads = {}
        
        batch_num = x.shape[0]
        
        # forward
        a1 = np.dot(x, W1) + b1
        z1 = sigmoid(a1)
        a2 = np.dot(z1, W2) + b2
        y = softmax(a2)
        
        # backward
        dy = (y - t) / batch_num
        grads['W2'] = np.dot(z1.T, dy)
        grads['b2'] = np.sum(dy, axis=0)
        
        da1 = np.dot(dy, W2.T)
        dz1 = sigmoid_grad(a1) * da1
        grads['W1'] = np.dot(x.T, dz1)
        grads['b1'] = np.sum(dz1, axis=0)

        return grads

mini-batch的实现
神经网络的学习的实现使用的是前面介绍过的mini-batch学习。下面,就以TwoLayerNet类为对象,使用MNIST数据集进行学习

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from two_layer_net import TwoLayerNet

# 读入数据
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)

network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)

iters_num = 10000  # 适当设定循环的次数
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.1

train_loss_list = []
train_acc_list = []
test_acc_list = []

iter_per_epoch = max(train_size / batch_size, 1)

for i in range(iters_num):
	# 获取mini-batch
    batch_mask = np.random.choice(train_size, batch_size)
    x_batch = x_train[batch_mask]
    t_batch = t_train[batch_mask]
    
    # 计算梯度
    #grad = network.numerical_gradient(x_batch, t_batch)
    grad = network.gradient(x_batch, t_batch)
    
    # 更新参数
    for key in ('W1', 'b1', 'W2', 'b2'):
        network.params[key] -= learning_rate * grad[key]
    
    #记录学习过程
    loss = network.loss(x_batch, t_batch)
    train_loss_list.append(loss)
    
    # 计算每个epoch的识别精度
    if i % iter_per_epoch == 0:
        train_acc = network.accuracy(x_train, t_train)
        test_acc = network.accuracy(x_test, t_test)
        train_acc_list.append(train_acc)
        test_acc_list.append(test_acc)
        print("train acc, test acc | " + str(train_acc) + ", " + str(test_acc))

# 绘制图形
markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, label='train acc')
plt.plot(x, test_acc_list, label='test acc', linestyle='--')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()

神经网络的学习中,必须确认是否能够正确识别训练数据以外的其他数据,即确认是否会发生过拟合。过拟合是指,虽然训练数据中的数字图像被正确辨别,但是不在训练数据中的数字图像却无法被识别的现象。
神经网络学习的最初目标是掌握泛化能力,因此,要评价神经网络的泛化能力,就必须使用不包含在训练数据中的数据。所以在进行学习过程中,要定期地对训练数据和测试数据记录识别精度。这里,每经过一个epoch,都会记录下训练数据和测试数据的识别精度。

epoch是一个单位。一个epoch表示学习中所有训练数据均被使用一次时的更新次数。比如,对于10000笔训练数据,用大小为100笔数据的mini-batch进行学习时,重复随机梯度下降法100次有的训练数据就都被“看过”了。此时,100次就是一个epoch

把从上面的代码中得到的结果用图标表示的话,如下图:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/668617.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

redhat 6.4安装oracle11g RAC (四)

创建集群数据库 在节点rac1上用oracle用户执行dbca创建RAC数据库 [rootrac1 ~]# su - oracle [oraclerac1 ~]$ dbca选择创建数据库 自定义数据库(也可以是通用) 配置类型选择Admin-Managed,输入全局数据库名orcl,每个节点实例SI…

java线上问题排查基本命令

1、jvm基本命令 1.1、java命令 1.1.1、简介 java命令启动java应用程序。它通过启动Java运行时环境(JRE)、加载指定的类并调用该类的main()方法来实现这一点。 1.1.2、命令链接 https://docs.oracle.com/javase/8/docs/techno…

electron 连接打印机打印pdf文件

electron 打印内容 区分系统 类似unix系统的使用 npm包:unix-printwindow系统使用: pdf-to-printer 运行线程 视图线程 函数参数 两个包都提供了print函数来打印文件,配置基本一致,只是参数形式有所不同,pdf-to-pr…

ESP32开发环境搭建Windows VSCode集成Espressif IDF插件ESP32_IDF_V5.0开发编译环境搭建

一、安装ESP32-IDF库 下载网址:https://dl.espressif.com/dl/esp-idf/ 打开上面的网页,选择单击页面中 ESP32-IDF v5.0.2 - Offine Installer,5.0.2是当前最新版本,如果没有ESP32-IDF v5.0.2 - Offine Installer,说明…

JS获取省市区/县,layui获取省市区,layui实现省市区联动,jquery实现省市区联动

前言 通过JS方式获取省市区数据,可自己手动更改JS文件数据 非常简单 效果 实现 百度网盘链接: https://pan.baidu.com/s/1RktJgXY0NP7Eq0ohvBPOEA 提取码: 477z gitee下载链接:https://gitee.com/yuanyongqiang/common-files/blob/master/area.js 下…

超高压系列IXBX50N360HV、IXBT14N300HV、IXBH32N300高压反向导通 (BiMOSFET™) IGBT器件

器件介绍: 超高压系列3000V - 3600V反向导通 (BiMOSFET™) IGBT将MOSFET和IGBT的优势相结合。这些高压器件的饱和电压和内置二极管的正向电压降均具有正电压温度系数,因此非常适合用于并联运行。“自由”内置体二极管用作保护二极管,为器件关…

LibOS Gramine安装

文章目录 参考资料Gramine安装运行helloworld升级kernel到5.15 参考资料 Gramine Quick start Gramine安装 Gramine安装要求: Linux 内核版本至少为 5.11(启用 SGX 驱动程序) 如果是5.4.0-150-generic版本,则可以参考《Install …

Baumer工业相机堡盟工业相机如何通过BGAPISDK将相机图像写入相机内存(C#)

Baumer工业相机堡盟工业相机如何通过BGAPISDK将相机图像写入相机内存(C#) Baumer工业相机Baumer工业相机BGAPISDK和相机内存的技术背景Baumer工业相机通过BGAPISDK将相机图像写入相机内存功能1.引用合适的类文件2.通过BGAPISDK将相机图像写入相机内存功能…

ADManager Plus:提升企业管理效率的强大利器

导语: 在当今数字化时代,企业管理的重要性不言而喻。有效的企业管理可以提高生产力、优化业务流程,并促进组织的持续增长。而ADManager Plus作为一款功能强大的企业管理工具,为企业提供了全面的解决方案,帮助企业管理…

从0到1精通自动化测试,pytest自动化测试框架,Fixture之conftest.py与yield实现teardown(四)

目录 一、Fixture之conftest.py 1、Fixture优势 2、fixture参数传入(scope”function”) 3、conftest.py配置 二、Fixture之yield实现teardown 1、scope“module” 2、yield执行teardown 3、yield遇到异常 4、addfinalizer终结函数 一、Fixture…

【云原生•监控】基于Prometheus的云原生集群监控(理论+实践)-01

【云原生•监控】基于Prometheus的云原生集群监控(理论实践)-01 前言 「笔者已经在公有云上搭建了一套临时环境,可以先登录体验下:」 http://124.222.45.207:17000/login 账号:root/root.2020 云原生监控挑战 Prometheus 是用 Go 语言编写&am…

FreeRTOS实时操作系统(四)中断任务管理

系列文章目录 文章目录 系列文章目录前言中断优先级FreeRTOS中的中断管理一系列中断管理寄存器中断配置寄存器中断屏蔽寄存器 中断管理实战 前言 跟着正点原子学习一下中断管理,正好之间没有总结过,还有些地方不清楚。 中断优先级 中断的工作方式就不介…

Oracle的DCL、DDL、DML语言学习使用——oracle入门学习(一)

Oracle的DCL、DDL、DML语言学习使用 前言1.SQL Plus1.1 命令行SQL PLUS使用sqlplus /nologsqlplus / as sysdba 1.2 oracle自带SQL PLUS使用1.3 sys和system用户的区别 2. Oracle的体系结构3.DCL语言什么是DCL语言3.1 查看数据文件位置和表空间3.2 创建表空间3.3 删除表空间3.4…

鉴源实验室丨TBOX通讯模组AT指令测试

作者 | 李伟 上海控安安全测评部总监 来源 | 鉴源实验室 引言:上一篇文章我们讲了整车的OTA升级测试(详解车载设备FOTA测试),本篇我们介绍在车载零配件上比较少见却很实用的测试:通讯模组的AT(Attention)指…

总结Nginx的安装、配置与设置开机自启

在Ubuntu下安装Nginx有以下方法,但是如果想要安装最新版本的就必须下载源码包编译安装。 一、Nginx安装 1、基于APT源安装 sudo apt-get install nginx 安装好的文件位置: /usr/sbin/nginx:主程序 /etc/nginx:存放配置文件 /us…

高压放大器可以驱动电机吗

高压放大器可以驱动电机。事实上,高压放大器在许多应用中都是电机控制的核心部件之一。高压放大器可以将输入信号转换为高电压和高电流,从而驱动电动机。 一、高压放大器的原理 高压放大器是一种电子设备,用于将低功率信号转换为高功率信号。…

云原生之深入解析如何在Kubernetes下快速构建企业级云原生日志系统

一、概述 ELK 是三个开源软件的缩写,分别表示 Elasticsearch , Logstash, Kibana , 它们都是开源软件。新增了一个 FileBeat,它是一个轻量级的日志收集处理工具 (Agent),Filebeat 占用资源少,适合于在各个服务器上搜集日志后传输…

Lombok你不知道的用法

Lombok是大家经常用的一款工具,它可以帮我们减少很多重复代码的书写,但是我们对它的使用,可能更多局限于实体类的编写,比如说通过Data注解实现减少getter/setter/toString等方法的编写,其实它还有更多的注解功能&#…

美依礼芽破圈!小红书数据解读,如何拨动二次元心弦?

今年,二次元存在感爆棚。新世代下二次元群体愈发壮大,不少博主发布二次元内容、品牌也试图融入二次元圈。那么,如何与之打成一片呢?今天,通过小红书数据,我们来解读二次元的内容方向。 点赞破8亿&#xff0…

【每日一题】1595. 连通两组点的最小成本

【每日一题】1595. 连通两组点的最小成本 1595. 连通两组点的最小成本题目描述解题思路 1595. 连通两组点的最小成本 题目描述 给你两组点,其中第一组中有 size1 个点,第二组中有 size2 个点,且 size1 > size2 。 任意两点间的连接成本…