【人工智能原理自学】梯度下降和反向传播:能改

news2025/1/6 20:04:45

😊你好,我是小航,一个正在变秃、变强的文艺倾年。
🔔笔记来自B站UP主Ele实验室的《小白也能听懂的人工智能原理》。
🔔本文讲解梯度下降和反向传播:能改,一起卷起来叭!

目录

  • 一、“挪”
  • 二、再“挪”
  • 三、梯度下降

一、“挪”

一步到位计算固然是好,但是非常消耗计算资源,抛物线最低点的寻找过程,其实不必一步到位,大可以采用一点点挪动的方式。
在这里插入图片描述
那么问题来了:该怎样挪呢?

机制如你,会想到用斜率判断:
在这里插入图片描述
我们知道最低点的斜率为0,当K > 0则,小球则向左挪动,当K < 0,则向右挪动。
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
那么斜率怎么求呢?机制如你又想到求导

在这里插入图片描述
接下来如何调整斜率,我们希望当W距离最低点比较远的时候,我们其实希望它能快一点,而逐渐接近最低点的时候,稳如老狗。

在这里插入图片描述


我们对上述想法进行代码实现:

豆豆数据集模拟:dataset.py

import numpy as np

def get_beans(counts):
	xs = np.random.rand(counts)
	xs = np.sort(xs)
	ys = [1.2*x+np.random.rand()/10 for x in xs]
	return xs,ys

使用固定步长下降算法:sgd_step.py

import dataset
import matplotlib.pyplot as plt
import numpy as np

# 豆豆数量m
m = 100
xs, ys = dataset.get_beans(m)

# 配置图像
plt.title("Size-Toxicity Function", fontsize=12)
plt.xlabel("Bean Size")
plt.ylabel("Toxicity")
plt.scatter(xs, ys)

w = 0.1
y_pre = w * xs
plt.plot(xs, y_pre)
plt.show()

step = 0.01
for _ in range(100):
    # 抛物线代价函数
    # e = x0^2 * w^2 + (-2x0y0) * w + y0^2
    # 斜率k = 2aw + b(求导)
    k = 2 * np.sum(xs ** 2) * w + np.sum(-2 * xs* ys)
    k = k / m
    if k > 0 :
        w = w - step
    else :
        w = w + step
    y_pre = w * xs
    # 绘制动态
    plt.clf() ## 清空窗口
    plt.scatter(xs, ys)
    plt.xlim(0, 1)
    plt.ylim(0, 1.2)
    plt.plot(xs, y_pre)
    plt.pause(0.01) # 暂停0.01秒

实验结果:
在这里插入图片描述
使用随机梯度下降算法:sgd.py

import dataset
import matplotlib.pyplot as plt
import numpy as np

# 豆豆数量m
m = 100
xs, ys = dataset.get_beans(m)

# 配置图像
plt.title("Size-Toxicity Function", fontsize=12)
plt.xlabel("Bean Size")
plt.ylabel("Toxicity")
plt.scatter(xs, ys)

w = 0.1
y_pre = w * xs
plt.plot(xs, y_pre)
plt.show()

for _ in range(100):
    for i in range(100):
        x = xs[i]
        y = ys[i]
        # 抛物线代价函数
        # e = x0^2 * w^2 + (-2x0y0) * w + y0^2
        # 斜率k = 2aw + b(求导)
        k = 2 * (x**2) * w + (-2 * x * y)
        # alpha为学习率
        alpha = 0.1
        w = w - alpha * k
        # 绘制动态
        plt.clf() ## 清空窗口
        plt.scatter(xs, ys)
        y_pre = w * xs
        plt.xlim(0, 1)
        plt.ylim(0, 1.2)
        plt.plot(xs, y_pre)
        plt.pause(0.01) # 暂停0.01秒

# 重新绘制散点图和预测曲线
# plt.scatter(xs, ys)
# y_pre = w * xs
# plt.plot(xs, y_pre)
# plt.show()

使用批量梯度下降算法:sgd_batch.py

