【一起撸的DL框架】5 实现:自适应线性单元

news2025/1/11 6:03:05
  • CSDN个人主页:清风莫追
  • 欢迎关注本专栏:《一起撸个DL框架》
  • GitHub获取源码:https://github.com/flying-forever/OurDL

文章目录

  • 5 实现:自适应线性单元🍇
    • 1 简介
    • 2 损失函数
      • 2.1 梯度下降法
      • 2.2 补充
    • 3 整理项目结构
    • 4 损失函数的实现
    • 5 修改节点类(Node)
    • 6 自适应线性单元

5 实现:自适应线性单元🍇

1 简介

上一篇:【一起撸个DL框架】4 反向传播求梯度

上一节我们实现了计算图的反向传播,可以求结果节点关于任意节点的梯度。下面我们将使用梯度来更新参数,实现一个简单的自适应线性单元

我们本次拟合的目标函数是一个简单的线性函数: y = 2 x + 1 y=2x+1 y=2x+1,通过随机数生成一些训练数据,将许多组x和对应的结果y值输入模型,但是并不告诉模型具体函数中的系数参数“2”和偏置参数“1”,看看模型能否通过数据“学习”到参数的值。

图1:自适应线性单元的计算图

2 损失函数

2.1 梯度下降法

损失是对模型好坏的评价指标,表示模型输出结果与正确答案(也称为标签)之间的差距。所以损失值越小就说明模型越准确,训练过程的目的便是最小化损失函数的值

自适应线性单元是一个回归任务,我们这里将使用绝对值损失,将模型输出与正确答案之间的差的绝对值作为损失函数的值,即 l o s s = ∣ l − a d d ∣ loss=|l-add| loss=ladd

评价指标有了,可是如何才能达标呢?或者说如何才能降低损失函数的值?计算图中有四个变量: x , w , b , l x,w,b,l x,w,b,l,而我们训练过程的任务是调整参数 w , b w,b w,b的值,以降低损失。因此训练过程中的自变量是w和b,而把x和l看作常量。此时损失函数是关于w和b的二元函数 l o s s = f ( w , b ) loss=f(w,b) loss=f(w,b),我们只需要求函数的梯度 ▽ f ( w , b ) = ( ∂ f ∂ w , ∂ f ∂ b ) \triangledown f(w,b)=(\frac{\partial f}{\partial w},\frac{\partial f}{\partial b}) f(w,b)=(wf,bf),则梯度的反方向就是函数下降最快的方向。沿着梯度的方向更新参数w和b的值,就可以降低损失。这就是经典的优化算法:梯度下降法

2.2 补充

关于损失和优化的概念,大家可能还是有些模糊。上面损失只讲到了一个输入x值对应的模型输出与实际结果之间的差距,但使用整个数据集的平均差距可能更容易理解,就像中学的线性回归

图2所示,改变直线的斜率w,将改变直线与数据点的贴近程度,即改变了损失函数loss的值。

在这里插入图片描述
图2:损失与参数更新示意图

参考: 【深度学习】3-从模型到学习的思路整理_清风莫追的博客-CSDN博客

3 整理项目结构

我们的小项目的代码也渐渐多起来了,好的目录结构将使它更加易于扩展。关于python包结构的知识大家可以自行去了解,大致目录结构如下:

- example
- ourdl
	- core
		- __init__.py
		- node.py
	- ops
		- __init__.py
		- loss.py
		- ops.py
	__init__.py

给这个简单框架的名字叫做OurDL,使用框架搭建的计算图等程序放在example目录下。在ourdl/core/node.py中存放了节点基类和变量类的定义,在ourdl/ops/下存放了运算节点的定义,包括损失函数和加法、乘法节点等。

4 损失函数的实现

/ourdl/ops/loss.py中,

from ..core import Node

class ValueLoss(Node):
    '''损失函数:作差取绝对值'''
    def compute(self):
        self.value = self.parent1.value - self.parent2.value
        self.flag = self.value > 0
        if not self.flag:
            self.value = -self.value
    def get_parent_grad(self, parent):
        a = 1 if self.flag else -1
        b = 1 if parent == self.parent1 else -1
        return a * b

其中compute()方法很显然就是对两个输入作差取绝对值;get_parent_grad()方法求本节点关于父节点的梯度。有绝对值如何求梯度?大家可以画一画绝对值函数的图像。

5 修改节点类(Node)

ourdl/core/node.py

class Node:
    pass  # 省略了一些方法的定义,大家可以查看上一篇文章

    def clear(self):
        '''递归清除父节点的值和梯度信息'''
        self.grad = None
        if self.parent1 is not None:  # 清空非变量节点的值
            self.value = None
        for parent in [self.parent1, self.parent2]:
            if parent is not None:
                parent.clear()
    def update(self, lr=0.001):
        '''根据本节点的梯度,更新本节点的值'''
        self.value -= lr * self.grad  # 减号表示梯度的反方向

