R语言深度学习-5-深度前馈神经网络

news2024/11/16 16:25:07

本教程参考《RDeepLearningEssential》

本篇我们将学习如何建立并训练深度预测模型。我们将关注深度前馈神经网络


5.1 深度前馈神经网络

我们还是使用之前提到的H2O包,详细可以见之前的博客:R语言深度学习-1-深度学习入门(H2O包安装报错解决及接入/H2O包连接数据集)-CSDN博客

library('h2o')
cl <- h2o.init(
  max_mem_size = "20G",
  nthreads = 10,
  ip = "127.0.0.1", port = 54321)

深度前馈神经网络,也被称为前馈神经网络或多层感知机(MLP),是一种典型的多层神经网络,其中数据在神经元之间单向流动,从输入层经过一个或多个隐藏层传递到输出层。这种类型的网络不包含任何循环或反馈连接,意味着信息的流向是从上到下,不会从输出层返回至输入层。

深度前馈神经网络的核心优势在于其能够通过学习输入数据和目标输出之间的复杂映射关系来执行各种任务。这得益于它的层次结构,每一层都从前一层接收信息并产生输出,这些输出作为下一层的输入。随着网络层次的加深,它能够捕捉更抽象的特征,从而提升模型的性能和泛化能力。

如图所示,来自输入X到输出Y的全部映射是一个多层函数。第一个隐藏层是:

H_{1} = f^{(1)}(X,\omega _{1},\alpha _{1})

而在每一层中有多少隐藏神经元及使用什么激活函数,我们在5.2讨论,另一个关键是成本或损失函数,常用的是交叉熵(cross-entropy)和二次的函数均方差(MSE)。

5.2 激活函数

激活函数在神经网络中扮演着至关重要的角色。它们通常被嵌入到神经网络的隐藏层中,用以引入非线性因素,使得神经网络能够学习和模拟复杂的数据模式。没有激活函数,无论神经网络有多少层,最终都只是相当于一个线性变换,无法解决非线性问题。

激活函数的种类多样,每种都有其特定的用途和特性。以下是一些常见的激活函数及其特点:

1. Sigmoid函数:Sigmoid函数可以将任意实数映射到(0,1)区间内,这使得它可以用来做二分类问题的输出层。然而,当输入值较大或较小时,Sigmoid函数的梯度接近于0,容易导致梯度消失

2. Tanh函数:Tanh函数是Sigmoid函数的变种,它将实数映射到(-1,1)区间内,相比于Sigmoid函数,Tanh函数的输出以0为中心,但它同样存在梯度消失的问题。

3. ReLU函数:ReLU(Rectified Linear Unit)函数是目前最常用的激活函数之一。它在输入大于0时直接输出该值,小于等于0时输出0。ReLU函数解决了梯度消失问题,计算简单且加速了神经网络的训练。但ReLU函数也有缺点,比如当输入为负数时,梯度始终为0,可能导致神经元“死亡”。

4. Leaky ReLU函数:Leaky ReLU是对ReLU的改进,它在输入小于等于0时,梯度不为0,而是一个很小的正数。这样可以缓解ReLU的死神经元问题。

5. Softmax函数:Softmax函数常用于多分类问题的输出层,它可以将一组数值转化为概率分布。

6. Swish函数:Swish函数是一个平滑且非单调的激活函数,由谷歌提出,在某些情况下比ReLU表现更好。

7. Mish函数:Mish函数结合了ReLU和Swish的优点,具有更好的性能表现。

5.3 选取超参数

我们在之前的模型选择参数,一般是选取如权重或者截距。不过还有一些参数能够被学到或者能被优化,我们在进行模型选择的时候,也是一种超参数。我们还是使用之前的手写字数据进行实践:

R语言深度学习-2-训练预测模型-CSDN博客

我们使用H2O的深度学习算法来训练一个分类器,并比较不同学习率对模型性能的影响。以下是两个深度学习模型的配置和它们的运行时间分析。

options(width = 70, digits = 2)
#初始化
dig_train <- read.csv("C:\\Users\\Huzhuocheng\\Desktop\\digit-recognizer\\train.csv")
dim(dig_train) #数据维度查看
dig_train$label <- factor(dig_train$label, levels = 0:9)

