P11 PyTorch Momentum

news2024/11/9 4:51:37

参考:

关于梯度下降与Momentum通俗易懂的解释_ssswill的博客-CSDN博客_梯度 momentum

前言:

      P9讲梯度的时候,讲到过这种算法的梯度更新方法

这边重点讲解一下原理

      Momentum算法又叫做冲量算法,其迭代更新公式如下:

     v_t= \beta v_{t-1}+(1-\beta)\bigtriangledown w

     w_{t}=w_{t}-\alpha v_{t}

实验表明,相比于标准梯度下降算法,Momentum算法具有更快的收敛速度


目录:

    1:   标准的梯度下降问题(w维度为1)

    2:   标准的梯度下降问题(W维度为2)

    3:   Momentum


一    标准的梯度下降问题(w维度为1)

           w_t=w_t-\alpha \bigtriangledown w

     

# -*- coding: utf-8 -*-
"""
Created on Tue Jan  3 21:48:15 2023

@author: cxf
"""

import numpy as np
import matplotlib.pyplot as plt


#计算梯度
def gradient(x):
    return 2*x


'''
a: 搜索的起始点
step: 步伐
epoch: 迭代次数
'''
def gradient_descent(a, step, epoch):
    
     x = a
     for i in range(epoch):
         
         grad = gradient(x)
         x = x-step*grad
         
         print('epoch:{},x= {},gradient={}'.format(i,round(x,3),round(grad,3)))
         
         if abs(grad)<1e-6 :
             return x
     return x

if __name__ == "__main__":
    x = np.linspace(-5,5,100)
    y = x**2 #损失函数
    
    plt.plot(x,y)
    gradient_descent(4, 1.0,20)
    

     1.1  学习率过小, 却使得学习过程过于缓慢

          gradient_descent(4, 0.1,20)

         

    1.2   学习率过大,无法收敛,发散

     gradient_descent(4, 1.0,20)

    


二  标准的梯度下降问题(W维度为2)

 假设权重系数为

             w=[w_0,w_1]

  假设损失函数

              l=w_0^2+50*w_1^2

  则  

             \triangledown w= [2*w_0,100*w_1]

   SGD 更新过程

            w=w-\alpha \bigtriangledown w

    问题:

         红线是标准梯度下降法,可以看到收敛过程中产生了一些震荡。这些震荡在纵轴方向上(w_1)是均匀的,几乎可以相互抵消,也就是说如果直接沿着横轴方向迭代,收敛速度可以加快

    ​​​​​​​

 对应的代码

# -*- coding: utf-8 -*-
"""
Created on Wed Jan  4 20:41:03 2023

@author: cxf
"""

import numpy as np
import pylab as plt



'''
显示等高线
'''
def  get_contour():
      w0 = np.linspace(-10, 10,100)
      w1 = np.linspace(-10, 10,100)
      X,Y = np.meshgrid(w0,w1)
      Z = X**2+Y**2
      
      return X,Y,Z
      
      
      
'''
计算梯度
'''
def gradient(w):
    
    w = np.array(w)
    w0 = w[0]
    w1 = w[1]
    grad = np.array([2*w0,100*w1])
    return grad



'''
w: 相当于权重系数
step: 步伐
epoch: 迭代次数
'''
def SGD(w,step,epoch):
    
    w = np.array(w,dtype='float64')
    w_list=[]
    
    for i in range(epoch):

        t = w.copy()
        w_list.append(t)
        grad = gradient(w)
        
        
        print('epoch: %d  权重系数x: [%5.3f   %5.3f]  梯度 [%5.3f  %5.3f ] '%(i,w[0],w[1],grad[0],grad[1]))
        w= w- step*grad
        if sum(abs(grad))<=1e-6:
            return w
    return w,w_list

'''
'''
def momentum(w0, step,mu, epoch):

     w = np.array(w0)
     w_list = []
     
     pre_gd = np.array([0,0])
     for i in range(epoch):
         
         t = w.copy()
         w_list.append(t)
         grad = gradient(w)
         pre_gd = mu*pre_gd+grad
         
         w = w- step*pre_gd
      
         print('epoch: %d  权重系数x: [%5.3f   %5.3f]  梯度 [%5.3f  %5.3f ] '%(i,w[0],w[1],grad[0],grad[1]))
         if sum(abs(grad))<=1e-6:
             return w
     return w,w_list
         
         
     
      
'''
模拟SGD梯度下降训练的过程
'''
def  main():
    
    X,Y,Z = get_contour()
    plt.figure(figsize=(15,7))
    C= plt.contour(X,Y,Z,[1,10,20,40,80,100])#find_grad
    plt.clabel(C, inline=True, fontsize=15)
    plt.plot(0,0,marker='*',markersize=20,color='r')
    
    mu =0.7
    step = 0.02 #步伐
    epoch = 50 #迭代次数
    w =[10,10] #初始化的权重系数
    #x, x_list = SGD(w, step,epoch)
    x, x_list =  momentum(w, step,mu, epoch)
    N = len(x_list)
    print("\n N",np.shape(x_list))
    for i in range(N-1):
        #print(i)
        plt.plot([x_list[i][0],x_list[i+1][0]],[x_list[i][1],x_list[i+1][1]])
        
      
    
    
