李沐深度学习记录4:12.权重衰减/L2正则化

news2024/11/21 2:23:13

权重衰减从零开始实现

#高维线性回归
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

#整个流程是,1.生成标准数据集,包括训练数据和测试数据
#          2.定义线性模型训练
#           模型初始化(函数)、包含惩罚项的损失(函数)
#           定义epochs进行训练,每训练5轮评估一次模型在训练集和测试集的损失,画图显示
#           训练结束后分别查看并比较是否添加范数惩罚项损失对应的训练结果w的L2范数
#生成数据集
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5  #训练数据样本数20,测试样本数100,数据维度200,批量大小5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05  #生成w矩阵(200,1),w值0.01,偏置b为0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train) #生成训练数据集X(20,200),y(20,1),y=Xw+b+噪声,train_data接收返回的X,y
train_iter = d2l.load_array(train_data, batch_size)  #传入数据集和批量大小,构造训练数据迭代器
test_data = d2l.synthetic_data(true_w, true_b, n_test) #生成测试数据集
test_iter = d2l.load_array(test_data, batch_size, is_train=False)  #构造测试数据迭代器

#初始化模型参数
def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]

#定义L2范数惩罚项
def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2  #L2范数公式需要开平方根,但这里L2范数惩罚项是L2范数的平方,所以不需要开平方根了

#训练代码
def train(lambd):  #输入λ超参数
    w, b = init_params()  #初始化模型参数
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss  #net线性模型torch.matmul(X, w) + b;loss是均方误差
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):  #进行多次迭代训练
        for X, y in train_iter:  #每个epoch,取训练数据
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)  #loss计算加上了λ×范数惩罚项
            l.sum().backward()  #这里计算损失和,下面参数更新时会对梯度求平均再更新参数
            d2l.sgd([w, b], lr, batch_size)  #进行参数更新操作
        if (epoch + 1) % 5 == 0:  #每5次epoch训练,评估一次模型的训练损失和测试损失
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())  #训练结束后,计算w的L2范数(没有平方)

#λ为0,无正则化项,训练
train(lambd=0)
d2l.plt.show()

在这里插入图片描述

#λ为10,有正则化项,训练
train(lambd=5)
d2l.plt.show()

在这里插入图片描述

权重衰减的简洁实现

#权重衰减的简洁实现
def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))   #定义模型
    for param in net.parameters():   #初始化参数
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')  #计算loss,这里不包含正则项
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    #在参数优化部分,计算梯度时加入了权重衰减
    #所以是计算loss时没计算正则项,只是在计算梯度时加入了权重衰减吗?
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):   #训练100轮
        for X, y in train_iter:  #对于每轮,取数据训练
            trainer.zero_grad()   #梯度清零
            l = loss(net(X), y)  #计算loss
            l.mean().backward() #反向传播
            trainer.step()  #更新梯度
        if (epoch + 1) % 5 == 0:   #每5轮评估一次模型在测试集和训练集的损失
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())
#没有进行权重衰减
train_concise(0)

在这里插入图片描述

#进行权重衰减
train_concise(5)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1061973.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【docker】数据卷和数据卷容器

一、如何管理docker容器中的数据? 二、数据卷 1、数据卷原理 将容器内部的配置文件目录,挂载到宿主机指定目录下 数据卷默认会一直存在,即使容器被删除 宿主机和容器是两个不同的名称空间,如果想进行连接需要用ssh,…

联合概率和条件概率的区别和联系

联合概率P(A∩B) 两个事件一起(或依次)发生的概率。 例如:掷硬币的概率是 ⁄₂ 50%,翻转 2 个公平硬币的概率是 ⁄₂ ⁄₂ ⁄₄ 25%(这也可以理解为 50% 的 50%) 对于 2 个硬币,样本空间将…

开机可用内存分析Tip

一、开机内存简介 开机内存指的是开机一段时间稳定后的可用内存。一般项目都会挑选同平台其他优秀竞品内存数据,这个也是衡量性能的一个重要标准。所以要进行开机内存检测,同时优化非法内存进程占用。 二、测试前期核查任务 开机内存测试前要进行测试机…

十二、同步互斥与通信

1、概述 (1)可以把多任务系统当做一个团队,里面的每一个任务就相当于团队中的一个人。团队成员之间要协调工作进度(同步)、争用会议室(互斥)、沟通(通信)。多任务系统中所涉及的概念,都可以在现实生活中找到例子。 (2)各类RTOS都会涉及这些概念&#x…

C语言编程经典100例——11至20例

