[动手学习深度学习]12.权重衰退

news2025/3/13 2:57:55

1.介绍

权重衰退是常见的处理过拟合的方法

  • 控制模型容量方法
    1. 把模型控制的比较小,即里面参数比较少
    2. 使参数选择范围小
  • 约束就是正则项
    每个特征的权重都大会导致模型复杂,从而导致过拟合。
    控制权重矩阵范数可以使得减少一些特征的权重,甚至使他们权重为0,从而导致模型简单,减轻过拟合

使用均方范数作为硬性限制

权重衰退即是通过控制参数选择范围来控制模型容量的

  • 公式表达:
    m i n   l ( w , b )    s u b j e c t   t o ∣ ∣ w ∣ ∣ 2 ≤ θ min\ l(w,b)\ \ subject\ to ||w||^2 ≤ \theta min l(w,b)  subject to∣∣w2θ
    l l l:损失函数
    w w w:参数
    b b b:偏移
    在最小化损失函数时加上限制,使参数的平方和小于一个特定的值,也就说明每个参数的值要小于 θ \theta θ开根
    通常不限制偏移b
    小的 θ \theta θ意味着更强的正则项

使用均方范数作为柔性限制

  • Df:对每个 θ \theta θ,都可以找到 λ \lambda λ使得之前的目标函数等价于下面:
    m i n   l ( w , b ) + λ 2 ∣ ∣ w ∣ ∣ 2 min \ l(w,b)+\frac{\lambda}{2} || w||^2 min l(w,b)+2λ∣∣w2
    (可以通过拉格朗日乘子来证明)
  • 超参数 λ \lambda λ控制了正则项的重要程度
    • λ = 0 \lambda=0 λ=0:无作用(当 λ = 0 \lambda=0 λ=0时,即没有后面的限制,相当于上一个公式里 θ = ∞ \theta=\infty θ=
    • λ → ∞ , w ∗ → 0 \lambda \rightarrow \infty, w^* \rightarrow0 λ,w0:相当于上面 θ → 0 \theta \rightarrow0 θ0,也就使 w ∗ → 0 w^* \rightarrow0 w0

想通过控制模型参数使模型不要太复杂时,可以通过增加 λ \lambda λ来满足需求(这里 λ \lambda λ是一个平滑的,不像以前的硬性限制)

请添加图片描述

  • 这里可以理解拉格朗日乘子法:
    • 拉格朗日乘子法原本是用于解决约束条件下的多元函数极值问题。举例,求f(x,y)的最小值,但是有约束C(x,y) = 0。乘子法给的一般思路是,构造一个新的函数g(x,y,λ) = f(x,y) +λC(x,y),当同时满足g’x = g’y = 0时,函数取到最小值。这件结论的几何含义是,当f(x,y)与C(x,y)的等高线相切时,取到最小值。
    • 具体到机器学习这里, C ( x , y ) = w 2 − θ C(x,y) = w^2 -θ C(x,y)=w2θ。所以视频中的黄色圆圈,代表不同θ下的约束条件。θ越小,则最终的parameter离原点越近。
  • 绿色的线就是原始损失函数l的等高线,优化损失函数l的最优解(波浪号即最优解)在中心位置
  • 当原始损失加入 λ 2 \frac{\lambda}{2} 2λ项之后,这个项是一个二次项,假如w就两个值,x1(横轴)x2(纵轴),则在图上这个二次项的损失以原点为中心的等高线为橙色的图所示。所以合并后的损失为绿色和黄色的线加一起的损失
  • 当加上损失项后,可知原来最优解对应的二次项的损失特别大,因此原来的最优解不是加上二次项后的公式的最优解了。若沿着橙色的方向走,原有l损失值会大一些,但是二次项罚的损失会变小,当拉到平衡点以内时,惩罚项减少的值不足以原有l损失增大的值,这样w*就是惩罚项后的最优解
  • 损失函数加上正则项成为目标函数,目标函数最优解不是损失函数最优解。
    正则项就是防止达到损失函数最优导致过拟合,把损失函数最优点往外拉一拉。 鼓励权重分散,将所有特征运用起来,而不是依赖其中的少数特征,并且权重分散的话他的内积就小一点
  • l2正则项会对大叔之的权值进行惩罚

回顾平方损失:
请添加图片描述
相对原来的权重更新,再减去一个值后,使得这个权重更进一步减小,这样会导致这个权重所占的比例进一步减小请添加图片描述

参数更新法则

在这里插入图片描述
请添加图片描述

2. 代码实现(手动实现)

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

像以前一样生成一些人工数据:
在这里插入图片描述

n_train, n_test, num_inputs, batch_size = 20,100,200,5
# 数据越简单,模型越复杂,越容易过拟合。
# num_inputs:特征维度

true_w, true_b = torch.ones((num_inputs, 1))*0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train) # 生成人工数据集
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