import dataset
import matplotlib.pyplot as plt
import numpy as np

# 豆豆数量m
m = 100
xs, ys = dataset.get_beans(m)

# 配置图像
plt.title("Size-Toxicity Function", fontsize=12)
plt.xlabel("Bean Size")
plt.ylabel("Toxicity")
plt.scatter(xs, ys)

w = 0.1
y_pre = w * xs
plt.plot(xs, y_pre)
plt.show()

alpha = 0.01
for _ in range(100):
    # 抛物线代价函数
    # e = x0^2 * w^2 + (-2x0y0) * w + y0^2
    # 斜率k = 2aw + b(求导)
    k = 2 * np.sum(xs ** 2) * w + np.sum(-2 * xs* ys)
    k = k / m
    w = w - alpha * k
    y_pre = w * xs
    # 绘制动态
    plt.clf() ## 清空窗口
    plt.scatter(xs, ys)
    plt.xlim(0, 1)
    plt.ylim(0, 1.2)
    plt.plot(xs, y_pre)
    plt.pause(0.01) # 暂停0.01秒

二、再“挪”

而事实上,一个直线完整的函数 应该是:y = wx + b(截距)
在这里插入图片描述
我们得到的代价函数是这样的:
在这里插入图片描述)
现在我们给b留出一个维度:
在这里插入图片描述
我们会发现得到的是一个“碗”状的曲面:
在这里插入图片描述)

曲面的最低点是:
在这里插入图片描述)
现在我们的目标很明确了,如何求出该最低点(Wmin,Bmin)

我们在b=0处沿着W的方向切上一刀,得到一个开口向上的抛物线,很容易得到最低点:
在这里插入图片描述)
但是我们发现此刻曲线的最低点并不是曲面的最低点
在这里插入图片描述)
那我们接着沿着b的方向切上一刀,得到一个开口向上的抛物线:
在这里插入图片描述)

我们把两个方向上的运动合为一个方向,这样我们完成了一次调整:

在这里插入图片描述)
在这里插入图片描述)
在这里插入图片描述)
我们分别对e函数有关w、b求偏导:
在这里插入图片描述)
在这里插入图片描述)
这样我们的整个过程可以总结为:

在这里插入图片描述)


我们对上述过程代码实现:

豆豆数据集模拟:dataset.py

import numpy as np

def get_beans(counts):
	xs = np.random.rand(counts)
	xs = np.sort(xs)
	ys = np.array([(0.7*x+(0.5-np.random.rand())/5+0.5) for x in xs])
	return xs,ys

豆豆毒性分布如下:
在这里插入图片描述)
代价函数(W为自变量):cost_function_w.py

import dataset
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 从数据中获取随机豆豆
m = 100
xs, ys = dataset.get_beans(m)

# 配置图像
plt.title("Size-Toxicity Function", fontsize=12)
plt.xlabel('Bean Size')
plt.ylabel('Toxicity')

# 豆豆毒性散点图
plt.scatter(xs, ys)

# 预测函数
w = 0.1
b = 0.1
y_pre = w * xs + b

# 预测函数图像
plt.plot(xs, y_pre)

# 显示图像
plt.show()

# 代价函数
ws = np.arange(-1, 2, 0.1)
bs = np.arange(-2, 2, 0.1)

# 配置3D图像显示插件
fig = plt.figure()
ax = Axes3D(fig)
ax.set_zlim(0, 2)

for b in bs:  # 每次取不同的w
    es = []
    for w in ws:
        y_pre = w * xs + b
        # 得到w和b的关系
        e = (1 / m) * np.sum((ys - y_pre) ** 2)
        es.append(e)
    # plt.plot(ws,es)
    ax.plot(ws, es, b, zdir='y')

# 显示图像
plt.show()

实验结果:
在这里插入图片描述)
代价函数(B为自变量):cost_function_b.py

import matplotlib.pyplot as plt
import dataset
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

#从数据中获取随机豆豆
m=100
xs,ys = dataset.get_beans(m)


