Pytorch 权重衰减

news2024/7/6 20:08:03

目录

1、权重衰减

2、L2正则化和L1正则化

3、高维线性回归演示权重衰减


1、权重衰减

一般来说,我们总是可以通过去收集更多的训练数据来缓解过拟合。 但这可能成本很高,耗时颇多,或者完全超出我们的控制,因而在短期内不可能做到。所以我们需要将重点放在正则化技术上,权重衰减(weight decay)是最广泛使用的正则化的技术之一, 它通常也被称为L2正则化

2、L2正则化和L1正则化

L2正则化线性模型构成经典的岭回归(ridge regression)算法, L1正则化线性回归是统计学中类似的基本模型, 通常被称为套索回归(lasso regression)。 使用L2范数的一个原因是它对权重向量的大分量施加了巨大的惩罚。 这使得我们的学习算法偏向于在大量特征上均匀分布权重的模型。 在实践中,这可能使它们对单个变量中的观测误差更为稳定。 相比之下,L1惩罚会导致模型将权重集中在一小部分特征上, 而将其他权重清除为零。 这称为特征选择(feature selection),这可能是其他场景下需要的。

L1正则化

L2正则化

3、高维线性回归演示权重衰减

导入所需库

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

定义拟合公式

为了使过拟合效果更加明显,将问题的维数增加到d=200,并只使用包含20个小样本的训练集训练

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

初始化模型参数

def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]

定义L2范数

def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2

定义训练代码

def train(lambd):
    w, b = init_params() # 初始化权重
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss #定义Lambda函数
    num_epochs, lr = 100, 0.003 #定义训练论述和学习率
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test']) #定义动画
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())

设置lambd=0进行训练,显然发生过拟合(测试损失不变,训练损失下降)

train(lambd=0)

 使用权重衰减,即设置lambd=3,训练误差增大了,但测试误差减小,正则化起到效果。

train(lambd=3)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/149846.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OpenTelemetry日志体系

前言 OpenTelemetry为了实现其可观测性有三大体系:Trace,Metrics和Log。本文将对于OpenTelemetry实现的日志体系进行详细的讲述。 日志 说到日志相比大家都能侃侃而谈:帮助快速定位出现的问题,数据追踪,性能分析等等…

联合证券|新年再现“A吃A” 建发股份拟控股美凯龙

2022年,A股上市公司之间的并购超越10起。2023年伊始,上市公司“A吃A”或将再添新事例。 建发股份(600153)1月8日晚公告,公司正在谋划经过现金方法协议收购红星美凯龙控股集团有限公司(以下简称“红星控股”)持有的美凯…

导致 MySQL 索引失效

1、索引失效情况1:非最左匹配 最左匹配原则指的是,以最左边的为起点字段查询可以使用联合索引,否则将不能使用联合索引。 我们本文的联合索引的字段顺序是 sn name age,我们假设它们的顺序是 A B C,以下联合索引的…

YOLOv8来啦 | 详细解读YOLOv8的改进模块!YOLOv5官方出品YOLOv8!

YOLOv8是Ultralytics开发的 YOLO(You Only Look Once)物体检测和图像分割模型的最新版本,详细介绍可以参考Ultralytics发布的网址,可以通过ultralytics python 包获取代码,暂时还没有官方公布代码 安装ultralytics py…

再见,Python 循环,向量化已超神

使用向量化 -- Python中循环的超级快速替代品 我们在几乎所有的编程语言中都学习过循环。所以,默认情况下,只要有重复性的操作,我们就会开始实施循环。但是当我们处理大量的迭代(数百万/数十亿行)时,使用循…

啊哈哈哈,2023年Python学习清单来喽;这清单都上齐了,怎么不收藏啊

不知不觉已经在CSDN写了三百多篇博客,这些博客中,Python相关的内容占了绝大多数,而这些与Python有关的内容中,绝大多数又都是我个人学习的总结,本文希望把我的Python学习过程做一个总结,也希望能够帮助不同…

【Java】阻塞队列

【Java】阻塞队列 什么是阻塞队列? 阻塞队列(BlockingQueue)是一个支持两个附加操作的队列。这2个附加的操作支持阻塞的插入和移除方法。 支持阻塞的插入方法:当队列满时,队列会阻塞插入元素的线程,直到…

4.3 集成运放电路简介

从本质上看,集成运放是一种高性能的直接耦合放大电路。尽管品种繁多,内部结构也各不相同,但是它们的基本组成部分、结构形式和组成原则基本一致。因此,对于典型电路的分析具有普遍意义,一方面可以从中理解集成运放的性…