我在节点类中新增了两个方法,其中clear()用于清除多余的节点值和梯度信息,因为当节点值或梯度已经存在时会直接返回结果而不会递归去求了(get_grad()forward()的代码)。update()有一个学习率参数lr,更新幅度太大可能导致参数值一直在目标值左右晃悠,无法收敛

6 自适应线性单元

/example/01_esay/自适应线性单元.py

import sys
sys.path.append('../..')
from ourdl.core import Varrible
from ourdl.ops import Mul, Add
from ourdl.ops.loss import ValueLoss

if __name__ == '__main__':
    # 搭建计算图
    x = Varrible()
    w = Varrible()
    mul = Mul(parent1=x, parent2=w)
    b = Varrible()
    add = Add(parent1=mul, parent2=b)
    label = Varrible()
    loss = ValueLoss(parent1=label, parent2=add)
    # 参数初始化
    w.set_value(0)
    b.set_value(0)
    # 生成训练数据
    import random
    data_x = [random.uniform(-10, 10) for i in range(10)]  # 按均匀分布生成[-10, 10]范围内的随机实数
    data_label = [2 * data_x_one + 1 for data_x_one in data_x]
    # 开始训练
    for i in range(len(data_x)):
        x.set_value(data_x[i])
        label.set_value(data_label[i])
        loss.forward()  # 前向传播 --> 求梯度会用到损失函数的值
        w.get_grad()
        b.get_grad()
        w.update(lr=0.05)
        b.update(lr=0.1)
        loss.clear()
        print("w:{:.2f}, b:{:.2f}".format(w.value, b.value))
    print("最终结果:{:.2f}x+{:.2f}".format(w.value, b.value))
    

运行结果:

w:0.13, b:0.10
w:0.36, b:0.20
w:0.58, b:0.10
w:0.74, b:0.00
w:1.13, b:0.10
w:1.43, b:0.20
w:1.62, b:0.30
w:1.94, b:0.20
w:1.50, b:0.30
w:1.87, b:0.40
最终结果:1.87x+0.40

上面自适应线性单元的训练,已经能够大致展现深度学习模型的训练流程:

  • 搭建模型 --> 初始化参数 --> 准备数据 --> 使用数据更新参数的值

我们这里参数只更新了10次,结果就已经大致接近了我们的目标函数 y = 2 x + 1 y=2x+1 y=2x+1。大家可以试试更改学习率lr,训练数据集的大小,观察运行结果会发生怎样的变化。(必备技能:调参)


下节预告:激活函数与计算图的非线性拟合能力

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/478106.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

glibc 系统C文件库

下载时候经常会需要选择glibc glibc是GNU发布的libc库,即c运行库。 glibc是linux系统中最底层的api,几乎其它任何运行库都会依赖于glibc。glibc除了封装linux操作系统所提供的系统服务外,它本身也提供了许多其它一些必要功能服务的实现。 li…

tcpdump使用教程

一、概述 tcpdump是一个功能强大的,用于抓取网络数据包的命令行工具,与带界面的Wireshark一样,基于libpcap库构建。这篇文章主要介绍tcpdump的使用。关于如何使用tcpdump的资料中,最有用的就是tcpdump的两个手册。 tcpdump使用手…

CANOE入门到精通——CANOE系列教程记录2

本系列以初学者角度记录学习CANOE,以《CANoe开发从入门到精通》参考学习,CANoe16 demo版就可以进行学习 创建工程 在一个路径中,创建这几个文件夹 创建工程,将工程命名Vehicle_System_CAN.cfg 创建Database dbc文件 在实际开…

SignOff Criteria——OCV applied and results

文章目录 1. O v e r v i e w Overview Overview1.1 w h a t i s o c v what\ is\ ocv what is ocv?1.2 O C V . E f f e c t o n s i g n o f f OCV.\ Effect\ on\ signoff OCV. Effect on signoff1.3 H o w t o r e m o v e t h e e f f e c t s o f O C V Ho…

【hello Linux】可重入函数、volatile和SIGCHLD信号

目录 1. 可重入函数 2. volatile 3. SIGCHLD信号 Linux!🌷 1. 可重入函数 先来谈一下重入函数的概念:重入函数便是在该函数还没有执行完毕便重复进入该函数(一般发生在多线程中); 可重入函数&#xff1a…

C++程序设计——lambda表达式

一、问题引入 在C98中,如果想对一个数据集合中的元素进行排序,可以使用sort()方法,但如果待排序元素为自定义类型,就需要用户自己定义排序时的比较规则。 随着C语法的发展,人们开始觉得其编写比较复杂,每次…

Word2vec原理+实战学习笔记(一)

来源:投稿 作者:阿克西 编辑:学姐 视频链接:https://ai.deepshare.net/detail/p_5ee62f90022ee_zFpnlHXA/6 文章标题: Efficient Estimation of Word Representations in Vector Space 基于向量空间中词表示的有效估计…

【计算机网络】学习笔记:第四章 网络层(七千字详细配图)【王道考研】