#配置图像
plt.title("Size-Toxicity Function", fontsize=12)
plt.xlabel('Bean Size')
plt.ylabel('Toxicity')

# 豆豆毒性散点图
plt.scatter(xs, ys) 

#预测函数
w=0.1
b=0.1
y_pre = w*xs+b

#预测函数图像
plt.plot(xs,y_pre) 

#显示图像
plt.show()  




#代价函数
ws = np.arange(-1,2,0.1)
bs = np.arange(-2,2,0.1)


fig = plt.figure()
ax = Axes3D(fig)

ax.set_zlim(0,2)

for w in ws:#每次取不同的w
	es = []	
	for b in bs:
		y_pre = w*xs+b
		#得到w和b的关系
		e = (1/m)*np.sum((ys-y_pre)**2)
		es.append(e)
	#plt.plot(ws,es) 
	figure = ax.plot(bs, es, w, zdir='y')


#显示图像
plt.show()  

实验结果:
在这里插入图片描述)
当然我们也可以使用plot_surface函数绘制曲面绘制曲面版的:cost_function_surface.py

import matplotlib.pyplot as plt
import dataset
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

#从数据中获取随机豆豆
m=100
xs,ys = dataset.get_beans(m)




#配置图像
plt.title("Size-Toxicity Function", fontsize=12)
plt.xlabel('Bean Size')
plt.ylabel('Toxicity')

# 豆豆毒性散点图
plt.scatter(xs, ys) 

#预测函数
w=0.1
b=0.1
y_pre = w*xs+b

#预测函数图像
plt.plot(xs,y_pre) 

#显示图像
plt.show()  



#代价函数
ws = np.arange(-1,2,0.1)
bs = np.arange(-2,2,0.1)

#把ws和bs变成一个网格矩阵
#这个网格矩阵的含义可以参考这篇文章:
#https://blog.csdn.net/lllxxq141592654/article/details/81532855
ws,bs = np.meshgrid(ws,bs)
print(ws)#打印出来瞅瞅
print(bs)


es = 0
#因为ws和bs已经变成了网格矩阵了
#一次性带入全部计算,我们需要一个一个的算
for i in range(m):
	y_pre = ws*xs[i]+bs#取出一个样本在网格矩阵上计算,得到一个预测矩阵
	e = (ys[i]-y_pre)**2#标准值减去预测(矩阵)得到方差矩阵
	es += e#把单样本上的方差矩阵不断累加到es上
es = es/m#求平均值,这样es方差矩阵每个点的位置就是对应的ws和bs矩阵每个点位置预测得到的方差



fig = plt.figure()
ax = Axes3D(fig)

ax.set_zlim(0,2)

#plot_surface函数绘制曲面
#cmap='rainbow表示彩虹图(用不同的颜色表示不同值)
ax.plot_surface(ws, bs, es, cmap='rainbow')

#显示图像
plt.show()  

实验结果:
在这里插入图片描述)

使用随机梯度下降算法:sgd_w_b.py

import dataset
import matplotlib.pyplot as plt

# 豆豆数量m
m = 100
xs, ys = dataset.get_beans(m)

# 配置图像
plt.title("Size-Toxicity Function", fontsize=12)
plt.xlabel("Bean Size")
plt.ylabel("Toxicity")
plt.scatter(xs, ys)

w = 0.1
b = 0.1
y_pre = w * xs + b
plt.plot(xs, y_pre)
plt.show()

# alpha为学习率
alpha = 0.01
# 训练500次
for _ in range(500):
    for i in range(100):
        x = xs[i]
        y = ys[i]
        # 抛物线代价函数
        # 斜率k(求导)
        dw = 2 * (x ** 2) * w + 2 * x * b - 2 * x * y
        db = 2 * b + 2 * x * w - 2 * y

        w = w - alpha * dw
        b = b - alpha * db
    # 训练一次后刷新
    # 绘制动态
    plt.clf()  ## 清空窗口
    plt.scatter(xs, ys)
    y_pre = w * xs + b
    plt.xlim(0, 1)
    plt.ylim(0, 1.2)
    plt.plot(xs, y_pre)
    plt.pause(0.01)  # 暂停0.01秒