if __name__ == "__main__":
    

  
    main()

问题1:steps = 0.015 #步伐 epoch = 50 #迭代次数 w =[10,10] #初始化的权重系数

发现收敛速度很慢

​​​​​​​

 问题2: 增大 steps = 0.02

   步伐 epoch = 50 #迭代次数
    w =[10,10] #初始化的权重系数

       w_1 参数震荡,无法收敛

 


三 Momentum

      Momentum通过对原始梯度做了一个平滑,正好将纵轴方向的梯度抹平了(红线部分),使得参数更新方向更多地沿着横轴进行,因此速度更快。

  code 跟上面一样,差别是参数更新过程

 一个用的是SGD, 一个用的是momentum

 

 

​​​​​​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/141781.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Vue中使用Echarts】响应式布局flexible.js+rem适配方案

文章目录一、vue集成flexible方案第一种&#xff1a;在编译的时候自动转换(px2rem-loader)第二种&#xff1a;直接在写css样式的时候转换(cssrem)二、安装&配置lib-flexible.js三、安装插件cssrem一、vue集成flexible方案 先介绍几个基本的概念 px像素&#xff08;Pixel&…

LeetCode分类刷题---数组篇

刷题班数组1.二分查找704.二分查找35.搜索插入位置34.在排序数组中查找元素的第一个和最后一个位置。69.X的平方和367.有效的完全平方数2.移除元素27.移除元素283.移动零844.比较含退格的字符串977.有序数组的平方3.长度最小的子数组209.长度最小的子数组904.水果成蓝76.最小覆…

2023-01-05 长亭科技 Go 后端开发实习生二面

由于面试官前几天 &#x1f40f; 了&#xff0c;在 HR 面后补了技术二面&#xff0c;不过问得倒也不难&#xff0c;但还是记录下。 1、请做 3 ~ 5 分钟的自我介绍。 2、你说研究生的方向是漏洞挖掘和模糊测试&#xff0c;可以介绍一下吗&#xff1f; 3、简单介绍下缓冲区溢出漏…

2022年中国特色智能工厂领航制造业升级分析报告

易观&#xff1a;当前&#xff0c;新一轮的科技革命和产业变革正在重塑世界格局&#xff0c;科技创新也成为影响国家竞争力的决定性因素。在全球制造业格局重塑的过程中&#xff0c;智能工厂作为全球智能制造产业实践的示范标杆与标准载体&#xff0c;是引领全球制造业企业与工…

机器学习中的数学原理——逻辑回归

这个专栏主要是用来分享一下我在机器学习中的学习笔记及一些感悟&#xff0c;也希望对你的学习有帮助哦&#xff01;感兴趣的小伙伴欢迎私信或者评论区留言&#xff01;这一篇就更新一下《白话机器学习中的数学——逻辑回归》&#xff01;什么是逻辑回归算法逻辑回归 (Logistic…

【信管6.3】成本挣值计算

成本挣值计算铺垫了那么久&#xff0c;不知道大家期待不期待。总算到了挣值计算这一课&#xff0c;这个名字很奇怪呀&#xff0c;什么叫做挣值&#xff1f;成本不就是我们的投资吗&#xff1f;这个挣值到底是要干嘛&#xff1f;带着这些疑问&#xff0c;我们就来看看挣值计算到…

2022全年度奶粉十大热门品牌销量榜单

随着居民收入水平的提升、消费观念的转变及健康饮食意识的逐渐增强&#xff0c;消费者对食品品质的要求也越来越高&#xff0c;奶粉市场也同样如此。当前&#xff0c;国内婴幼儿奶粉市场规模呈稳步增长态势&#xff0c;同时&#xff0c;“三孩政策”的发布实施&#xff0c;也利…

C++模板 - 提高编程

引言 本阶段主要针对C泛型编程和STL技术做详细的讲解&#xff0c;探讨C更深层的使用 1 模板 1.1 模板的概念 模板就是建立通用的模具&#xff0c;大大提高复用性 例如生活中的模板&#xff1a; 一寸照片模板&#xff1a; 模板的特点&#xff1a; 模板不可以直接使用&#…

制造业项目管理软件如何帮助企业做好项目费用管理?

在项目导向型制造型企业中&#xff0c;项目的成本管理与费用控制是企业进行项目评价与利润管控、指导市场选择和项目筛选的重要手段。而传统的手工管理模式下&#xff0c;制造企业管理层很难快速了解到哪些项目出现了延误、哪些项目发生了费用超支、哪些项目产生了变更等问题与…

