求函数最小值-torch版

news2024/9/27 12:18:07

目标:torch实现下面链接中的梯度下降法

先计算 y=x^2 的导函数 y{}'=2x ,然后计算导函数 在x_{0}=7.5处的梯度 (导数)

让 x_{0}沿着 梯度的负方向移动

x\leftarrow x-y{}'x_{0}

自变量x 的更新过程如下

x_{1}\leftarrow x_{0}-y{}'_{x0}

x_{2}\leftarrow x_{1}-y{}'_{x1}

x_{3}\leftarrow x_{2}-y{}'_{x2}

\cdot \cdot \cdot

x_{n}\leftarrow x_{n-1}-y{}'_{x_{n-1}}

torch代码实现如下

import torch

x = torch.tensor([7.5],requires_grad=True)
# print(x.grad)

optimizer = torch.optim.SGD([x], lr=1)

print('x_0 = {}'.format(x))

for i in range(10):
    y = x * x
    optimizer.zero_grad()
    y.backward()

    optimizer.step()
    print('x_{} = {}'.format(i+1,x))

运行效果如下:

x_0 = tensor([7.5000], requires_grad=True)
x_1 = tensor([-7.5000], requires_grad=True)
x_2 = tensor([7.5000], requires_grad=True)
x_3 = tensor([-7.5000], requires_grad=True)
x_4 = tensor([7.5000], requires_grad=True)
x_5 = tensor([-7.5000], requires_grad=True)
x_6 = tensor([7.5000], requires_grad=True)
x_7 = tensor([-7.5000], requires_grad=True)
x_8 = tensor([7.5000], requires_grad=True)
x_9 = tensor([-7.5000], requires_grad=True)
x_10 = tensor([7.5000], requires_grad=True)

给梯度加系数

我们可以给 梯度 加个系数,如下

x_{1}\leftarrow x_{0}-0.01*y{}'_{x0}

x_{2}\leftarrow x_{1}-0.01*y{}'_{x1}

x_{3}\leftarrow x_{2}-0.01*y{}'_{x2}

\cdot \cdot \cdot

x_{n}\leftarrow x_{n-1}-0.01*y{}'_{x_{n-1}}

torch代码实现如下

import torch

x = torch.tensor([7.5],requires_grad=True)
# print(x.grad)

optimizer = torch.optim.SGD([x], lr=0.01)

print('x_0 = {}'.format(x))

for i in range(10):
    y = x * x
    optimizer.zero_grad()
    y.backward()

    optimizer.step()
    print('x_{} = {}'.format(i+1,x))

运行效果如下:

x_0 = tensor([7.5000], requires_grad=True)
x_1 = tensor([7.3500], requires_grad=True)
x_2 = tensor([7.2030], requires_grad=True)
x_3 = tensor([7.0589], requires_grad=True)
x_4 = tensor([6.9178], requires_grad=True)
x_5 = tensor([6.7794], requires_grad=True)
x_6 = tensor([6.6438], requires_grad=True)
x_7 = tensor([6.5109], requires_grad=True)
x_8 = tensor([6.3807], requires_grad=True)
x_9 = tensor([6.2531], requires_grad=True)
x_10 = tensor([6.1280], requires_grad=True)

调迭代次数

发现 x变化的很慢,我们可以增加迭代次数,如下

import torch

x = torch.tensor([7.5],requires_grad=True)
# print(x.grad)

optimizer = torch.optim.SGD([x], lr=0.01)

print('x_0 = {}'.format(x))

for i in range(200):
    y = x * x
    optimizer.zero_grad()
    y.backward()

    optimizer.step()
    print('x_{} = {}'.format(i+1,x))

运行结果如下:

x_0 = tensor([7.5000], requires_grad=True)
x_1 = tensor([7.3500], requires_grad=True)
x_2 = tensor([7.2030], requires_grad=True)
...
x_199 = tensor([0.1346], requires_grad=True)
x_200 = tensor([0.1319], requires_grad=True)

调梯度系数

我们把 0.01 换成 0.1 试试

import torch

