pytorch深度学习基础 6(简单的参数估计学习2)

news2024/9/21 23:40:09

上一节我们建立了一个简单的模型进行分析散点图,利用均方差来实现损失函数的计算,但是并没有计算出具体的参数值,这次我们来计算损失函数的损失值以及不断减小损失值,计算出最优的参数,代码原理非常简单大家可以自行理解查阅资料

import numpy as np
import torch
import time
torch.set_printoptions(edgeitems=2)
t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0, 8.0,
                    3.0, -4.0, 6.0, 13.0, 21.0])
t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9,
                    33.9, 21.8, 48.4, 60.4, 68.4])
t_un = 1 * t_u
def model(t_u, w, b):
    return w * t_u + b
def loss_fn(t_p, t_c):
    squared_diffs = (t_p - t_c)**2
    return squared_diffs.mean()
params = torch.tensor([1.0, 0.0], requires_grad=True)
loss = loss_fn(model(t_u, *params), t_c)
loss.backward()
if params.grad is not None:
    params.grad.zero_()


def training_loop(n_epochs, learning_rate, params, t_u, t_c):
    for epoch in range(1, n_epochs + 1):
        if params.grad is not None:  # <1>
            params.grad.zero_()

        t_p = model(t_u, *params)
        loss = loss_fn(t_p, t_c)
        loss.backward()

        with torch.no_grad():  # <2>
            params -= learning_rate * params.grad

        if epoch:
            # time.sleep(0.2)
            print('Epoch %d, Loss %f' % (epoch, float(loss)))
            print('params:',params)
    return params
training_loop(
    n_epochs = 10000,
    learning_rate = 1e-2,
    params = torch.tensor([1.0, 0.0], requires_grad=True), # <1>
    t_u = t_un, # <2>
    t_c = t_c)

我来简单说一下代码的思想,代码中规定了一个model()函数用于计算出输入参数得到的预测值,也就是t_p,把预测的值存在t_p中,然后利用t_p与实际值也就是t_c计算出一个方差,loss_fn()函数的作用就是计算预测值与实际值之间的方差,由于我们使用梯度下降算法,所以每次计算完都需要把梯度重新归零防止梯度累加,training_loop()这个函数作用就是循环训练,并更新所有训练的参数,接下来,我们运行代码

实际上,我们得到的损失值竟然在快速变大,越来越离谱,发生了什么呢?也就是出现我们经常说的梯度爆炸的现象

梯度爆炸

梯度爆炸是指在训练深度神经网络时,梯度(即损失函数对参数的导数)变得异常大,导致参数更新幅度过大,破坏了模型的稳定性,并可能使损失函数值急剧增加。这通常发生在以下几种情况:

  1. 不合适的初始学习率:如果学习率设置得过高,可能会导致在每次参数更新时,参数的改变量过大,从而使得损失函数值快速增加。

  2. 激活函数选择不当:某些激活函数(如Sigmoid或Tanh)在极端值(如非常大或非常小的输入)下可能会导致梯度消失或梯度爆炸。虽然这种情况更常见于梯度消失,但在某些情况下,特别是与其他因素结合时,也可能导致梯度爆炸。

  3. 深度过深的网络:非常深的网络可能导致梯度在反向传播过程中累积放大,特别是当每一层的梯度都比较大时。

  4. 损失函数或优化器设置不当:某些损失函数或优化器可能在特定条件下表现不佳,导致梯度不稳定。

