【PyTorch】权重衰减

news2025/1/21 5:51:44

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现

1. 理论介绍

  • 通过对模型过拟合的思考,人们希望能通过某种工具调整模型复杂度,使其达到一个合适的平衡位置。
  • 权重衰减(又称 L 2 L_2 L2正则化)通过为损失函数添加惩罚项,用来惩罚权重的 L 2 L_2 L2范数,从而限制模型参数值,促使模型参数更加稀疏或更加集中,进而调整模型的复杂度,即: L ( w , b ) + λ 2 ∥ w ∥ 2 L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2 L(w,b)+2λw2其中 λ \lambda λ权重衰减的超参数
  • L p L_p Lp范数: ∥ x ∥ p = ( ∑ i = 1 n ∣ x i ∣ p ) 1 / p \|\mathbf{x}\|_p = \left(\sum_{i=1}^n \left|x_i \right|^p \right)^{1/p} xp=(i=1nxip)1/p
    p = 1 p=1 p=1时称为 L 1 L_1 L1范数;当 p = 2 p=2 p=2时称为 L 2 L_2 L2范数。
    惩罚 L 1 L_1 L1范数会导致模型将权重集中在一小部分特征上, 而将其他权重清除为零, 这称为特征选择;惩罚 L 2 L_2 L2范数会导致模型在大量特征上均匀分布权重,使得模型对单个变量的观测误差更为稳定。
  • 通常不建议对偏置进行正则化,因为偏置的取值并不像权值那样会随着训练过程而变化,因此对偏置进行正则化对于控制模型的复杂度影响较小;另外,对偏置进行正则化可能会导致对数据中的偏移进行过度拟合,而减弱了模型对其他特征的学习。

2. 实例解析

2.1. 实例描述

使用以下公式生成包含20个样本的小训练集和100个样本的测试集,并用线性网络进行拟合: y = 0.05 + ∑ i = 1 200 0.01 x i + ϵ  where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^{200} 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=12000.01xi+ϵ where ϵN(0,0.012).

2.2. 代码实现

  • 主要代码
optimizer = optim.SGD([
            {"params": net.weight,"weight_decay": weight_decay},
            {"params": net.bias}
            ], lr=lr)
  • 完整代码
import os
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from tensorboardX import SummaryWriter
from rich.progress import track

def data_generator(w, b, num):
    """为线性模型生成数据"""
    X = torch.randn(num, len(w))
    y = torch.sum(X @ w, dim=1) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape(-1, 1)

def load_dataset(*tensors):
    """加载数据集"""
    dataset = TensorDataset(*tensors)
    return DataLoader(dataset, batch_size, shuffle=True)

def evaluate_loss(dataloader, net, criterion):
    """评估模型在指定数据集上的损失"""
    num_examples = 0
    loss_sum = 0.0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.cuda(), y.cuda()
            loss = criterion(net(X), y)
            num_examples += y.shape[0]
            loss_sum += loss.sum()
        return loss_sum / num_examples


if __name__ == '__main__':
    # 全局参数设置
    lr = 0.003
    num_epochs = 100
    batch_size = 5

    # 创建记录器
    def log_dir():
        root = "runs"
        if not os.path.exists(root):
            os.mkdir(root)
        order = len(os.listdir(root)) + 1
        return f'{root}/exp{order}'
    writer = SummaryWriter(log_dir=log_dir())
    
    # 合成数据集
    num_inputs = 200
    n_train, n_test = 20, 100
    true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
    X, y = data_generator(true_w, true_b, n_train + n_test)

    # 加载数据集
    dataloader_train = load_dataset(X[:n_train], y[:n_train])
    dataloader_test = load_dataset(X[n_train:], y[n_train:])

    def loop(weight_decay):
        # 定义模型
        net = nn.Linear(num_inputs, 1).cuda()
        nn.init.normal_(net.weight)
        nn.init.constant_(net.bias, 0)
        criterion = nn.MSELoss(reduction='none')
        optimizer = optim.SGD([
            {"params": net.weight,"weight_decay": weight_decay},
            {"params": net.bias}
            ], lr=lr)

        # 训练循环
        for epoch in track(range(num_epochs), description=f'wd={weight_decay}'):
            for X, y in dataloader_train:
                X, y = X.cuda(), y.cuda()
                loss = criterion(net(X), y)
                optimizer.zero_grad()
                loss.mean().backward()
                optimizer.step()
            writer.add_scalars(f'wd={weight_decay}', {
                'train_loss': evaluate_loss(dataloader_train, net, criterion),
                'test_loss': evaluate_loss(dataloader_test, net, criterion),
            }, epoch)


    for weight_decay in [0, 3]:
        loop(weight_decay)
    writer.close()
  • 输出结果
    • weight_decay = 0
      0
    • weight_decay = 3
      3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1290352.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用23种设计模式打造一个cocos creator的游戏框架----(七)代理模式

1、模式标准 模式名称:代理模式 模式分类:结构型 模式意图:为其他对象提供一种代理以控制对这个对象的访问。 结构图: ​ 适用于: 远程代理:也称为大使,这是最常见的类型,在分…

中文BERT模型预训练参数总结以及转化为pytorch的方法

1.目前针对中文的bert预训练模型有三家: 谷歌发布的chinese_L-12_H-768_A-12 还有哈工大的chinese-bert-wwm / chinese-bert-wwm-ext 以及HuggingFace上的bert-base-chinese(由清华大学基于谷歌的BERT在中文数据集上训练开发的模型,上传在HuggingFace) …

ElasticSearch篇---第四篇

系列文章目录 文章目录 系列文章目录前言一、elasticsearch 是如何实现 master 选举的?二、elasticsearch 索引数据多了怎么办,如何调优,部署?三、说说你们公司 es 的集群架构,索引数据大小,分片有多少?前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽…

