【一起撸个深度学习框架】6 折与曲的相会——激活函数

news2024/11/26 1:40:37
  • CSDN个人主页:清风莫追
  • 欢迎关注本专栏:《一起撸个DL框架》
  • GitHub获取源码:https://github.com/flying-forever/OurDL
  • blibli视频合集:https://space.bilibili.com/3493285974772098/channel/series

文章目录

  • 6 折与曲的相会——激活函数🍈
    • 1 前言
    • 2 激活函数
      • 2.1 Relu
      • 2.3 LeakyRelu
    • 3 拟合曲线的尝试
      • 3.1 设计计算图
      • 3.2 实现训练过程
      • 3.3 训练效果
      • 3.4 训练过程的可视化动画
    • 4 补充:参数初始化的影响


6 折与曲的相会——激活函数🍈

1 前言

  • 在上一节,我们实现了一个“自适应线性单元”,不断地将一个一次函数的输入和输出“喂”给它,它就可以自动地找到一次函数 y = w x + b y=wx+b y=wx+b中合适的参数值w和b。计算图通过前向传播和反向传播,初步展现了它的神奇之处。

  • 但在实际遇到的问题中,输入与输出之间往往并不是简单的线性关系,它们之间的函数关系可能是二次的、指数、甚至分段的。此时”自适应线性单元“就不足以满足我们的需求了。而”激活函数“,将为计算图带来一种拟合这些非线性函数关系的能力。

  • 同时为了得到对于激活函数更加清晰和形象化的认知,本节我们还将使用matplotlib对拟合过程进行一些可视化的展现。

  • 本节任务:在计算图中加入激活函数relu,拟合二次函数 y = x 2 y=x^2 y=x2在区间[0,2]的一小段曲线。

可视化效果

请添加图片描述
图1:二次函数拟合动画

2 激活函数