MapGIS用投影变换功能绘制多条测线

1 问题的提出 在做测线设计的时候,经常要在MapGIS里投点,投线。投点可以用section自带功能实现,但投线还是另有讲究的。可以用MapGIS自带的投影变换功能来实现。 先看下我已知线在奥维地图里是什么样的。 下面就来对这些线,进行投影变换,生成wl线文件,从而可以放入设计…

1. 【prometheus 学习】架构Architecture

prometheus是开源的系统监控及告警系统,很多企业、互联网公司应用prometheus,搭配可视化的grafana,实现对系统的全面度量。 prometheus应用的场景: 1)对于数据准确率要求不高,可以粗略反映监控数据走势的场…

前端实战:Vue实现数据导出导入案例

❤️作者主页:IT技术分享社区 ❤️作者简介:大家好,我是IT技术分享社区的博主,从事C#、Java开发九年,对数据库、C#、Java、前端、运维、电脑技巧等经验丰富。 ❤️荣誉: CSDN博客专家、数据库优质创作者🏆&…

linux内核调度浅析

目录 进程控制块PCB 就绪队列结构体 调度队列成员 下一个进程的选择 进程切换 加入就绪队列 linux进程调度相关的知识再重新梳理一遍。抽取主要数据结构中的主要成员,以最简单的方式实现进程调度。 进程控制块PCB task_struct /* 进程PCB */ struct task_s…

人脸识别速度超高识别度超高项目,可实时进行检测,一看就会!

1.本项目属于pytorch-facenet项目,核心代码是facenet算法,经过1周的代码修改,可以进行入库和识别的连续操作,经过测试,识别效果很好,在GPU环境中可以进行实时摄像头的识别,同时项目将放在百度网…

知行之桥传输带附件的文件示例

在大多数的项目中,交易伙伴往往只要求传输报文消息,业务数据经由报文内容来进行传输。但有些交易伙伴也会要求传输带附件的文件,比如在与大众和延锋汽车YFAI对接的项目当中,交易伙伴要求传输VDA4951 ENGDAT报文,该业务…

vue3 销毁组件方法

问题描述:使用elementplus的dialog,当关闭弹窗后不刷新页面,直接再次打开发现弹窗中还存留上一次的数据。尝试定义关闭事件,或者使用api中提供的属性destroy-on-close 都不行。后来发现这是一个误区。弹窗关闭时并不代表这个组件已经被销毁了…

Linux测试主机之间连通性和端口是否开放的方法

文章目录测试主机之间的连通性测试端口是否开放(curl)测试端口是否开放(wget)测试端口是否开放(ssh)下面每一种测试方式都给出了成功通信的截图,如果与截图不相符可以根据你控制台的报错调试。测试主机之间的连通性 测试两个主机之间是否可以通信,通常使…

Odoo 16 企业版手册 - 库存管理之规则与路线

规则和路线 产品上定义的路线将帮助您理解和跟踪产品的每一次调拨。它是用于库存调拨的操作规则或路线。没有适当的策略,就很难监控和管理公司的库存变动。根据您的公司政策,您可以设置某些操作规则来定义库存中的产品调拨。使用这些规则,Odo…

何为 Vue3 组件标注 TS 类型,看这篇文章就够了!

文章目录前言一、为 props 标注类型使用 < script setup >非 < script setup >二、为 emits 标注类型使用 < script setup >非 < script setup >三、为 ref() 标注类型默认推导类型通过接口指定类型通过泛型指定类型四、为 reactive() 标注类型默认推导…

什么真无线蓝牙耳机值得入手?蓝牙耳机全方位挑选攻略

从我们的日常生活中可以看到&#xff0c;蓝牙耳机的使用频率真的是越来越高了&#xff0c;这主要得益于蓝牙耳机的使用便捷性以及近几年的快速发展。很多人在选择时不禁有些疑问&#xff0c;不知道哪款真无线蓝牙耳机值得入手&#xff1f; 都说买新不买旧&#xff0c;所以&…

黑马2022新版SSM框架教程(SpringMVC_day01)

SpringMVC_day01 文章目录SpringMVC_day011&#xff0c;SpringMVC简介1.1 SpringMVC概述2&#xff0c;SpringMVC入门案例2.1 需求分析2.2 案例制作步骤1:创建Maven项目&#xff0c;并导入对应的jar包步骤2:创建控制器类步骤3:创建配置类步骤4:创建Tomcat的Servlet容器配置类步骤…