针对你的问题,你可以尝试以下几种方法来解决梯度爆炸问题:

  • 降低学习率:使用更小的学习率可以减缓参数更新的速度,从而减少梯度爆炸的风险。

  • 使用梯度裁剪(Gradient Clipping):在梯度更新之前,对梯度值进行裁剪,确保梯度的最大绝对值不超过某个阈值。

  • 更换激活函数:尝试使用ReLU及其变体(如Leaky ReLU、PReLU等),这些激活函数在大多数情况下能更好地缓解梯度消失或爆炸的问题。

  • 简化网络结构:如果可能的话,尝试简化网络结构,减少网络深度或宽度,以减少梯度累积放大的可能性。

  • 更换优化器:某些优化器(如Adam、RMSprop等)自带了一定的梯度缩放机制,可能更适合你的任务。

  • 正则化:通过添加L1或L2正则化项来限制参数的更新幅度,这有助于保持参数的稳定性。

  • 数据预处理:确保输入数据已经过适当的预处理(如归一化、标准化等),以避免由于数据规模不当导致的梯度问题。  

实际上也就是发生了过度修正的现象,就是每次修正的步长太大了,导致参数接收到的更新太大了,参数开始来回波动,每一次的修正过度,就会导致下一次的修正也发生过度的现象,优化不稳定,反而是其变得发散而不是收敛,从损失值快速变大就可以发现,问题显而易见,params -= learning_rate * params.grad这行代码所导致的,那我们如何优化来限制它的大小呢,调整学习率和梯度

调整学习率

我们可以调整一下学习率,较小的学习率会使得损失值慢慢减小,我们修改一下学习率继续实验

training_loop(
    n_epochs = 10000,
    learning_rate = 1e-4,
    params = torch.tensor([1.0, 0.0], requires_grad=True), # <1>
    t_u = t_un, # <2>
    t_c = t_c)

不难发现,和上一次的结果截然相反,这个损失值慢慢减小,直到训练轮数的不断提升,参数和损失值都趋于稳定,由于学习率降低了,从原理出发,我们需要提高训练的轮数来提高结果的稳定,我这里训练轮数改为了10w轮,看看结果吧

可以看到损失值变化非常小,如果加大轮数很容易得到一个最近似的参数值

我们除了调整学习率之外,还可以进行梯度的调整,我们采用的方法叫做归一化输入

归一化输入

从第一次运行完代码的图片中可以发现,权重w的梯度与偏置b的梯度相差大概50倍,这很好地说明权重与偏置的存在于不同的比例空间中,在这中情况下如果学习率的值偏大,很容易导致只能有效更新其中一个参数,对于另外一个参数而言,学习率就会变得不稳定,无法对其进行有效的更新,每个参数都有自己的学习率,他们彼此是独立的,除非参数之间相差不大,但是为了得到准确的结果,还是建议使用适合参数自身的学习率,可以有效的保证模型训练的精度,这里的话小编就不介绍复杂的归一化算法,我们就使用最简单的归一化输入算法。

那我们改变输入,这样梯度就不会有太大的不同,也就是对输入的值进行缩放,仅仅用作计算,最后输出把倍率放回去就行,我们采用对x进行缩小成原来的0.1倍

t_un = 0.1 * t_u

我自己设置的是训练1k轮,看看效果吧

上面10w轮损失值才到3.7,而我归一化处理以后1k就可以到达3.8,可见效率大大提升了。正常的网络当中的参数都是百万级别的,我们仅仅用了最简单的两个办法就可以大大加快了训练的效率,在真实的训练当中我们需要花费大量的时间在进行网络的训练,有两许多优化器以及许多优化算法,可以大大减少了训练的时间,这在深度学习中非常重要。

最后我们来看看计算出来的参数的可视化效果

可视化

params =training_loop(
    n_epochs = 10000,
    learning_rate = 1e-2,
    params = torch.tensor([1.0, 0.0], requires_grad=True),
    t_u = t_un,
    t_c = t_c)

from matplotlib import pyplot as plt

t_p = model(t_un, *params)

fig = plt.figure(dpi=100)
plt.xlabel("Temperature (°Fahrenheit)")
plt.ylabel("Temperature (°Celsius)")
plt.plot(t_u.numpy(), t_p.detach().numpy()) # <2>
plt.plot(t_u.numpy(), t_c.numpy(), 'o')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2068941.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TOMCAT-企业级WEB应用服务器

