实现Newton方法的最小化函数(pytorch)

news2024/12/23 20:09:28

首先,我们要明确需求
def newton(theta, f, tol = 1e-8, fscale=1.0, maxit = 100, max_half = 20)
● theta是优化参数的初始值的一个向量。
● f是要最小化的目标函数。该函数应将PyTorch张量作为输入,并返回一个张量。
● tol是收敛容忍度。
● fscale 粗略估计最佳状态下f的大小–用于收敛性测试。
● maxit 在放弃之前尝试的牛顿迭代的最大数量。
● max_half 一个步骤在得出该步骤未能改善目标的结论之前应该被减半的最大次数。
目标。

1.函数需检查初始θ是否有限,并计算目标值f0。然后,使用“torch.autograd.functional”计算f在初始θ处的雅可比和hessian(同时检查它们是否是有限的)。
2.当Hessian不是正定时,使用变量“time” * 10。
3.变量“try_half”确保只输出一个关于达到最大减半次数的警告
4.更重要的是,它使用牛顿步长和方向迭代更新θ,直到达到收敛或最大迭代次数。
5.在每次迭代过程中,函数检查目标或导数在当前或新θ下是否是有限的,以及该步骤是否导向较小的目标。如果步骤未能减少目标,则函数会将步骤大小减半,直到目标减少或达到步骤减半的最大次数。如果达到步骤减半的最大次数,该功能将发出警告。
6.该函数还通过评估梯度向量的范数以及Hessian是否是正定的来检查收敛性。如果梯度向量足够接近零,函数会检查Hessian是否是正定的。如果Hessian不是正定的,函数将单位矩阵的一个小倍数(10^-8)加到Hessian上,然后重试(time*10)。
7.如果在没有收敛的情况下达到最大迭代次数,则函数会发出错误。该函数返回一个dict,其中包含theta的最终值、目标值f0、迭代次数iter_count和f在最终theta处的梯度。

def newton(theta, f, tol=1e-8, fscale=1.0, maxit=100, max_half=20): 
    # check if the initial theta is finite
    if not torch.isfinite(theta).all():
        raise ValueError("The initial theta must be finite.")
    
    # initialize variables
    iter_count = 0
    #theta is a vector of initial values for the optimization parameters
    f0 = f(theta)
    #f is the objective function to minimize, which should take PyTorch tensors as inputs and returns a Tensor
    
    #autograd functionality for the calculation of the jacobian and hessians
    grad = torch.autograd.functional.jacobian(f, theta).squeeze()
    hess = torch.autograd.functional.hessian(f, theta)
    
    # check if the objective or derivatives are not finite at the initial theta
    if not torch.isfinite(f0) or not torch.isfinite(grad).all() or not torch.isfinite(hess).all():
        raise ValueError("The objective or derivatives are not finite at the initial theta.")
    
    #multiplication factor of Hessian is positive definite 
    time = 1
    #trying max_ half step halfings
    try_half=True
    
    # iterate until convergence or maximum number of Newton iterations
    while iter_count < maxit:
        
        # calculate the Newton step and direction
        step = torch.linalg.solve(hess, grad.unsqueeze(-1)).squeeze()
        direction = -step
        
        # initialize variables for step halfing
        halfing_count = 0
        new_theta = theta + direction
        
        # iterate until the objective is reduced or max_half is reached(maximum number of times a step should be halfed before concluding that the step has failed to improve the objective)
        while halfing_count < max_half:
            
            # calculate the objective and derivatives at the new theta
            new_f = f(new_theta)
            new_grad = torch.autograd.functional.jacobian(f, new_theta).squeeze()
            new_hess = torch.autograd.functional.hessian(f, new_theta)
            
            # check if the objective or derivatives are not finite at the new theta
            if not torch.isfinite(new_f) or not torch.isfinite(new_grad).all() or not torch.isfinite(new_hess).all():
                raise ValueError("The objective or derivatives are not finite at the new theta.")
            
            # check if the step leads to a smaller objective
            if new_f < f0:
                theta = new_theta
                f0 = new_f
                grad = new_grad
                hess = new_hess
                break
                
            # if not, half the step size
            else:
                direction /= 2
                new_theta = theta + direction
                halfing_count += 1
                
        # if max_half is reached, step fails to reduce the objective, issue a warning
        if halfing_count == max_half:
            if(try_half):
                print("Warning: Max halfing count reached.")
                try_half = False
                
        # check for convergence(tol), judge whether the gradient vector is close enough to zero
        # The gradient can be judged to be close enough to zero when they are smaller than tol multiplied by the objective value
        if torch.norm(grad) < tol * (torch.abs(f0) + fscale): #add fscale to the objective value before multiplying by tol
            cholesky_hess = None
            # Check if Hessian is positive definite
            try:
                cholesky_hess = torch.linalg.cholesky(hess)
            #Hessian is not positive definite
            except RuntimeError: 
                pass
            #Hessian is still not positive definite
            if cholesky_hess is not None: 
                theta_dict = {"f": f0, "theta": theta, "iter": iter_count, "grad": grad}#return f,theta, iter, grad
                return theta_dict
            else:#If it is not positive definite, add a small multiple of the identity matrix to the Hessian and try again
                #we adding ε(the largest absolute value in your Hessian multiplied by 10^-8)
                epsilon = time * 10 ** -8 * torch.max(torch.abs(hess)).item() 
                hess += torch.eye(hess.shape[0], dtype=hess.dtype) * epsilon
                # keep multiplying ε by 10 until Hessian is positive definite
                time *= 10
                
                
        # update variables for the next iteration
        iter_count += 1
    
    # if maxit is reached without convergence, issue an error
    raise RuntimeError("Max iterations reached without convergence.")

