权重衰减-笔记

news2025/4/4 21:39:55

《动手学深度学习》-4.5-笔记

权重衰减就像给模型“勒紧裤腰带”,不让它太贪心、不让它学太多。

你在学英语单词,别背太多冷门单词,只背常见的就行,这样考试时更容易拿分。”

—— 这其实就是在“限制你学的内容复杂度”。

在训练一个模型的时候,也要防止它“死记硬背”训练集的数据。
权重衰减就是给模型一个规则: “你可以学,但不能太激进,尽量简单点!”

因为模型在训练的时候可能会过拟合,就是它把训练数据背得滚瓜烂熟,但换一套题就不会了。
权重衰减就像提醒它:“你不能只靠死记硬背,要理解和归纳。”

权重衰减是一种正则化方法,用于防止模型过拟合。它的核心思想是:在训练模型时,对模型的权重(参数)进行约束,让权重的值不要变得太大。具体来说,就是在损失函数中加入一个额外的惩罚项,这个惩罚项与权重的大小有关。

将原来的训练目标(最小化训练标签上的预测损失)调整为:最小化预测损失 + 惩罚项

即:原本模型只关心“预测对不对”,现在还要关心“学得是不是太复杂”。

那什么是“权重”呢?

“权重”就是模型中学到的参数,相当于它的“记忆”或“经验”。

  • 没有衰减:模型会很用力去记住训练数据,容易“死记硬背”。

  • 有衰减:模型会被“约束”,学得更稳,考试(预测新数据)更容易拿分。

  • 权重衰减 = 不让模型的“记忆力”太强,逼它学得简单一点,防止它考试(泛化)翻车。
     

    权重(weight)模型的记忆
    衰减(decay)给它一点压力,让它别太极端
    权重衰减限制模型学得太多,防止过拟合

假设原本的损失函数是:Loss=预测误差

加了权重衰减以后变成:

  • w:模型的权重

  • :所有权重平方加起来(L2 范数)

  • λ:惩罚项的强度(可以调节)

# 定义训练集和测试集的大小,以及输入特征数量和批量大小
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
# n_train:训练样本数量为 20
# n_test:测试样本数量为 100
# num_inputs:每个样本有 200 个特征(就是输入维度)
# batch_size:每次训练时喂给模型 5 个样本(小批量训练)

# 生成“真实的”权重和偏置,用于生成数据时使用
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
# true_w:一个 shape 为 (200, 1) 的向量,每个元素是 0.01
# true_b:偏置是 0.05
# 这两个值是我们“人为设定”的,用来生成“假的数据”(方便我们验证模型能不能学出来)

# 生成训练数据(特征 + 标签),使用上面设定的权重和偏置
train_data = d2l.synthetic_data(true_w, true_b, n_train)
# synthetic_data:是 d2l 提供的函数,用来生成 y = Xw + b + 噪声 的假数据
# train_data 是一个二元组:(features, labels),大小是 (20, 200) 和 (20, 1)

# 把训练数据放进“迭代器”,每次取 batch_size 个样本
train_iter = d2l.load_array(train_data, batch_size)
# load_array:把数据打包成 PyTorch 可用的 DataLoader,支持按批次取数据

# 同样的方法,生成测试数据
test_data = d2l.synthetic_data(true_w, true_b, n_test)
# 测试数据有 100 个样本

# 把测试数据放进测试用的迭代器中(is_train=False 表示测试模式)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

 下面我们将从头开始实现权重衰减,只需将L2的平方惩罚添加到原始目标函数中。

首先,随机初始化模型参数,从正态分布中随机生成数据,模拟“随机初始化权重”。

def init_params():
    # 初始化权重 w,形状为 (特征数 num_inputs, 1)
    # 从均值为0,标准差为1的正态分布中随机生成
    # requires_grad=True 表示这个参数需要计算梯度(用于训练)
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)

    # 初始化偏置 b,初始为0,且也需要计算梯度
    b = torch.zeros(1, requires_grad=True)

    # 返回这两个参数,供模型使用
    return [w, b]

实现这一惩罚最方便的方法是对所有项求平方后并将它们求和。(核心)

def l2_penalty(w):
    # 计算权重 w 的 L2 范数平方,并除以 2
    # 这是 L2 正则化中常用的形式
    return torch.sum(w.pow(2)) / 2

L2 正则化项的数学形式是:

def train(lambd):
    # 初始化模型参数 w 和 b(包含梯度信息)
    w, b = init_params()

    # 定义模型 net(就是线性回归函数)和损失函数(平方损失)
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss

    # 设置训练轮数和学习率
    num_epochs, lr = 100, 0.003

    # 动画器,用来实时画出训练/测试集的损失变化(log坐标轴)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])

    # 开始训练模型
    for epoch in range(num_epochs):
        for X, y in train_iter:  # 遍历一个个 batch
            # 计算损失:平方损失 + L2惩罚项(权重衰减)
            # l2_penalty(w) 是一个数,广播后加到每个样本的损失上
            l = loss(net(X), y) + lambd * l2_penalty(w)

            # 反向传播,计算梯度
            l.sum().backward()

            # 使用随机梯度下降法更新参数 w 和 b
            d2l.sgd([w, b], lr, batch_size)

        # 每训练5轮画一次图
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (
                d2l.evaluate_loss(net, train_iter, loss),  # 训练集损失
                d2l.evaluate_loss(net, test_iter, loss)   # 测试集损失
            ))

    # 最后输出训练好的权重向量的 L2 范数
    print('w 的 L2 范数是:', torch.norm(w).item())

 忽略正则化直接训练,“不使用正则化(也就是不加权重惩罚项),进行模型训练。”

