Adam算法及python实现

news2024/11/24 1:39:03

文章目录

  • 算法介绍
  • 代码实现
  • 结果展示
  • 参考

算法介绍

Adam算法的发展经历了:SGD->SGDM->SGDNA->AdaGrad->AdaDelta->Adam->Adamax的过程。它是神经网络优化中的常用算法,在收敛速度上比较快,比SGD对收敛速度的纠结上有了很大的改进。但是该算法的学习率是不断减少的,可能收敛不到真正的最优解,实践中经常是前期Adam,后期SGD进行优化。
在这里插入图片描述

代码实现

现以如下无约束凸优化问题为例进行算法实施,
min ⁡ 5 x 1 2 + 2 x 2 2 + 3 x 1 − 10 x 2 + 4 \min 5x^2_1+2x^2_2+3x_1−10x_2+4 min5x12+2x22+3x110x2+4

# Adam之实现

import numpy
from matplotlib import pyplot as plt


# 目标函数0阶信息
def func(X):
    funcVal = 5 * X[0, 0] ** 2 + 2 * X[1, 0] ** 2 + 3 * X[0, 0] - 10 * X[1, 0] + 4
    return funcVal


# 目标函数1阶信息
def grad(X):
    grad_x1 = 10 * X[0, 0] + 3
    grad_x2 = 4 * X[1, 0] - 10
    gradVec = numpy.array([[grad_x1], [grad_x2]])
    return gradVec


# 定义迭代起点
def seed(n=2):
    seedVec = numpy.random.uniform(-100, 100, (n, 1))
    return seedVec


class Adam(object):

    def __init__(self, _func, _grad, _seed):
        '''
        _func: 待优化目标函数
        _grad: 待优化目标函数之梯度
        _seed: 迭代起始点
        '''
        self.__func = _func
        self.__grad = _grad
        self.__seed = _seed

        self.__xPath = list()
        self.__JPath = list()


    def get_solu(self, alpha=0.001, beta1=0.9, beta2=0.999, epsilon=1.e-8, zeta=1.e-6, maxIter=3000000):
        '''
        获取数值解,
        alpha: 步长参数
        beta1: 一阶矩指数衰减率
        beta2: 二阶矩指数衰减率
        epsilon: 足够小正数
        zeta: 收敛判据
        maxIter: 最大迭代次数
        '''
        self.__init_path()

        x = self.__init_x()
        JVal = self.__calc_JVal(x)
        self.__add_path(x, JVal)
        grad = self.__calc_grad(x)
        m, v = numpy.zeros(x.shape), numpy.zeros(x.shape)
        for k in range(1, maxIter + 1):
            # print("k: {:3d},   JVal: {}".format(k, JVal))
            if self.__converged1(grad, zeta):
                self.__print_MSG(x, JVal, k)
                return x, JVal, True

            m = beta1 * m + (1 - beta1) * grad
            v = beta2 * v + (1 - beta2) * grad * grad
            m_ = m / (1 - beta1 ** k)
            v_ = v / (1 - beta2 ** k)

            alpha_ = alpha / (numpy.sqrt(v_) + epsilon)
            d = -m_
            xNew = x + alpha_ * d
            JNew = self.__calc_JVal(xNew)
            self.__add_path(xNew, JNew)
            if self.__converged2(xNew - x, JNew - JVal, zeta ** 2):
                self.__print_MSG(xNew, JNew, k + 1)
                return xNew, JNew, True

            gNew = self.__calc_grad(xNew)
            x, JVal, grad = xNew, JNew, gNew
        else:
            if self.__converged1(grad, zeta):
                self.__print_MSG(x, JVal, maxIter)
                return x, JVal, True

        print("Adam not converged after {} steps!".format(maxIter))
        return x, JVal, False


    def get_path(self):
        return self.__xPath, self.__JPath


    def __converged1(self, grad, epsilon):
        if numpy.linalg.norm(grad, ord=numpy.inf) < epsilon:
            return True
        return False


    def __converged2(self, xDelta, JDelta, epsilon):
        val1 = numpy.linalg.norm(xDelta, ord=numpy.inf)
        val2 = numpy.abs(JDelta)
        if val1 < epsilon or val2 < epsilon:
            return True
        return False


    def __print_MSG(self, x, JVal, iterCnt):
        print("Iteration steps: {}".format(iterCnt))
        print("Solution:\n{}".format(x.flatten()))
        print("JVal: {}".format(JVal))


    def __calc_JVal(self, x):
        return self.__func(x)


    def __calc_grad(self, x):
        return self.__grad(x)


    def __init_x(self):
        return self.__seed


    def __init_path(self):
        self.__xPath.clear()
        self.__JPath.clear()


    def __add_path(self, x, JVal):
        self.__xPath.append(x)
        self.__JPath.append(JVal)