C#,图像二值化(14)——全局阈值的最佳迭代算法及源代码

1、图像二值化 图像二值化是将彩色图像转换为黑白图像。大多数计算机视觉应用程序将图片转换为二进制表示。图像越是未经处理&#xff0c;计算机就越容易解释其基本特征。 二值化过程 在计算机存储器中&#xff0c;所有文件通常以灰度级的形式存储&#xff0c;灰度级具有从0…

欢迎来到,个人数据安全“世界杯”

2022年国际足联世界杯&#xff0c;巴西止步8强&#xff0c;克罗地亚挺到半决赛&#xff0c;阿根廷与法国双强对决最终阿根廷点球大战胜出……精彩纷呈的世界杯已经落幕&#xff0c;而我们因足球而起的激情和热爱不会消退。世界杯是属于每个人的&#xff0c;每个球迷在世界杯中都…

03-redis篇 架构设计之一: 主从复制

目录 第一篇: 主从复制 二. 实践操作 1. 准备工作 -> ps: 安装redis的文章: docker版 的redis安装 2. 制作docker镜像 -> 2.1 制作redis6379 -> 2.2 制作redis6380 -> 2.3 制作redis6381 3. 查看主镜像redis6379的ip地址 -> 3.1 IPAddress位置在这: …

【数据库数据恢复】mdb_catalog.wt文件丢失的MongoDB数据恢复案例

MongoDB数据库数据恢复环境&#xff1a; MongoDB数据库部署在一台虚拟机上&#xff0c;虚拟机操作系统为Windows Server2012。 MongoDB数据库故障&分析&#xff1a; 由于业务发展需求&#xff0c;需要对MongoDB数据库内的文件进行迁移&#xff0c;在MongoDB服务开启的状态…

内部排序:希尔排序

希尔排序&#xff0c;又称为“缩小增量排序”&#xff0c;是直接插入排序的优化。 对于直接插入排序&#xff0c;当待排记录序列处于正序时&#xff0c;时间复杂度可达O(n)&#xff0c;若待排记录序列越接近有序&#xff0c;直接插入排序越高效。希尔排序的思想正是基于这个点…

QT(5)-QHeaderView

QHeaderView1 说明2 函数2.1 级联调整大小2.2 默认对齐方式2.3 count()2.4 表头默认单元格大小2.5 hiddenSectionCount()2.6 分区显示和隐藏2.7 表头高亮2.8 是否可以移动第一列2.7 是否显示排序索引2.8 表头长度2.9 逻辑索引2.10 表头分区最大/小大小2.11 移动分区2.12 表头偏…

Qlik帮助提升数据素养:新一代打工人“必备招式”

“营销”在业务推进过程中扮演着至关重要的角色。然而&#xff0c;当前营销的影响力却往往未得到广泛理解和重视。 在数字世界里&#xff0c;数据浩瀚如海&#xff0c;但如果“探险者”没有乘风破浪的能力&#xff0c;这片数据汪洋只能沉寂在角落里“吃灰”。而数据素养&#…

Ubuntu20.04 rosdep 失败解决方法

参考文章http://www.autolabor.com.cn/book/ROSTutorials/chapter1/12-roskai-fa-gong-ju-an-zhuang/124-an-zhuang-ros.htmlsudo gedit ./rosdistro/__init__.py sudo gedit ./rosdep2/gbpdistro_support.py sudo gedit ./rosdep2/sources_list.py sudo gedit ./rosdep2/rep3.…

厚积薄发打卡Day112:堆栈实践(二)<汉诺塔问题>

厚积薄发打卡Day112&#xff1a;堆栈实践&#xff08;二&#xff09;&#xff1c;汉诺塔问题&#xff1e; 问题 相传在古印度圣庙中&#xff0c;有一种被称为汉诺塔(Hanoi)的游戏。该游戏是在一块铜板装置上&#xff0c;有三根杆(编号A、B、C)&#xff0c;在A杆自下而上、由大…

Jvm知识点二(GC)

GC 相关知识点一、垃圾收集器二、 java 中的引用三、 怎么判断对象是否可以被回收&#xff1f;四、 Java对象在虚拟机中的生命周期五、垃圾收集算法标记-清除算法复制算法补充知识点深拷贝和浅拷贝标记-压缩算法&#xff08;Mark-Compact&#xff09;分代收集算法Java堆的分区六…

SSH实验部署

一&#xff0c;实验要求 1&#xff0c;两台机器&#xff1a;第一台机器作为客户端&#xff0c;第二台机器作为服务器&#xff0c;在第一台使用rhce用户免 密登录第二台机器 2&#xff0c;禁止root用户远程登录和设置三个用户sshuser1, sshuser2, sshuser3&#xff0c; 只允许ss…