# 初始化模型参数
def init_params():
    w=torch.normal(0,1,size=(num_inputs,1), requires_grad=True)
        # 均值为0,方差为1,长度时num_inputs*1的向量,需要梯度
    b=torch.zeros(1,requires_grad=True)
        # b:为全0的标量
    return [w,b]

# 定义L2范数惩罚项(核心)
def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2
    # 注意不要把lambda写进去,因为要写在外面

def train(lambd):
    w, b = init_params() # 初始化模型参数
    net, loss = lambda X:d2l.linreg(X,w,b), d2l.squared_loss
    # net做了个很简单的线性回归
    # 损失函数用平方损失
    num_epochs, lr = 100, 0.003 # 因为数据量很小,所以可以多训练几次
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log', xlim=[5,num_epochs], legend=['train', 'test']) # 实现动画效果
    
    # 标准训练过程
    for epoch in range(num_epochs):
        for X,y in train_iter:
            # with torch.enable_grad():
            l = loss(net(X), y) + lambd*l2_penalty(w)  # L2范数惩罚项
            l.sum().backward()
            d2l.sgd([w,b], lr, batch_size) # 使用小批量随机梯度下降迭代模型参数
        if (epoch+1)%5==0:
            animator.add(epoch+1, 
                         (d2l.evaluate_loss(net, train_iter, loss), 
                          d2l.evaluate_loss(net, test_iter, loss))
            )
    print('w的L2范数是:', torch.norm(w).item())

在这里插入图片描述在这里插入图片描述

3.简单实现(使用框架)