x = torch.tensor([7.5],requires_grad=True)
# print(x.grad)

optimizer = torch.optim.SGD([x], lr=0.1)

print('x_0 = {}'.format(x))

for i in range(10):
    y = x * x
    optimizer.zero_grad()
    y.backward()

    optimizer.step()
    print('x_{} = {}'.format(i+1,x))

运行结果如下:

x_0 = tensor([7.5000], requires_grad=True)
x_1 = tensor([6.], requires_grad=True)
x_2 = tensor([4.8000], requires_grad=True)
x_3 = tensor([3.8400], requires_grad=True)
x_4 = tensor([3.0720], requires_grad=True)
x_5 = tensor([2.4576], requires_grad=True)
x_6 = tensor([1.9661], requires_grad=True)
x_7 = tensor([1.5729], requires_grad=True)
x_8 = tensor([1.2583], requires_grad=True)
x_9 = tensor([1.0066], requires_grad=True)
x_10 = tensor([0.8053], requires_grad=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1914481.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【xinference】(15):在compshare上,使用docker-compose运行xinference和chatgpt-web项目,配置成功!!!

视频演示 【xinference】(15):在compshare上,使用docker-compose运行xinference和chatgpt-web项目,配置成功!!! 1,安装docker方法: #!/bin/shdistribution$(…

【SVN-CornerStone客户端使用SVN-多人开发-解决冲突 Objective-C语言】

一、接下来,我们来说第三方的图形化界面啊, 1.Corner Stone:图形化界面,使用SVN, Corner Stone的界面,大概就是这样的, 1)左下角:是我们远程的一个仓库, 2)右上角:是我们本地的一些东西, 首先,在我的服务器上,再开一个仓库,叫做wechat, 我在这个里边,新建…

红队常用命令速查大全(非常详细)零基础入门到精通,收藏这一篇就够了

这里我整合并且整理成了一份【282G】的网络安全/红客技术从零基础入门到进阶资料包,需要的小伙伴文末免费领取哦,无偿分享!!! 对于从来没有接触过网络安全的同学,我们帮你准备了详细的学习成长路线图。可以…

开放式耳机什么品牌好?四款音质好的开放式耳机推荐

长时间佩戴耳机,舒适度成为了不可忽视的因素。开放式耳机通常采用轻量化材料和透气耳垫,减轻耳朵的负担,即使长时间聆听,也能保持耳朵的舒适与干爽。 然而,众多品牌的开放式耳机琳琅满目,究竟哪个品牌的开…

从重庆元宇宙国风秀看未来元宇宙发展趋势

2024年2月24日,为纪念梅兰芳先生诞辰130周年,以“新国风东方美”为主题的【承华灵境】元宇宙国风秀在重庆市人民大礼堂发布。这场活动将中国经典艺术与数字化技术融合,呈现了一场新国风东方美学的跨越时空人文科技之旅,其中的重点…

【Linux】数据流重定向

数据流重定向(redirect)由字面上的意思来看,好像就是将【数据给它定向到其他地方去】的样子? 没错,数据流重定向就是将某个命令执行后应该要出现在屏幕上的数据,给它传输到其他的地方,例如文件或…

前端发布项目后,解决缓存的老版本文件问题

最近碰到如题目所说的问题,用了思路一的解决方法,结束之后又上网看技术大牛们的解决方法,总结得出下面的文章。 方式一:纯前端 每次打包发版时都使用webpack构建一个version.json文件,文件里的内容是一个随机的字符串…

递归(五)—— 初识暴力递归之“如何利用递归实现栈逆序”

题目:要求不使用额外的数据结构,仅利用递归函数实现栈的逆序。 题目分析: 利用实例来理解题意,栈内元素从栈底到栈顶一次是3,2,1 ,要求经过处理后,栈底到栈顶依次是1,2…

FastAPI 学习之路(三十五)项目结构优化

之前我们创建的文件都是在一个目录中,但是在我们的实际开发中,肯定不能这样设计,那么我们去创建一个目录,叫models,大致如下。 主要目录是: __init__.py 是一个空文件,说明models是一个package…