李宏毅gpt个人记录

参考: 李宏毅机器学习--self-supervised:BERT、GPT、Auto-encoder-CSDN博客 用无标注资料的任务训练完模型以后,它本身没有什么用,GPT 1只能够把一句话补完,可以把 Self-Supervised Learning 的 Model做微微的调整&am…

【改进YOLOv8】融合感受野注意力卷积RFCBAMConv的杂草分割系统

1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 研究背景与意义 随着计算机视觉技术的不断发展,图像分割成为了一个重要的研究领域。图像分割可以将图像中的不同对象或区域进行有效的分离,对于许多应用领…

shell脚本生成随机双色球号码

[rootcentos7 ~]#cat lottery.sh #!/bin/bash #定义零长度数组 arr() length${#arr[]} while [ "${length}" -lt 6 ]do#取1到33的随机数s$[$RANDOM%331]#判断随机数是否在数组中,不在就赋值给数组if [[ ! "${arr[]}" ~ "${s}" ]]then…

什么是神经网络的超参数

1 引言 超参数在神经网络的设计和训练中起着至关重要的作用。它们是在开始训练之前设置的参数,与网络的结构、训练过程和优化算法有关。正确的超参数选择对于达到最优模型性能至关重要。 2 神经网络结构的超参数 层数(Layers): 决…

pyecharts可视化作图1:基金净值-折线图

近期,接触到pyecharts模块,感觉其在可视化作图上比较强大,虽然无法和前端页面相比,但对于基础的数据展示,可以轻松处理。 本期主要以基金净值走势为案例,绘制相应的折线图,由于该模块较为简单&a…

多用户商城系统支付模块 用户支付的钱到哪里去了

多用户商城系统是类似京东天猫的电商平台,用户一般使用微信或者支付宝支付,在购买商品或服务支付后,商家发货或提供服务后,平台需要将钱结算给提供商品或者服务的商户。 这时会涉及平台和商户的结算问题,一般有两种解决…

【Qt开发流程】之对象模型3:对象树及其所有权

描述 Qt对象树是一种基于父子关系的对象管理机制,用于管理Qt应用程序中的所有对象。在Qt中,每个对象都可以拥有一个或多个子对象,并且每个子对象只能属于一个父对象。每个对象的所有权(也称为生存期)由其父对象控制。…

LangChain学习一:模型-实战

文章目录 上一节内容学习目标:模型(models)学习内容一:模型分类学习内容二:不同模型实战3.1 Chat-聊天模型3.1.1 声明3.1.2 Chat-聊天类型实战3.1.2.1 AIMessage(AI 消息)3.1.2.2 HumanMessage&…

力扣46. 全排列(java 回溯法)

Problem: 46. 全排列 文章目录 题目描述思路解题方法复杂度Code 题目描述 给定一个不含重复数字的数组 nums ,返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 思路 1.该题目要求求出一个数组的全排列,我们可以利用回溯模拟出一个对数组中所有…

9_企业架构队列缓存中间件分布式Redis

企业架构队列缓存中间件分布式Redis 学习目标和内容 1、能够描述Redis作用及其业务适用场景 2、能够安装配置启动Redis 3、能够使用命令行客户端简单操作Redis 4、能够实现操作基本数据类型 5、能够理解描述Redis数据持久化机制 6、能够操作安装php的Redis扩展 7、能够操作实现…

maven生命周期回顾

目录 文章目录 **目录**两种最常用打包方法:生命周期: 两种最常用打包方法: 1.先 clean,然后 package2.先 clean,然后install 生命周期: 根据maven生命周期,当你执行mvn install时&#xff0c…

JAVA IO:NIO

1.阻塞 IO 模型 ​ 最传统的一种 IO 模型,即在读写数据过程中会发生阻塞现象。当用户线程发出 IO 请求之后,内核会去查看数据是否就绪,如果没有就绪就会等待数据就绪,而用户线程就会处于阻塞状态,用户线程交出 CPU。当…

Unity 简单打包脚本

打包脚本 这个打包脚本适用于做demo,脚本放在Editor目录下 using System; using System.Collections; using System.Collections.Generic; using System.IO; using UnityEditor; using UnityEngine;public class BuildAB {[MenuItem("Tools/递归遍历文件夹下…

构建第一个事件驱动型 Serverless 应用

我相信,我们从不缺精彩的应用创意,我们缺少的把这些想法变成现实的时间和付出。 我认为,无服务器技术真的有助于最大限度节省应用开发和部署的时间,并且无服务器技术用可控的成本,实现了我的那些有趣的想法。 在我 2…

kali学习

目录 黑客法则: 一:页面使用基础 二:msf和Windows永恒之蓝漏洞 kali最强渗透工具——metasploit 介绍 使用永恒之蓝进行攻击 ​编辑 使用kali渗透工具生成远程控制木马 渗透测试——信息收集 域名信息收集 黑客法则: 一&…

Java架构师系统架构设计原则应用

目录 1 导语2 如何设计高并发系统:局部并发原则3 如何设计高并发系统:服务化与拆分4 高可用系统有哪些设计原则?5 如何保持简单轻量的架构-DRY、KISS,YAGNI原则6 如何设计组件间的交互和行为-HCLC,CQS,SOC7 框架层面的发展趋势-约定大于配置想学习架构师构建流程请跳转:…

有源滤波器在矿区配电网中的应用

针对目前有源滤波器应用于矿区谐波治理时电网频率适应能力较低的问题,针对定采样点数字控制系统提出了一种具有频率自适应能力的谐振控制策略。该策略不仅可以实现对电网频率波动的自适应,提高滤波器补偿效果,而且不需要在线对控制器参数进行…