权重衰减学习

news2024/11/24 12:24:07

1.权重衰减是最广泛使用的正则化技术之一

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

2.生成数据

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

3.初始化模型参数

def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]

4.定义L2范数惩罚

def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2

5.定义训练代码实现

def train(lambd):
    w, b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())

6.忽略正则化直接训练

train(lambd=0)

7.使用权重衰减来运行代码。这里训练误差增大,但测试误差减小。

train(lambd=3)

8.简洁实现

def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())

9.训练过程中使用train_concise方法查看损失函数(loss)变化情况

train_concise(0)


​​​​​​​train_concise(3)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2223871.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

论文笔记:LaDe: The First Comprehensive Last-mile Delivery Dataset from Industry

2023 KDD 1 intro 1.1 背景 随着城市化进程的加快和电子商务的发展,最后一公里配送已成为一个关键的研究领域 最后一公里配送,如图1所示,是指连接配送中心和客户的包裹运输过程,包括包裹的取件和配送除了对客户满意度至关重要外…

《等保测评新视角:安全与发展的双赢之道》

在数字化转型的浪潮中,企业面临的不仅是技术革新的挑战,更有信息安全的严峻考验。等保测评,作为国家网络安全等级保护的一项重要措施,不仅为企业的安全护航,更成为推动企业高质量发展的新引擎。本文将从全新的视角&…

如何用 Spring AI + Ollama 构建生成式 AI 应用

为了构建生成式AI应用,需要完成两个部分: • AI大模型服务:有两种方式实现,可以使用大厂的API,也可以自己部署,本文将采用ollama来构建• 应用构建:调用AI大模型的能力实现业务逻辑,…

mfc之tab标签控件的使用--附TabSheet源码

TabSheet源码 TabSheet.h #if !defined(AFX_TABSHEET_H__42EE262D_D15F_46D5_8F26_28FD049E99F4__INCLUDED_) #define AFX_TABSHEET_H__42EE262D_D15F_46D5_8F26_28FD049E99F4__INCLUDED_#if _MSC_VER > 1000 #pragma once #endif // _MSC_VER > 1000 // TabSheet.h : …

es实现桶聚合

目录 聚合 聚合的分类 DSL实现桶聚合 dsl语句 结果 聚合结果排序 限定聚合范围 总结 聚合必须的三要素: 聚合可配置的属性 DSL实现metric聚合 例如:我们需要获取每个品牌的用户评分的min,max,avg等值 只求socre的max 利用RestHighLevelClien…

【Multisim14.0正弦波>方波>三角波】2022-6-8

缘由有没有人会做啊Multisim14.0-其他-CSDN问答参考方波、三角波、正弦波信号产生 - 豆丁网

arcgis中dem转模型导入3dmax

文末分享素材 效果 1、准备数据 (1)DEM (2)DOM 2、打开arcscene软件 3、加载DEM、DOM数据 4、设置DOM的高度为DEM

【虚幻引擎UE】UE5 音频共振特效制作

UE5 音频共振特效制作 一、基础准备1.插件准备2.音源准备 二、创建共感NRT解析器和设置1.解析器选择依据2. 创建解析器3. 创建解析器设置(和2匹配)4.共感NRT解析器设置参数调整5.为共感NRT解析器关联要解析的音频和相应设置 三、蓝图控制1.创建Actor及静…

Openlayers高级交互(8/20):选取feature,平移feature

本示例介绍如何在vue+openlayers中使用Translate,选取feature,平移feature。选择的时候需要按住shift。Translate 功能通常是指在地图上平移某个矢量对象的位置。在 OpenLayers 中,可以通过修改矢量对象的几何位置来实现这一功能。 效果图 配置方式 1)查看基础设置:http…

AnaTraf | 全面掌握网络健康状态:全流量的分布式网络性能监测系统

AnaTraf 网络性能监控系统NPM | 全流量回溯分析 | 网络故障排除工具AnaTraf网络流量分析仪是一款基于全流量,能够实时监控网络流量和历史流量回溯分析的网络性能监控与诊断系统(NPMD)。通过对网络各个关键节点的监测,收集网络性能…

