【PyTorch深度学习实践】02_梯度下降

news2025/1/16 4:44:28

文章目录

    • 梯度下降
      • 1.梯度下降算法
        • 实现代码
      • 2.随机梯度下降
        • 实现代码
      • 3.小批量随机梯度下降

梯度下降

1.梯度下降算法

之前可以使用穷举的方法逐个测试找使损失函数最小的点,但当数据过多时,维度过高,会使穷举变得非常困难,因此需要优化,梯度下降法就是其中一种优化方式。
要找到最小值的点,可以让点沿着下降最快的方向移动,梯度的负方向就是下降最快的方向,w随之更新。
在这里插入图片描述
图中公式的α值代表学习率,通常是一个很小的数,代表步长。

这样的计算存在两个问题,如下图:
在这里插入图片描述在这里插入图片描述
本例中的梯度计算结果如下图:
在这里插入图片描述

实现代码

import numpy as np               # 导入必须的包
import matplotlib.pyplot as plt

# 1.建立线性模型
x_data = [1.0, 2.0, 3.0]  # 数据集整理
y_data = [2.0, 4.0, 6.0]

w = 1.0                  # 给权重w一个初始值

def forward(x):      # 定义线性模型,计算y_hat
    return x * w

def cost(xs,ys):        # 定义mse
    cost = 0
    for x,y in zip(xs,ys):
        y_pred = forward(x)   #计算y_hat
        cost += (y_pred - y) ** 2 # 求平方并累加
    return cost / len(xs)  # 求均值

def gradient(xs,ys):
    grad = 0
    for x,y in zip(xs,ys):
        grad += 2 * x * (x * w - y)
    return grad / len(xs)

print("Predict (before training)",4 ,forward(4))

for epoch in range(100):            # 进行100轮训练
    cost_val = cost(x_data,y_data)      # 计算损失值
    grad_val = gradient(x_data,y_data)
    w -= 0.01 * grad_val            # 学习率取0.01
    print("Epoch",epoch, "w=",w ,"loss=", cost_val)  #打印日志

print("Predict (after training)",4 ,forward(4))

2.随机梯度下降

cost:N个求平均->loss:从N个随机挑一个

在这里插入图片描述
原因如下图:
在这里插入图片描述

实现代码

# 1.建立线性模型
x_data = [1.0, 2.0, 3.0]  # 数据集整理
y_data = [2.0, 4.0, 6.0]

w = 1.0                  # 给权重w一个初始值

def forward(x):      # 定义线性模型,计算y_hat
    return x * w

def loss(x,y):          # 求其中一个样本的损失值
    y_pred = forward(x)
    return (y_pred - y) ** 2

def gradient(x,y):      # 求梯度
    return 2 * x * (x * w - y)

print("Predict (before training)",4 ,forward(4))	# 训练前

for epoch in range(100):              # 进行100轮训练
    print(f"这是第{epoch}轮")
    for x,y in zip(x_data, y_data):	  # 以下是w的更新,并不随机!
        grad = gradient(x,y)          # 对一个样本求梯度,注意这里没有随机性!不符合随机梯度下降,只是按顺序选了其中一个!
        w = w - 0.01 * grad           # 进行更新,学习率取0.01
        l = loss(x,y)
        print("\tgrad", x, y, w, grad)

print("Predict (after training)",4 ,forward(4))     # 训练后

3.小批量随机梯度下降

随机梯度下降并行度较差,时间复杂度高,但是性能较好。
梯度下降并行度更好,时间复杂度低,但是性能较差。(实话讲这里没听懂老师说的原因,只知道结论了,先记录下来)
因此,引入小批量随机梯度下降(mini-batch),给数据分组,组内使用梯度下降,组间使用随机梯度下降。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/133546.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

K8s 数据管理

目录前言一、Volume1.1 emptyDir1.1.1 基本概念1.1.2 应用案例1.2 hostPath1.2.1 基本概念1.2.2 应用案例1.3 外部 Storage Provider二、Persistent Volume2.1 基本概念2.1.1 PersistentVolume2.1.2 PersistentVolumeClaim2.2 NFS PersistentVolume前言 与 Docker 类似&#x…

QML教程(一)基础语法

目录 一、导入 二、对象声明 三、对象属性 1.声明对象属性 2.信号属性 3.方法属性 4.附加属性略 5.枚举属性 6.对象属性赋值 四、自定义对象 一、导入 模块导入 语法&#xff1a; import <ModuleIdentifier> [<Version.Number>] [as <Qualifier>…

面向对象设计原则概述

面向对象设计原则概述 软件的可维护性和可复用性 软件工程和建模大师Peter coad认为&#xff0c;一个好的系统设计与应该具备如下三个性质 可扩展性 灵活性 可插入性 软件的可维护性和可复用性 软件的复用和重用拥有众多优点&#xff0c;如可以提高软件的开发效率&#xf…

Educational Codeforces Round 92 (Rated for Div. 2) B. Array Walk

翻译&#xff1a; 给定一个数组&#x1d44e;1&#xff0c;&#x1d44e;2&#xff0c;…&#xff0c;&#x1d44e;&#x1d45b;&#xff0c;由&#x1d45b;个正整数组成。 最初&#xff0c;您位于索引1&#xff0c;分数等于&#x1d44e;1。你可以执行两种动作: 向右移动…

CDN

CDN——Content Delivery Network&#xff0c;内容分发网络。 具体来说&#xff0c;CDN就是采用更多的缓存服务器&#xff08;CDN边缘节点&#xff09;&#xff0c;布放在用户访问相对集中的地区或网络中。当用户访问网站时&#xff0c;利用全局负载技术&#xff0c;将用户的访…

【CSP】邻域均值

