【Python机器学习】梯度下降法的讲解和求解方程、线性回归实战(Tensorflow、MindSpore平台 附源码)

news2024/11/25 6:34:02

需要全部源码请点赞关注收藏后评论区留言私信~~~

基本思想

迭代关系式是迭代法应用时的关键问题,而梯度下降(Gradient Descent)法正是用梯度来建立迭代关系式的迭代法。 机器学习模型的求解一般可以表示为:

其中,f(x)为机器学习模型的损失函数。

也称为无约束最优化模型。

对于无约束最优化问题argmin┬xf(x),其梯度下降法求解的迭代关系式为:

 

式中,x为多维向量,记为x=(x^(1),x^(2),…,x^(n));α为正实数,称为步长,也称为学习率;df(x)/dx=(■8(■8(∂f(x)/∂x^(1)&∂f(x)/∂x^(2))&■8(⋯&∂f(x)/∂x^(n))))是f(x)的梯度函数。

 梯度下降法的几个问题:

1)梯度下降法的结束条件,一般采用:①迭代次数达到了最大设定;②损失函数降低幅度低于设定的阈

2)关于步长α,过大时,初期下降的速度很快,但有可能越过最低点,如果“洼地”够大,会再折回并反复振荡。如果步长过小,则收敛的速度会很慢。因此,可以采取先大后小的策略调整步长,具体大小的调节可根据f(x)降低的幅度或者x前进的幅度进行。

3)关于特征归一化问题,梯度下降法应用于机器学习模型求解时,对特征的取值范围也是敏感的,当不同的特征值取值范围不一样时,相同的步长会导致尺度小的特征前进比较慢,从而走之字型路线,影响迭代的速度,甚至不收敛。

梯度下降法解方程

梯度下降法求解方程示例:为了迭代到取值为0的点,可采取对原函数取绝对值或者求平方作为损失函数。

在MindSpore中,通过mindspore.ops.GradOperation提供对任意函数式自动求导的支持。

在TensorFlow2中,通过GradientTape提供对自动微分的支持,它记录了求微分的过程,为后续自动计算导数奠定了基础。

 部分代码如下

### 求方程的根
class loss_func(ms.nn.Cell): # 用方程的平方作为求导目标函数
    def __init__(self):
        super(loss_func, self).__init__()
        self.mspow = ms.ops.Pow()       
    def construct(self, x):
        y = self.mspow(x, 3.0) + self.mspow(math.e, x)/2.0 + 5.0*x - 6
        y = self.mspow(y, 2) # 方程的输出的平方
        return y
x = ms.Tensor([0.0], dtype=ms.float32)
for i in range(200): # 200次迭代
    grad = GradNetWrtX(loss_func())(x)
    #print(grad)
    x = x - 2.0 * alpha * grad # 步长加大一倍
    print(str(i)+":"+str(x))

import tensorflow as tf
x = tf.constant(1.0)
with tf.GradientTape() as g:
    g.watch(x)
    y = x**3 + (math.e**x)/2.0 + 5.0*x - 6
dy_dx = g.gradient(y, x)
print(dy_dx)
>>> tf.Tensor(9.35914, shape=(), dtype=float32)
x = tf.constant(0.0)
for i in range(200):
    with tf.GradientTape() as g:
        g.watch(x)
        loss = tf.pow(f(x), 2)
    grad = g.gradient(loss, x)
    x = x – 2.0 * alpha * grad
    print(str(i)+":"+str(x))

 梯度下降法解线性回归问题

线性回归问题中m个样本的损失函数表示为:

回归系数的更新过程如下:

 

 

50000次迭代后效果如下

 

 部分代码如下

alpha = 0.00025
class loss_func2(ms.nn.Cell):
    def __init__(self):
        super(loss_func2, self).__init__()
        self.transpose = ms.ops.Transpose()
        self.matmul = ms.ops.MatMul()
        
    def construct(self, W, X, y):
        k = y - self.matmul(X, W)
        return self.matmul(self.transpose(k, (1,0)), k) / 2.0

class GradNetWrtW(ms.nn.Cell):
    def __init__(self, net):
        super(GradNetWrtW, self).__init__()
        self.net = net
        self.grad_op = ms.ops.GradOperation()
        
    def construct(self, W, X, y):
        gradient_func = self.grad_op(self.net)
        return gradient_func(W, X, y)
    
