【一起撸个DL框架】5 实现:自适应线性单元

news2024/9/28 19:23:34
  • CSDN个人主页:清风莫追
  • 欢迎关注本专栏:《一起撸个DL框架》
  • GitHub获取源码:https://github.com/flying-forever/OurDL

文章目录

  • 5 实现:自适应线性单元🍇
    • 1 简介
    • 2 损失函数
      • 2.1 梯度下降法
      • 2.2 补充
    • 3 整理项目结构
    • 4 损失函数的实现
    • 5 修改节点类(Node)
    • 6 自适应线性单元

5 实现:自适应线性单元🍇

1 简介

上一篇:【一起撸个DL框架】4 反向传播求梯度

上一节我们实现了计算图的反向传播,可以求结果节点关于任意节点的梯度。下面我们将使用梯度来更新参数,实现一个简单的自适应线性单元

我们本次拟合的目标函数是一个简单的线性函数: y = 2 x + 1 y=2x+1 y=2x+1,通过随机数生成一些训练数据,将许多组x和对应的结果y值输入模型,但是并不告诉模型具体函数中的系数参数“2”和偏置参数“1”,看看模型能否通过数据“学习”到参数的值。

图1:自适应线性单元的计算图

2 损失函数

2.1 梯度下降法

损失是对模型好坏的评价指标,表示模型输出结果与正确答案(也称为标签)之间的差距。所以损失值越小就说明模型越准确,训练过程的目的便是最小化损失函数的值

自适应线性单元是一个回归任务,我们这里将使用绝对值损失,将模型输出与正确答案之间的差的绝对值作为损失函数的值,即 l o s s = ∣ l − a d d ∣ loss=|l-add| loss=ladd

评价指标有了,可是如何才能达标呢?或者说如何才能降低损失函数的值?计算图中有四个变量: x , w , b , l x,w,b,l x,w,b,l,而我们训练过程的任务是调整参数 w , b w,b w,b的值,以降低损失。因此训练过程中的自变量是w和b,而把x和l看作常量。此时损失函数是关于w和b的二元函数 l o s s = f ( w , b ) loss=f(w,b) loss=f(w,b),我们只需要求函数的梯度 ▽ f ( w , b ) = ( ∂ f ∂ w , ∂ f ∂ b ) \triangledown f(w,b)=(\frac{\partial f}{\partial w},\frac{\partial f}{\partial b}) f(w,b)=(wf,bf),则梯度的反方向就是函数下降最快的方向。沿着梯度的方向更新参数w和b的值,就可以降低损失。这就是经典的优化算法:梯度下降法

2.2 补充

关于损失和优化的概念,大家可能还是有些模糊。上面损失只讲到了一个输入x值对应的模型输出与实际结果之间的差距,但使用整个数据集的平均差距可能更容易理解,就像中学的线性回归

图2所示,改变直线的斜率w,将改变直线与数据点的贴近程度,即改变了损失函数loss的值。

在这里插入图片描述
图2:损失与参数更新示意图

参考: 【深度学习】3-从模型到学习的思路整理_清风莫追的博客-CSDN博客

3 整理项目结构

我们的小项目的代码也渐渐多起来了,好的目录结构将使它更加易于扩展。关于python包结构的知识大家可以自行去了解,大致目录结构如下:

- example
- ourdl
	- core
		- __init__.py
		- node.py
	- ops
		- __init__.py
		- loss.py
		- ops.py
	__init__.py

给这个简单框架的名字叫做OurDL,使用框架搭建的计算图等程序放在example目录下。在ourdl/core/node.py中存放了节点基类和变量类的定义,在ourdl/ops/下存放了运算节点的定义,包括损失函数和加法、乘法节点等。

4 损失函数的实现

/ourdl/ops/loss.py中,

from ..core import Node

class ValueLoss(Node):
    '''损失函数:作差取绝对值'''
    def compute(self):
        self.value = self.parent1.value - self.parent2.value
        self.flag = self.value > 0
        if not self.flag:
            self.value = -self.value
    def get_parent_grad(self, parent):
        a = 1 if self.flag else -1
        b = 1 if parent == self.parent1 else -1
        return a * b

其中compute()方法很显然就是对两个输入作差取绝对值;get_parent_grad()方法求本节点关于父节点的梯度。有绝对值如何求梯度?大家可以画一画绝对值函数的图像。

5 修改节点类(Node)

ourdl/core/node.py

class Node:
    pass  # 省略了一些方法的定义,大家可以查看上一篇文章

    def clear(self):
        '''递归清除父节点的值和梯度信息'''
        self.grad = None
        if self.parent1 is not None:  # 清空非变量节点的值
            self.value = None
        for parent in [self.parent1, self.parent2]:
            if parent is not None:
                parent.clear()
    def update(self, lr=0.001):
        '''根据本节点的梯度,更新本节点的值'''
        self.value -= lr * self.grad  # 减号表示梯度的反方向