邻域均值 邻域均值 题意比较好理解&#xff0c;就是算一些数字。如果采用暴力方法的话&#xff0c;就是用一个边长为 2∗r12*r12∗r1 的正方形框框住大矩阵&#xff0c;然后遍历这个框&#xff0c;求出其平均值&#xff0c;然后移动正方形框&#xff0c;直到大矩阵内所有像…

【免费开放源码】审批类小程序项目实战(预约审批端)

第一节&#xff1a;什么构成了微信小程序、创建一个自己的小程序 第二节&#xff1a;微信开发者工具使用教程 第三节&#xff1a;深入了解并掌握小程序核心组件 第四节&#xff1a;初始化云函数和数据库 第五节&#xff1a;云数据库的增删改查 第六节&#xff1a;项目大纲以及制…

6.5 特殊用途语言特性

文章目录默认实参使用默认实参调用函数默认实参声明默认实参初始值内联函数和constexpr函数内联函数constexpr 函数把内联函数和constexpr函数声明在头文件内调试帮助assert预处理宏NDEBUG预处理变量默认实参 某些函数有这样一种形参,在函数的很多次调用中它们都被赋予一个相同…

电子游戏销售之缺失值检测与处理

电子游戏销售之缺失值检测与处理 文章目录电子游戏销售之缺失值检测与处理0、写在前面1、数据缺失值预处理1.1 表的形状1.2 原始数据每个特征缺失和非缺失的数目1.3 每个特征缺失的率1.4 处理后各特征缺失值的数目1.5 删除缺失值后的数据展示2、替换法处理缺失值2.1 替换法2.2 …

1.Springboot配置细节

一、参考资料 13-SpringBoot配置-项目外部配置加载顺序_哔哩哔哩_bilibili 二、配置 2.1 配置文件 注意变量后面是:&#xff0c;而不是等号 2.2 读取配置文件 2.2.1 Value 比如配置文件application.properities中定义了一个name&#xff0c;其值为abc。 代码里面只需按照如…

一、软件安装与配置

一、PyTorch环境软件安装与配置 1.安装anaconda参考 anaconda老版本下载方法&#xff08;如何查看anaconda与python版本对应关系&#xff09;及安装教程_breadth_的博客-CSDN博客_anaconda旧版本下载 2.在anconda下安装和激活pytorch环境 此步并没有下载pytorch 3.下载pyto…

云计算运营—03 KVM虚拟化技术方案介绍

KVM虚拟化技术方案介绍 1.背景介绍 KVM&#xff08;Kernel-based Virtual Machine&#xff09; 开源全虚拟化方案 支持体系结构 x86(32位,64位)、IA64、PowerPC、S390 依赖x86硬件支持&#xff1a;Intel VT-x/ AMD-V内核模块&#xff0c;使得linux内核成为hypervisor XEN架构 …

《B-树》

tips&#xff1a;B-树读成b树&#xff0c;并不是b减树 【一】基本搜索结构 种类数据格式时间复杂度顺序查找无要求O(N)二分查找有序O(log2N)二叉搜索树无要求O(N)二叉平衡树&#xff08;AVL和红黑树&#xff09;无要求&#xff0c;最后随机O(log2N)哈希无要求O(1)位图无要求O…

linux系统中SPI驱动框架的基本原理与实现

大家好&#xff0c;今天主要和大家聊一聊&#xff0c;如何使用linux系统中SPI驱动ICM-20608六轴传感器的操作。 目录 第一&#xff1a;linux系统下SPI驱动框架简介 第二&#xff1a;SPI设备驱动编写 第三&#xff1a;SPI设备和驱动匹配过程 第一&#xff1a;linux系统下SPI驱…

MySQL数据库高级面试题(1)

✅作者简介&#xff1a;热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏&#xff1a;Java面试题…

CSDN年度征文 | 你好,2023

祝大家新年快乐~&#x1f9e7;&#x1f9e7;&#x1f9e7;⭐过去的2022⭐2022已成过去&#xff0c;2023慢步向我们走来。回首2022&#xff0c;这一年不是平凡的一年。这一年&#xff0c;有苦也有乐。冬奥会的成功举办、香港回归25周年、二十大胜利召开、航天任务圆满成功等等都…

设计 | 分享5个好用的PPT模板网站

第一PPT 这个老牌的模板网站了&#xff0c;全站都是免费下载&#xff0c;还是不错的 但是素材质量嘛&#xff0c;免费所以不太高。 第一PPT下载https://www.1ppt.com/ 模板狗 这个是最近发现的一个网站&#xff0c;其中内容比较精美。 而且不用开会员也能单独购买&#x…

【Android】APT

引言&#xff1a; 安卓中APT又叫Annotation Processing Tool&#xff0c;即注解处理器。Java中注解分为编译时注解和运行时注解&#xff0c;编译时注解无法通过反射的方式进行获取和处理&#xff0c;所以我们可以利用APT处理编译时注解。比如常见的三方库ButterKnife、ARouter…

mysql删除重复记录并且只保留一条

准备的测试表结构及数据 插入的数据中A,B,E存在重复数据,C没有重复记录 CREATE TABLE tab ( id int(11) NOT NULL AUTO_INCREMENT, name varchar(20) DEFAULT NULL, PRIMARY KEY (id) ) ENGINEInnoDB AUTO_INCREMENT13 DEFAULT CHARSETutf8; -- --------------------…

使用root用户和普通用户完成分配任务案例 ansible(1)

目录 案例一&#xff1a; 控制主机和受控主机通过root用户通过免密验证方式远程控制受控主机实施对应任务。 案例二&#xff1a; 控制主机连接受控主机通过普通用户以免密验证远程控制受控主机实施特权指定操作。 案例一&#xff1a; 控制主机和受控主机通过root用户通过免…