18. 权重衰退的代码的从零实现和简洁实现

news2024/9/22 7:21:36

通过一个简单的例子来演示权重衰减。

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

在这里插入图片描述

0.01就是权重,xi是随机的输入,噪音是均值为0,方差为0.01的一个正态分布

n_train,n_test,num_inputs,batch_size = 20,100,200,5  # num_inputs特征的维度
true_w,true_b = torch.ones((num_inputs,1))* 0.01,0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train) # 生成人工训练数据集
train_iter = d2l.load_array(train_data, batch_size) # 读取内存中的数组变成迭代器,返回的是随机挑选的数据:特征和标签
test_data = d2l.synthetic_data(true_w, true_b, n_test) 
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

下面我们将从头开始实现权重衰减,只需将的L2平方惩罚添加到原始目标函数中。

1. 初始化模型参数

def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]

2. 定义L2范数惩罚

实现这一惩罚最方便的方法是对所有项求平方后并将它们求和。

def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2

L1范数定义如下:

# 定义L1范数惩罚
def l1_penalty(w):
    return torch.sum(torch.abs(w))

最后的训练结果如下,可以看出曲线不够平滑。

在这里插入图片描述

3. 定义训练代码实现

下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。 线性网络和平方损失没有变化, 所以我们通过d2l.linreg和d2l.squared_loss导入它们。 唯一的变化是损失现在包括了惩罚项

def train(lambd):
    w, b = init_params()
    # lambda定义了一个net(X)函数
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs): # 扫num_epochs次数据
        for X, y in train_iter: #从迭代器中读取真实的输入和输出(标签)
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
                                     
    # norm(self, input, p=2),input是输入,p代表是求p范数,默认p=2;
    # item()返回的是一个浮点型数据
    print('w的L2范数是:', torch.norm(w).item())

4. 忽略正则化直接训练

我们现在用lambd = 0禁用权重衰减后运行这个代码。 注意,这里训练误差有了减少,但测试误差没有减少, 这意味着出现了严重的过拟合。

train(lambd=0)

运行结果:

在这里插入图片描述

w的L2范数是14.22…,训练损失一直在降,测试的损失不变,并且差值越来越大,这是非常明显的过拟合。因为损失一直在降说明一直在拟合函数,拟合噪音,但是在新的数据集上,在测试集上面看不到有任何进展,是非常严重的过拟合。

5. 使用权重衰减

我们使用权重衰减来运行代码。 注意,在这里训练误差增大,但测试误差减小。 这正是我们期望从正则化中得到的效果。

train(lambd=3)

在这里插入图片描述

6. 简洁实现

在下面的代码中,我们在实例化优化器时直接通过weight_decay指定weight decay超参数。 默认情况下,PyTorch同时衰减权重和偏移。 这里我们只为权重设置了weight_decay,所以偏置参数b不会衰减。

更多关于optim.SGD第一个参数可以传入 定义了参数组的dict 的解释链接:原文链接

def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())

这些图看起来和我们从零开始实现权重衰减时的图相同。 然而,它们运行得更快,更容易实现。 对于更复杂的问题,这一好处将变得更加明显。

train_concise(0)

在这里插入图片描述

train_concise(3)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/79815.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Alibaba 官方微服务扛把子「SpringCloudAlibaba 全彩学习手册.PDF」,开源学习ing,

最近我在知乎上看过的一个热门回答: 初级 Java 开发面临的最大瓶颈在于,脱离不出自身业务带来的局限。日常工作中大部分时间在增删改查、写写接口、改改 bug,久而久之就会发现,自己的技术水平跟刚工作时相比没什么进步。 所以我们…

Spring Batch 批处理入门案例

引言 书接上篇 何为Spring Batch?怎么玩? ,前面普及了一下Spring Batch 相关介绍,本篇来一个helloword,简单试一下Spring Batch 怎么玩 批量处理流程 开始前,先了解一下Spring Batch程序运行大纲&#x…

[附源码]计算机毕业设计惠农微信小程序论文Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

这些不知道,别说你熟悉 Spring

写在前面 我们大多数 Java 程序员的日常工作基本都是在做业务开发,俗称 crudboy。 作为 crudboy 的你有没有这些烦恼呢? 随着业务的迭代,新功能的加入,代码变得越来越臃肿,可维护性越来越低,慢慢变成了屎…

【亲测有用】ERR_PROXY_CONNECTION_FAILED的解决方案(电脑明明有网络,但就是无法访问浏览器的网页!)

一、问题描述 就我而言,每次遇到这种问题,都是因为电脑意外关机导致的。昨天,我忘记给电脑充电,结果一觉醒来,发现电脑明明有网络,因为微信、QQ甚至向日葵远程连接别的电脑都没有问题,但就是所…

基于java+springboot+mybatis+vue+mysql的漫画动漫管理网站

项目介绍 本系统主要包括管理员和用户两个角色组成,主要包括以下功能: (1)前台:首页、漫画资源、排行榜、交流论坛、公告信息、个人中心、后台管理 。 (2)管理员:首页、个人中心、…

算法基础篇-06-排序-NB三人组(快速/堆/归并排序)