h2odigits <- as.h2o(
  dig_train,
  destination_frame = "h2odigits")
i <- 1:32000
h2odigits.train <- h2odigits[i, ]
itest <- 32001:42000
h2odigits.test <- h2odigits[itest, ]
xnames <- colnames(h2odigits.train)[-1]

#训练模型
system.time(ex1 <- h2o.deeplearning(
  x = xnames,
  y = "label",
  training_frame= h2odigits.train,
  validation_frame = h2odigits.test,
  activation = "RectifierWithDropout",
  hidden = c(100),
  epochs = 10,
  adaptive_rate = FALSE,
  rate = .001,
  input_dropout_ratio = 0,
  hidden_dropout_ratios = c(.2)
))
system.time(ex2 <- h2o.deeplearning(
  x = xnames,
  y = "label",
  training_frame= h2odigits.train,
  validation_frame = h2odigits.test,
  activation = "RectifierWithDropout",
  hidden = c(100),
  epochs = 10,
  adaptive_rate = FALSE,
  rate = .01,
  input_dropout_ratio = 0,
  hidden_dropout_ratios = c(.2)
))

 我们选择了不同的学习率,ex1中学习率是0.001,在ex2中,学习率是0.01,我们发现ex1的运行时间长很多,但是就模型效果来说,ex1更好:

 深刻理解超参数,对我们在模型进行训练中有事半功倍的效果,有的时候不是模型不行,而是选择了错误的超参数,这很重要。

5.4 深度神经网络训练及预测

我们使用之前提到的UCI数据进行演示:UCI Machine Learning Repository

#数据导入
train_x <- read.table("C:/Users/Huzhuocheng/Desktop/UCI数据/UCI HAR Dataset/UCI HAR Dataset/train/X_train.txt")
train_Y <- read.table("C:/Users/Huzhuocheng/Desktop/UCI数据/UCI HAR Dataset/UCI HAR Dataset/train/y_train.txt")
test_x <- read.table("C:/Users/Huzhuocheng/Desktop/UCI数据/UCI HAR Dataset/UCI HAR Dataset/test/X_test.txt")
test_Y <- read.table("C:/Users/Huzhuocheng/Desktop/UCI数据/UCI HAR Dataset/UCI HAR Dataset/test/y_test.txt")
barplot(table(train_Y))
train_x <- as.data.frame(train_x)
train_Y <- as.data.frame(train_Y)

train_Y <- factor(train_Y)  
test_Y <- factor(test_Y)   

use.train <- cbind(train_x, Outcome = train_Y)
use.test <- cbind(test_x, Outcome = test_Y)

use.labels <- read.table("C:\\Users\\Huzhuocheng\\Desktop\\UCI数据\\UCI HAR Dataset\\UCI HAR Dataset\\activity_labels.txt")
h2oactivity.train <- as.h2o(
  use.train,
  destination_frame = "h2oactivitytrain")
h2oactivity.test <- as.h2o(
  use.test,
  destination_frame = "h2oactivitytest")

 接下来,我们使用H2O的deeplearning包进行深度学习,使用的激活函数是线性整流器,并且使用了我们上次讲的丢弃正则化,带有输入变量20%丢弃和隐藏神经元50%丢弃,并且我们建立的是一个50神经元和10次迭代的浅层网络,损失函数是交叉熵。

mt1 <- h2o.deeplearning(
  x = colnames(train_x),
  y = "Outcome",
  training_frame= h2oactivity.train,
  activation = "RectifierWithDropout",
  hidden = c(50),
  epochs = 10,
  loss = "CrossEntropy",
  input_dropout_ratio = .2,
  hidden_dropout_ratios = c(.5),
  export_weights_and_biases = TRUE
)

显示了层数及每个层中单元的个数,单元的类型,丢弃百分比和其他正则信息。

这个则显示了模型的性能,包括均方误,对数损失等。

混淆矩阵显示了预测与真实值的差距。

5.5 小结

我们本次使用H2O包对深度神经网络进行了学习应用,不过我们在例子中构建的都是浅层的神经网络,大家可以自己调参数实现更好的理解与应用。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1522305.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring炼气之路(炼气一层)