实验结果:
在这里插入图片描述)

对于单个样本的曲面实际上是一个U型的特殊碗

三、梯度下降

训练神经网络的时候,基本就是三个步骤:

在这里插入图片描述)

  1. 正向计算网络输出;
  2. 计算Loss;
  3. 反向传播,计算Loss的梯度来更新参数(即梯度下降)。

在这里插入图片描述)

小的训练集上训练的时候,通常每次对所有样本计算Loss之后通过梯度下降的方式更新参数(批量梯度下降),但是在大的训练集时,这样每次计算所有样本的Loss再计算一次梯度更新参数的方式效率是很低的。

因此梯度下降常常分为:随机梯度下降、mini-batch梯度下降以及batch梯度下降。


随机梯度下降(Stochastic Gradient Descent):

在这里插入图片描述
随机梯度下降每次迭代(iteration)计算单个样本的损失并进行梯度下降更新参数,这样在每轮epoch就能进行 m 次参数更新

优点:

  • 参数更新速度大大加快,因为计算完每个样本的Loss都会进行一次参数更新

缺点:

  • 计算量大且无法并行。批量梯度下降能够利用矩阵运算和并行计算来计算Loss,但是SGD每遍历到一个样本就进行梯度计算和参数下降,无法进行有效的并行计算。

  • 容易陷入局部最优导致模型准确率下降。因为单个样本的Loss无法代替全局Loss,这样计算出来的梯度方向也会和全局最优的方向存在偏离。但是由于样本数量多,总体的Loss会保持降低,只不过Loss的变化曲线会存在较大的波动。像下图这样:

    在这里插入图片描述


批量梯度下降(Batch Gradient Descent):

批量梯度下降就是每个epoch计算所有样本的Loss,进而计算梯度进行反向传播、参数更新:

在这里插入图片描述
m 为训练集样本数,l 为损失函数,ϵ 表示学习率

优点:

  • 每个epoch通过所有样本来计算Loss,这样计算出的Loss更能表示当前分类器在于整个训练集的表现,得到的梯度的方向也更能代表全局极小值点的方向。如果损失函数为凸函数,那么这种方式一定可以找到全局最优解。

缺点:

  • 每次都需要用所有样本来计算Loss,在样本数量非常大的时候即使也只能有限的并行计算,并且在每个epoch计算所有样本Loss后只更新一次参数,即只进行一次梯度下降操作,效率非常低。

小批量梯度下降(min-Batch Gradient Descent):

在这里插入图片描述
小批量梯度下降将所有的训练样本划分到 batches 个min-batch中,每个mini-batch包含 batchsize 个训练样本。每个iteration计算一个mini-batch中的样本的Loss,进而进梯度下降和参数更新,这样兼顾了批量梯度下降的准确度和随机梯度下降的更新效率。

  • 当 batch_size=m 时,小批量梯度下降就变成了批量梯度下降;
  • 当 batch_size=1 ,就退化为了SGD。

一般来说 batch_size 取2的整数次方的值。不得不说,“折中调和”真是经久不衰的智慧。

事实上,我们平时用梯度下降的时候说的最多的SGD指的是小批量梯度下降,各种论文里所说的SGD也大都指的mini-batch梯度下降这种方式。tensorflow中也是通过定义batch_size的方式在优化过程中使用小批量梯度下降的方式(当然,也取决于batch_size的设置)


相关代码仓库链接,欢迎Star:传送门

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/146633.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Rollup Decentralization

1. 引言 当前的Rollup为中心化的&#xff0c;这并不是必须的&#xff0c;而是当前的选择。 2. 何为Rollup&#xff1f; Rollup与其它L2或侧链的主要区别在于&#xff1a; L1数据可用性 Rollup&#xff1a;只要L1的数据可用性存在&#xff0c;用户可重构L2状态&#xff0c;然…