1. NB 三人组介绍 1.1 快速排序(Quick Sort) 时间复杂度:O(nlogn) 归位: 让元素去它该去的位置,保证左边的元素都比他小,右边都比他大; 1.1.1 原理图示: 假设初始列表: 我们从左边第一个…

三秒钟,我要拿到世界杯所有队伍阵容信息

文章目录🕐Im coming~🕑我写了个啥?🕔咋写的?🕘代码供上🕛 See you next time专栏Python零基础入门篇🔥Python网络蜘蛛🔥Python数据分析Django基础入门宝典🔥…

硬核,阿里自爆虐心万字面试手册,Github上获赞89.7K

开篇小叙 现在Java面试可以说是老生常谈的一个问题了,确实也是这么回事。面试题、面试宝典、面试手册......各种Java面试题一搜一大把,根本看不完,也看不过来,而且每份面试资料也都觉得Nice,然后就开启了收藏之路。 …

CSS - 02. CSS进阶

文章目录CSS进阶1 Emmet语法1.1 快速生成HTML结构语法1.2 快速生成CSS样式语法1.3 快速格式化代码2 CSS的复合选择器2.1 什么是复合选择器2.2 后代选择器2.3 子选择器2.4 并集选择器2.5 伪类选择器2.6 链接伪类选择器2.7 :focus 伪类选择器2.8 复合选择器总结3 CSS 的元素显示模…

Pr:导出设置之多路复用器与常规

多路复用器 MULTIPLEXERH.264、HEVC(H.265)和 MPEG 等格式中包含多路复用器 MULTIPLEXER模块,可用于控制如何将视频和音频数据合并到单个流中(又称“混合”)。基本设置Basic Settings多路复用器Multiplexer视频和音频流…

SolidWorks综合教程

SolidWorks综合教程 SolidWorks 认证工程师 (CSWA​​) 考试的完美指南,包含实例、测验和实践培训 课程英文名:SOLIDWORKS Academy A Comprehensive Course on SolidWorks 此视频教程共11.0小时,中英双语字幕,画质清晰无水印&a…

Android编译优化~Gradle构建基准测试

背景 之前对安卓打包编译优化有所实践,但当时对优化提升结果采取了手动测试的办法才拿到结果,而且遇到大型工程更是痛不欲生。不过当时采取的策略是将增量测试代码提到了Git,编译一次抄一次代码,样本数据只重复了10次&#xff0c…

【实时数仓】业务数据采集之Maxwell的介绍、底层原理、安装及初始化、监控功能、采集服务和MySQL的binlog

文章目录一 业务数据采集0 业务数据采集思路1 Maxwell 介绍2 Maxwell工作原理(1) MySQL主从复制过程(2)Maxwell的工作原理3 MySQL的binlog(1)什么是binlog(2)binlog的开启&#xff0…

算法基础篇-07-排序-希尔排序(Shell Sort)

1. 希尔排序简介 希尔排序(Shell Sort) 是一种分组插入排序算法,是基于插入排序算法演进而来;首先取一个整数d1n/2, 将元素分为d1个组,每组相邻元素之间距离d1,在各组内进行直接插入排序;取第二个整数d2d1/2,重复上述分组排序过程…

微信小程序框架(二)-全面详解(学习总结---从入门到深化)

目录 组件_基础视图 容器 view 文本 text 图片 image 组件_滑块视图容器 基本实现 Swiper常用属性说明 组件_滚动视图区域 水平滚动 垂直滚动 常用属性 组件_icon 图标使用 字体图标属性 组件_progress 基本进度条 属性说明 组件_表单 登录页面 组件_button 按钮属…

这个队列的思路是真的好,现在它是我简历上的亮点了。

前几天在一个开源项目的 github 里面看到这样的一个 pr: 光是看这个名字,里面有个 MemorySafe,我就有点陷进去了。 我先给你看看这个东西: 这个肯定很眼熟吧?我是从阿里巴巴开发规范中截的图。 为什么不建议使用 Fix…

k8s-使用kube-install一键部署(亲测超详细)

文章目录测试环境操作系统环境配置升级操作系统内核基础配置k8s前置配置docker 安装kube-install1.获取kube-install软件包kube-install提供两种安装方式通过Web平台安装kubernetes集群通过命令行快速安装kubernetes集群扩容与销毁Node|修复Master|卸载集群Kubernetes Dashboar…

“特耐苏“牌传递窗——洁净区与洁净之间的辅助设备

传递窗安于高洁净要求的洁净区与洁净之间; 传递窗是一种洁净室的辅助设备,主要用于洁净区与洁净区之间、洁净区与非洁净区之间小件物品的传递,以减少洁净室的开门次数,把对洁净室的污染降低到zui低程度。传递窗采用不锈钢板制作,…

【期末过过过90+】【数据库系统概述】亲笔备考笔记+知识点整理+备考建议

文章目录:故事的开头总是极尽温柔,故事会一直温柔……💜一、前言二、考试题型三、必考重点四、知识点梳理及笔记第1章:绪论第2章:关系数据库第3章:关系数据库标准语言SQL第4章:数据库安全性第5章…