《PyTorch深度学习实践》第三讲 梯度下降算法

news2024/11/24 8:33:12

《PyTorch深度学习实践》第三讲 梯度下降算法

  • 问题描述
  • 梯度下降
    • 问题分析
    • 编程实现
      • 代码
      • 实现效果
  • 随机梯度下降
    • 问题分析
    • 编程实现
      • 代码
      • 实现效果
  • 参考资料

问题描述

在这里插入图片描述

梯度下降

问题分析

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

编程实现

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码

import matplotlib.pyplot as plt

# 训练集数据
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 设置初始权重猜测
w = 1.0

# 前馈计算
def forward(x):
    return x * w

# 计算损失
def cost(xs, ys):
    cost = 0
    for x, y in zip(xs, ys):
        y_pred = forward(x)
        cost += (y_pred - y) ** 2
    return cost / len(xs)

# 计算梯度
def gradien(xs, ys):
    grad = 0
    for x, y in zip(xs, ys):
        grad += 2 * x * (x * w - y)
    return grad / len(xs)

print('Predict(before training)', 4, forward(4))

# 存放每轮的数据
cost_list = []
epoch_list = []

# 训练过程
for epoch in range(100):  # 训练100轮
    cost_val = cost(x_data, y_data)
    grad_val = gradien(x_data, y_data)   # 更新梯度
    w -= 0.01 * grad_val    # 0.01 学习率
    print('Epoch:', epoch, 'w = ', w, 'loss = ', cost_val)
    cost_list.append(cost_val)
    epoch_list.append(epoch)

print('Predict(after training)', 4, forward(4))

# 绘图展示
plt.plot(epoch_list, cost_list)
plt.xlabel('Epoch')
plt.ylabel('Cost')
plt.show()

实现效果

在这里插入图片描述

随机梯度下降

使用随机梯度下降对上述问题进行求解,随机梯度下降法和梯度下降法的主要区别在于:
1、损失函数由计算所有训练数据的损失,更改为计算一个训练数据的损失。
2、梯度函数由计算所有训练数据的梯度,更改为计算一个训练数据的梯度。

问题分析

在这里插入图片描述

编程实现

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码

import matplotlib.pyplot as plt

# 训练集数据
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 设置初始权重猜测
w = 1.0

# 前馈计算
def forward(x):
    return x * w

# 计算损失
def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2

# 计算梯度
def gradien(x, y):
    return 2 * x * (x * w - y)

print('Predict(before training)', 4, forward(4))

# 存放每轮的数据
loss_list = []
epoch_list = []

# 训练过程
for epoch in range(100):  # 训练100轮
    for x, y in zip(x_data, y_data):
        grad = gradien(x, y)
        w = w - 0.01 * grad
        print('\tgrad:', x, y, grad)
        l = loss(x, y)
    epoch_list.append(epoch)
    loss_list.append(l)

print('Predict(after training)', 4, forward(4))

# 绘图展示
plt.plot(epoch_list, loss_list)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

实现效果

在这里插入图片描述

参考资料

传送门梯度下降算法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1091945.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++算法 —— 贪心(1)介绍

文章目录 1、什么是贪心算法2、特点3、学习方向 1、什么是贪心算法 贪心应当是一个策略,通过局部找到最优,来找到全局最优。它把解决问题的过程分为若干步,解决每一步的时候,都选择当前看起来最优的解法,通过这样做希…

Python数据分析实战-实现F检验(附源码和实现效果)

实现功能 F 检验(F-test)是一种常用的统计方法,用于比较两个或多个样本方差是否存在显著差异。它可以应用于多种场景,其中一些常见的应用场景包括: 方差分析(ANOVA):F 检验在方差分…

【软考-中级】系统集成项目管理工程师-进度管理历年案例

持续更新。。。。。。。。。。。。。。。 进度管理历年案例和解析 2023 上 试题二(20分)2023 上 试题二(20分) 问题1(5分) 结合案例: (1)请写出项目关键路径,并计算项目工期。 答案:项目关键路径 ACEFGJKN,项目工期 80 解析(2)如果活动L工期拖延10天,对整个工期是否有影响…

C语言-程序环境和预处理(1)编译、连接介绍以及预处理函数,预处理符号详解及使用说明。

前言 本篇文章讲述了程序的翻译环境和执行环境,编译、连接,预定义符号,#define,#符号和##符号的相关知识。 文章目录 前言1.程序的翻译环境和执行环境2.编译链接2.1 翻译环境2.2 运行环境 3.预处理详解(各预处理符号使…

Java之SPI

Java的SPI(Service Provider Interface)是一种面向接口编程的机制,用于实现组件之间的解耦和扩展。通过SPI机制,我们可以定义接口,并允许第三方提供不同的实现,从而实现可插拔、可扩展的架构。 SPI讲解 它…

Studio One6.5最新中文版安装步骤

在唱歌效果调试当中,我们经常给客户安装的几款音频工作站。第一,Studio One 6是PreSonus公司开发的一款功能强大的音频工作平台,具有丰富的音频处理功能和灵活的工作流程。以下是Studio One6的一些主要特点: 1.多轨录音和编辑&…