前端面试39(关于git)

针对前端开发者的Git面试题可以覆盖Git的基础概念、常用命令、工作流程、团队协作、以及解决冲突等方面。以下是一些具体的Git面试 Git基础知识 什么是Git? Git是一个分布式版本控制系统,用于跟踪计算机文件的更改,并协调多个人共同在一个项…

最简单详细的jwt用户登录校验教程(新手必看)

首先简单建张用户表。 DROP TABLE IF EXISTS user; CREATE TABLE user (id bigint NOT NULL AUTO_INCREMENT,name varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL,username varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL…

C++ 编译体系入门指北

前言 之从入坑C之后,项目中的编译构建就经常跟CMake打交道,但对它缺乏系统的了解,遇到问题又陷入盲人摸象。对C的编译体系是如何发展的,为什么要用CMake,它的运作原理是如何的比较感兴趣,所以就想系统学习…

迁移至 AI-Ready 基础架构:日立内容平台至 MinIO

借助我们的 HCP-to-MinIO 工具,从 Hitachi Content Platform (HCP) 过渡到 MinIO 从未如此简单。该工具旨在支持客户不断变化的存储需求,可在 GitHub 上免费获得,大大简化了迁移过程。许多组织正在转型,以利…

台灯怎么选对眼睛好?六大重点教你台灯怎么选不踩雷

根据2024年国家卫健委最新公布的数据,我国儿童青少年总体近视率为52.7%,其中,小学生为42%,初中生为80.7%,高中生为85.7%。儿童的学习环境对学习效果和视力健康都有很大影响。面对日益严峻的近视形势,家长和…

基因检测3 - 遗传性耳聋

1. 耳聋简介 在每1000个新生儿中有1-3个耳聋患儿,绝大部分为遗传学耳聋。遗传性耳聋疾病的遗传方式包括常染色体隐性遗传、常染色体显性遗传、线粒体遗传以及伴性遗传。 根据遗传性耳聋除听力损失外是否存在其他表型,将耳聋分为综合征型耳聋 &#xff…

c++ 多边形 xyz 数据 获取 中心点方法,线的中心点取中心值搞定 已解决

有需求需要对。多边形 获取中心点方法&#xff0c;绝大多数都是 puthon和java版本。立体几何学中的知识。 封装函数 point ##########::getCenterOfGravity(std::vector<point> polygon) {if (polygon.size() < 2)return point();auto Area [](point p0, point p1, p…

数据结构之顺序表(入门)

在了解顺序表之前我们需要先了解什么是线性表 1.线性表的定义 线性表(List)&#xff1a;由零个或多个数据组成的有限数列&#xff0c;线性表是一种在实际中广泛使用的数据结构&#xff0c; 常见的线性表&#xff1a;顺序表&#xff0c;链表&#xff0c;栈&#xff0c;队列&…

LoRaWAN网络协议Class A/Class B/Class C三种工作模式说明

LoRaWAN是一种专为广域物联网设计的低功耗广域网络协议。它特别适用于物联网&#xff08;IoT&#xff09;设备&#xff0c;可以在低数据速率下进行长距离通信。LoRaWAN 网络由多个组成部分构成&#xff0c;其中包括节点&#xff08;终端设备&#xff09;、网关和网络服务器。Lo…

python中unittest框架应用

1、Unittest为Python内嵌的测试框架&#xff0c;不需要特殊配置 2、编写规范 需要导入 import unittest 测试类必须继承unittest.TestCase 测试方法以 test_开头 模块和类名没有要求 TestCase 理解为写测试用例 TestSuite 理解为测试用例的集合 TestLoader 理解为的测试…

数字经济时代,你有数商吗?

引言&#xff1a;随着科技的飞速发展&#xff0c;我们正步入一个全新的数字经济时代。在这个时代里&#xff0c;数据成为了新的石油&#xff0c;是推动经济增长和社会进步的关键要素。而在这个数据洪流中&#xff0c;一个新兴的概念——“数商”&#xff0c;正逐渐进入公众的视…