目录 一、IOC 1.1 控制反转是什么&#xff1f; 1.2 什么是IOC容器&#xff1f; 1.3 IOC容器的作用 1.4 IOC容器存放的是什么&#xff1f; 二、DI 2.1 依赖注入是什么&#xff1f; 2.2 依赖注入的作用 三、IOC案例实现 3.1下载Maven 3.2 配置Maven中的settings.xml文…

Stable Diffusion科普文章【附升级gpt4.0秘笈】

随着人工智能技术的飞速发展&#xff0c;我们越来越多地看到计算机生成的艺术作品出现在我们的生活中。其中&#xff0c;Stable Diffusion作为一种创新的图像生成技术&#xff0c;正在引领一场艺术创作的革命。本文将为您科普Stable Diffusion的相关知识&#xff0c;带您走进这…

部署一个本地的ChatGPT(Ollama)

一 下载Ollama Ollama下载地址&#xff1a;https://ollama.com/download 下载完后 二 安装运行 双击下载好的OllamaSetup.exe开发 安装Ollama: 安装完成后&#xff0c;多了一个Ollama的菜单如下图 &#xff1a; Ollama安装好默认是配置开机运行&#xff0c;如果没有运行可以在…

python的opencv最最基础初学

localhost中详解OpenCV的函数imread()和函数imshow(),并利用它们实现对图像的读取和显示_opencv imshow-CSDN博客 其实以下均为numpy 显示一张图片 import cv2 ####opencv读取的格式是BGR import matplotlib.pyplot as plt import numpy as np %matplotlib inline imgcv2.…

Golang协程详解

一.协程的引入 1.通过案例文章引入并发,协程概念 见:[go学习笔记.第十四章.协程和管道] 1.协程的引入,调度模型&#xff0c;协程资源竞争问题 通过上面文章可以总结出Go并发编程原理: 在一个处理进程中通过关键字 go 启用多个协程&#xff0c;然后在不同的协程中完成不同的子任…

Spark-Scala语言实战(1)

在之前的文章中&#xff0c;我们学习了如何在Linux安装Spark以及Scala&#xff0c;想了解的朋友可以查看这篇文章。同时&#xff0c;希望我的文章能帮助到你&#xff0c;如果觉得我的文章写的不错&#xff0c;请留下你宝贵的点赞&#xff0c;谢谢。 Spark及Scala的安装https:/…

图像处理ASIC设计方法 笔记11 像素误差与字长优化

P108 P105 定点误差分析与字长优化 1 像素误差是什么原因导致的? 在本书所说的算法中,像素误差是由几次定点运算累加导致的: 首先由行(列)号与定点正弦/正切值计算出该行(列)的小数平移量,然后将这些小数平移量截取一定字长用来计算插值核,再将这些插值核也截取一…

VMware Worksation 问题

几个晚上在虚拟机装了好多东西&#xff0c;配置mysql&#xff0c;配置docker、Git工具等等&#xff0c;可能废寝忘食导致太困强制关了虚拟机&#xff0c;结果第二天晚上回来发现打不开&#xff0c;心态直接崩了。 问题&#xff1a; 疯狂百度告知要删除后缀为.lck的文件夹及文件…

pytorch 实现线性回归(Pytorch 03)

一 线性回归框架 线性模型的四个模块&#xff1a;训练的数据集&#xff0c;线性模型&#xff0c;损失函数&#xff0c;优化算法。 1.1 数据集 使用房价预测数据集&#xff0c;我们希望根据房屋的面积和房龄等来估算房屋价格。 1.2 线性模型 预测公式&#xff0c; 价格 权重…

蓝桥杯练习系统(算法训练)ALGO-969 N车

资源限制 内存限制&#xff1a;256.0MB C/C时间限制&#xff1a;1.0s Java时间限制&#xff1a;3.0s Python时间限制&#xff1a;5.0s 问题描述 给定NN的棋盘&#xff0c;问有多少种放置N个车使他们不互相攻击的方案。 输入格式 一行一个整数&#xff0c;N。 输出格式…

又是一场心碎的div2