基于本人观看学习b站王道计算机网络课程所做的笔记&#xff0c;不做任何获利 仅进行交流分享 特此鸣谢王道考研 若有侵权请联系&#xff0c;立删 如果本篇笔记帮助到了你&#xff0c;还请点赞 关注 支持一下 ♡>&#x16966;<)!! 主页专栏有更多&#xff0c;如有疑问欢迎…

安装chatglm

地址 下载源代码 下载完成后解压 安装cuda 输入nvcc -V查看是否安装cuda 输入nvidia-smi查看支持的最高版本&#xff0c;最高支持12.1 下载cudahttps://developer.nvidia.com/cuda-downloads 双击安装 同意之后点击下一步 选择精简模式即可 等待下载安装包 …

链接sqlite

一.sqlite库函数 1.sqlite3_open()函数 语法&#xff1a;*sqlite3_open(const char *filename, sqlite3 *ppDb) 作用&#xff1a;该例程打开一个指向 SQLite 数据库文件的连接&#xff0c;返回一个用于其他 SQLite 程序的数据库连接对象。 参数1&#xff1a;如果 filename …

如何在自己的Maven工程上搭建Mybatis框架?

编译软件&#xff1a;IntelliJ IDEA 2019.2.4 x64 操作系统&#xff1a;win10 x64 位 家庭版 Maven版本&#xff1a;apache-maven-3.6.3 Mybatis版本&#xff1a;3.5.6 目录 前言 一. 什么是Mybatis框架&#xff1f;1.1 框架是什么&#xff1f;1.2 什么是MyBatis &#xff1f;1…

3.11 C结构体及结构体数组

结构体的意义 问题&#xff1a;学籍管理需要每个学生的下列数据&#xff1a;学号、姓名、性别、年龄、分数&#xff0c;请用C语言程序存储并处理一组学生的学籍。 思考&#xff1a;如果有多个学生&#xff0c;该怎么定义 已学数据类型无法解决。 结构体概述 正式&#xff1a;…

【Sping学习详解】

重新学习Spring很久了&#xff0c;也看了不少的视频&#xff0c;但是没有系统总结&#xff0c;容易忘记&#xff0c;网上寻找相关博客&#xff0c;也没有找到按照路线总结的&#xff0c;只能说不顺我心&#xff0c;所以自己总结一下&#xff01;&#xff01;&#xff01; 从下…

vulnhub靶机dpwwn1

准备工作 下载连接&#xff1a;https://download.vulnhub.com/dpwwn/dpwwn-01.zip 网络环境&#xff1a;DHCP、NAT 下载完后解压&#xff0c;然后用VMware打开dpwwn-01.vmx文件即可导入虚拟机 信息收集 主机发现 端口发现 继续查看端口服务信息 打开网站发现只有Apache默认…

【Spring篇】IOC/DI注解开发

&#x1f353;系列专栏:Spring系列专栏 &#x1f349;个人主页:个人主页 目录 一、IOC/DI注解开发 1.注解开发定义bean 2.纯注解开发模式 1.思路分析 2.实现步骤 3.注解开发bean作用范围与生命周期管理 1.环境准备 2.Bean的作用范围 3.Bean的生命周期 4.注解开发依赖…

行为识别 Activity Recognition

行为识别 行为检测是一个广泛的研究领域&#xff0c;其应用包括安防监控、健康医疗、娱乐等。 课程大纲 导论 图卷积在行为识别中的应用&#xff1a;论文研读&#xff0c;代码解读&#xff0c;实验 Topdown关键点检测中的hrnet&#xff1a;论文研读&#xff0c;代码解读&a…

ETL工具 - Kettle 流程、应用算子介绍

一、Kettle 流程和应用算子 上篇文章对Kettle 转换算子进行了介绍&#xff0c;本篇文章继续对Kettle 的流程和应用算子进行讲解。 下面是上篇文章的地址&#xff1a; ETL工具 - Kettle 转换算子介绍 流程算子主要用来控制数据流程和数据流向&#xff1a; 应用算子则是Kettle给…

ESP32 ESP-Rainmaker 本地点灯控制Demo测试

基于ESP-Rainmaker 本地点灯控制Demo测试 &#x1f33f;ESP-Rainmaker项目地址&#xff1a;https://github.com/espressif/esp-rainmaker/tree/master ✨这个项目早些时候就已经开始测试了&#xff0c;最后卡在了手机APP连接esp32设备端一直无法连接上&#xff0c;也一直没有找…

性能:Intel Xeon(Ice Lake) Platinum 8369B阿里云CPU处理器

阿里云服务器CPU处理器Intel Xeon(Ice Lake) Platinum 8369B&#xff0c;基频2.7 GHz&#xff0c;全核睿频3.5 GHz&#xff0c;计算性能稳定。目前阿里云第七代云服务器ECS计算型c7、ECS通用型g7、内存型r7等规格均采用该款CPU。 Intel Xeon(Ice Lake) Platinum 8369B Intel …