一 WEB技术 1.1 HTTP协议和B/S 结构 HTTP&#xff08;HyperText Transfer Protocol&#xff09;协议即超文本传输协议&#xff0c;是用于在万维网&#xff08;WWW&#xff09;上传输超文本内容的基础协议。 一、HTTP 协议的特点 1、简单快速 客户向服务器请求服务时&#…

八股(3)——计网

八股&#xff08;3&#xff09;——计网 3. 计算机基础3.1 计算机网络OSI 七层模型是什么&#xff1f;每一层的作用是什么&#xff1f;TCP/IP 四层模型是什么&#xff1f;每一层的作用是什么&#xff1f;1. 应用层&#xff08;Application layer&#xff09;2. 传输层&#xff…

【iOS安全】iPhone8 iOS14.4.2 越狱教程

环境配置 iPhone 8&#xff1a; 固件版本 iOS 14.4.2 (18D70) 产品类型 iPhone10,1 (A1906) 销售型号 MQ862J/A MacBook Pro&#xff1a; macOS 10.15.7 装有CheckRa1n beta 0.12.4 概述 尝试了几个版本的unc0ver和Taurine&#xff0c;发现都不好使 unc0ver显示unsupported…

如何实现一棵AVL树

目录 1.什么是AVL树&#xff1f; 2.AVL树的实现 2.1AVL树结点的定义 2.2AVL树的插入 2.2.1插入的步骤 2.2.2插入情况分析 2.2.3旋转操作的分析 2.3AVL树的查找 3.AVL树的验证 4.AVL树的性能分析 1.什么是AVL树&#xff1f; AVL树其实就是一棵加了限制条件的二叉搜索树…

day38.动态规划+MySql数据库复习

844.比较含退格的字符串 给定 s 和 t 两个字符串&#xff0c;当它们分别被输入到空白的文本编辑器后&#xff0c;如果两者相等&#xff0c;返回 true 。# 代表退格字符。 注意&#xff1a;如果对空文本输入退格字符&#xff0c;文本继续为空 思路:定义两个栈&#xff0c;将字符…

集合及数据结构第九节————树和二叉树

系列文章目录 集合及数据结构第九节————树和二叉树 树和二叉树 树型结构的概念树的概念树的表示形式&#xff08;了解&#xff09;树的应用二叉树的概念两种特殊的二叉树二叉树的性质二叉树的性质练习二叉树的存储二叉树的遍历二叉树的基本操作二叉树相关练习题 文章目录…

flutter 中 ssl 双向证书校验

SSL 证书&#xff1a; 在处理 https 请求的时候&#xff0c;通常可以使用 中间人攻击的方式 获取 https 请求以及响应参数。应为通常我们是 SSL 单向认证&#xff0c;服务器并没有验证我们的客户端的证书。为了防止这种中间人攻击的情况。我么可以通过 ssl 双向认证的方式。即…

Leetcode JAVA刷刷站(91)解码方法

一、题目概述 二、思路方向 这个问题是一个典型的动态规划问题&#xff0c;其中我们可以使用一个数组来存储到达每个位置时的解码方法的总数。 我们定义一个数组 dp&#xff0c;其中 dp[i] 表示字符串 s 的前 i 个字符&#xff08;从索引 0 到 i-1&#xff09;的解码方法总数。…

企业数字化转型管控平台探索 ---基于流程的企业经络管理框架DEM

篇幅有限&#xff0c;获取完整内容、更多感兴趣的内容 见下图

OpenCV几何图像变换(8)调整图像大小的函数resize()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 resize 函数调整图像 src 的大小&#xff0c;使其缩小或放大至指定的大小。需要注意的是&#xff0c;初始的 dst 类型或大小不被考虑。相反&…

解决Jasper Studio报表工具中预览正常显示,但部署到服务器上面无法正常显示的问题

