基于PyTorch实战权重衰减——L2范数正则化方法(附代码)

news2024/11/17 12:43:55

文章目录

      • 0. 前言
      • 1. 权重衰减方法作用
      • 2. 权重衰减方法原理介绍
      • 3. 验证权重衰减法实例说明
        • 3.1 训练数据样本
        • 3.2 网络模型
        • 3.3 损失函数
        • 3.4 训练参数
      • 4. 结果对比
      • 5. 源码

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文旨在通过实例验证权重衰减法(L2范数正则化方法)对深度学习神经元网络模型训练过程中出现的过拟合现象的抑制作用,加深对这个方法的理解。

1. 权重衰减方法作用

在训练神经元网络模型时,如果训练样本不足或者网络模型过于复杂,往往会导致训练误差可以快速收敛,但是在测试数据集上的泛化误差很大,即出现过拟合现象。

出现这种情况当然可以通过增多训练样本数来解决,但是如果增加额外的训练数据很困难,对应这类过拟合问题常用方法就是权重衰减法

2. 权重衰减方法原理介绍

权重衰减等价于L2范数正则化,其方法是在损失函数中增加权重的L2范数作为惩罚项。以MSE均方误差为例,原本损失函数应该是:

l o s s = 1 n Σ ( y − y ^ ) 2 loss = \dfrac{1}{n} \Sigma (y - \widehat{y})^2 loss=n1Σ(yy )2

增加L2范数后变成:

l o s s = 1 n Σ ( y − y ^ ) 2 + λ 2 n ∣ ∣ w ∣ ∣ 2 loss = \dfrac{1}{n} \Sigma (y - \widehat{y})^2+ \dfrac{\lambda}{2n}||w||^2 loss=n1Σ(yy )2+2nλ∣∣w2
其中 ∣ ∣ w ∣ ∣ 2 ||w||^2 ∣∣w2代表权重的二范数, λ \lambda λ为权重二范数的系数, λ \lambda λ≥0。

可以见得,如果 λ \lambda λ越大,权重的“惩罚力度”就越大,权重 w w w的绝对值就越接近0,如果 λ \lambda λ=0,相当于没有“惩罚力度”。

3. 验证权重衰减法实例说明

3.1 训练数据样本

本次演示实例使用的输入训练数据为x_train = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],输出训练数据为y_train = [0.52, 8.54, 6.94, 20.76, 32.17, 30.65, 40.46, 80.12, 75.12, 98.83]。

这个数据集是由 y = x 2 y = x^2 y=x2函数增加一个噪声数据生成得出,可以理解为 y = x 2 y = x^2 y=x2为该实例的真实解析解(真实规律)。

3.2 网络模型

使用torch.nn.Sequential()构建6层全连接层网络,每层神经元个数为:
InputLayer = 1,HiddenLayer1 = 3,HiddenLayer2 = 5,HiddenLayer3 = 10,HiddenLayer4 = 5,OutputLayer = 1

3.3 损失函数

选择MSE均方差损失函数,使用 torch.norm()计算权重的L2范数。

3.4 训练参数

无论是否增加L2范数惩罚项,训练参数都是一样的(控制变量):优化函数选用torch.optim.Adam(),学习速率lr=0.005,训练次数epoch=3000。

4. 结果对比

增加L2范数学习结果为:
请添加图片描述
其中红点为训练数据;黄色线为解析解,即 y = x 2 y=x^2 y=x2;蓝色线为训练后的模型在 x = [ 0 , 10 ] x=[0, 10] x=[0,10]上的预测结果。

不加L2惩罚项的学习结果为:
在这里插入图片描述
可以见得增加L2范数惩罚项后,测试的输出数据可以明显更贴合 y = x 2 y=x^2 y=x2理论曲线,尤其是在0~4范围上。

这里也可以增加一个类似损失函数的方式通过数据说明增加L2范数后学习结果更好,定义为:

l o s s = Σ ( y − y ^ ) 2 y ^ 2 loss = \Sigma\dfrac{(y - \widehat{y})^2}{\widehat{y}^2} loss=Σy 2(yy )2
其中 y ^ \widehat{y} y 为网络模型的输出结果, y = x 2 y=x^2 y=x2

不加L2范数惩罚项 l o s s w i t h o u t L 2 = 183.65 loss_{without L2}=183.65 losswithoutL2=183.65
增加L2范数惩罚项后 l o s s w i t h L 2 = 115.70 loss_{with L2}=115.70 losswithL2=115.70

5. 源码

import torch
import matplotlib.pyplot as plt

torch.manual_seed(25)