测试案例

对于以上newton()函数,采用以下五个函数测试找到全局最小值验证有效性(使用3个不同的起始值(theta))
f ( x , y ) = x 2 − 2 x + 2 y 2 + y + 3 f(x,y) = x^2-2x+2y^2+y+3 f(x,y)=x22x+2y2+y+3
在这里插入图片描述
f ( x , y ) = 100 ∗ ( y − x 2 ) 2 + ( 1 − x ) 2 f(x,y) = 100*(y-x^2)^2 + (1-x)^2 f(x,y)=100(yx2)2+(1x)2
在这里插入图片描述
f ( x 1 , … , x 5 ) = ∑ i = 1 4 [ 100 ( x i + 1 − x i 2 ) 2 + ( 1 − x i ) 2 ] f(x_1,\ldots,x_5) = \sum_{i=1}^{4} \left[100(x_{i+1}-x_i^2)^2+(1-x_i)^2\right] f(x1,,x5)=i=14[100(xi+1xi2)2+(1xi)2]
在这里插入图片描述
f ( x 1 , … , x 1 0 ) = ∑ i = 1 10 x i 2 f(x_1,\ldots,x_10) = \sum_{i=1}^{10} x_i^2 f(x1,,x10)=i=110xi2
在这里插入图片描述
f ( x , y ) = ( 1.5 − x + x y ) 2 + ( 2.25 − x + x y 2 ) 2 + ( 2.625 − x + x y 3 ) 2 f(x,y) = (1.5-x+xy)^2 + (2.25-x+xy^2)^2+(2.625-x+xy^3)^2 f(x,y)=(1.5x+xy)2+(2.25x+xy2)2+(2.625x+xy3)2
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/479671.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Leetcode -328.奇偶链表 - 725.分隔链表】

Leetcode Leetcode -328.奇偶链表Leetcode - 725.分隔链表 Leetcode -328.奇偶链表 题目&#xff1a;给定单链表的头节点 head &#xff0c;将所有索引为奇数的节点和索引为偶数的节点分别组合在一起&#xff0c;然后返回重新排序的列表。 第一个节点的索引被认为是 奇数 &am…