我在节点类中新增了两个方法,其中clear()用于清除多余的节点值和梯度信息,因为当节点值或梯度已经存在时会直接返回结果而不会递归去求了(get_grad()forward()的代码)。update()有一个学习率参数lr,更新幅度太大可能导致参数值一直在目标值左右晃悠,无法收敛

6 自适应线性单元

/example/01_esay/自适应线性单元.py

import sys
sys.path.append('../..')
from ourdl.core import Varrible
from ourdl.ops import Mul, Add
from ourdl.ops.loss import ValueLoss

if __name__ == '__main__':
    # 搭建计算图
    x = Varrible()
    w = Varrible()
    mul = Mul(parent1=x, parent2=w)
    b = Varrible()
    add = Add(parent1=mul, parent2=b)
    label = Varrible()
    loss = ValueLoss(parent1=label, parent2=add)
    # 参数初始化
    w.set_value(0)
    b.set_value(0)
    # 生成训练数据
    import random
    data_x = [random.uniform(-10, 10) for i in range(10)]  # 按均匀分布生成[-10, 10]范围内的随机实数
    data_label = [2 * data_x_one + 1 for data_x_one in data_x]
    # 开始训练
    for i in range(len(data_x)):
        x.set_value(data_x[i])
        label.set_value(data_label[i])
        loss.forward()  # 前向传播 --> 求梯度会用到损失函数的值
        w.get_grad()
        b.get_grad()
        w.update(lr=0.05)
        b.update(lr=0.1)
        loss.clear()
        print("w:{:.2f}, b:{:.2f}".format(w.value, b.value))
    print("最终结果:{:.2f}x+{:.2f}".format(w.value, b.value))
    

运行结果:

w:0.13, b:0.10
w:0.36, b:0.20
w:0.58, b:0.10
w:0.74, b:0.00
w:1.13, b:0.10
w:1.43, b:0.20
w:1.62, b:0.30
w:1.94, b:0.20
w:1.50, b:0.30
w:1.87, b:0.40
最终结果:1.87x+0.40

上面自适应线性单元的训练,已经能够大致展现深度学习模型的训练流程:

  • 搭建模型 --> 初始化参数 --> 准备数据 --> 使用数据更新参数的值

我们这里参数只更新了10次,结果就已经大致接近了我们的目标函数 y = 2 x + 1 y=2x+1 y=2x+1。大家可以试试更改学习率lr,训练数据集的大小,观察运行结果会发生怎样的变化。(必备技能:调参)


下节预告:激活函数与计算图的非线性拟合能力

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/486896.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第二十七章 Unity碰撞体Collision(下)

本章节我们继续研究碰撞体,并且探索一下碰撞体与刚体之间的联系。我们回到之前的工程,然后给我们的紫色球体Sphere1也添加一个刚体组件。如下所示 此时,两个球体都具备了碰撞体和刚体组件。接下来,我们Play运行查看效果 我们发现&…

第二十六章 Unity碰撞体Collision(上)

在游戏世界中,游戏物体之间的交互都是通过“碰撞接触”来进行交互的。例如,攻击怪物则是主角与怪物的碰撞,触发机关则是主角与机关的碰撞。在DirectX课程中,我们也大致介绍过有关碰撞检测的内容。游戏世界中的3D模型的形状是非常复…

浅谈区块链1.0-比特币

1. 比特币解决的问题 高度自治:国际经济危机无国界贸易:不同国家进行的贸易或者不同平台进行贸易 不可窜改:例如银行交易可能会被窜改数据 隐私安全:传统汇款方式会暴露你的个人信息,一旦数据库被别人入侵&#xff0c…

android基础知识复习