x_train = torch.tensor([1,2,3,4,5,6,7,8,9,10],dtype=torch.float32).unsqueeze(-1)
y_train = torch.tensor([0.52,8.54,6.94,20.76,32.17,30.65,40.46,80.12,75.12,98.83],dtype=torch.float32).unsqueeze(-1)
plt.scatter(x_train.detach().numpy(),y_train.detach().numpy(),marker='o',s=50,c='r')

class Linear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(in_features=1, out_features=3),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=3,out_features=5),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=10),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=10,out_features=5),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=1),
            torch.nn.ReLU(),
        )

    def forward(self,x):
        return self.layers(x)

linear = Linear()

opt = torch.optim.Adam(linear.parameters(),lr= 0.005)
loss = torch.nn.MSELoss()


for epoch in range(3000):
    l = 0
    L2 = 0
    for iter in range(10):

        for w in linear.parameters():
            L2 = torch.norm(w, p=2)*1e8  #计算权重的L2范数,如果要取消L2正则化惩罚只要把这项*0就可以了
        opt.zero_grad()
        output = linear(x_train[iter])
        loss_L2 = loss(output, y_train[iter]) + L2
        loss_L2.backward()
        l = loss_L2.detach() + l
        opt.step()
    print(epoch,L2,loss_L2)

#     plt.scatter(epoch, l, s=5,c='g')
#
# plt.show()


if __name__ == '__main__':
    predict_loss = 0
    for i in range(1000):
        x = torch.tensor([i/100], dtype=torch.float32)
        y_predict = linear(x)
        plt.scatter(x.detach().numpy(),y_predict.detach().numpy(),s=2,c='b')
        plt.scatter(i/100,i*i/10000,s=2,c='y')
        predict_loss = (i*i/10000 - y_predict)**2/(y_predict)**2 + predict_loss   #计算神经元网络模型输出对解析解的loss
# plt.show()

# print(linear.state_dict())
print(predict_loss)

本文的主要参考文献:
[1]Aston Zhang, Mu Li. Dive into deep learning.北京:人民邮电出版社.2021-8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/757000.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java 设计模式——迭代器模式