苏州百特电器有限公司网站设计

苏州百特电器有限公司网站设计 五一假期作业企业门户网站布局设计 基于 <div> 的企业门户网站设计 by 小喾苦 我这里仅仅是使用 html css 来实现这个网站的效果&#xff0c;并不是宣传这个网站(现在这个网站已经过时并且无法进入) 实现效果 https://xkk1.github.io/…

出差在外,远程访问企业局域网象过河ERP系统「内网穿透」

文章目录 概述1.查看象过河服务端端口2.内网穿透3. 异地公网连接4. 固定公网地址4.1 保留一个固定TCP地址4.2 配置固定TCP地址 5. 使用固定地址连接 转载自远程穿透文章&#xff1a;公网远程访问公司内网象过河ERP系统「内网穿透」 概述 ERP系统对于企业来说重要性不言而喻&am…

初识中央处理器CPU

目录 一、CPU功能 1.控制器功能 2.运算器功能 3.功能执行顺序 4.其他功能 二、CPU结构图 1.CPU与系统总线 2.CPU内部结构 3.运算器中的寄存器组 4.控制器中的寄存器组 三、执行指令的过程 1.指令周期的基本概念 2.完整的指令周期流程 3.数据通路 4.指令周期的数据…

React超级简单易懂全面的有关问题回答(面试)

目录 React事件机制&#xff1a; 2、React的事件和普通的HTML有什么不同&#xff1a; - 事件命名的规则不同&#xff0c;原生事件采用全小写&#xff0c;react事件采用小驼峰 3、React组件中怎么做事件代理&#xff1f;他的原理是什么&#xff1f; 4、React高阶组件、Rend…

【SpringBoot】 整合RabbitMQ 保证消息可靠性传递

生产者端 目录结构 导入依赖 修改yml 业务逻辑 测试结果 生产者端 目录结构 导入依赖 <dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter</artifactId></dependency>&…

无人机集群路径规划:淘金优化算法(Gold rush optimizer,GRO)提供MATLAB代码

一、淘金优化算法GRO 淘金优化算法&#xff08;Gold rush optimizer&#xff0c;GRO&#xff09;由Kamran Zolf于2023年提出&#xff0c;其灵感来自淘金热&#xff0c;模拟淘金者进行黄金勘探行为。 参考文献&#xff1a; K. Zolfi. Gold rush optimizer: A new population-ba…

Python小姿势 - #### Python技术博客:Python多线程编程

Python技术博客&#xff1a;Python多线程编程 你好&#xff0c;这里是自媒体技术博主Aurora&#xff0c;今天我想分享一下Python多线程编程。 首先&#xff0c;什么是多线程编程&#xff1f;多线程编程是一种让多个线程同时执行的编程方式&#xff0c;它可以让程序的执行更加高…

2023年华中杯C题计算结果

经过一晚上代码的编写&#xff0c;论文的写作&#xff0c;C题完整版论文已经发布&#xff0c; 注&#xff1a;蓝色字体为说明备注解释字体&#xff0c;不能出现在大家的论文里。黑色字体为论文部分&#xff0c;大家可以根据红色字体的注记进行摘抄。对应的详细的写作视频教程&…

推荐一款网站内链爬取python脚本

目标 使用 web-tools 提供的webSpider来爬取网站内链&#xff0c;并且将其导出。 webSpider介绍&#xff1a; 官网链接&#xff1a;https://web-tools.cn/web-spider 仓库地址&#xff1a;https://github.com/duerhong/web-spider Web Spider 专门用于爬取网站内链&#xf…

C++ srand()和rand()用法

参考C rand 与 srand 的用法 计算机的随机数都是由伪随机数&#xff0c;即是由小M多项式序列生成的&#xff0c;其中产生每个小序列都有一个初始值&#xff0c;即随机种子。&#xff08;注意&#xff1a; 小M多项式序列的周期是65535&#xff0c;即每次利用一个随机种子生成的随…