X = ms.Tensor((np.mat([[1,1,1,1,1,1], temperatures])).T, dtype=ms.float32)
y = ms.Tensor((np.mat(flowers)).T, dtype=ms.float32)
W = ms.Tensor([[0.0],[0.0]], dtype=ms.float32)
for i in range(50000):
    grad = GradNetWrtW(loss_func2())(W, X, y)
    #print(grad)
    W = W - alpha * grad
    print(i,'--->', '\tW:', W)

alpha = 0.00025
X = tf.constant( (np.mat([[1,1,1,1,1,1], temperatures])).T, shape=[6, 2], dtype=tf.float32)
y = tf.constant( (np.mat(flowers)), shape=[6, 1], dtype=tf.float32)

def linear_mode(X, W):
    return tf.matmul(X, W)

W = tf.ones([2,1], dtype=tf.float32)

for i in range(50000):
    with tf.GradientTape() as g:
        g.watch(W)
        loss = tf.reduce_sum( tf.pow(linear_mode(X, W) - y, 2) ) /2.0
    grad = g.gradient(loss, W)
    #print(grad)
    W = W - alpha * grad
    print(i,'--->', '\tW:', W)#, '\t\tloss:', loss)

随机梯度下降和批梯度下降

从梯度下降算法的处理过程,可知梯度下降法在每次计算梯度时,都涉及全部样本。在样本数量特别大时,算法的效率会很低。

随机梯度下降法(Stochastic Gradient Descent,SGD),试图改正这个问题,它不是通过计算全部样本来得到梯度,而是随机选择一个样本来计算梯度。随机梯度下降法不需要计算大量的数据,所以速度快,但得到的并不是真正的梯度,可能会造成不收敛的问题。

批梯度下降法(Batch Gradient Descent,BGD)是一个折衷方法,每次在计算梯度时,选择小批量样本进行计算,既考虑了效率问题,又考虑了收敛问题。

创作不易 觉得有帮助请点赞关注收藏~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/101053.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LSTM返向传播代码实现——LSTM从零实现系列(4)

一、前言 这个LSTM系列是在学习时间序列预测过程中的一些学习笔记,包含理论分析和源码实现两部分。本质属于进阶内容,因此神经网络的基础内容不做过多讲解,想学习基础,可看之前的神经网络入门系列文章: https://blog.…

IntelliJ IDEA建立SSM论文基本增删改查管理系统

IntelliJ IDEA建立SSM论文基本增删改查管理系统 1、启动IntelliJ IDEA程序 2、点击File----->New ------>Project建立项目 3、在弹出的对话框中,左边点击”maven”建立maven项目,右边的选择框不要选择,选择maven-archetype-webapp不能…

GAN2 ~

这也是第二季了 近年来,基于生成对抗式网络(Generative Adversarial Network, GAN)的图片生成研究工作取得了显著的进展。除了能够生成高分辨率、逼真的图片之外,许多创新应用也应运而生,诸如图片个性化编辑、图片动画…

C++ STL算法(一)利用STL算法解决很常见的一些子问题

