Pytorch深度学习实践笔记3

news2024/9/20 20:24:13

🎬个人简介:一个全栈工程师的升级之路!
📋个人专栏:pytorch深度学习
🎀CSDN主页 发狂的小花
🌄人生秘诀:学习的本质就是极致重复!

视频来自【b站刘二大人】

目录

1 梯度下降(Gradient Descent)

2 随机梯度下降(SGD Stochastic Gradient Descent)

3 批量梯度下降(BGD Batch Gradient Descent)

4 小批量梯度下降(mini-Batch GD,mini-batch Gradient Descent)

5 代码


1 梯度下降(Gradient Descent)

  • 梯度:

方向导数在该点的最大值,建立cost与w的关系,优化使得w能够快速收敛

  • 引入:

可以发现上一章节我们寻找权重ω的时候,使用的是遍历 ω 的方法,显然在工程上这是不可行的,于是引入了梯度下降(Gradient Descent)算法。

  • 方案:

我们的目标是找出 ω∗ 最小化 cost(ω)函数 。梯度下降使用公式ω=ω−α∗∂cost∂ω,其中α是人为设定的学习率,显然 ω 总是往cost局部最小化的方向趋近(可以注意并不总是往全局最优的方向)

  • 局部最优:

我们经常担心模型在训练的过程中陷入局部最优的困境中,但实际上由于Mini-batch的存在在实际工程中模型陷入鞍点(局部最优)的概率是很小的

  • epoch:

轮次,一个epoch指的是所有的训练样本在模型中都进行了一次正向传播和一次反向传播


2 随机梯度下降(SGD Stochastic Gradient Descent)


随机梯度下降算法在梯度下降算法的基础上进行了一定的优化,其对于每一个实例都进行更新,也就是不是用MSE,而是

对 ω进行更新
优势:SGD更不容易陷入鞍点之中,同时其拥有更好的性能
优点:
由于不是在全部训练数据上的损失函数,而是在每轮迭代中,随机优化某一条训练数据上的损失函数,这样每一轮参数的更新速度大大加快。
缺点:
准确度下降。由于即使在目标函数为强凸函数的情况下,SGD仍旧无法做到线性收敛。
可能会收敛到局部最优,由于单个样本并不能代表全体样本的趋势.
不易于并行实现。

for i in range(number of epochs):
        np.random.shuffle(data)
        for each in data:
                weights_grad = evaluate_gradient(loss_function, each, weights)
                weights = weights - learning_rate * weights_grad


3 批量梯度下降(BGD Batch Gradient Descent)


BGD通常是取所有训练样本损失函数的平均作为损失函数,每次计算所有样本的梯度,进行求均值,计算量比较大,会陷入鞍点
优点:
一次迭代是对所有样本进行计算,此时利用矩阵进行操作,实现了并行。
由全数据集确定的方向能够更好地代表样本总体,从而更准确地朝向极值所在的方向。当目标函数为凸函数时,BGD一定能够得到全局最优。
缺点:
当样本数目 m 很大时,每迭代一步都需要对所有样本计算,训练过程会很慢。(有些样本被重复计算,浪费资源)

for i in range(number of epochs):
        np.random.shuffle(data)
        for each in data:
                weights_grad = evaluate_gradient(loss_function, each, weights)
                weights = weights - learning_rate * weights_grad


4 小批量梯度下降(mini-Batch GD,mini-batch Gradient Descent)


mini-batch GD采取了一个折中的方法,每次选取一定数目(mini-batch)的样本组成一个小批量样本,然后用这个小批量来更新梯度,这样不仅可以减少计算成本,还可以提高算法稳定性。

for i in range(number of epochs):
        np.random.shuffle(data)
        for batch in get_batches(data, batch_size = batch_size):
                weights_grad = evaluate_gradient(loss_function, batch, weights)
                weights = weights - learning_rate * weights_grad


优点:融合了BGD和SGD优点

  • 通过矩阵运算,每次在一个batch上优化神经网络参数并不会比单个数据慢太多。
  • 每次使用一个batch可以大大减小收敛所需要的迭代次数,同时可以使收敛到的结果更加接近梯度下降的效果。
  • 可实现并行化。

梯度下降:BGD、SGD、mini-batch GD介绍及其优缺点​

blog.csdn.net/qq_41375609/article/details/112913848​编辑


5 代码

  • BGD
import matplotlib.pyplot as plt
import numpy as np