真要破防了&#xff0c;还是没做出C题&#xff0c;感觉这次C已经很简单了。 C题这么多人过&#xff0c;反观D题这个人数有点诡异。但是这么多人过我都没过。看了一个半小时就是没看出哪写错了。 就完全是浪费这么多时间。我真碎了。受不了了。还是晚安吧&#xff0c;每天抄作业…

RT-Thread之USB组件的使用记录(SD卡和USB同时挂载)

前言 使用usb-host组件读取u盘记录同时挂载sd和u盘用到的芯片为stm32f407zgt6u盘的格式为fat 组件选择 文件相关的宏定义 /* DFS: device virtual file system */ /* 设备虚拟文件系统 */ #define RT_USING_DFS #define DFS_USING_WORKDIR #define DFS_FILESYSTEMS_MAX 3 //…

MIT线性代数-方程组的几何解释

文章目录 1. 二维空间1.1 行方向1.2 列方向 2. 三维空间2.1 行方向2.2 列方向 假设有一个方程组 A X B AXB AXB表示如下 2 x − y 0 (1) 2x-y0\tag{1} 2x−y0(1) − x 2 y 3 (2) -x2y3\tag{2} −x2y3(2) 矩阵表示如下&#xff1a; [ 2 − 1 − 1 2 ] [ x y ] [ 0 3 ] (3)…

Python基础入门 --- 4.循环语句

文章目录 Python基础入门第四章&#xff1a;4.1 while循环语句4.1.1 while循环的嵌套4.2 for循环语句4.2.1 range语句4.2.2 变量作用域4.2.3 for循环的嵌套应用 4.3 循环中断 continue和break Python基础入门 第四章&#xff1a; 4.1 while循环语句 语法结构&#xff1a; w…

Day66:WEB攻防-Java安全SPEL表达式SSTI模版注入XXEJDBCMyBatis注入

目录 JavaSec搭建 Hello-Java-Sec搭建 Java安全-SQL注入-JDBC&MyBatis Java安全-XXE注入-Reader&Builder Java安全-SSTI模版-Thymeleaf&URL Java安全-SPEL表达式-SpringBoot框架 知识点&#xff1a; 1、Java安全-SQL注入-JDBC&MyBatis 2、Java安全-XXE注…

html编辑器

HTML 编辑器推荐 html可以使用记事本编辑 但是更建议使用专业的 HTML 编辑器来编辑 HTML&#xff0c;我在这里给大家推荐几款常用的编辑器&#xff1a; VS Code&#xff1a;https://code.visualstudio.com/WebStorm: https://www.jetbrains.com/webstorm/Notepad: https://no…

相机与相机模型(针孔/鱼眼/全景相机)

本文旨在较为直观地介绍相机成像背后的数学模型&#xff0c;主要的章节组织如下&#xff1a; 第1章用最简单的针孔投影模型为例讲解一个三维点是如何映射到图像中的一个像素 第2章介绍除了针孔投影模型外其他一些经典投影模型&#xff0c;旨在让读者建立不同投影模型之间的建模…

考研C语言复习进阶(6)

目录 1. 程序的翻译环境和执行环境 2. 详解编译链接 2.1 翻译环境 ​编辑​编辑 2.2 编译本身也分为几个阶段&#xff1a; 2.3 运行环境 3. 预处理详解 3.1 预定义符号 3.2 #define 3.2.1 #define 定义标识符 3.2.2 #define 定义宏 2.2.3 #define 替换规则 3.2.4…

MySQL语法分类 DQL(6)分页查询

为了更好的学习这里给出基本表数据用于查询操作 create table student (id int, name varchar(20), age int, sex varchar(5),address varchar(100),math int,english int );insert into student (id,name,age,sex,address,math,english) values (1,马云,55,男,杭州,66,78),…

物联网竞赛板CubMx全部功能简洁配置汇总

目录 前言&#xff1a;1、按键&LED灯配置&#xff1a;2、OLED配置&#xff1a;3、继电器配置&#xff1a;4、LORA模块配置&#xff1a;5、矩阵模块&#xff1a;6、串口模块&#xff1a;7、RTC配置&#xff1a;8、ADC模块配置&#xff1a;9、温度传感器模块&#xff1a;后续…