【机器学习】036_权重衰退

news2025/1/10 1:30:22

一、范数

· 定义:向量的范数表示一个向量有多大(分量的大小)

L1范数:

        · 即向量元素绝对值之和,用符号 ‖ v ‖ 1 表示。

        · 公式:\left \| x \right \|_1 = \sum_{n}^{i=1}|x_i|

L2范数:

        · 即向量的模,向量各元素绝对值的平方之和再开根号,用符号 ‖ v ‖ 2 表示。

        · 公式:\left \| x \right \|_2=\sqrt{\sum_{n}^{i=1}x_i^2}

Lp范数:

        · 即向量范数的一般形式,各元素绝对值的p次幂之和再开p次根号,用符号 ‖ v ‖ p 表示。

        · 公式:\left \| x \right \|_p = (\sqrt[p]{\sum_{n}^{i=1}|x|^p})

二、权重衰减(L2正则化)

模型(函数)复杂度的度量:

· 一般通过线性函数 f(x) = w^Tx 中的权重向量的某个范数(如 \left \| w \right \|^2)来度量其复杂度

要想避免模型的过拟合,就要控制模型容量,使模型的权重向量尽可能小

· 通过限制参数值的选择范围来控制模型容量

衰减方法:

借助损失函数,将权重范数作为惩罚项添加到最小化损失中;使得损失函数的作用变为“最小化预测损失和惩罚项之和”。

损失函数公式如下:

J(w,b)=L(w,b)+\frac{\lambda }{2}\left \| w \right \|^2

· 其中,L(w,b) 是模型原本的损失函数,\frac{\lambda }{2}\left \| w \right \|^2 是新添加的惩罚项。

· 正则化常数 \lambda 用来描绘这种权衡,其为一个非负超参数。

· \lambda 的值越大,表示对 w 的约束较大;反之 \lambda 的值越小,表示对 w 的约束较小。

※为何选用平方范数而不是标准范数:

        · 便于计算。平方范数可以去掉平方根使得导数更容易计算,利于反向传播过程。

        · 使用L2范数是因为它会对权重向量的大分量施加巨大的惩罚,使各权重均匀分布。

        · L1范数惩罚会导致权重集中在某一小部分特征上,其它权重被清除为0(特征选择)。

使用该损失函数,就可以使梯度下降的优化算法在训练的每一步都衰减权重,避免过拟合发生。

如上图所示,现在模型的损失函数同时受两项影响,一是误差项,二是惩罚项。

        现在在等高线图上,梯度下降最终收敛的位置不再是某一个项所造成的最低点,因为在这时,可能误差项达到最小了,但是惩罚项很大,使得惩罚项拉着损失函数再向另一个方向移动。

        只有当达到了两个项共同作用下的一个平衡点时,损失函数才具有最小值,这个时候的模型往往复杂度也降低了,虽然有可能造成训练损失增大,但是测试损失会减小。

三、代码实现权重衰减

从零实现代码如下:

import matplotlib
import torch
from torch import nn
from d2l import torch as d2l

# 训练数据集、测试数据集、输入值、训练批次
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
# 初始化w和b的真实值
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
# 拿到训练数据
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

# 初始化模型参数w和b
def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]
# 定义L2范数惩罚项
def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2
# 实现训练代码,读入参数为兰姆达(正则化参数)
def train(lambd):
    w, b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())
# 使用权重进行训练
train(lambd=3)

简洁实现代码如下:

import torch
from torch import nn
from d2l import torch as d2l

# 训练数据集、测试数据集、输入值、训练批次
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
# 初始化w和b的真实值
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
# 拿到训练数据
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())

    train_concise(3)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1232588.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

对一个Series序列内的元素逐个扩展同一聚合操作一个序列中共有m个元素,从指定的第n个元素开始,对前i元素进行聚合计算Series.expanding()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 一个序列中共有m个元素 从指定的第n个元素开始 对前i元素进行聚合计算 其中&#xff1a;n < i < m 聚合计算&#xff1a;求最大、平均值等 Series.expanding(n).max() Series.expanding(…

重磅,瑞士药监局 发布 EU GMP附录1《无菌药品生产》官方解读!

近日&#xff0c;瑞士药监局发布了EU GMP附录1《无菌药品生产》&#xff08;同时也是PIC/S和WHO GMP附录1&#xff09;的解读文件&#xff0c;该文件侧重于新版EU、PIC/S和WHO GMP附录1的一些最重要的变化&#xff0c;也涵盖了长期以来反复引起问题的方面。反映了检查员对这些主…

HALCON中的运算符和控制流算子

HALCON中的运算符 Haclon中的运算符包括算术运算符、逻辑运算符、关系运算符&#xff0c;其功能和用法与C语言相类似。但HALCON中每个运算符都有一个算子相对应&#xff0c;如表2-2所示。 HALCON中的控制流算子 HALCON通过控制流算子来控制程序的走向&#xff0c;包括条件选…

MQTT协议消息代理服务远程连接

目录 1. Linux 搭建 Mosquitto 2. Linux 安装Cpolar 3. 创建MQTT服务公网连接地址 4. 客户端远程连接MQTT服务 5. 代码调用MQTT服务 6. 固定连接TCP公网地址 7. 固定地址连接测试 Mosquitto是一个开源的消息代理&#xff0c;它实现了MQTT协议版本3.1和3.1.1。它可以在不…

在浏览器中使用WebRTC获取用户IP地址

本文翻译自 Discover WebRTC: Obtain User IP Addresses in the Browser&#xff0c;作者&#xff1a;Zack&#xff0c; 略有删改。 如果需要在程序中获取当前用户的IP&#xff0c;通常手段都是需要使用服务器。但现在借助WebRTC的强大功能&#xff0c;我们可以直接在浏览器客户…