论文学习笔记:Transformer Attention Is All You Need

Transformer: Attention Is All You Need 2022 年年底&#xff0c;一个大语言模型 ChatGPT 横空出世&#xff0c;并且迅速点燃了普罗大众对 AI 的热情&#xff0c;短短两个月&#xff0c; ChatGPT 就成为了史上最快成为上亿月活的应用&#xff0c;并且持续受到关注&#xff0c…

【Vue2.0源码学习】变化侦测篇-Object的变化侦测

文章目录 1. 前言2. 使Object数据变得“可观测”3. 依赖收集3.1 什么是依赖收集3.2 何时收集依赖&#xff1f;何时通知依赖更新&#xff1f;3.3 把依赖收集到哪里 4. 依赖到底是谁5. 不足之处6. 总结 1. 前言 我们知道&#xff1a;数据驱动视图的关键点则在于我们如何知道数据发…

记录docker swarm的使用

在前面的几篇文章中我们依次学习了dockerfile、docker-compose的使用&#xff0c;接下来是docker有一个比较 重要的使用&#xff0c;docker swarm的使用&#xff0c;与dockerfile和docker-compose相比较而言&#xff0c;docker swarm是在 多个服务器或主机上创建容器集群服务准…

Leetcode——66. 加一

&#x1f4af;&#x1f4af;欢迎来到的热爱编程的小K的Leetcode的刷题专栏 文章目录 1、题目2、暴力模拟(自己的第一想法)3、官方题解 1、题目 给定一个由 整数 组成的 非空 数组所表示的非负整数&#xff0c;在该数的基础上加一。最高位数字存放在数组的首位&#xff0c; 数组…

CTF-PHP反序列化漏洞2-典型题目

作者&#xff1a;Eason_LYC 悲观者预言失败&#xff0c;十言九中。 乐观者创造奇迹&#xff0c;一次即可。 一个人的价值&#xff0c;在于他所拥有的。可以不学无术&#xff0c;但不能一无所有&#xff01; 技术领域&#xff1a;WEB安全、网络攻防 关注WEB安全、网络攻防。我的…

【纯属娱乐】随机森林预测双色球

目录 一、数据标准化二、预测代码三、后续 一、数据标准化 首先&#xff0c;我们需要对原始数据进行处理&#xff0c;将其转换为可用于机器学习的格式。我们可以将开奖号码中的红球和蓝球分开&#xff0c;将其转换为独热编码&#xff0c;然后将其与期数一起作为特征输入到机器…

ETL工具 - Kettle 查询、连接、统计、脚本算子介绍

一、 Kettle 上篇文章对 Kettle 流程、应用算子进行了介绍&#xff0c;本篇对查询、连接、统计、脚本算子进行讲解&#xff0c;下面是上篇文章的地址&#xff1a; ETL工具 - Kettle 流程、应用算子介绍 二、查询算子 数据输入使用 MySQL 表输入&#xff0c;表结构如下&#x…

给httprunnermanager接口自动化测试平台换点颜色瞧瞧

文章目录 一、背景1.1、修改注册表单的提示颜色1.2、修改后台代码&#xff1a;注册错误提示&#xff0c;最后提交注册&#xff0c;密码校验&#xff1b;1.3、修改了注册&#xff0c;那登录呢&#xff0c;也不能放过二、总结 一、背景 虽然咱给HttpRunnerManger引入进来&#xf…

【云台】开源版本SimpleBGC的电机驱动与控制方式

前言 最近想学习一下云台&#xff0c;发现资料确实还不太好找&#xff0c;比较有参考价值的是俄版的开源版本的云台代码&#xff0c;后面就不开源了&#xff0c;开源版本的是比较原始的算法&#xff0c;差不多是玩具级别的&#xff0c;不过还是决定学习一下&#xff0c;了解一…