权重衰减(Weight Decay)

news2024/11/15 7:52:47

       在深度学习中,权重衰减(Weight Decay)是一种常用的正则化技术,旨在减少模型的过拟合现象。权重衰减通过向损失函数添加一个正则化项,以惩罚模型中较大的权重值。

一、权重衰减

       在深度学习中,模型的训练过程通常使用梯度下降法(或其变种)来最小化损失函数。梯度下降法的目标是找到损失函数的局部最小值,使得模型的预测能力最好。然而,当模型的参数(即权重)过多或过大时,容易导致过拟合问题,即模型在训练集上表现很好,但在测试集上表现较差。

       权重衰减通过在损失函数中引入正则化项来解决过拟合问题。正则化项通常使用L1范数或L2范数来度量模型的复杂度。L2范数正则化(也称为权重衰减)是指将模型的权重的平方和添加到损失函数中,乘以一个较小的正则化参数$ \lambda $这个额外的项迫使模型学习到较小的权重值,从而减少模型的复杂度。

       具体而言,对于一个深度学习模型的损失函数$L(w, b)$,其中$w,b$表示模型的参数(权重和偏置),权重衰减可以通过以下方式实现:

$ L'\left( w,b \right) =L\left( w,b \right) +\lambda \cdot \lVert w \rVert ^2 $

       其中,$ L'\left( w,b \right) $是添加了权重衰减的损失函数,$ \lVert w \rVert ^2 $表示参数的L2范数的平方和,$ \lambda $是正则化参数,用于控制正则化项的重要性。

       在训练过程中,梯度下降法将同时更新损失函数和权重。当计算梯度时,权重衰衰减的正则化项将被添加到梯度中,从而导致权重更新的幅度减小。这使得模型的权重趋向于减小,避免过拟合现象。

       需要注意的是,正则化参数$ \lambda $的选择对模型的性能有重要影响。较小的$ \lambda $值会导致较强的正则化效果,可能会使模型欠拟合。而较大的$ \lambda $值可能会减少正则化效果,使模型过拟合。因此,选择合适的正则化参数是权衡模型复杂度和泛化能力的关键。

       偏置(biases)在神经网络中起到平移激活函数的作用,通常不会像权重那样导致过度拟合。偏置的主要作用是调整激活函数的位置,使其更好地对应所需的输出。由于偏置的影响较小,因此将权重衰减应用于偏置通常不是常见的做法。

二、权重衰减数学解释

       L2范数正则化在解决过拟合问题方面具有一定的效果,这是因为它在损失函数中引入了权重的平方和作为正则化项。下面我将解释一下L2范数正则化的数学原理。

       在深度学习中,我们的目标是最小化损失函数,该函数包括两部分:经验误差和正则化项。对于L2范数正则化,我们将正则化项定义为权重的平方和的乘以一个正则化参数$ \lambda $

       针对损失函数$ L'\left( w,b \right)$,我们使用梯度下降法来最小化这个损失函数。在梯度下降的每一步中,我们计算损失函数的梯度,然后更新权重。对于L2范数正则化,梯度的计算中包含了正则化项的贡献。

       具体来说,我们计算损失函数对权重w的梯度,记为$ \nabla L\left( w,b \right) $。那么加入L2范数正则化后的梯度可以写为:

$ \nabla L'\left( w,b \right) =\nabla L\left( w,b \right) +2\lambda w $

       这里,$ 2\lambda w $是正则化项的梯度贡献,其中$ 2\lambda $是正则化参数$ \lambda $的倍数,$w$是权重的梯度。

       当我们使用梯度下降法更新权重时,梯度的负方向指示了损失函数下降的方向。由于L2范数正则化项的存在,权重的梯度会受到惩罚,从而导致权重的更新幅度减小。

       这种减小权重更新幅度的效果使得模型倾向于学习到较小的权重值,从而降低了模型的复杂度。通过减小权重的幅度,L2范数正则化可以有效地控制模型的过拟合,提高模型的泛化能力。

       总结起来,L2范数正则化通过引入权重的平方和作为正则化项,在梯度计算和权重更新中对权重进行惩罚,从而减小了模型的复杂度,防止过拟合现象的发生。

也可以参考李沐老师的课件:

三、代码从零开始实现

import torch
from torch import nn
from d2l import torch as d2l

1、生成数据

       首先,我们像以前一样生成一些数据,生成公式如下:

$y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2).$

       我们选择标签是关于输入的线性函数。标签同时被均值为0,标准差为0.01高斯噪声破坏。为了使过拟合的效果更加明显,我们可以将问题的维数增加到$d = 200$(w的长度为200),并使用一个只包含20个样本的小训练集。

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5   # 训练集长度为20、验证机长度为100、权重参数有200个、批量大小为5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05   # 真实的权重和偏置
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

2、初始化模型参数

       我们将定义一个函数来随机初始化模型参数。

def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]

3、定义L2范数惩罚

       实现这一惩罚最方便的方法是对所有项求平方后并将它们求和。

def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2

4、定义训练代码实现

       下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。和之前线性回归一样,线性网络和平方损失没有变化,所以我们通过`d2l.linreg`和`d2l.squared_loss`导入它们。唯一的变化是损失现在包括了惩罚项。

def train(lambd):
    w, b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())

5、忽略正则化直接训练

       我们现在用`lambd = 0`禁用权重衰减后运行这个代码。注意,这里训练误差有了减少,但测试误差没有减少,这意味着出现了严重的过拟合。

train(lambd=0)
w的L2范数是: 12.963241577148438

 

6、使用权重衰减

       下面,我们使用权重衰减来运行代码。注意,在这里训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。

train(lambd=3)
w的L2范数是: 0.3556520938873291

 

四、简洁实现

       由于权重衰减在神经网络优化中很常用,深度学习框架为了便于我们使用权重衰减,将权重衰减集成到优化算法中,以便与任何损失函数结合使用。此外,这种集成还有计算上的好处,允许在不增加任何额外的计算开销的情况下向算法中添加权重衰减。由于更新的权重衰减部分仅依赖于每个参数的当前值,因此优化器必须至少接触每个参数一次。

1、定义训练代码实现

       在下面的代码中,我们在实例化优化器时直接通过`weight_decay`指定weight decay超参数。默认情况下,PyTorch同时衰减权重和偏移。这里我们只为权重设置了`weight_decay`,所以偏置参数$b$不会衰减。

def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd}, {"params":net[0].bias}],
                              lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())

2、忽略正则化直接训练

train_concise(0)
w的L2范数: 13.727912902832031

3、使用权重衰减

train_concise(3)
w的L2范数: 0.3890590965747833

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1316025.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flutter在Visual Studio Code上首次创建运行应用

一、创建Flutter应用 1、前提条件 安装Visual Studio Code并配置好运行环境 2、开始创建Flutter应用 1)、打开Visual Studio Code 2)、打开 View > Command Palette。 3)、在搜索框中输入“flutter”,弹出内容如下图所示,选择“ Flutter: New Pr…

SoloLinker第一次使用记录,解决新手拿到板子的无所适从

本文目录 一、简介二、进群获取资料2.1 需要下载资料2.2 SDK 包解压 三、SDK 编译3.1 依赖安装3.2 编译配置3.3 启动编译3.4 编译后的固件目录 四、固件烧录4.1 RV1106 驱动安装4.2 打开烧录工具4.3 进入boot 模式(烧录模式)4.4 烧录启动固件4.5 烧录升级…

浏览器录屏技术探究与实践

一、引言 随着网络技术的快速发展,浏览器已经成为人们获取信息的主要途径。浏览器录屏技术作为一种新兴的媒体捕捉和分享方式,逐渐受到广泛关注。本文将对浏览器录屏技术进行深入探讨,分析其实现原理,并给出实际应用中的解决方案…

MC-30A (32.768 kHz用于汽车应用的晶体单元)

MC-30A 32.768 kHz用于汽车应用的晶体,车规晶振中的热销型号之一。该款石英晶体谐振器,可以在-40 to 85 C的温度内稳定工作,能满足起动振动的要求。同时满足AEC-Q200无源元件质量标准认证,满足汽车仪表系统的所有要求。 频率范围…

内网穿透工具,如何保障安全远程访问?

内网穿透工具是一种常见的技术手段,用于在没有公网IP的情况下将本地局域网服务映射至外网。这种工具的使用极大地方便了开发人员和网络管理员,使得他们能够快速建立起本地服务与外部网络之间的通信渠道。然而,在享受高效快捷的同时&#xff0…

win10电脑字体大小怎么设置?介绍四种方法

在Win10操作系统中,字体大小的设置对于用户来说是一个非常重要的问题。合适的字体大小能够保护我们的视力,提高我们的工作效率。本文将介绍几种常用的方法来调整Win10电脑的字体大小,帮助用户轻松设置自己喜欢的字体大小。 方法一&#xff1…

安装鸿蒙开发者工具DevEco Studio