class AdamPlot(object):

    @staticmethod
    def plot_fig(adamObj):
        x, JVal, tab = adamObj.get_solu(0.1)
        xPath, JPath = adamObj.get_path()

        fig = plt.figure(figsize=(10, 4))
        ax1 = plt.subplot(1, 2, 1)
        ax2 = plt.subplot(1, 2, 2)

        ax1.plot(numpy.arange(len(JPath)), JPath, "k.", markersize=1)
        ax1.plot(0, JPath[0], "go", label="starting point")
        ax1.plot(len(JPath)-1, JPath[-1], "r*", label="solution")

        ax1.legend()
        ax1.set(xlabel="$iterCnt$", ylabel="$JVal$")

        x1 = numpy.linspace(-100, 100, 300)
        x2 = numpy.linspace(-100, 100, 300)
        x1, x2 = numpy.meshgrid(x1, x2)
        f = numpy.zeros(x1.shape)
        for i in range(x1.shape[0]):
            for j in range(x1.shape[1]):
                f[i, j] = func(numpy.array([[x1[i, j]], [x2[i, j]]]))
        ax2.contour(x1, x2, f, levels=36)
        x1Path = list(item[0] for item in xPath)
        x2Path = list(item[1] for item in xPath)
        ax2.plot(x1Path, x2Path, "k--", lw=2)
        ax2.plot(x1Path[0], x2Path[0], "go", label="starting point")
        ax2.plot(x1Path[-1], x2Path[-1], "r*", label="solution")
        ax2.set(xlabel="$x_1$", ylabel="$x_2$")
        ax2.legend()

        fig.tight_layout()
        # plt.show()
        fig.savefig("plot_fig.png")



if __name__ == "__main__":
    adamObj = Adam(func, grad, seed())

    AdamPlot.plot_fig(adamObj)

结果展示

在这里插入图片描述

参考

https://www.cnblogs.com/xxhbdk/p/15063793.html
论文:Adam: A method for stochastic optimization

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/91437.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

单商户商城系统功能拆解46—应用中心—足迹气泡

单商户商城系统&#xff0c;也称为B2C自营电商模式单店商城系统。可以快速帮助个人、机构和企业搭建自己的私域交易线上商城。 单商户商城系统完美契合私域流量变现闭环交易使用。通常拥有丰富的营销玩法&#xff0c;例如拼团&#xff0c;秒杀&#xff0c;砍价&#xff0c;包邮…

基于微信小程序的课程分享平台-计算机毕业设计

项目介绍 随着社会的发展&#xff0c;社会的方方面面都在利用信息化时代的优势。互联网的优势和普及使得各种系统的开发成为必需。 本文以实际运用为开发背景&#xff0c;运用软件工程原理和开发方法&#xff0c;它主要是采用java语言技术和mysql数据库来完成对系统的设计。整…

[附源码]Node.js计算机毕业设计高校就业管理信息系统Express

项目运行 环境配置&#xff1a; Node.js最新版 Vscode Mysql5.7 HBuilderXNavicat11Vue。 项目技术&#xff1a; Express框架 Node.js Vue 等等组成&#xff0c;B/S模式 Vscode管理前后端分离等等。 环境需要 1.运行环境&#xff1a;最好是Nodejs最新版&#xff0c;我…

学生竞赛网站

开发工具(eclipse/idea/vscode等)&#xff1a; 数据库(sqlite/mysql/sqlserver等)&#xff1a; 功能模块(请用文字描述&#xff0c;至少200字)&#xff1a; 模块划分&#xff1a;通知类型、通知信息、学院信息、学生信息、学科信息、竞赛信息、报名信 息、成果上传、评分排名 管…

YOLOv5小目标切图检测

当我们在检测较大分辨率的图片时&#xff0c;对小目标的检测效果一直是较差的&#xff0c;所以就有了下面几种方法&#xff1a; 将图片压缩成大尺寸进行训练&#xff08; 想法&#xff1a;没显存&#xff0c;搞不来&#xff09;添加小检测头&#xff08;想法&#xff1a;P5模型…

【爬虫实战项目】Python爬虫批量下载相亲网站数据并保存本地(附源码)

前言 今天给大家介绍的是Python爬虫批量下载相亲网站图片数据&#xff0c;在这里给需要的小伙伴们代码&#xff0c;并且给出一点小心得。 首先是爬取之前应该尽可能伪装成浏览器而不被识别出来是爬虫&#xff0c;基本的是加请求头&#xff0c;但是这样的纯文本数据爬取的人会…

数据结构---树和二叉树

树和二叉树定义二叉树二叉树的物理结构链式存储数组二叉树应用查找维持相对顺序二叉树的遍历深度优先遍历前序遍历中序遍历后序遍历二叉树广度优先遍历层序遍历定义 有且仅有一个特定的称为根的节点。当n>1时&#xff0c;其余节点可分为m&#xff08;m>0&#xff09;个互…

数据结构与算法——Java实现栈、逆波兰计算器(整数加减乘除)