架构: 应用框架层(Java API Framework)所提供的主要组件: 名称功能描述Activity Manager(活动管理器)管理各个应用程序生命周期,以及常用的导航回退功能Location Manager(位置管理器…

SpringBoot整合Mybatis-plus实现多级评论

在本文中,我们将介绍如何使用SpringBoot整合Mybatis-plus实现多级评论功能。同时,本文还将提供数据库的设计和详细的后端代码,前端界面使用Vue2。 数据库设计 本文的多级评论功能将采用MySQL数据库实现,下面是数据库的设计&…

Boonz-KeygenMe#1(★★★)

运行程序 错误: 查壳 没有壳,是汇编写的程序 载入OD 前面是在读取输入内容,到这里开始做计算了 分析 首先遍历了用户名,计算结果保存在EBX,在存放到 0x40E0F8 对EBX中的值再次计算,最后结果保存到 …

JavaFX: Java音乐播放读取歌词

JavaFX: Java音乐播放读取歌词 1、lrc歌词文件2、解析lrc歌词2.1 读取每行歌词2.2 解析歌词时间标签Time-tag2.3 解析歌词标识标签ID-tags2.4 创建对象包含歌词相关信息 3、播放显示歌词** 相关文献 JavaFX: Java音乐播放 1、lrc歌词文件 lrc歌词文件的扩展名 1、标准格式&a…

图像处理:Retinex算法

目录 前言 概念介绍 Retinex算法理论 单尺度Retinex(SSR) 多尺度Retinex(MSR) 多尺度自适应增益Retinex(MSRCR) Opencv实现Retinex算法 SSR算法 MCR算法 MSRCR算法 效果展示 总结 参考文章 前…

基频建模方法总结

基频F0建模方法 语音合成领域需要对基频进行建模,具体到文语转换TTS、语音转换VC、情感语音转换EVC领域等。 语音合成F0 包括文语转换,情感语音转换 TTEF:text-to-emotional-features synthesis EVC:emotional voice conversio…

这些你熟知的 app 和服务,都用上了人工智能

从微软在 Microsoft 365 服务中全面整合 GPT-4 能力 ,让 PPT、Word 文档、Excel 表格的制作变成了「一句话的事」,到 Adobe 刚刚发布 Adobe Firefly模型集合,让图形设计、字体风格、视频渲染乃至 3D 建模的门槛显著降低——你我熟知的那些工…

idea的快捷键

一.idea的快捷键: 递进选择&#xff1a;ctrl w复制行&#xff1a;ctrl d删除行&#xff1a;ctrl y大小写切换&#xff1a;ctrl shift u展开/折叠&#xff1a;ctrl shift 减号/加号向前/向后&#xff1a;ctrl <— / —>Live Template(例如 输入psvm会自动打出mai…

华为OD机试题,用 Java 解【最远足迹】问题

华为Od必看系列 华为OD机试 全流程解析+经验分享,题型分享,防作弊指南华为od机试,独家整理 已参加机试人员的实战技巧华为od 2023 | 什么是华为od,od 薪资待遇,od机试题清单华为OD机试真题大全,用 Python 解华为机试题 | 机试宝典使用说明 参加华为od机试,一定要注意不要…

Python实战项目:手势识别控制电脑音量

今天给大家带来一个OpenCV的实战小项目——手势识别控制电脑音量 先上个效果图&#xff1a; 通过大拇指和食指间的开合距离来调节电脑音量&#xff0c;即通过识别大拇指与食指这两个关键点之间的距离来控制电脑音量大小 技术交流 技术要学会分享、交流&#xff0c;不建议闭…

石英晶体振荡器【Multisim】【高频电子线路】

目录 一、实验目的与要求 二、实验仪器 三、实验内容与测试结果 1、观察输出波形&#xff0c;测量振荡频率和输出电压幅度 2、测量静态工作点的变化范围(IEQmin~IEQmax) 3、测量当静态工作点在上述范围时输出频率和输出电压的变化 4、测量负载变化对振荡频率和输出电压幅…

SpringCloud:微服务保护之初识Sentinel

1.初识Sentinel Sentinel是阿里巴巴开源的一款微服务流量控制组件。Sentinel官网 Sentinel具有以下特征&#xff1a; 丰富的应用场景&#xff1a;Sentinel承接了阿里巴巴近 10 年的双十一大促流量的核心场景&#xff0c;例如秒杀&#xff08;即突发流量控制在系统容量可以承受…

JavaEE阶段测试复习

文章全部内容在个人站点内的置顶文章中,访问密码:AIIT 小凯的宝库 模块三、面向对象 继承: a. 单继承:Java只支持单继承,即一个子类只能有一个直接父类。但子类可以间接地继承多个父类。 b. 构造方法与继承:在子类中可以通过super()关键字调用父类的构造方法。如果子类没…

探索深度学习中的计算图:PyTorch的动态图解析

❤️觉得内容不错的话&#xff0c;欢迎点赞收藏加关注&#x1f60a;&#x1f60a;&#x1f60a;&#xff0c;后续会继续输入更多优质内容❤️ &#x1f449;有问题欢迎大家加关注私戳或者评论&#xff08;包括但不限于NLP算法相关&#xff0c;linux学习相关&#xff0c;读研读博…

Windows系统的JDK安装与配置

1 选择JDK版本 以在Windows 64位平台上安装JDK 8版本为例。JDK 8 Windows版官网下载地址&#xff1a;https://www.oracle.com/java/technologies/downloads/#java8-windows 现在下载需要先注册并登录Oracle的账号。 2 安装 双击jdk安装包&#xff0c;进入安装程序页面直接选择…

freetype用法

freetype用法 文章目录 freetype用法0.实现1.变量定义2.lcd操作获取屏幕信息3.freetype初始化4.绘画 1.字形度量2.类1.FT 中的面向对象2.FT_Library 类3.FT_Face 类4 FT_Size 类5 FT_GlyphSlot 类 3.函数1.把一个字符码转换为一个字形索引FT_Get_Char_Index函数2.从 face 中装…

银行家算法--申请资源

银行家算法–申请资源 问题描述&#xff1a; 输入N个进程(N<100)&#xff0c;以及M类资源&#xff08;M<100&#xff09;&#xff0c;初始化各种资源的总数&#xff0c;T0时刻资源的分配情况。例如&#xff1a; 假定系统中有5个进程{P0&#xff0c;P1&#xff0c;P2&…