目录 第 11 例 第 12 例 第 13 例 第 14 例 第 15 例 第 16 例 第 17 例 第 18 例 第 19 例 第 20 例 第 11 例 程序源码: /* 题目:古典问题(兔子生崽):有一对兔子,从出生后第3个月起每个月都生…

洛谷题目题解详细解答

洛谷是一个很不错的刷题软件,可是找不到合适的题解是个大麻烦,大家有啥可以私信问我,以下是我已经通过的题目。 你如果有哪一题不会(最好是我通过过的,我没过的也没关系),可以私信我&#xff0…

yolo如何添加模块???修改parse_model()

如何修改添加模块!!! 先贴代码,加模块时有些地方需要修改,只讲核心部分!!!! def parse_model(d, ch): # model_dict, input_channels(3)logger.info(\n%3s%18s%3s%10s …

应用层协议——DNS、DHCP、HTTP、FTP

目录 1、DNS 协议 1-1)Hosts 文件 1-2)DNS 系统 1-3)域名的组成、分类和树状结构 1-4)DNS 域名服务器类型 1-5)DNS 查询方式 1-6)DNS 域名解析的一般步骤 1-7)对象类型与资源记录 2、D…

数据结构-优先级队列(堆)

文章目录 目录 文章目录 前言 一 . 堆 二 . 堆的创建(以大根堆为例) 堆的向下调整(重难点) 堆的创建 堆的删除 向上调整 堆的插入 三 . 优先级队列 总结 前言 大家好,今天给大家讲解一下堆这个数据结构和它的实现 - 优先级队列 一 . 堆 堆(Heap&#xff0…

lv7 嵌入式开发-网络编程开发 10 TCP协议是如何实现可靠传输的

目录 1 TCP 最主要的特点 1.1 特点 1.2 面向流的概念 1.3 Socket 有多种不同的意思 2 TCP是如何实现可靠传输的? 3 TCP报文段的首部格式 4 作业 1 TCP 最主要的特点 TCP 是面向连接的运输层协议,在无连接的、不可靠的 IP 网络服务基础之上提供可…

【实用工具】谷歌浏览器插件开发指南

谷歌浏览器插件开发指南涉及以下几个方面: 1. 开发环境准备:首先需要安装Chrome浏览器和开发者工具。进入Chrome应用商店,搜索“Extensions Reloader”和“Manifest Viewer”两个插件进行安装,这两个插件可以方便开发和调试。 2…

MyBatisPlus(十一)判空查询:in

说明 判空查询&#xff0c;对应SQL语句中的 in 语句&#xff0c;查询参数包含在入参列表之内的数据。 in Testvoid inNonEmptyList() {// 非空列表&#xff0c;作为参数List<Integer> ages Stream.of(18, 20, 22).collect(Collectors.toList());in(ages);}Testvoid in…

基于Kylin的数据统计分析平台架构设计与实现

目录 1 前言 2 关键模块 2.1 数据仓库的搭建 2.2 ETL 2.3 Kylin数据分析系统 2.4 数据可视化系统 2.5 报表模块 3 最终成果 4 遇到问题 1 前言 这是在TP-LINK公司云平台部门做的一个项目&#xff0c;总体包括云上数据统计平台的架构设计和组件开发&#xff0c;在此只做…

李沐深度学习记录4:11模型选择、欠拟合和过拟合

权重衰减从零开始实现 #高维线性回归 %matplotlib inline import torch from torch import nn from d2l import torch as d2l#整个流程是&#xff0c;1.生成标准数据集&#xff0c;包括训练数据和测试数据 # 2.定义线性模型训练 # 模型初始化&#xff08;函…

一种4g扫码付费通电控制器方案

之前开发了一款扫码付款通电控制器 功能&#xff1a;用户扫码付款后设备通电&#xff0c;开始倒计时&#xff0c;倒计时结束后设备断电&#xff0c;资金到账商家的商家助手里面&#xff0c;腾讯会收取千分之6手续费。 产品主要应用场景 本产品主要应用于各类无人值守或者自助…

【算法基础】基础算法(二)--(高精度、前缀和与差分)

一、高精度 当一个数很大&#xff0c;大到 int 无法存下时&#xff0c;我们可以考虑用数组来进行存储&#xff0c;即数组中一个位置存放一位数。 但是对于数组而言&#xff0c;一个数顺序存入数组后&#xff0c;对其相加减是很简单的。但是当需要进位时&#xff0c;还是很麻烦的…

(c语言)调试——习题

第一题 题目&#xff1a; 解析&#xff1a; 答案&#xff1a;C 栈溢出属于运行时错误&#xff0c;在错误分类时不能分作一类 第二题 题目 &#xff1a; 解析&#xff1a; 答案&#xff1a;A F5是调试不执行&#xff0c;ctrlF5是开始执行不调试 第三题 题目&#xff1a; …

做好微信CRM,这些功能你不可不知!

在当前的数字化时代&#xff0c;微信已成为我们日常生活中的重要元素&#xff0c;无论是社交交流、信息传递还是商务合作&#xff0c;微信都扮演着不可或缺的角色。为了更有效地管理微信资源并提高工作效率&#xff0c;很多组织和公司都选择引入微信CRM系统。那么&#xff0c;怎…

服装服饰小程序商城的作用是什么

服装绝对算是市场重要的组成部分&#xff0c;零售批发都有大量从业者&#xff0c;随着线下流量匮乏、经营困难重重&#xff0c;很多厂家商家选择线上经营&#xff0c;主要方式是直播、入驻第三方平台等&#xff0c;同时私域节奏加快及线上平台限制等&#xff0c;不少商家也是通…

Java架构师设计思想

目录 1 设计核心思想封装2 设计核心思想隔离2.1 隔离的好处3 设计思想由大到小,由粗到精,逐步细化3.1 由大到小3.2 由粗到精3.3 逐步细化4 设计思想迭代4.1 和设计思想由大到小,由粗到精,逐步细化区别5 总结1 设计核心思想封装 首先我们来看一下什么是封装。那封装呢也叫做…