关于”激活函数“这个名称(非专业解释),首先我们可以看阶跃函数。当输入超过0这个阈值时,输出就从0跳到了1,0是一个非激活的状态,而1是一个激活的状态。这个和生物领域中神经元间的突触有一定相似性,当突触间的兴奋性神经元递质超过某个阈值后,下一个神经元才会进入兴奋状态继续传递信号。
f ( x ) = { 0 , x < 0 1 , x > = 0 f(x)=\begin{cases} 0, & x<0 \\ 1, & x>=0 \end{cases} f(x)={0,1,x<0x>=0

在这里插入图片描述
图2:阶跃函数的图像

2.1 Relu

人们发明了许多各式各样的激活函数,它们有着不同的特点,而Relu是其中比较常用的一种。Relu是一个简单的分段函数,它的核心思想是通过多段折线来贴近曲线,折线段越多、越短,拟合效果就越好,理论上使用relu几乎可以较好地任何曲线。
r e l u ( x ) = { x , x > = 0 0 , x < 0 relu(x)=\begin{cases} x, & x>=0 \\ 0, & x<0 \end{cases} relu(x)={x,0,x>=0x<0

在这里插入图片描述
图3:Relu函数图像

Relu节点的实现

# ourdl/ops/ops.py
class Relu(Op):
    def compute(self):
        assert len(self.parents) == 1
        self.value = self.parents[0].value if self.parents[0].value >= 0 else 0
    def get_parent_grad(self, parent):
        return 1. if self.parents[0].value > 0 else 0  # 发现relu的导函数就是step
    @staticmethod
    def relu(x: float):
        '''静态方法 --> 在计算图之外使用relu'''
        return x if x >= 0. else 0.

在前向传播的过程中,它接受一个父节点的输入,并产生一个输出。我们还使用装饰器@staticmethod,实现了一个静态方法relu(x),这样我们也可以在计算图之外直接调用relu函数了,例如可以在使用matplotlib绘制函数图像时用到。

2.3 LeakyRelu

和加法节点、乘法节点等节点一样,激活函数也是计算图中的一个运算节点,需要在该节点类中实现对应的get_parent_grad()方法对父节点进行求导。Relu函数在输入小于0时函数值都是0,对应的导数也是0,这种情况下参数就不会进行更新了。

人们提出了一种对Relu函数的修正方案,那就是LeakyRelu。在输入大于等于0的部分函数值不变,仍然是x;但是在输入小于0的部分取 0.1 x 0.1x 0.1x,这样在反向传播的过程中,节点的输入小于0时,虽然导数只有0.1,但并没有直接消失,参数仍然可以进行更新。(这里0.1是一个”超参数“,也可以取其它值)

在我的一些尝试中,使用Relu函数时训练过程会卡住一直无法拟合,但LeakyRelu可以一定程度上缓解问题,仍然可以拟合只是比较慢。

在这里插入图片描述
图4:LeakyRelu的函数图像

LeakyRelu节点的实现

# ourdl/ops/ops.py
class LeakyRelu(Op):
    '''消除了relu中导数为0的情况'''
    def compute(self):
        assert len(self.parents) == 1
        t = self.parents[0].value
        self.value = t if t >= 0 else t * 0.1
    def get_parent_grad(self, parent):
        return 1. if self.parents[0].value > 0 else 0.1  # 发现relu的导函数就是step
    @staticmethod
    def relu(x: float):
        '''静态方法 --> 在计算图之外使用leakyrelu'''
        return x if x >= 0. else x * 0.1

超参数”0.1"直接写死在代码中了,因为它通常并不需要改变。

3 拟合曲线的尝试

3.1 设计计算图

在这里插入图片描述
图5:计算图的设计

这个计算图中明确地画出了所有的节点,看起来有一些复杂。从整体看,计算图包含了三次变换:

输入 − − > 线性变换 − − > 激活函数 − − > 线性变换 − − > 输出 输入-->线性变换-->激活函数-->线性变换-->输出 输入>线性变换>激活函数>线性变换>输出

上述三次变换各自的意义是什么?计算图为什么设计成这个样子?这都是很重要的问题。

其实在图一中大致就能得到答案。

  • 第一次变换,产生两条不同的直线,它们有着不同的斜率,更重要的是:它们与x轴有着不同的交点
  • 第二次变换,使用激活函数relu(图一中是LeakyRelu),两条直线都在与x轴的交点处折断,得到两条折线。
  • 第三次变换,两条折线线性叠加,由于它们折断点不同,故得到是一个三段折线。

而我所希望的,就是利用三段的折线去尽可能地贴合二次函数的曲线。

3.2 实现训练过程

1、计算图的搭建

# example/01_esay/04_relu与二次拟合.py
import sys
sys.path.append('../..')  # 父目录的父目录
from ourdl.core import Varrible
from ourdl.ops import Mul, Add
from ourdl.ops.loss import ValueLoss
from ourdl.ops import LeakyRelu as Relu
import matplotlib.pyplot as plt
import numpy as np
import random

# 1.1 线性变换一
x = Varrible()
w_11 = Varrible()
w_12 = Varrible()
mul_11 = Mul([x, w_11])
mul_12 = Mul([x, w_12])
b_11 = Varrible()
b_12 = Varrible()
add_11 = Add([mul_11, b_11])
add_12 = Add([mul_12, b_12])
# 1.2 激活函数 --> 非线性变换
relu_11 = Relu([add_11])
relu_12 = Relu([add_12])
# 1.3 线性变换二
w_21 = Varrible()
w_22 = Varrible()
mul_21 = Mul([relu_11, w_21])
mul_22 = Mul([relu_12, w_22])
b_21 = Varrible()
add_21 = Add([mul_21, mul_22, b_21])
# 1.4 损失函数
label = Varrible()
loss = ValueLoss([label, add_21])

在完成计算图的设计后,搭建的过程比较简单,就是一些节点的创建和连接。由于节点的数量比较多,因此稍有些繁琐,后面我们会实现一些对象和方法用于批量创建和连接节点以及计算图的封装,简化计算图的搭建过程。一个个节点地创建也有其优势——灵活。

2、初始化计算图参数

# example/01_esay/04_relu与二次拟合.py
# 2 参数初始化
params = [w_11, w_12, b_11, b_12, w_21, w_22, b_21]
for param in params:
    param.set_value(random.uniform(-1, 1))
print([param.value for param in params])

这里调用了random库,使用均匀分布进行参数的随机初始化。将所有需要训练的参数加入到了一个列表params中,方便批量进行初始化以及后面的参数更新。

有时为了对比多次训练的效果,需要进行固定的初始化(初始化有时可以很大程度地影响训练效果),你可以手动地指定这些参数的初始值,例如

values = [-0.1571950013426796, -0.1070365984042347, 0.3791639008324807, 0.31960284774415215, 0.4263410176300597, 0.5097967360623379, 0.7597168751185974]
for i in range(len(params)):
    params[i].set_value(values[i])

如果你使用的是Relu激活函数而不是LeakyRelu,同时采用随机参数初始化,你将发现你的训练时而成功时而失败。

3、构造训练数据

# example/01_esay/04_relu与二次拟合.py
# 3 生成数据
data_x = [random.uniform(0, 2) for i in range(1500)]  # 似乎实数比离散的[0, 1, 2]要好
data_label = [x * x for x in data_x]

使用均匀分布,在[0, 2]的范围内生成了1500个随机值,作为输入的x。然后使用了列表推导式得到对应的二次函数输出值,作为计算图中的标签。

4、训练过程

# example/01_esay/04_relu与二次拟合.py
# 4 开始训练
losses = []
for i in range(len(data_x)):
    x.set_value(data_x[i])
    label.set_value(data_label[i])
    loss.forward()
    for param in params:
        param.get_grad()
        param.update(lr=0.01)
    if i % 100 == 0:
        print(f'[{i}]:loss={loss.value},', [param.value for param in params])
    losses.append(loss.value)
    loss.clear()

# 5 画出训练过程中loss的变化曲线
show_x = [i for i in range(len(losses))]
show_y = [_ for _ in losses]
plt.plot(show_x, show_y)
plt.show()

训练过程与上一节“自适应线性单元”基本相同。在第三步构造训练数据时,我们生成了1500个数据样本,因此绘制曲线来观察训练过程中的损失变化会更加直观。我们使用losses列表记录了每次参数更新后的损失值,并使用matplotlib库绘制损失的变化曲线。

3.3 训练效果

这里我们就完成了训练过程的所有代码编写,让我们运行一下代码看看效果吧!

在这里插入图片描述
图6:损失变化曲线

可以看到随着训练过程的进行,损失值呈下降的趋势,并渐渐趋于平稳。损失值越低表示着模型的输出越准确,我们的模型看起来好像训练得还不错。

但只看损失函数其实还是不太直观,我们可以直接将模型所表示的函数,与二次函数 y = x 2 y=x^2 y=x2画在一起,看看它们到底贴得近不近。

3.4 训练过程的可视化动画

当然,我觉得只看一个最后的贴合结果还不够,甚至只看输出的结果也仍然不能很清晰地了解训练的过程。所以我决定将中间节点的输出也画出来,并随着训练的过程以动画的形式呈现。

效果大家已经看过啦!就是图1所示的动画。

# example/01_esay/04_relu与二次拟合.py
# 4 开始训练
# 4.1 创建画布
fig = plt.figure(figsize=(15,4))
ax = fig.subplots(1,3,sharex=True,sharey=False)  # ax是包含一行三列,一共三块子画布的列表
# 4.2 训练,同时绘制动画
losses = []
for i in range(len(data_x)):
    x.set_value(data_x[i])
    label.set_value(data_label[i])
    loss.forward()
    for param in params:
        param.get_grad()
        param.update(lr=0.01)
    if i % 200 == 0:
        print(f'[{i}]:loss={loss.value},', [param.value for param in params])
    losses.append(loss.value)
    loss.clear()
    if i % 40 == 0:
        show_ax_mul(ax)  # 用于绘制多图动画

首先我们修改了 4、训练过程 部分的代码,创建画布,然后将画布对象传递给show_ax_mul()函数,绘制图像。在show_ax_mul()每次绘制完成后,调用plt.pause()让画面暂停下(否则画面会一闪而逝啥也看不清),然后清空画布方便下次绘图。

画布的绘制、清空都是在后台的,因此“清空画布”操作不会直接清空已经画出的函数图像。在下次绘制图像时,才会将原来的图像覆盖掉。

通过反复的绘制——清空,就形成了动画的效果。这样绘制的效率比较低,大家也可以自行搜索其它的动画绘制方法。

# example/01_esay/04_relu与二次拟合.py
def show_ax_mul(ax):
    '''
    用于绘制多图动画\n
	'''
    # 1 画真实的二次函数曲线
    show_x = np.linspace(-0.5, 2.5, 30, endpoint=True)
    show_y = [x_one * x_one for x_one in show_x]
    ax[2].plot(show_x, show_y)
    # 2 画模型拟合的曲线
    show_x = np.linspace(-0.5, 2.5, 10, endpoint=True)
    shows = {'add1':[], 'add2':[], 'mul1':[], 'mul2':[], 'y':[]}
    for x_one in show_x:
        x.set_value(x_one)
        add_21.clear()
        add_21.forward()
        shows['y'].append(add_21.value)
        shows['add1'].append(add_11.value)
        shows['add2'].append(add_12.value)
        shows['mul1'].append(mul_21.value)
        shows['mul2'].append(mul_22.value)
    ax[2].plot(show_x, shows['y'])
    ax[0].plot(show_x, shows['add1'])
    ax[0].plot(show_x, shows['add2'])
    ax[1].plot(show_x, shows['mul1'])
    ax[1].plot(show_x, shows['mul2'])
    y_0 = [0 for _ in show_x]  # 水平的参考线
    ax[2].plot(show_x, y_0)
    ax[0].plot(show_x, y_0)
    ax[1].plot(show_x, y_0)
    plt.pause(0.01)
    for i in range(ax.shape[0]):
        ax[i].cla()

4 补充:参数初始化的影响

参数的初始化对训练效果的影响真的很大!它可以影响训练的速度,最终成功拟合的方式,以及是否能够成功拟合。大家可以尝试着调整各个初始化参数的正负以及大小,观察它们对应训练效果的影响,并思考这样的影响是怎样产生的。

同时,由于我们的计算图还比较简单,也方便进行比较透彻的思考。

对比 图7图8 的拟合结果,发现它们虽然有所区别,但是都较好的拟合了二次函数在[0, 2]的这一段曲线。这种区别是损失函数曲线中所无法体现出来的信息,这便是可视化的意义之一。

图8的训练迭代了1500次,而 图7 迭代了10000次以上,两次训练的区别仅仅就是参数的初始化不同。随着计算图参数规模的增大,对应参数初始化的敏感程度会没有那么高。但利用不同的随机分布、不同的数值范围进行初始化,有时训练效果仍会有较大的差距。

# 图8对应的参数初始化
values = [-0.1571950013426796, -0.1070365984042347, 0.3791639008324807, 0.31960284774415215, 0.4263410176300597, 0.5097967360623379, 0.7597168751185974]
# 图7对应的参数初始化
values = [-0.4571950013426796, -0.4070365984042347, 0.3791639008324807, 0.31960284774415215, 0.4263410176300597, 0.5097967360623379, 0.7597168751185974]
在这里插入图片描述
图7:实际的拟合方式
在这里插入图片描述
图8:我期待的拟合方式

下节预告:计算图的封装

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/529644.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

史蒂夫·青木主题的 Game Jam

准备好潜入史蒂夫青木的脑海中&#xff0c;创造一个探索他内心思想的游戏吧&#xff01;史蒂夫青木主题 Game Jam 正式推出&#xff0c;这是一场为期两周的游戏制作比赛&#xff0c;鼓励参赛者创造和史蒂夫青木内心世界有关的游戏。 探索这位传奇艺术家和 DJ 潜意识&#xff0c…

nginx压测记录

nginx压测记录 1 概述2 原理3 环境3.1 设备与部署3.2 nginx配置/服务器配置 4 netty服务5 步骤6 结果7 写在最后 1 概述 都说nginx的负载均衡能力很强&#xff0c;最近出于好奇对nginx的实际并发能力进行了简单的测试&#xff0c;主要测试了TCP/IP层的长链接负载均衡 2 原理 …

Python 与数据科学实验(Exp9)

实验9 多分类手写数字识别实验 1.实验数据 &#xff08;1&#xff09;训练集 所给数据一共有42000张灰度图像&#xff08;分辨率为28*28&#xff09;&#xff0c;目前以train_data.csv文件给出. 图像内容涵盖了10个手写数字0-9。 图像示例如图所示&#xff1a; train_data.…

算法(一)—— 回溯(4)困难题

文章目录 1 37 解数独2 51 N 皇后 1 37 解数独 首先明确需要两个for循环&#xff0c;这样才可以遍历整个9*9的表。 此题数字的选取逻辑再次展现了回溯的暴力性。 此题需要拥有返回值&#xff0c;与数据结构&#xff08;六&#xff09;—— 二叉树&#xff08;5&#xff09;中…

物联网和云计算:如何将设备数据和云端服务相结合

第一章&#xff1a;引言 物联网和云计算是当今IT领域中的两个重要概念&#xff0c;它们的结合为企业和个人带来了巨大的机遇和挑战。物联网通过连接各种设备和传感器&#xff0c;实现了设备之间的互联互通&#xff0c;而云计算则提供了强大的计算和存储能力。本文将深入探讨如何…

MySQL学习(基础篇1.0)

MySQL概述&#xff08;基础&#xff09; SQL 全称Structured Query Language,结构化察浑语言。操作关系型数据库的编程语言&#xff0c;定义了一套操作关系型数据库的统一标准。 SQL通用语法 SQL语言的统统用语法&#xff1a; SQL语句可以单行或多行书写&#xff0c;以分号…

论文阅读|基于图神经网络的配电网故障定位方法

来源&#xff1a;北京交通大学硕士学位论文&#xff0c;2022 摘要 电网拓扑形态多样&#xff0c;重构场景频繁&#xff0c;&#xff0c;传统故障定位方法的单一阈值设定无法满足要求&#xff0c;基于人工智能的配电网故障定位技术具有很大的应用潜力&#xff0c;但仍存在着拓…

HTML概述及常用语法

什么是 HTML HTML 用来描述网页的一种语言 HTML -- hyper text markup language 超文本标记语言 超文本包括&#xff1a;文字、图片、音频、视频、动画等等 标记语言&#xff1a;是一套标记标签&#xff0c; HTML 使用标记标签来 描述 网页 <> HTML 发展史 HTML5 …

Web基础 ( 二 ) CSS

2.CSS 2.1.概念与基础 2.1.1.什么是CSS Cascading Style Sheets 全称层叠样式单 简称样式表。 是告诉浏览器如何来显示HTML的元素的特殊标记 2.1.2.编写方式 2.1.2.1.外部文件 在html文件的<head>中加入<link>结点来引入外部的文件 <link rel"stylesh…

Go Wails Docker图形界面管理工具 (5)

文章目录 1. 前言2. 效果图3. 代码 1. 前言 接上篇&#xff0c;本次添加Docker存储卷功能 待优化: 优化分页效果添加存储卷大小查看功能 2. 效果图 3. 代码 直接调用官方库 app.go func (a *App) VolumeList() ([]*volume.Volume, error) {resp, err : Cli.VolumeList(context…

Linux中关于时间修改的命令

目录 Linux中关于时间修改的命令 data命令 语法格式 示例 date命令中的参数以及作用 常用格式示例 timedatectl命令 语法格式 timedatectl 命令中的参数以及作用 常用格式 Linux中关于时间修改的命令 data命令 data --- 用于显示或设置系统的时间与日期 用户只需在强…

干货丨警惕!14个容易导致拒稿的常见错误

Hello,大家好&#xff01; 这里是壹脑云科研圈&#xff0c;我是喵君姐姐~ 从做研究、到写论文、再到投稿&#xff0c;每一步都是巨大的挑战。以下列举了一些在这些过程中可能导致拒稿的常见错误&#xff0c;希望能帮助大家避开。 01 格式问题 1.没有遵守投稿须知 期刊提供了…

oracle基于时间点恢复遇到ORA-10877错误

一次给客户进行基于时间点恢复的时候,出现报错ORA-10877,如下: 这里很奇怪,这个归档日志有的,当前全库的备份是05-14 23点的,所以应该是可以恢复的,检查一下alter日志: 这里报错,指定的时间scn不属于当前的incarnation,那么检查一下当前的incarnation: 这里当前的incarnation是…

Linux实操篇---常用的基本命令3(用户(组)管理命令、文件权限类、搜索查找类、压缩解压类)

一、用户管理命令 Linux是一个多用户&#xff0c;多任务的分时操作系统。甚至有可能同时登录&#xff0c;同时操作。所以给用户不同的账号。 useradd添加新用户 基本语法&#xff1a; 只能用root进行操作。 useradd 用户名 添加新用户 useradd -g 组名 用户名 添加新用…

MyBatis Plus 代码生成器

一、引入POM依赖 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.5.3.1</version></dependency><dependency><groupId>com.baomidou</groupId&g…

3ds Max云渲染平台哪个好?

3ds Max云渲染平台哪个好&#xff1f; 3ds Max是一款包含建模、动画、粒子动力学等强大功能的三维动画制作软件&#xff0c;3ds Max对特定如游戏建模、特效制作、产品模型设计等领域都具备了过硬的专业能力&#xff0c;同时3ds Max也是很多CGer青睐的CG软件。 作为支持3ds Ma…

黑马Redis笔记高级篇 | Redis最佳实践

黑马Redis笔记高级篇 | Redis最佳实践 1、Redis键值设计1.1、优雅的key结构1.2、拒绝BigKey1.3、恰当的数据类型1.4、总结 2、批处理优化1.1、Pipeline1.2、集群下的批处理 3、服务端优化3.1、持久化配置3.2、慢查询3.3、命令及安全配置3.4、内存配置 4、集群最佳实践 1、Redis…

深度学习用于医学预后-第二课第三周14-15节-评估方法比较以及Kaplan-Meier估计

评估对比 我们现在对 t25 的生存率得出了一个新的估计值&#xff0c;为0.56。现在&#xff0c;让我们将其与之前所做的估计进行比较。 当我们假设所有患者在他们截尾时间立即死亡时&#xff0c;我们获得了一个低生存概率为0.29。而在另一极端&#xff0c;如果我们假设他们永久…

LeetCode 周赛 345(2023/05/14)体验一题多解的算法之美

本文已收录到 AndroidFamily&#xff0c;技术和职场问题&#xff0c;请关注公众号 [彭旭锐] 提问。 往期回顾&#xff1a;LeetCode 双周赛第 104 场 流水的动态规划&#xff0c;铁打的结构化思考 周赛概览 T1. 找出转圈游戏输家&#xff08;Easy&#xff09; 标签&#xff…

微信小程序入门02-安装mysql

我们上一篇介绍的是微信开发者工具的安装&#xff0c;开发一个小程序肯定要有后端服务&#xff0c;有后端服务首先要可以存储和查询数据。 数据库种类比较多&#xff0c;我们这里选择mysql&#xff0c;为啥选择这个呢&#xff0c;因为首先用的人多比较稳定&#xff0c;再一个免…