ezEIP信息泄露

漏洞描述 ezEIP存在信息泄露漏洞,通过遍历Cookie中的参数值获取敏感信息 漏洞复现 漏洞Url为 /label/member/getinfo.aspx访问时添加Cookie(通过遍历获取用户的登录名电话邮箱等信息) WHIR_USERINFORwhir_mem_member_pid1;漏洞证明&…

同比增长29.89%,长城汽车9月销售新车超12万辆

10月8日,长城汽车股份有限公司(股票代码601633.SH、02333.HK、82333.HK;以下简称“长城汽车”)发布2023年9月产销数据。今年9月,长城汽车销售新车121,632辆,同比增长29.89%,1-9月累计销售864,04…

安捷伦E9321A射频传感器

安捷伦E9321A射频传感器 E9321A 是 Agilent 使用的 6 GHz 0.1 瓦射频传感器。电子测试设备传感器测量波形的功率,例如多音和调制射频 (RF) 波形。传感器使用二极管检测器收集高度精确的调制测量值。50 MHz 至 6 GHz 300 kHz 视频带宽 功率范围:-65 至 20…

Android JNI调用流程

文章目录 前言一、JNI是什么二、JNI的优劣三、JNI的开发流程Java调用C函数1、创建声明native方法的Java工程,加载native函数的动态库,生成.h文件2、创建实现C函数的C工程,将本地代码编译成动态库C函数和Java本地方法的隐式映射(相…

压缩炸弹,Java怎么防止

一、什么是压缩炸弹,会有什么危害 1.1 什么是压缩炸弹 压缩炸弹(ZIP):一个压缩包只有几十KB,但是解压缩后有几十GB,甚至可以去到几百TB,直接撑爆硬盘,或者是在解压过程中CPU飙到100%造成服务器宕机。虽然…

11-网络篇-DNS步骤

1.URL URL就是我们常说的网址 https://www.baidu.com/?from1086k https是协议 m.baidu.com是服务器域名 ?from1086k是路径 2.域名 比如https://www.baidu.com 顶级域名.com 二级域名baidu 三级域名www 3.域名解析DNS DNS就是将域名转换成IP的过程 根域名服务器&#xff1a…

python2 paramiko 各种报错解决方案

一、介绍 paramiko是一个基于SSHv2协议的python库,支持以加密和认证的方式进行远程服务器的连接,用于实现远程文件的上传、下载或通过ssh远程执行命令。 paramiko支持Python(2.7,3.4)版本 paramiko库可直接使用pip …

谈谈C++中模板分离式编译出现的一些问题

什么是分离式编译 通俗的来讲就是将声明和定义分离在不同文件中 一个程序由若干个源文件共同实现,而每个源文件单独编译生成目标文件,最后将所有 目标文件链接起来形成单一的可执行文件的过程称为分离编译模式。 正常函数与模板分离式编译 看代码&…

生物制剂\化工\化妆品等质检损耗、制造误差处理作业流程图(ODOO15/16)

生物制剂、化工、化妆品等行业,因为产品为液体,产品形态和质量容易在各个业务环节发生变化,常常导致实物和账面数据不一致,如果企业业务流程不清晰,会导致系统大量的库存差异,以及财务难以核算的问题&#…

上门服务小程序源码 理疗,足疗,美容SAP上门服务小程序源码

上门服务小程序源码 理疗,足疗,美容SAP上门服务小程序源码 运行环境:Nginx 1.20PHP7.1MySQL 5.6 通过HBuilder X编译小程序APP版本 一、上门预定操作 1、技师管理。 技师满意度进行统一跟踪评估,进行分级管理,分级…

Web测试框架SeleniumBase

首先,SeleniumBase支持 pip安装: > pip install seleniumbase它依赖的库比较多,包括pytest、nose这些第三方单元测试框架,是为更方便的运行测试用例,因为这两个测试框架是支持unittest测试用例的执行的。 Seleniu…

Canal安装

安装和配置Canal Canal Framework 是阿里巴巴开源的一款基于数据库增量日志解析和同步的数据中间件。它主要用于解决分布式系统中数据同步的问题,支持多种数据源,如 MySQL、SQL Server、PostgreSQL、Oracle 等,同时也支持多种数据目标&#…

函数栈帧的创建与销毁(保姆级讲解)

局部变量是怎么创建的? 在为main函数开辟栈帧空间时,在一定范围内初始化成0CCCCC,再把里面0CCCC的一些开辟空间给局部变量使用。 为什么局部变量的值是随机值? 因为我们在为main函数开辟栈帧空间时,会将一定范围内空间初始成0CCCCCC里面…

【宏实现二进制奇偶位交换】

文章目录 一. 二进制奇偶位交换说明意思?二. 解题思路三. 代码验证四. 总结 一. 二进制奇偶位交换说明意思? 就是一个int类型的整数在操作系统下是32位二进制01序列,第一位和第二位交换,第二位和第三位交换,依次类推。…