目录 一、栈 1.1 基本介绍 1.2 栈的思路分析 1.3 栈的代码实现 二、栈实现综合计算器 2.1 思路分析 2.2 代码实现&#xff08;中缀表达式实现&#xff09; 三、栈的前缀&#xff08;波兰&#xff09;、中缀、后缀&#xff08;逆波兰&#xff09;表达式 3.1 表达式的介绍…

访问pcie总线地址内容

调用代码如下&#xff1a; uint32_t value;void * addr;printk("------1--------\n");addr0x2730000;struct resource *res;char const *name dev_name(&pdev->dev);printk("dev_name%s\n", name);res request_mem_region(addr, 16, "name1&…

【语音之家公开课】SRD: A Dataset and Benchmark Perspective

本次语音之家公开课邀请到陈果果进行分享Speech Recognition Development: A Dataset and Benchmark Perspective。 公开课简介 主题&#xff1a;Speech Recognition Development: A Dataset and Benchmark Perspective 时间&#xff1a;12月15日&#xff08;周四&#xff09…

web网页设计期末课程大作业:美食餐饮文化主题网站设计——HTML+CSS+JavaScript美食餐厅网站设计与实现 11页面

&#x1f468;‍&#x1f393;静态网站的编写主要是用HTML DIVCSS JS等来完成页面的排版设计&#x1f469;‍&#x1f393;,常用的网页设计软件有Dreamweaver、EditPlus、HBuilderX、VScode 、Webstorm、Animate等等&#xff0c;用的最多的还是DW&#xff0c;当然不同软件写出的…

C# IO及文件管理

一 System.IO ① System.IO名字空间&#xff1b; ② 提供了许多用于&#xff1b; ③ 文件和数据流进行读写操作的类&#xff1b; 二 流的分类 1 Stream类 按存取位置分&#xff1a;FileStream,MemeryStream,BufferedStream; 2 读写类 BinaryReader和BinaryWriter; TextRe…

从 0 到 1 搞一个 Compose Desktop 版本的玩天气之打包

从 0 到 1 搞一个 Compose Desktop 版本的玩天气之打包 大家好&#xff0c;前两篇文章大概介绍了下上手 Compose Desktop 和自定义绘制时遇到的一些问题&#xff0c;项目的最终实现效果如下&#xff1a; 视频代码写好了&#xff0c;该弄的动画也弄了&#xff0c;该请求的网络数…

【数据结构】八大排序算法详解

&#x1f9d1;‍&#x1f4bb;作者&#xff1a; 情话0.0 &#x1f4dd;专栏&#xff1a;《数据结构》 &#x1f466;个人简介&#xff1a;一名双非编程菜鸟&#xff0c;在这里分享自己的编程学习笔记&#xff0c;欢迎大家的指正与点赞&#xff0c;谢谢&#xff01; 排序前言一…

汇编语言第一章:基础知识

1. 基础知识 机器语言 机器语言是机器指令的集合&#xff0c;是一台机器可以正确执行的命令。现在一般电子计算机的机器指令是一列二进制数字。机器指令集是机器语言。 汇编语言 机器语言难以辨别和记忆&#xff0c;所以产生了汇编语言。汇编语言的主体是汇编指令。 操作&…

on-device training

又搬来个好玩呃 说来又想试试了 , 仅用256KB就实现单片机上的神经网络训练&#xff08;training,notinference&#xff09;&#xff0c;从此终端智能不再是单纯的推理&#xff0c;而是能持续的自我学习自我进化 On-Device Training under 256KB Memory 说到神经网络训练&#…

编译原理实验四

编译原理实验四 实验要求 cminus-f的词法分析和语法分析部分已经完成&#xff0c;最终得到的是语法分析树。而为了产生目标代码&#xff0c;还需要将语法分析树转为抽象语法树&#xff0c;通过抽象语法分析树生成中间代码(即IR)&#xff0c;最后使用中间代码来进行优化并生成…

easyExcel导出表头合并 不得不说真牛

有个导出单元格合并的任务&#xff0c;表头不规则合并格式&#xff0c;看得就烦&#xff0c;尤其是对于没玩儿过合并的我来说&#xff0c;任务放在哪里不知咋做&#xff0c;网上也看了一堆合并的方法&#xff0c;自己写注解来写的那些&#xff0c;麻烦得要命&#xff0c;我写一…

48.python break语句-终止循环

48.break语句-终止循环 文章目录48.break语句-终止循环1.循环控制2.break的作用3. 语法4. 实操练习5. 知识扩展&#xff1a;print的位置6. break语句循环图1.循环控制 在循环的过程中如果要退出循环&#xff0c;我们可以用break语句和continue语句。 2.break的作用 break [b…

Android入门第49天-使用RadioGroup+Fragment来重构类首页底部4个按钮的界面

简介 我们在&#xff1a;Android入门第47天-Fragment的基本使用 中使用Fragment制作了一个类首页底部含4个按钮的界面。今天的课程我们要做的是把第47天里的代码中一部分共用的东西抽象到res/values/themes.xml文件中。另外我们使用RadioGroup天然的只有一个可以被选中来代替…