# BGD 批量梯度下降

x_data = np.arange(1.0,200.0,1.0)
y_data = np.arange(2.0,400.0,2.0)

def forward(x,w):
    return x*w

def cost(x,y,w):
    cost = 0
    for x_val,y_true in zip(x,y):
        y_pred = forward(x_val,w)
        loss_val = (y_true - y_pred)**2
        cost = cost + loss_val
    return cost/len(x)

def gradient(x,y,w):
    gradient = 0
    for x_val,y_true in zip(x,y):
        gradient_temp = 2 * x_val *(x_val * w - y_true)
        gradient = gradient + gradient_temp
    
    return gradient/len(x)

w = 1.0
lr = 0.00001

epoch_list = []
cost_list = []

print("Before train 4: ",forward(400,w))
for epoch in range(100):
    cost_val = cost(x_data,y_data,w)
    gradient_val = gradient(x_data,y_data,w)
    w = w - lr * gradient_val
    print("epoch: ",epoch," loss: ",cost_val," w: ",w)
    epoch_list.append(epoch)
    cost_list.append(cost_val)

    if (cost_val < 1e-5):
        break

print("After train 4: ",forward(400,w))

plt.plot(epoch_list,cost_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig("./data/pytorch2.png")

  • SGD
import matplotlib.pyplot as plt

import numpy as np

# SGD随机梯度下降

x_data = np.arange(1.0,200.0,1.0)
y_data = np.arange(2.0,400.0,2.0)

def forward(x,w):
    return x * w

def loss(x,y_true,w):
    y_pred = forward(x,w)

    return (y_pred-y_true)**2

def gradient(x,y,w):

    return 2 *x *(x *w-y)

w = 1.0
lr = 0.00001

epoch_list = []
loss_list = []

print("Before train 4: ",forward(400,w))
for epoch in range(1000):
    seed = np.random.choice(range(len(x_data)))
    loss_val = loss(x_data[seed],y_data[seed],w)
    gradient_val = gradient(x_data[seed],y_data[seed],w)
    w -= lr*gradient_val
    print("epoch: ",epoch," loss: ",loss_val," w: ",w)
    epoch_list.append(epoch)
    loss_list.append(loss_val)
    if (loss_val < 1e-7):
        break
print("After train 4: ",forward(400,w))

plt.plot(epoch_list,loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig("./data/pytorch2_1.png")

  • mini-Batch GD
import matplotlib.pyplot as plt

import numpy as np
import random

# mini-batch GD 小批量随机梯度下降

x_data = np.arange(1.0,200.0,1.0)
y_data = np.arange(2.0,400.0,2.0)

def forward(x,w):
    return x*w

def cost(x,y,w):
    cost = 0
    for x_val,y_true in zip(x,y):
        y_pred = forward(x_val,w)
        loss_val = (y_true - y_pred)**2
        cost = cost + loss_val
    return cost/len(x)

def gradient(x,y,w):
    gradient = 0
    for x_val,y_true in zip(x,y):
        gradient_temp = 2 * x_val *(x_val * w - y_true)
        gradient = gradient + gradient_temp
    
    return gradient/len(x)

 
def get_seed_two(nums):
    # 从数组中随机取两个索引
    index1 = random.randrange(len(nums))
    index2 = random.randrange(len(nums))
    while index2 == index1:
        index2 = random.randrange(len(nums))
    
    return index1, index2
 
w = 1.0
lr = 0.00001
batch_size = 2

epoch_list = []
loss_list = []

# 存储随机取出的数的索引
seed = []

# 存储一个batch的数据
x_data_mini = []
y_data_mini = []

print("Before train 4: ",forward(400,w))
for epoch in range(1000):
    # 设定Batchsize 大小为2,每次随机取所有数据中的两个,作为一个batch,进行训练
    idx1,idx2= get_seed_two(x_data)
    seed.append(idx1)
    seed.append(idx2)

    for i in range(batch_size):
        x_data_mini.append(x_data[seed[i]])
        y_data_mini.append(y_data[seed[i]])

    loss_val = cost(x_data_mini,y_data_mini,w)
    gradient_val = gradient(x_data_mini,y_data_mini,w)

    w -= lr*gradient_val
    print("epoch: ",epoch," loss: ",loss_val," w: ",w)
    epoch_list.append(epoch)
    loss_list.append(loss_val)
    if (loss_val < 1e-7):
        break
print("After train 4: ",forward(400,w))

plt.plot(epoch_list,loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig("./data/pytorch2_2.png")

🌈我的分享也就到此结束啦🌈
如果我的分享也能对你有帮助,那就太好了!
若有不足,还请大家多多指正,我们一起学习交流!
📢未来的富豪们:点赞👍→收藏⭐→关注🔍,如果能评论下就太惊喜了!
感谢大家的观看和支持!最后,☺祝愿大家每天有钱赚!!!欢迎关注、关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1693896.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

html简述——part1

HTML概述 HTML&#xff08;HyperText Markup Language&#xff09;是一种用于创建网页的标准标记语言&#xff0c;具体指超文本标记语言。它不是一种编程语言&#xff0c;而是一种标记语言&#xff0c;用于描述网页的结构和内容。通过HTML&#xff0c;开发者可以定义网页的标题…

【算法】递归、搜索与回溯——简介

简介&#xff1a;递归、搜索与回溯&#xff0c;本节博客主要是简单记录一下关于“递归、搜索与回溯”的相关简单概念&#xff0c;为后续算法做铺垫。 目录 1.递归1.1递归概念2.2递归意义2.3学习递归2.4写递归代码步骤 2.搜索3.回溯与剪枝 递归、搜索、回溯的关系&#xff1a; …

广告圈策划大师课:活动策划到品牌企划的深度解析

对于刚接触营销策划的新人来说&#xff0c;在这个知识密集型行业里生存&#xff0c;要学习非常多各种意思相近的概念&#xff0c;常常让人感到头疼&#xff0c;难以区分。 这里对这些策划概念进行深入解析&#xff0c;帮助您轻松理清各自的含义和区别。 1. 活动策划&#xff…

CCF20230901——坐标变换(其一)

CCF20230901——坐标变换&#xff08;其一&#xff09; #include<bits/stdc.h> using namespace std; int main() {int n,m,x[101],y[101],x1[101],y1[101];cin>>n>>m;for(int i0;i<n;i)cin>>x1[i]>>y1[i];for(int j0;j<m;j)cin>>x[…

PD协议:引领电子设备充电新时代

随着科技的飞速发展&#xff0c;电子设备已成为我们日常生活中不可或缺的一部分。然而&#xff0c;这些设备的充电问题一直困扰着广大用户。传统的充电方式不仅效率低下&#xff0c;而且存在着安全隐患。为了解决这一问题&#xff0c;USB Implementers Forum&#xff08;USB-IF…

IPv6 地址创建 EUI-64 格式接口 ID 的过程

IPv6 接口标识符 IPv6 地址中的接口标识符&#xff08;ID&#xff09;用于识别链路上的唯一接口&#xff0c;有时被称为 IPv6 地址的 “主机部分”。接口 ID 在链路上必须是唯一的&#xff0c;始终为 64 位长&#xff0c;并且可以根据数据链路层地址动态创建。 MAC 地址 中的…

【C++项目】实时聊天的在线匹配五子棋对战游戏

目录 项目介绍 开发环境 核心技术 项目前置知识点介绍 Websocketpp 1. WebSocket基本认识 2. WebSocket协议切换原理解析 3. WebSocket报文格式 4. Websocketpp介绍 5. 搭建一个简单WebSocket服务器 JsonCpp 1. Json格式的基本认识 2. JsonCpp介绍 3. 序列化与反序…

在ubuntu中关于驱动得问题:如何将nouveau驱动程序加入黑名单和安装NVIDIA显卡驱动

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、nouveau驱动程序加入黑名单二、安装NVIDIA显卡驱动 一、nouveau驱动程序加入黑名单 (1) 打开黑名单列表文件 终端输入&#xff1a; sudo gedit /etc/modprobe…

CCF20230501——重复局面

CCF20230501——重复局面 代码如下&#xff1a; #include<bits/stdc.h> using namespace std; int main() {int n;cin>>n;char a[101][64];int i,j;for(i0;i<n;i){for(j0;j<64;j){cin>>a[i][j];}}int temp0,flag1;for(i0;i<n;i){flag1;for(j0;j<…

Linux程序开发(十二):线程与多线程同步互斥实现抢票系统

Tips&#xff1a;"分享是快乐的源泉&#x1f4a7;&#xff0c;在我的博客里&#xff0c;不仅有知识的海洋&#x1f30a;&#xff0c;还有满满的正能量加持&#x1f4aa;&#xff0c;快来和我一起分享这份快乐吧&#x1f60a;&#xff01; 喜欢我的博客的话&#xff0c;记得…

Mongodb分布式id

1、分布式id使用场景 分布式ID是指在分布式系统中用于唯一标识每个元素的数字或字符串。在分布式系统中&#xff0c;各个节点或服务可能独立运行在不同的服务器、数据中心或地理位置&#xff0c;因此需要一种机制来确保每个生成的ID都是全局唯一的&#xff0c;以避免ID冲突。 …

Pytorch线性模型(Linear Model)

基本步骤 ①首先准备好数据集&#xff08;DataSet&#xff09; ②模型的选择或者设计&#xff08;Model&#xff09; ③进行训练&#xff08;Train&#xff09;大部分模型都需要训练&#xff0c;有些不需要。这一步后我们会确定不同特征的权重 ④推理&#xff08;inferring…

就业班 第三阶段(ELK) 2401--5.20 day1 ELK 企业实战 ES+head+kibana+logstash部署(最大集群)

ELKkafkafilebeat企业内部日志分析系统 1、组件介绍 1、Elasticsearch&#xff1a; 是一个基于Lucene的搜索服务器。提供搜集、分析、存储数据三大功能。它提供了一个分布式多用户能力的全文搜索引擎&#xff0c;基于RESTful web接口。Elasticsearch是用Java开发的&#xff…

学习单向链表带哨兵demo

一、定义 在计算机科学中&#xff0c;链表是数据元素的线性集合&#xff0c;其每个元素都指向下一个元素&#xff0c;元素存储上并不连续。 1.可以分三类为 单向链表&#xff0c;每个元素只知道其下一个元素是谁 双向链表&#xff0c;每个元素知道其上一个元素和下一个元素 …

mySql从入门到入土

基础篇 在cmd中使用MYSQL的相关指令&#xff1a; net start mysql // 启动mysql服务 net stop mysql // 停止mysql服务 mysql -uroot -p1234//登录MYSQL&#xff08;-u为用户名-p为密码&#xff09; //登录参数 mysql -u用户名 -p密码 -h要连接的mysql服务器的ip地址(默认1…

记一次安卓“Low on memory“崩溃问题

前言 最近再调人脸识别算法相关demo,发现调试期间总是偶发性崩溃&#xff0c;捕获不到异常的那种&#xff0c;看日志发现原因是Low on memory&#xff0c;一开始还疑惑 App内存不够应该是OOM啊,怎么会出现这种问题&#xff0c;百思不得其解&#xff0c;直到我打开了 Android s…

Git 仓库中 -- 代码冲突产生、定位、解决的流程

目录 前置知识1 工具环境2 冲突的产生2.1 仓库中的源代码2.2 人员 A 首先更改代码2.3 人员 B 更改代码&#xff0c;产生冲突2.3.1 第一次错误提示&#xff1a;2.3.2 第二次错误提示&#xff1a; 3 查看冲突4 手动解决冲突4.1 方式一4.2 方式二&#xff08;tortoisegit&#xff…

Vitis HLS 学习笔记--控制驱动任务示例

目录 1. 简介 2. 代码解析 2.1 kernel 代码回顾 2.2 功能分析 2.3 查看综合报告 2.4 查看 Schedule Viewer 2.5 查看 Dataflow Viewer 3. Vitis IDE的关键设置 3.1 加载数据文件 3.2 设置 Flow Target 3.3 配置 fifo 深度 4. 总结 1. 简介 本文对《Vitis HLS 学习…

CSAPP(datalab)解析

howManyBits /* howManyBits - 返回用二进制补码表示x所需的最小位数* 示例: howManyBits(12) 5* howManyBits(298) 10* howManyBits(-5) 4* howManyBits(0) 1* howManyBits(-1) 1* howManyBits(0x80000000) …

【Linux】TCP协议【上】{协议段属性:源端口号/目的端口号/序号/确认序号/窗口大小/紧急指针/标记位}

文章目录 1.引入2.协议段格式4位首部长度16位窗口大小32位序号思考三个问题【demo】标记位URG: 紧急指针是否有效提升某报文被处理优先级【0表示不设置1表示设置】ACK: 确认号是否有效PSH: 提示接收端应用程序立刻从TCP缓冲区把数据读走RST: 对方要求重新建立连接; 我们把携带R…