SQL LIKE 操作符

LIKE 操作符用于在 WHERE 子句中搜索列中的指定模式。 SQL LIKE 操作符 LIKE 操作符用于在 WHERE 子句中搜索列中的指定模式。 SQL LIKE 语法 SELECT column1, column2, ... FROM table_name WHERE column LIKE pattern; 参数说明&#xff1a; column1, column2, ...&…

php://filter伪协议(总结)

文章目录php://filter伪协议总结php://filter伪协议介绍php://filter伪协议使用方法php://filter过滤器分类filter字符串过滤器string.rot13string.toupperstring.tolowerstring.strip_tagsfilter转换过滤器convert.base64-encodeconvert.base64-decodeconvert.quoted-printabl…

【Kotlin】空安全 ① ( Kotlin 的空安全机制 | 变量可空性 | 默认变量不可赋空值 | 声明可空类型变量 )

文章目录一、Kotlin 的空安全机制二、变量可空性1、默认变量不可赋空值2、声明可空类型变量一、Kotlin 的空安全机制 Java 中的空指针问题 : 在 Java 语言 编写的程序中 , 出现最多的崩溃就是 NullPointerException 空指针异常 , 该异常是 运行时 才爆出的 , 在 代码编写时 以…

冰冰学习笔记:C++11的新特性

欢迎各位大佬光临本文章&#xff01;&#xff01;&#xff01; 还请各位大佬提出宝贵的意见&#xff0c;如发现文章错误请联系冰冰&#xff0c;冰冰一定会虚心接受&#xff0c;及时改正。 本系列文章为冰冰学习编程的学习笔记&#xff0c;如果对您也有帮助&#xff0c;还请各位…

k8s入门教程

文章导读 kubernetes&#xff0c;是一个全新的基于容器技术的分布式架构领先方案&#xff0c;是谷歌严格保密十几年的秘密武器----Borg系统的一个开源版本&#xff0c;于2014年9月发布第一个版本&#xff0c;2015年7月发布第一个正式版本。 kubernetes的本质是一组服务器集群&…

创建自己的docker镜像

dockerfile案例1导入dockerfiel以及java文件导入后第一步docker build -t javaweb:1.0 .docker build -t &#xff08;名称以及对应的版本&#xff09;javaweb:1.0 .&#xff08;空格之后的一个点表示从当前目录开始&#xff09;导入成功之后运行容器即可docker run --name web…

二十七、linux系统详解

一、Linux基础篇 1. Linux目录结构 ⑴ 基本介绍: linux的文件系统是采用级层式的树状目录结构&#xff0c;在此结构中的最上层是根目录“/”&#xff0c;然后在此目录下再创建其他的目录。 深刻理解linux树状文件目录是非常重要的&#xff0c;这里我给大家说明一下。 记住一…

Markdown生成目录结构的方法

参考文章&#xff1a;https://www.cnblogs.com/abc-x/p/13470575.htmlmarkdown生成目录结构的方法&#xff1a;示例&#xff1a;project│ README.md│ file001.txt │└───folder1│ │ file011.txt│ │ file012.txt│ ││ └───subfolder1│ │ file111.txt│ │ fil…

【Linux多线程编程】5. 线程锁(2)——死锁、读写锁

前言 上篇文章【Linux多线程编程】4. 线程锁&#xff08;1&#xff09;——互斥锁 我们介绍了线程同步的其中一种方式——互斥锁&#xff0c;互斥锁也可以理解为独占锁&#xff0c;只要有一个线程拿到该锁&#xff0c;其他的线程想要获取只能阻塞等待。但互斥锁的使用不当也可…

【云原生】Ceph 在 k8s中应用

文章目录一、概述二、Ceph Rook 介绍三、通过Rook在k8s中部署Ceph1&#xff09;下载部署包2&#xff09;部署 Rook Operator3&#xff09;创建 Rook Ceph 集群4&#xff09;部署Rook Ceph 工具5&#xff09;部署Ceph Dashboard6&#xff09;检查6&#xff09;通过ceph-tool工具…