def train(lambd):
    ...
    l = loss(...) + lambd * l2_penalty(w)
里面这个 lambd * l2_penalty(w) 是控制“正则化强度”的。

打印

正则化情况影响
lambd = 0没有权重惩罚,模型容易过拟合
lambd > 0会限制权重变大,提升泛化能力(防止过拟合)

简单实现:

def train_concise(wd):  # wd 就是 weight decay 权重衰减系数(λ)
    # 定义一个线性回归模型:输入是 num_inputs 维,输出是 1 维
    net = nn.Sequential(nn.Linear(num_inputs, 1))

    # 初始化模型参数为服从正态分布(均值0,标准差1)
    for param in net.parameters():
        param.data.normal_()

    # 定义损失函数为 MSE(均方误差),每个样本独立输出,不求平均
    loss = nn.MSELoss(reduction='none')

    # 设置训练轮数和学习率
    num_epochs, lr = 100, 0.003

    # 定义优化器:使用 SGD(随机梯度下降)
    # 注意:只有权重使用 weight_decay,偏置 bias 不加惩罚项
    trainer = torch.optim.SGD([
        {"params": net[0].weight, "weight_decay": wd},  # 对 weight 使用正则化
        {"params": net[0].bias}  # 对 bias 不使用正则化
    ], lr=lr)

    # 用于画图:训练过程中的 train/test 损失变化(对数坐标)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])

    # 开始训练
    for epoch in range(num_epochs):
        for X, y in train_iter:  # 遍历训练数据的小批量
            trainer.zero_grad()  # 梯度清零
            l = loss(net(X), y)  # 计算当前 batch 的损失
            l.mean().backward()  # 求平均后反向传播计算梯度
            trainer.step()  # 优化器更新参数

        # 每隔 5 个 epoch 可视化一次 train/test 的损失
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),  # 训练集损失
                          d2l.evaluate_loss(net, test_iter, loss)))  # 测试集损失

    # 打印训练后模型的权重大小(L2 范数)
    print('w 的 L2 范数:', net[0].weight.norm().item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2325543.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Hyperliquid 遇袭「拔网线」、Polymarket 遭治理攻击「不作为」,从双平台危机看去中心化治理的进化阵痛

作者:Techub 热点速递 撰文:Glendon,Techub News 继 3 月 12 日「Hyperliquid 50 倍杠杆巨鲸」引发的 Hyperliquid 清算事件之后,3 月 26 日 晚间,Hyperliquid 再次遭遇了一场针对其流动性和治理模式的「闪电狙击」。…

软考笔记6——结构化开发方法

第六章节——结构化开发方法 结构化开发方法 第六章节——结构化开发方法一、系统分析与设计概述1. 系统分析概述2. 系统设计的基本原理3. 系统总体结构设计 二、结构化分析方法1. 结构化分析方法概述2. 数据流图(DFD)3. 数据字典 三、结构化设计方法(了解&#xff…

一种C# Winform的UI处理

效果 圆角 阴影 突出按钮 说明 这是一种另类的处理,不是多层窗口 也不是WPF 。这种方式的特点是比较简单,例如圆角、阴影、按钮等特别容易修改过。其实就是html css DirectXForm。 在VS中如下 圆角和阴影 然后编辑这个窗体的Html模板&#xff0c…

为什么视频文件需要压缩?怎样压缩视频体积即小又清晰?

在日常生活中,无论是为了节省存储空间、便于分享还是提升上传速度,我们常常会遇到需要压缩视频的情况。本文将介绍为什么视频需要压缩,压缩视频的好处与坏处,并教你如何使用简鹿视频格式转换器轻松完成MP4视频文件的压缩。 为什么…

Nginx — Nginx处理Web请求机制解析

一、Nginx请求默认页面资源 1、配置文件详解 修改端口号为8080并重启服务: 二、Nginx进程模型 1、nginx常用命令解析 master进程:主进程(只有一个) worker进程:工作进程(可以有多个,默认只有一…

5.0 WPF的基础介绍1-Grid,Stack,button

WPF: Window Presentation Foundation. WPF与WinForms的对比如下: 特性WinFormsWPF技术基础基于传统的GDI(图形设备接口)基于DirectX,支持硬件加速的矢量渲染UI设计方式拖拽控件事件驱动代码(简单但局限)…

Docker 端口映射原理

在 Docker 中,默认情况下容器无法直接与外部网络通信。 为了使外部网络能够访问容器内的服务,Docker 提供了端口映射功能,通过将宿主机的端口映射到容器内的端口,外部可以通过宿主机的IP和端口访问容器内的服务 以下通过动手演示…

SDL —— 将sdl渲染画面嵌入Qt窗口显示(附:源码)

🔔 SDL/SDL2 相关技术、疑难杂症文章合集(掌握后可自封大侠 ⓿_⓿)(记得收藏,持续更新中…) 效果 使用QWidget加载了SDL的窗口,渲染器使用硬件加速跑GPU的。支持Qt窗口缩放或显示隐藏均不影响SDL的图像刷新。   操作步骤 1、在创建C++空工程时加入SDL,引入头文件时需…

算法每日一练 (23)

💢欢迎来到张翊尘的技术站 💥技术如江河,汇聚众志成。代码似星辰,照亮行征程。开源精神长,传承永不忘。携手共前行,未来更辉煌💥 文章目录 算法每日一练 (23)最大正方形题目描述解题思路解题代码…

UE5学习笔记 FPS游戏制作28 显式玩家子弹数

文章目录 添加变量修改ShootOnce方法,设计时减少子弹,没有子弹不能开枪在UI上显示 添加变量 在Gun类中添加BulletNum和ClipSize两个参数 BulletNum是当前还有多少子弹,ClipSize是一个弹匣多少子弹 Rifle的ClipSzie设置为30,Laun…

《构建有效的AI代理》学习笔记

原文链接:https://www.anthropic.com/engineering/building-effective-agents 《构建有效的AI代理》学习笔记 一、概述 核心结论 • 成功的AI代理系统往往基于简单、可组合的模式,而非复杂框架。 • 需在性能、成本与延迟之间权衡,仅在必要时增加复杂度…

数据处理专题(四)

目标 使用 Matplotlib 进行基本的数据可视化。‍ 学习内容 绘制折线图 绘制散点图 绘制柱状图‍ 代码示例 1. 导入必要的库 import matplotlib.pyplot as pltimport numpy as npimport pandas as pd 2. 创建示例数据集 # 创建示例数据集data { 月份: [1月, 2月, 3…

【目标检测】【深度学习】【Pytorch版本】YOLOV1模型算法详解

【目标检测】【深度学习】【Pytorch版本】YOLOV1模型算法详解 文章目录 【目标检测】【深度学习】【Pytorch版本】YOLOV1模型算法详解前言YOLOV1的模型结构YOLOV1模型的基本执行流程YOLOV1模型的网络参数YOLOV1模型的训练方式 YOLOV1的核心思想前向传播阶段网格单元(grid cell)…

云钥科技多通道工业相机解决方案设计

项目应用场景分析与需求挑战 1. 应用场景 ‌目标领域‌:工业自动化检测(如精密零件尺寸测量、表面缺陷检测)、3D立体视觉(如物体建模、位姿识别)、动态运动追踪(如高速生产线监控)等。 ‌核心…

从零到一:ESP32与豆包大模型的RTC连续对话实现指南

一、对话效果演示 ESP32与豆包大模型的RTC连续对话 二、ESP-ADF 介绍 乐鑫 ESP-ADF(Espressif Audio Development Framework)是乐鑫科技(Espressif Systems)专为 ESP32 系列芯片开发的一款音频开发框架。它旨在简化基于 ESP32 芯…

【深度学习与实战】2.3、线性回归模型与梯度下降法先导案例--最小二乘法(向量形式求解)

为了求解损失函数 对 的导数,并利用最小二乘法向量形式求解 的值‌ 这是‌线性回归‌的平方误差损失函数,目标是最小化预测值 与真实值 之间的差距。 ‌损失函数‌: 考虑多个样本的情况,损失函数为所有样本的平方误差之和&a…

【Django】教程-2-前端-目录结构介绍

【Django】教程-1-安装创建项目目录结构介绍 3. 前端文件配置 3.1 目录介绍 在app下创建static文件夹, 是根据setting中的配置来的 STATIC_URL ‘static/’ templates目录,编写HTML模板(含有模板语法,继承,{% static ‘xx’ …

详解list容器

1.list的介绍 list的底层结构是双向带头循环链表,允许随机的插入和删除,但其内存空间不是连续的。随机访问空间能力差,需要从头到尾遍历节点,不像vector一样高效支持 2.list的使用 构造函数 1.默认构造函数:创建一个…

leetcode_977. 有序数组的平方_java

977. 有序数组的平方https://leetcode.cn/problems/squares-of-a-sorted-array/ 1.题目 给你一个按 非递减顺序 排序的整数数组 nums,返回 每个数字的平方 组成的新数组,要求也按 非递减顺序 排序。 示例 1: 输入:nums [-4,-1…

网络探索之旅:网络原理(第二弹)

上篇文章,小编分享了应用层和传输层深入的一点的知识,那么接下来,这篇文章,继续分享网络层和数据链路层。 网络层 了解这个网络层,那么其实就是重点来了解下IP这个协议 对于这个协议呢,其实也是和前面的…