目录 1.1、错误描述 1.2、解决方案 1.1、错误描述 之前有遇到过一个Jasper Studio报表开发相关的问题&#xff0c;这里记录一下&#xff0c;方便其他小伙伴可以快速解决问题。问题是这样的&#xff1a;当我在Jasper Studio报表工具里面设计好样式之后&#xff0c;预览报表发…

[论文阅读] mobile aloha实验部分

DP:[1] CHI C, FENG S, DU Y, et al. Diffusion Policy: Visuomotor Policy Learning via Action Diffusion[J]. 2023. Diffusion Policy: Visuomotor Policy Learning via Action Diffusion精读笔记&#xff08;一&#xff09;-CSDN博客 VINN:[1] PARI J, SHAFIULLAH N, ARU…

视频达人的秘密武器:全能型剪辑软件深度剖析

剪辑视频&#xff0c;作为视频创作过程中的关键环节&#xff0c;其重要性不言而喻。无论是专业影视制作团队&#xff0c;还是热衷于Vlog创作的个人&#xff0c;都离不开一款强大且易用的视频剪辑工具。今天&#xff0c;就让我们一起踏上一场探索之旅&#xff0c;对市面上的视频…

java基础 之 关键字static

文章目录 前言1、特征2、修饰变量3、修饰方法4、修饰代码块优缺点应用场景代码理解 前言 本文主要是从类与对象的方向来讲&#xff0c;所以在文章开始前&#xff0c;我们先理解一下类和对象 类是一个模板&#xff0c;对象是一个实例。 如【手机】是一个类&#xff08;一个模板…

MySQL系统性的学习--基础

学习资料是黑马的mysql课程 Mysql概述 相关概念 数据模型 关系型数据库 数据模型 SQL SQL通用语法 SQL分类 DDL 数据库操作 表操作 查询 创建 数据类型 修改/删除 DML 添加数据INSERT 修改数据UPDATE 删除数据DELETE DQL 基础查询 条件查询 聚合函数 分组查询 排序查询 分…

Otterctf 2018 内存取证 (复现)

题目地址: https://otterctf.com/challenges 1 - What the password? 描述:you got a sample of ricks PCs memory. can you get his user password? 首先查看一下镜像的信息 python2 vol.py -f /home/kali/Desktop/OtterCTF.vmem imageinfo 题目描述需要获取密码, 使用mi…

el-form中使用v-model和prop实现动态校验

如何在Vue的el-form中使用v-model和prop实现动态校验&#xff0c;包括多个变量控制校验、数组循环校验和字段级条件显示。通过实例演示了如何配合rules和自定义验证函数来确保表单的完整性和有效性。 公式&#xff1a; 动态校验项的v-model的绑定值 el-form的属性 :model的值 …

PCSE不同播种时间的对比

目录 简介对比图源代码简介 设置为2022年10月15日播种,然后每隔5天往后播种一次,然后探究播种时间对于作物各个长势的影响 对比图 源代码 import sys, os import matplotlib from matplotlib import style matplotlib.style.use("seaborn-whitegrid") import ma…

ST 表算法

ST 表 ST 表&#xff0c;主要思想是空间换时间&#xff0c;用于解决可重复贡献问题和 RMQ 问题。 可重复贡献问题 指某个运算 o p op op&#xff0c;有 x o p x x x\ op\ x\ \ x x op x x 。例如 m a x ( x , x ) x m i n ( x , x ) x g c d ( x , x ) x max(x,x)x\…

Linux基础环境开发工具gcc/g++ make/Makefile git

1.Linux编译器-gcc/g使用 1. 预处理&#xff08;进行宏替换) 预处理功能主要包括宏定义,文件包含,条件编译,去注释等。 预处理指令是以#号开头的代码行。 实例: gcc –E hello.c –o hello.i 选项“-E”,该选项的作用是让 gcc 在预处理结束后停止编译过程。 选项“-o”是指目标…