你知道如何实现游戏中的透视效果吗?

引言 游戏中的透视效果可以合理运用CtrlCV实现。 不知道大家有没有这样一段经历&#xff1a;在做Cocos项目时需要一些特定的Shader去做一些特定的效果&#xff0c;例如透视、高光、滤镜等等&#xff0c;想自己写吧&#xff0c;不怎么会啊&#xff0c;网上又找不到&#xff0c…

Gensim库——文本处理和主题建模的强大工具

在信息时代&#xff0c;海量的文本数据不断地涌现。如何从这如山如海的文本中提取有意义的信息&#xff0c;成为了一项关键任务。Python语言提供了许多优秀的库和工具来处理文本数据&#xff0c;其中一款备受推崇的工具就是Gensim库。Gensim是一个开源的Python库&#xff0c;它…

ESP32 MicroPython AI摄像头应用⑩

ESP32 MicroPython AI摄像头应用⑩ 1、AI摄像头应用2、移动检测&#xff08;LCD显示&#xff09;3、实验内容3、参考代码4、实验结果 1、AI摄像头应用 我们小车MCU支持AI(人工智能)加速&#xff0c;可以用于加速神经网络计算和信号处理等工作的向量指令 (vector instructions)…

Haclon简介及数据类型

Haclon简介 HALCON是由德国MVtec公司开发的机器视觉算法包&#xff0c;它由一千多个各自独立的函数&#xff08;算子&#xff09;构成&#xff0c;其中除了包含各类滤波、色彩以及几何、数学转换、形态学计算分析、图像校正&#xff0c;目标分类辨识、形状搜寻等基本的图像处理…

【SEO学习】专家优化

创建、编辑和推广独特的高质量内容既困难又耗时。如果你真的认真对待搜索引擎优化&#xff0c;但却没有取得预期效果&#xff0c;那么最好聘请一位搜索引擎优化专家。 搜索引擎优化专家会执行以下任务&#xff1a; 代码验证和清理 - 确保代码对搜索引擎友好并符合标准。网站结…

CentOS7安装Docker遇到的问题笔记

笔记/朱季谦 以下是笔者本人学习搭建docker过程当中记录的一些实践笔记&#xff0c;过程当中也遇到了一些坑&#xff0c;但都解决了&#xff0c;就此记录&#xff0c;留作以后再次搭建时可以直接参考。 一、首先&#xff0c;先检查CentOS版本&#xff0c;保证在CentOS7版本以…

Linux:详解(yum的使用、vim编辑器命令集合以及gcc/g++编译器的使用)

Linux 软件包管理器 yum 什么是软件包&#xff1a; 在Linux下安装软件, 一个通常的办法是下载到程序的源代码, 并进行编译, 得到可执行程序. 但是这样太麻烦了, 于是有些人把一些常用的软件提前编译好, 做成软件包(可以理解成windows上的安装程序)放在一个服务器上, 通…

【opencv】debug报错HEAP CORRUPTION DETECTED

运行至第一句涉及矩阵运算的代码&#xff08;如cv::multiply&#xff09;时报错 HEAP CORRUPTION DETECTED: after Normal block (#45034) at 0x000001BDC586F0E0. CRT detected that the application wrote to memory after end of heap buffer.release下不会报错&#xff0…

探索人工智能领域——每日30个名词详解【day4】

目录 前言 正文 总结 &#x1f308;嗨&#xff01;我是Filotimo__&#x1f308;。很高兴与大家相识&#xff0c;希望我的博客能对你有所帮助。 &#x1f4a1;本文由Filotimo__✍️原创&#xff0c;首发于CSDN&#x1f4da;。 &#x1f4e3;如需转载&#xff0c;请事先与我联系以…

(免费领源码)python#flask#mysql旅游数据可视化81319-计算机毕业设计项目选题推荐

摘要 信息化社会内需要与之针对性的信息获取途径&#xff0c;但是途径的扩展基本上为人们所努力的方向&#xff0c;由于站在的角度存在偏差&#xff0c;人们经常能够获得不同类型信息&#xff0c;这也是技术最为难以攻克的课题。针对旅游数据可视化等问题&#xff0c;对旅游数据…

vscode设置前进、后退快捷键

前言 在我们使用vscode编写程序时&#xff0c;经常需要在不同的文件之间跳来跳去&#xff0c;如果只是依靠个人记忆去操作会显得非常不方便。本文介绍如何设置vscode的前进、后退快捷键。 1 vscode设置前进、后退快捷键 点击“设置”图标&#xff0c;然后点击“键盘快捷方式…

注解案例:山寨Junit与山寨JPA

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 上篇讲了什么是注解&am…

一文讲明 网络调试助手的基本使用 NetAssist

我 | 在这里 &#x1f575;️ 读书 | 长沙 ⭐软件工程 ⭐ 本科 &#x1f3e0; 工作 | 广州 ⭐ Java 全栈开发&#xff08;软件工程师&#xff09; &#x1f383; 爱好 | 研究技术、旅游、阅读、运动、喜欢流行歌曲 &#x1f3f7;️ 标签 | 男 自律狂人 目标明确 责任心强 ✈️公…

初刷leetcode题目(7)——数据结构与算法

&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️Take your time ! &#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️…

NameServer源码解析

1 模块入口代码的功能 本节介绍入口代码的功能&#xff0c;阅读源码的时候&#xff0c;很多人喜欢根据执行逻辑&#xff0c;先从入口代码看起。NameServer部分入口代码主要完成命令行参数解析&#xff0c;初始化Controller的功能。 1.1 入口函数 首先看一下NameServer的源码目…