VirusTotal智能搜索itw查找从github下载的恶意Android样本

1. Introduction ITW是in the wild的缩写&#xff0c;VirusTotal提供了itw这个搜索关键词&#xff0c;可以搜到从某个url&#xff08;部分url&#xff09;上下载到的样本。 作者写过的其他VirusTotal智能搜索用法的文章见参考1和2. 2. itw使用 比如为了查找从github下载的恶…

Day852.Thread-Per-Message模式 -Java 性能调优实战

Thread-Per-Message模式 Hi&#xff0c;我是阿昌&#xff0c;今天学习记录的是关于Thread-Per-Message模式的内容。 Thread-Per-Message 模式&#xff0c;简言之就是为每个任务分配一个独立的线程。 并发编程领域的问题总结为三个核心问题&#xff1a; 分工同步互斥 其中&…

client-go源码学习(三):Indexer、SharedInformer

本文基于Kubernetes v1.22.4版本进行源码学习&#xff0c;对应的client-go版本为v0.22.4 3、Informer机制 4&#xff09;、Indexer Indexer中有Informer维护的指定资源对象的相对于etcd数据的一份本地缓存&#xff0c;可通过该缓存获取资源对象&#xff0c;以减少对Kubernete…

计算Java对象大小(附实际例子分析)

对象大小如何计算 对象大小包括俩部分的内容&#xff0c;对象头和对象内容。&#xff08;图片源于网络&#xff09; 对象头 此处假设是64位的JVM 对象地址&#xff0c;占4个字节。对象标记&#xff0c;占8个字节&#xff0c;包括锁标记&#xff0c;hashcode, age 等。数组…

python 如何使用 pandas 在 flask web 网页中分页显示 csv 文件数据

目录 一、实战场景 二、知识点 python 基础语法 python 文件读写 python 分页 pandas 数据处理 flask web 框架 jinja 模版 三、菜鸟实战 初始化 Flask 框架&#xff0c;设置路由 jinja 模版 渲染列表数据 分页请求数据 显示详情页数据示例 运行结果 运行截图 …

[oeasy]python0040_换行与回车的不同_通用换行符_universal_newlines

换行回车 回忆上次内容 区分概念 terminal终端 主机网络中 最终的 端点 TeleTYpewriter 电传打印机终端硬件 shell 终端硬件基础上的 软件壳子 Console 控制台 主机旁边 的 控制面板 存储文件 的 时候 我 在文件里 打了回车\n系统 将0x0a存入字节 进文件换行 自动就有 回车…

航空客运订票系统(C语言,软件用的DEV)

这两天整理之前的作业代码&#xff0c;把自己一点一点敲出来的系统又看了一下&#xff0c;挑几个发出来供大家参考。想要源码、报告可以找我啦&#xff0c;代码的注释之前写的都是非常详细的&#xff01; 但是不是无偿的啦&#xff08;不坑&#xff0c;一杯奶茶喽&#xff0c;不…

Java逃逸分析(附实际例子分析)

Java逃逸分析 1. 什么是Java逃逸分析 我们知道对象一般是在堆上生成的&#xff0c;但这并不是绝对的。特例就是今天要说的逃逸分析。 JVM 在分析代码以后&#xff0c;发现一个对象在声明之后&#xff0c;只有在它当前声明的这个函数中调用&#xff0c;那么它就会将这个对象在…

《微SaaS创富周刊》第3期:GPT-3\ChatGPT、Stable Diffusion等AI模型驱动的微SaaS创意盘点

大家新年好&#xff01;第3期《微SaaS创富周刊》问世啦&#xff01;本周刊面向独立开发者、早期创业团队&#xff0c;报道他们主要的产品形态——微SaaS如何变现的最新资讯和经验分享等。所谓微SaaS&#xff0c;就是“针对利基市场的SaaS”&#xff0c;特点是一般由个人或者小团…