目录 1.概述2.结构3.案例实现3.1.抽象迭代器3.2.具体迭代器3.3.抽象聚合3.4.具体聚合3.5.测试 4.优缺点5.使用场景6.JDK 源码解析——Iterator 1.概述 迭代器模式 (Iterator Pattern) 是一种行为型设计模式,它提供一种顺序访问聚合对象(如列表、集合等&…

JVM学习之内存与垃圾回收篇1

文章目录 1 JVM与Java体系结构1.0 Java发展重大事件1.1 虚拟机和Java虚拟机1.3 JVM整体结构1.4 Java代码执行流程1.5 JVM架构模型1.6 JVM的生命周期1.7 JVM发展历程 2 类加载子系统2.1 ClassLoader2.2 用户自定义类加载器2.2.1 为什么需要自定义类加载器2.2.2 自定义类加载器的…

【框架篇】对象注入的三种实现方式

对象注入的实现 一,实现方式的使用 对象注入也可被称为对象装配,是把Bean对象获取出来放到某个类中。 对象注入的实现方式有3种,分别为属性注入,Setter注入和构造方法注入。 为了更好地理解对象注入的实现方式,搞个…

24 MFC文档串行化和单文档应用程序

文章目录 文档串行化全部代码 单文档应用程序搭建原理搭建框架Win32 过度到MFC 三部曲设置ID资源全部代码 单文档应用程序设置标题绘图 简单的管理系统部分代码 文档串行化 ui 设计 保存 void CfileDemoDlg::OnBnClickedBtnSave() {UpdateData();//CFile file(L"Demo.dat…

python+pytest接口自动化(9)-cookie绕过登录(保持登录状态)

目录 cookie工作原理 cookie绕过登录 总结 在编写接口自动化测试用例或其他脚本的过程中,经常会遇到需要绕过用户名/密码或验证码登录,去请求接口的情况,一是因为有时验证码会比较复杂,比如有些图形验证码,难以通过…

旅游信息推荐系统带文档springboot+vue

功能 用户注册和登录:用户可以注册一个账户并登录到系统中。旅游项目展示:系统展示各种旅游项目的信息,包括目的地、行程、费用等。旅游项目搜索和筛选:用户可以搜索和筛选旅游项目,根据目的地、日期、费用等条件。预…

Linux系统终端窗口ctrl+c,ctrl+z,ctrl+d的区别

时常在Linux系统上,执行某命令停不下来,就这几个ctrl组合键按来按去,今天稍微总结下具体差别,便于以后linux系统运维操作 1、ctrlc强制中断程序,相应进程会被杀死,中断进程任务无法恢复执行 2、ctrlz暂停正…

mongodb集群搭建

下载地址: https://www.mongodb.com/try/download/community下载mongodb-linux-x86_64-rhel70-5.0.18 搭建集群 tar -zxvf mongodb-linux-x86_64-rhel70-5.0.18.tgz mkdir -p data/dp cd mongodb-linux-x86_64-rhel70-5.0.18 mkdir -p data/db mkdir log mkdir c…

Ubuntu 23.04安装最新版本Halcon 23.05

Ubuntu 23.04安装最新版本Halcon 23.05 官网下载安装环境变量设置创建快捷方式给个最新ubuntu的镜像源地址 官网下载 去Halcon官网:https://www.mvtec.com/products/halcon/,注册或登录,点击Download: 或者进入大恒网站&#xf…

Ubuntu最新版本23.05配置Flameshot(途中解决疑难杂症)

Ubuntu最新版本23.05配置Flameshot截图软件 安装方法:添加Ubuntu的快捷键遇到的问题解决 安装方法: sudo apt install flameshot出现该页面表示成功: 可以直接在终端输入:flameshot gui flameshot gui进行截图。 添加Ubuntu的…

云计算与大数据——MPI集群配置

什么是MPI集群? MPI(消息传递接口)是一种用于编写并行程序的标准,它允许在多个计算节点上进行通信和协作。MPI集群配置是指在一个或多个计算节点上设置MPI环境以实现并行计算。 MPI集群配置的步骤: 硬件选型&#x…

三菱PLC上位机测试

利用三菱的MX Component与三菱PLC进行以太网通信,我们可以用官方的dll编写C#代码,特别简单,最后附上整个源码下载。 1. 安装MX Component(必须)和GX WORKS3(主要是仿真用,实际可以不装&#xf…

空间光通信-调制解调滤波与同步

图文并茂,讲解电磁波传播原理_哔哩哔哩_bilibili 深入浅出空间光通信-3.调制解调滤波与同步_哔哩哔哩_bilibili 傅里叶变换这样学,何愁不会呢?直观理解傅里叶变换_哔哩哔哩_bilibili 第二十三课:声音编辑必看!&…

【六袆 - windows】windows计划任务,命令行执行,开启计划任务,关闭计划任务,查询计划任务

windows计划任务 查看 Windows 自动执行的指令取消 Windows 中的计划任务启动执行计划任务 查看 Windows 自动执行的指令 您可以使用以下方法: 使用任务计划程序:任务计划程序是 Windows 内置的工具,可以用于创建、编辑和管理计划任务。您可…

pytest+allure运行出现乱码的解决方法

pytestallure运行出现乱码的解决方法 报错截图: 这是因为没有安装allure运行环境或者没有配置allure的环境变量导致,解决方案: 1.安装allure运行环境 官方下载地址:https://github.com/allure-framework/allure2/releases 百度…

JavaSE - 内部类

目录 final定义常量 1. 内部类 1.1 实例内部类 1.1.1 如何获取实例内部类的对象 1.1.2 实例内部类中不能有静态的成员变量 1.1.3 实例内部类方法中可以直接访问外部类中的任何成员 1)在实例内部类方法中访问同名的成员时,优先访问自己的&#xff0…

nacos注册中心+Ribbon负载均衡+完成openfeign的调用

目录 1.注册中心 1.1.nacos注册中心 1.2. 微服务注册和拉取注册中心的内容 2.3.修改订单微服务的代码 3.负载均衡组件 3.1.什么是负载均衡 3.2.什么是Ribbon 3.3.Ribbon 的主要作用 3.4.Ribbon提供的负载均衡策略 4.openfeign完成服务调用 4.1.什么是OpenFeign 4.2…

5.2 Python高阶特性之---切片迭代

一、 切片 一般用于提取指定区间内的内容,常用于:str、list、tuple等类型的的局部变量,如下为具体案例1、 【列表切片】 res_list [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95]1) 无步长: …

C++——类的六大默认成员构造函数

文章目录 1.默认成员函数思维导图2.构造函数定义特性 2.析构函数定义特性 3.拷贝构造函数定义特性 4.赋值构造函数定义特性 5.重载取地址运算符定义特性 6.重载const取地址运算符定义特性 1.默认成员函数思维导图 2.构造函数 定义 在面向对象编程中,构造函数是一种…

RHCSA——Linux网络、磁盘及软件包管理

ZY目录 Linux操作系统讲解:一、网络管理1、NetworkManager1.1、nmtui界面:1.2、nmcli使用方法: 2、配置网络2.1、网络接口以及网络连接2.2、配置方法:2.3、ping命令:2.4、wget命令 二、磁盘管理2.1、分区得两种格式2.1…