1.进入官网下载工具 https://developer.harmonyos.com/cn/develop/deveco-studio/ 选择您电脑对应的系统下载即可 2.安装 很简单直接点击“next”,此处不做赘述 3.配置环境 安装完成后,打开DevEco Studio 会提示配置环境。安装node.js和ohpm 如果不小心关了&a…

linux性能优化-上下文切换

如何理解上下文切换 Linux 是一个多任务操作系统,它支持远大于 CPU 数量的任务同时运行,这是通过频繁的上下文切换、将CPU轮流分配给不同任务从而实现的。 CPU 上下文切换,就是先把前一个任务的 CPU 上下文(CPU 寄存器和程序计数…

NO-IOT翻频,什么是翻频,电信为什么翻频

1.1 翻频迁移最终的目的就是减少网络的相互干扰,提供使用质量. 1.2 随着与日俱增的网络规模的扩大,网内干扰已成了影响网络的质量标准之一,为了保障电信上网体验,满足用户日益增长的网速需求,更好的服务客户,电信针对…

Git中stash的使用

Git中stash的使用 stash命令1. stash保存当前修改2. 重新使用缓存3. 查看stash3. 删除 使用场景 stash命令 1. stash保存当前修改 git stash 会把所有未提交的修改(包括暂存的和非暂存的)都保存起来. git stashgit stash save 注释2. 重新使用缓存 #…

Python 直观理解基尼系数

基尼系数最开始就是衡量人群财富收入是否均衡,大家收入平平,那就是很平均,如果大家收入不平等,那基尼系数就很高。 还是给老干部们讲的言简意赅。 什么是基尼系数 我们接下来直接直观地看吧,程序说话 # -*- coding:…

【AI】YOLO学习笔记

作为经典的图像识别网络模型,学习YOLO的过程也是了解图像识别的发展过程,对于初学者来说,也可以了解所采用算法的来龙去脉,构建解决问题的思路。 1.YOLO V1 论文地址:https://arxiv.org/abs/1506.02640 YOLO&#x…

TSINGSEE视频智能解决方案边缘AI智能与后端智能分析的区别与应用

视频监控与AI人工智能的结合是当今社会安全领域的重要发展趋势。随着科技的不断进步,视频监控系统已经不再局限于简单的录像和监视功能,而是开始融入人工智能技术,实现更加智能化的监控和安全管理。传统的监控系统往往需要人工操作来进行监控…

在滴滴和网易划水4年,过于真实了...

先简单交代一下吧,猫哥是某不知名985的本硕,19年毕业加入滴滴,之后跳槽到了网易,一直从事测试开发相关的工作。之前没有实习经历,算是四年半的工作经验吧。 这四年半之间他完成了一次晋升,换了一家公司&am…

Bootstrap 响应式实用工具-来自Twitter,目前最受欢迎的前端框架

Bootstrap 提供了一些辅助类,以便更快地实现对移动设备友好的开发。这些可以通过媒体查询结合大型、小型和中型设备,实现内容对设备的显示和隐藏。 需要谨慎使用这些工具,避免在同一个站点创建完全不同的版本。响应式实用工具目前只适用于块和表切换。 超小屏幕 手机 (<…

Linux 常用的操作命令

我们习惯的使用Windows,安装软件进行使用&#xff0c;比如 WPS&#xff0c;浏览器&#xff0c;一些工具&#xff0c;但是在Linux上就需要用命令去操作&#xff0c;也可以使用像Ubuntu 和 CentOS这类的可视化面板 Linux系统是开源的&#xff0c;所以开发人员可以反复的发现Bug以…

HTTP代理服务器脚本录制

1、报错1 target controller is configured to “use recording Controller“ but no such controller exists,ensure_target controller is configured to "use recording -CSDN博客

《打造第二大脑》—如何构建高效的笔记系统

最近看了一本书&#xff0c;因为我也用Obsidian来记笔记&#xff0c;&#xff08;Obsidian之前有介绍过Obsidian使用教程&#xff08;如何构建你的个人知识库&#xff0c;第二大脑&#xff09;&#xff09;看完这本书后发现里面给的方法跟Obsidian很契合&#xff0c;所以就整理…

RabbitMQ消息顺序性保障

RabbitMQ 没有属性设置消息的顺序性&#xff0c;只能设置消息的优先级&#xff0c;因此消息顺序性保障只能在 consumer 上实现 场景分析&#xff1a; 生产者向 RabbitMQ 里发送了三条数据&#xff0c; 顺序依次是 data1-> data2 -> data3&#xff0c;压入的是一个内存…