def train_concise(wd):
    net=nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss()
    num_epoch, lr = 100,0.003
    trainer = torch.optim.SGD(
        [{"params":net[0].weight,"weight_decay":wd},{'params':net[0].bias}], 
        lr=lr)
    # 惩罚项既可以写在目标函数里,也可以写在训练算法里,每一次更新之前把当前的w乘以衰退因子weight_decay

    animator=d2l.Animator(xlabel='epochs',ylabel='loss',yscale='log',xlim=[5, num_epoch],legend=['train','test'])
    for epoch in range(num_epoch):
        for X,y in train_iter:
            with torch.enable_grad():
                trainer.zero_grad()
                l = loss(net(X), y)
            l.backward()
            trainer.step()
            if (epoch+1) % 5 == 0:
                animator.add(epoch+1, (d2l.evaluate_loss(net, train_iter, loss), d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', net[0].weight.norm().item())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2314050.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaEE_多线程(二)

目录 1. 线程的状态2. 线程安全2.1 线程不安全问题的原因 3. 线程安全中的部分概念3.1 原子性3.2 可见性3.3 指令重排序 4. 解决线程安全问题4.1 synchronized关键字4.1.1 可重入4.1.2 synchronized使用 4.2 volatile关键字4.2.1 volatile使用 5. wait和notify5.1 wait()方法5.…

【unity小技巧】分享vscode如何进行unity开发,且如何开启unity断点调试模式,并进行unity断点调试(2025年最新的方法,实测有效)

文章目录 前言一、前置条件1、已安装Visual Studio Code,并且unity首选项>外部工具>外部脚本编辑器选择为Visual Studio Code [版本号],2、在Visual Studio Code扩展中搜索Unity,并安装3、同时注意这个插件下面的描述,需要根…

【Hadoop】详解HDFS

Hadoop 分布式文件系统(HDFS)被设计成适合运行在通用硬件上的分布式文件系统,它是一个高度容错性的系统,适合部署在廉价的机器上,能够提供高吞吐量的数据访问,非常适合大规模数据集上的应用。为了做到可靠性,HDFS创建了…

Spring(4)——响应相关

一、返回静态页面 1.1**RestController和Controller** 想返回如下页面: 如果我们依旧使用原来的**RestController** 可以看到的是仅仅返回了字符串。 此时将**RestController改为Controller** 可以看到这次返回的是html页面。 那么**RestController和Controller…

axure11安装教程包含下载、安装、汉化、授权(附安装包)图文详细教程

文章目录 前言一、axure11安装包下载二、axure11安装教程1.启动安装程序2.安装向导界面3.安装协议协议页面2.选择安装位置3.开始安装4.完成安装 三、axure11汉化教程1.axure11汉化包2.axure11汉化设置 四、axure11授权教程1.打开axure112.设置使用方式3.输入许可证号4.axure11安…

Redis-缓存穿透击穿雪崩

1. 穿透问题 缓存穿透问题就是查询不存在的数据。在缓存穿透中,先查缓存,缓存没有数据,就会请求到数据库上,导致数据库压力剧增。 解决方法: 给不存在的key加上空值,防止每次都会请求到数据库。布隆过滤器…

Windows server网络安全

摘要 安全策略 IP安全策略,简单的来说就是可以通过做相应的策略来达到放行、阻止相关的端口;放行、阻止相关的IP,如何做安全策略,小编为大家详细的写了相关的步骤: 解说步骤: 阻止所有: 打…

Python从入门到精通1:FastAPI

引言 在现代 Web 开发中,API 是前后端分离架构的核心。FastAPI 凭借其高性能、简洁的语法和自动文档生成功能,成为 Python 开发者的首选框架。本文将从零开始,详细讲解 FastAPI 的核心概念、安装配置、路由设计、请求处理以及实际应用案例&a…

Leetcode做题记录----2

1、两数之和 思路: 1、不能使用相同元素,可以想到哈希表,,C#中可以通过字典建立当前值和下标的关系 2、显然,依次判断数组中的每个数即可 3、定义other target - num[ i ] 这个other就是我们用于在字典中进行寻找…

批量合并 Word 文档,支持合并成一个 Word,也支持按文件夹合并

我们经常会碰到需要将多个 Word 文档批量合并成一个 Word 文档的场景,比如需要合并后打印、合并后方便整理存档等等。如果是人工的操作,会非常的麻烦。因此我们通常会借助一些批量处理脚本或者寻找批量处理的工具来帮我们实现批量合并 Word 文档的操作。…

项目实操分享:一个基于 Flask 的音乐生成系统,能够根据用户指定的参数自动生成 MIDI 音乐并转换为音频文件

在线体验音乐创作:AI Music Creator - AI Music Creator 体验者账号密码admin/admin123 系统架构 1.1 核心组件 MusicGenerator 类 负责音乐生成的核心逻辑 包含 MIDI 生成和音频转换功能 管理音乐参数和音轨生成 FluidSynth 集成 用于 MIDI 到音频的转换 …

神经网络为什么要用 ReLU 增加非线性?

在神经网络中使用 ReLU(Rectified Linear Unit) 作为激活函数的主要目的是引入非线性,这是神经网络能够学习复杂模式和解决非线性问题的关键。 1. 为什么需要非线性? 1.1 线性模型的局限性 如果神经网络只使用线性激活函数&…

动态规划详解(二):从暴力递归到动态规划的完整优化之路

目录 一、什么是动态规划?—— 从人类直觉到算法思维 二、暴力递归:最直观的问题分解方式 1. 示例:斐波那契数列 2. 递归树分析(以n5为例) 3. 问题暴露 三、第一次优化:记忆化搜索(Memoiza…

ubuntu下在pycharm中配置已有的虚拟环境

作者使用的ubuntu系统位于PC机上的虚拟机。系统版本为: 在配置pycharm解释器之前你需要先创建虚拟环境以及安装pycharm。 作者创建的虚拟环境位于/home/topeet/miniconda3/envs/airproject/,如下图所示: 作者安装的pycharm版本为2023社区…

爬虫中一些有用的用法

文本和标签在一个级别下 如果文本和a标签在一个级别下 比如: # 获取a标签后的第一个文本节点text_node a.xpath(following-sibling::text()[1])[0].strip() 将xpath的html代码转换成字符串 etree.tostring(root, pretty_printTrue, encoding"utf-8")…

DeepIn Wps 字体缺失问题

系统缺失字体 Symbol 、Wingdings 、Wingdings2、Wingdings3、MT—extra 字体问题 问了下DeepSeek 在应用商店安装或者在windows 里面找 装了一个GB-18030 还是不行 在windows里面复制了缺失的字体 将字体复制到DeepIn 的字体目录(Ubuntu 应该也是这个目录&am…

【webrtc debug tools】 rtc_event_log_to_text

一、rtc_event_log 简介 在学习分析webrtc的过程中,发现其内部提供了一个实时数据捕获接口RtcEventLog。通过该接口可以实时捕获进出webrtc的RTP报文头数据、音视频配置参数、webrtc的探测数据等。其内容实现可参考RtcEventLogImpl类的定义。其文件所在路径 loggin…

数字IC后端项目典型问题(2025.03.10数字后端项目问题记录)

小编发现今天广大学员发过来的问题都比较好,立即一顿输出分享给大家(每天都有好多种类的数字后端问题)。后续可能会经常通过这种方式来做分享。其实很多问题都是实际后端项目中经常遇到的典型问题。希望通过这种方式的分享能够帮助到更多需要…

Redis 持久化详解:RDB 与 AOF 的机制、配置与最佳实践

目录 引言 1. Redis 持久化概述 1.1 为什么需要持久化? 1.2 Redis 持久化的两种方式 2. RDB 持久化 2.1 RDB 的工作原理 RDB 的触发条件 2.2 RDB 的配置 2.3 RDB 的优缺点 优点 缺点 3. AOF 持久化 3.1 AOF 的工作原理 AOF 的触发条件 3.2 AOF 的配置…

说一下spring的事务隔离级别?

大家好,我是锋哥。今天分享关于【说一下spring的事务隔离级别?】面试题。希望对大家有帮助; 说一下spring的事务隔离级别? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Spring的事务隔离级别是指在数据库事务管理中…