文章目录next_permutationlower_bound 与 upper_boundpartial_sumsort 与 uniquenext_permutation cplusplus: next_permutation 作用:得到所有的全排列 例题: P1706 全排列问题 void test1() {int n;cin >> n;int* arr new int[n…

Oracle基础版

这是上上周的事情,我们甲方强烈要求使用oracle数据库,也就上学的时候玩过Oracle也忘得差不多了,所以一直不想弄,我们开会产品说要提测了,我还没弄,这不得哐哐开始干活,过程吧还算顺利&#xff0…

Java学习之第八章练习题-1

目录 第一题 题目 我的代码 Person类 错误 正确写法 输出结果 附加要求 代码 结果 第二题 题目 答案 第三题 题目 代码 总结不足 创建对象并运行 第四题 题目 运行结果​编辑 第五题 题目 第六题 题目 第一题 题目 我的代码 Person类 package com.hspedu…

DBCO-NHS 1353016-71-3,二苯基环辛炔-活性酯 可用于以高特异性和反应性标记叠氮化物修饰的生物分子

名称 DBCO-NHS ester 中文名称 二苯基环辛炔-活性酯 英文名称 DBCO-NHS NHS-DBCO 分子量 402.40 CAS 1353016-71-3 溶剂 溶于DMSO, DMF, DCM, THF, Chloroform 存储条件 -20冷冻保存 保存时间 一年 结构式 DBCO(二苯并环辛炔)是一种环炔烃&…

怎么将视频转为音频mp3格式?这些转换方法一分钟就能学会

随着现在娱乐方式的多样化,我们可以在闲暇时间做一些令人放松的事情。对于我来说,就很喜欢一边听歌一边发呆。我之前喜欢的一位歌手,他的翻唱歌曲以及原创音乐都得到了网友很高的评价,但是有些歌曲在平台上没有音源,我…

【内网安全-CS】Cobalt Strike启动运行上线方法

目录 一、启动运行 1、第一步:进入cs目录 2、第二步:查看本机ip 3、第三步:启动"团队服务器" 4、第四步:客户端连接 二、上线方法 1、第一步:生成监听器 2、第二步:生成木马 3、第三步&…

如何将智能设备关联至云开发中的项目?

将应用中已经连接的设备关联至云项目后,就可以在 涂鸦 IoT 开发平台 通过云开发主动管理和控制对应的设备。云开发提供多种应用中的设备关联方式: 关联自有 App 账号关联自有小程序关联涂鸦 App 账号关联 SaaS 方式一:关联自有 App 大家可以…

深度学习入门(六十)循环神经网络——门控循环单元GRU

深度学习入门(六十)循环神经网络——门控循环单元GRU前言循环神经网络——门控循环单元GRU课件关注一个序列门候选隐状态隐状态总结教材1 门控隐状态1.1 重置门和更新门1.2 候选隐状态1.4 隐状态2 从零开始实现2.1 初始化模型参数2.2 定义模型2.3 训练与…

前端本地存储数据库 IndexedDB 存储文件

介绍 IndexedDB 是一种底层 API,用于在客户端存储大量的结构化数据。目前各浏览器都已支持,兼容性很好。 特点 IndexedDB 是一个基于 JavaScript 的面向对象数据库,IndexedDB 允许您存储和检索用键索引的对象;可以存储结构化克隆…

MySQL8.0基础篇

文章目录一、MySQL概述1、数据库概述1.1 数据库作用1.2 数据库的相关概念2、MySQL概述2.1 概述2.2 RDBMS与非RDBMS3、MySQL环境安装3.1 MySQL的下载、安装、配置(win)3.2 MySQL登录3.3 MySQL演示使用3.4 MySQL目录结构与源码二、SQL查询1、SQL详情1.1 SQL分类1.2 SQL语言的规则…

Docker和docker-compose中部署nginx-rtmp实现流媒体服务与oob和ffmpeg推流测试

场景 Windows上搭建Nginx RTMP服务器并使用FFmpeg实现本地视频推流: Windows上搭建Nginx RTMP服务器并使用FFmpeg实现本地视频推流_霸道流氓气质的博客-CSDN博客_nginx-rtmp-win64 上面讲的是在windows中搭建nginx-rtmp,如果实在centos中使用docker或…

使用Git拉取和推送到仓库

使用Git拉取和推送到仓库 0、前置工作 首先安装和配置git ,参考: git安装教程_嘴巴嘟嘟的博客-CSDN博客_全局安装gitGit上传文件代码到GitHub(超详细)_蓝布棉的博客-CSDN博客_git上传文件到github仓库 没有仓库的情况 创建仓…

项目总结篇

注意会话管理:cookie,session的作用;(Redis等) 过滤敏感词(相关算法),事务(Spring怎么管理) Redis的数据结构适合那种情况 kafka:框架背后通用的原则,模式,生…

jsp+ssm计算机毕业设计房屋租赁管理系统【附源码】

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: JSPSSM mybatis Maven等等组成,B/S模式 Mave…

大数据 集群测试部分

查看HDFS集群状态 在浏览器里访问http://master:9870 不能通过主机名master加端口9870的方式,原因在于没有在hosts文件里IP与主机名的映射,现在只能通过IP地址加端口号的方式访问:http://192.168.1.101:9870 修改宿主机的C:\Windows\System…

2023年大学毕业生,我有话想对你说

虽然每年都说大学毕业生有多少多少,就业难,但貌似以往的经济寒冬,互联网寒冬都不如2022年2023年这么寒冷。 可以说,2022年一整年都是在裁员的声音中度过的,有的公司逐渐取消年终奖,原本熙熙攘攘的办公室&am…

看看欧洲国际学校的IB分数排名

大家好,今天为大家整理了欧洲的国际学校IB分数排名,信息搬运自IB分数网站。如果有偏差还请好心人出来指正。 可以看到,整个榜单瑞士的国际学校数量最多。确实,其实大部分国家的一线国际学校都是集齐在首都城市。 而瑞士的国际学校…