麒麟v10 arm64 部署 kubesphere 3.4 修改记录

arm64环境&#xff0c;默认安装 kubesphere 3.4 &#xff0c;需要修改几个地方的镜像&#xff0c;并且会出现日志无法显示 1 fluentbit:v1.9.4 报错 <jemalloc>: Unsupported system page size Error in GnuTLS initialization: ASN1 parser: Element was not found. &…

2.3 机器学习--朴素贝叶斯(分类)

本章我们来学习朴素贝叶斯算法&#xff0c;贝叶斯分类是一类分类算法的总称&#xff0c;这类算法均以贝叶斯定理为基础&#xff0c;故统称为贝叶斯分类。而朴素贝叶斯分类是贝叶斯分类中最简单&#xff0c;也是常见的一种分类方法。 我们常常遇到这样的场景。与友人聊天时&…

合合信息亮相2024中国模式识别与计算机视觉大会,用AI构建图像内容安全防线

近日&#xff0c;第七届中国模式识别与计算机视觉大会&#xff08;简称“PRCV 2024”&#xff09;在乌鲁木齐举办。大会由中国自动化学会&#xff08;CAA&#xff09;、中国图象图形学学会&#xff08;CSIG&#xff09;、中国人工智能学会&#xff08;CAAI&#xff09;和中国计…

分享几个办公类常用的AI工具

办公类 WPS AI讯飞智文iSlideProcessOn亿图脑图ChatPPT WPS AI 金山办公推出的协同办公 AI 应用&#xff0c;具有文本生成、多轮对话、润色改写等多种功能&#xff0c;可以辅助用户进行文档编辑、表格处理、演示文稿制作等办公操作。 https://ai.wps.cn/ 讯飞智文 科大讯飞推…

我谈Canny算子

在Canny算子的论文中&#xff0c;提出了好的边缘检测算子应满足三点&#xff1a;①检测错误率低——尽可能多地查找出图像中的实际边缘&#xff0c;边缘的误检率&#xff08;将边缘识别为非边缘&#xff09;低&#xff0c;且避免噪声产生虚假边缘&#xff08;将非边缘识别为边缘…

高效文本编辑与导航:Vim中的三种基本模式及粘滞位的深度解析

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢&#xff0c;在这里我会分享我的知识和经验。&am…

有色行业测温取样机器人 - SNK施努卡

SNK施努卡有色行业熔炼车间机器人测温取样 在有色行业&#xff0c;测温取样机器人专门设计用于自动化处理高温熔体的温度监测和样品采集任务。这类机器人在铜、铝、锌等金属冶炼过程中扮演着关键角色&#xff0c;以提高生产效率、确保产品质量并增强工作安全性。 主要工作项 …

ABAP ALV

目录 一、基本概念 1、ALV概览 2、基本概念 二、属性更改 1、FIELDCAT 2、宏 3、LAYOUT 4、升序降序SORT 5、FILTER 三、交互 1、实现自己的按钮 一、基本概念 1、ALV概览 ALV&#xff1a;SAP List View,是SAP提供的一个强大的数据报表显示工具 ALV实际上是一个…

前端零基础入门到上班:【Day3】从零开始构建网页骨架HTML

HTML 基础入门&#xff1a;从零开始构建网页骨架 目录 1. 什么是 HTML&#xff1f;HTML 的核心作用 2. HTML 基本结构2.1 DOCTYPE 声明2.2 <html> 标签2.3 <head> 标签2.4 <body> 标签 3. HTML 常用标签详解3.1 标题标签3.2 段落和文本标签3.3 链接标签3.4 图…

CRLF、UTF-8这些编辑器右下角的选项的意思

经常使用编辑器的小伙伴应该经常能看到右下角会有这么两个选项&#xff0c;下图是VScode中的示例&#xff0c;那么这两个到底是啥作用呢&#xff1f; 目录 字符编码ASCII 字符集GBK 字符集Unicode 字符集UTF-8 编码 换行 字符编码 此部分参考博文 在计算机中&#xff0c;所有…