【标准化方法】(4) Weight Normalization 原理解析、代码复现,附Pytorch代码

news2024/11/16 19:26:49

今天和各位分享一下深度学习中常用的归一化方法,权重归一化(Weight Normalization, WN),通过理论解析,用 Pytorch 复现一下代码。

Weight Normalization 的论文地址如下:https://arxiv.org/pdf/1903.10520.pdf


1. 原理解析

权重归一化(Weight  Normalization,WN)选择对神经网络的权值向量 W 进行参数重写,参数化权重改善条件最优问题来加速收敛,灵感来自批归一化算法,但是并不像批归一化算法一样依赖于批次大小,不会对梯度增加噪声且计算量很小。权重归一化成功用于 LSTM 和对噪声敏感的模型,如强化学习和生成模型。

对深度学习网络权值 W 进行归一化的操作公式如下:

 

w = \frac{g}{||v||} v

通过一个 k 维标量 g 和一个向量 V 对权重向量 W 进行解耦合。标量 g=||W|| ,即权重 W 的大小,||v|| 表示 v 的欧几里得范数(二范数)。

作者提出对参数 v,g 直接重新参数化然后执行新的随机梯度下降,并且认为通过将权重向量(g)的范数与(\frac{v}{||v||})的方向解耦,加速了随机梯度下降的收敛

假设代价函数记为 L,此时的深度学习网络权值的梯度计算公式为:

\Delta_{_g}L=\Delta_{_w}L\cdot\Delta_{_g}W=\frac{\Delta_{_w}L\cdot\nu}{||\nu||}

M_w=I-\frac{ww'}{||w||^2},其中 M_w 是投影矩阵。梯度计算可以写成\Delta_{_v}L=\frac{g}{||v||}\cdot M_{_w}\Delta_{_w}L

\frac{||\Delta v||}{||v||} = c当梯度噪声大时,c 变大,有 \|v'\|=(\|v\|^2+c^2\|v\|^2)^{1/2}>\|v\|,则 \Delta_{v'}L 变小。

当梯度较小时,c 变小趋于0,有 \|v'\|=(\|v\|^2+c^2\|v\|^2)^{1/2} \approx \|v\|。即:权重归一化 WN 使用这种机制做到梯度稳定。另外,作者也发现 ||v|| 对学习率有很强的鲁棒性。

WN 不像 BN 还具有固定神经网络各层产生的特征尺度的好处,WN 需要小心的参数初始化给 v 的范数设定一个范围(正态分布均值为零,标准差为 0.05),这样虽然延长了参数更新的时间,但收敛后的测试性能会比较好。

t = \frac{v \cdot x}{||v||},仅在初始化期间取 g\leftarrow\frac{1}{\sigma[t]},b\leftarrow\frac{-\mu[t]}{\sigma[t]}

可以得到应用 WN 后,

\begin{aligned} & y=\phi(w\cdot x+b) \\ &=\phi(g\cdot{\frac{v}{||v||}}x+b) \\ &=\phi(\frac{1}{\sigma[t]}\cdot\frac{v}{||v||}x-\frac{\mu[t]}{\sigma[t]}) \\ &=\phi(\frac{t-\mu[t]}{\sigma[t]}) \end{aligned}

由上式可得,当 WN 进行参数初始化时可以在一开始达到和 BN 相同的作用。


2. 代码演示

这里以《Micro-Batch Training with Batch-Channel Normalization and Weight Standardization》这篇文章中的权重归一化方法为例,展示一下代码,比较简单,只需要对权重文件的每个通道做归一化处理。示意图如下。

import torch

def WS(weight:torch.Tensor, eps:float):
    # 权重shape=[c_out, c_in, k_h, k_w]
    c_out, c_in, *kernel_shape = weight.shape
    # [c_out, c_in, k_h, k_w]-->[c_out, c_in*k_h*k_w]
    weight = weight.view(c_out, -1)
    # 计算 [c_in*k_h*k_w] 维度上的均值和方差 --> [c_out,1]
    var, mean = torch.var_mean(weight, dim=1, keepdim=True)
    # 权重标准化
    weight = (weight-mean) / torch.sqrt(var+eps)
    # [c_out, c_in*k_h*k_w]-->[c_out, c_in, k_h, k_w]
    return weight.view(c_out, c_in, *kernel_shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/506869.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GEE:基于主成分分析(PCA)的风险筛选环境指标(RSEI)计算方法

作者:CSDN @ _养乐多_ 利用主成分分析(Principal Component Analysis, PCA)进行风险筛选环境指标(Risk-Screening Environmental Indicators, RSEI)的计算是一种常用的方法。本文介绍了基于主成分分析的RSEI计算方法,通过将多个基于遥感指数的环境指标转化为少数几个主成…

电脑cpu占用率高?怎么办?1分钟快速解决!

案例:电脑cup过高怎么办? 【我的电脑运行缓慢,导致我学习和工作的效率很低。刚刚查看了一下电脑,发现它的cpu占用率很高。有没有小伙伴知道如何解决此电脑cpu过高的问题?】 电脑是我们生活中不可缺少的工具&#xff…

Linux 多线程(1)线程概念与线程控制

多线程:概念、线程控制(创建、终止、等待、分离),线程安全(问题&实现),应用(生产者与消费者模型,线程池,单例模式) (重要&#xf…

linux系统(进程间通信)06_IPC概念,pipe管道,fifo通信,mmap 共享映射区

01 学习目标 1.熟练使用pipe进行父子进程间通信 2.熟练使用pipe进行兄弟进程间通信 3.熟练使用fifo进行无血缘关系的进程间通信 4.熟练掌握mmap函数的使用 5.掌握mmap创建匿名映射区的方法 6.使用mmap进行有血缘关系的进程间通信 7.使用mmap进行无血缘关系的进程间通信 02 IPC概…

Netfilter和iptables命令详解,从入门到精通

本文目录 1、netfilter架构和工作原则简介2、iptables操作命令说明2.1 、Filtering Specifications2.2、Target Specifications2.3、一个基于Linux的基本的防火墙的配置例子 netfilter 是Linux内核里网络部分的一个重要框架,内核通过netfilter完成IP报文的一些操作。…

缓存雪崩问题

缓存雪崩:指在同一时段大量的缓存key同时失效或者Redis服务宕机,导致大量的请求到达数据库,带来巨大的压力 解决方案: 1.给不同的key的TTL添加随机值 2.利用redis集群提高服务的可用性 3.给缓存业务添加降级限流策略 4.给业务添…

扫雷,咱就是一扫一大片(C语言完美递归版)

🤩本文作者:大家好,我是paperjie,感谢你阅读本文,欢迎一建三连哦。 🥰内容专栏:这里是《C语言》专栏,笔者用重金(时间和精力)打造,基础知识一网打尽,希望可以…

零基础入门 Stable Diffusion - 无需显卡把 AI 绘画引擎搬进家用电脑

我从小特别羡慕会画画的伙伴。他们能够将心中的想法画出来,而我最高水平的肖像画是丁老头。但在接触 Stable Diffusion 之后,我感觉自己脱胎换骨,给自己贴上了「会画画」的新标签。 丁老头进化旅程 Stable Diffusion 是一个「文本到图像」的…

区间预测 | MATLAB实现QRLSTM长短期记忆神经网络分位数回归时间序列区间预测

区间预测 | MATLAB实现QRLSTM长短期记忆神经网络分位数回归时间序列区间预测 目录 区间预测 | MATLAB实现QRLSTM长短期记忆神经网络分位数回归时间序列区间预测效果一览基本介绍模型描述程序设计参考资料 效果一览 进阶版 基础版 基本介绍 MATLAB实现QRLSTM长短期记忆神经…

微波方向有哪些SCI期刊推荐? - 易智编译EaseEditing

微波方向的SCI期刊推荐包括: IEEE Transactions on Microwave Theory and Technology: 该期刊是电磁场与微波技术领域的著名期刊,被世界上许多研究机构和大学广泛引用。 IEEE Transactions on Antennas and Propagation: 该期刊…

C++学习记录——이십일 AVL树

文章目录 1、了解AVL树2、模拟实现3、旋转1、左单旋2、右单旋3、双旋(先左后右)4、双旋(先右后左) 4、检查平衡5、测试性能(随机数)6、删除 1、了解AVL树 如果数据有序或接近有序,二叉搜索树将…

Java+Python+Paddle提取长文本文章中词频,用于Echart词云图数据

公司有个需求,就是需要提供给echart词云图的数据,放在以前我们的数据来源都是从产品那直接要,产品也是跑的别的接口,那怎么行呢,当然有自己的一套可以随便搞了,那么操作来了 Java package cn.iocoder.yud…

推荐几款2023年还在用的IDE工具

近期有不少刚学编程的小伙伴来问我,市面上那么多IDE工具,该怎么选?今天在这里跟大家分享几款个人比较钟爱的IDE工具,供大家参考。 Visual Studio 优点:支持多种语言,包括C#, C, Visual Basic等&#xff0c…

【Linux】进程信号“疑问?坤叫算信号吗?“

鸡叫当然也算信号啦~ 文章目录 前言一、认识信号量二、信号的产生 1.调用系统函数向进程发信号2.由软件条件产生信号3.硬件异常产生信号总结 前言 信号在我们生活中很常见,下面我们举一举生活中信号的例子: 你在网上买了很多件商品,再等待不…

【跟着陈七一起学C语言】今天总结:函数、数组、指针之间的关系

友情链接:专栏地址 知识总结顺序参考C Primer Plus(第六版)和谭浩强老师的C程序设计(第五版)等,内容以书中为标准,同时参考其它各类书籍以及优质文章,以至减少知识点上的错误&#x…

深度学习实战29-AIGC项目:利用GPT-2(CPU环境)进行文本续写与生成歌词任务

大家好,我是微学AI,今天给大家介绍一下深度学习实战29-AIGC项目:利用GPT-2(CPU环境)进行文本续写与生成歌词任务。在大家没有GPU算力的情况,大模型可能玩不动,推理速度慢,那么我们怎么才能跑去生成式的模型…

14 KVM虚拟机配置-配置虚拟设备(其它常用设备)

文章目录 14 KVM虚拟机配置-配置虚拟设备(其它常用设备)14.1 概述14.2 元素介绍14.3 配置示例 14 KVM虚拟机配置-配置虚拟设备(其它常用设备) 14.1 概述 除存储设备、网络设备外,XML配置文件中还需要指定一些其他外部…

Python+selenium,轻松搭建 Web 自动化测试框架

在程序员的世界中,一切重复性的工作,都应该通过程序自动执行。「自动化测试」就是一个最好的例子。 随着互联网应用开发周期越来越短,迭代速度越来越快,只会点点点,不懂开发的手工测试,已经无法满足如今的…

云渲染靠谱吗,使用云渲染会不会被盗作品?

云渲染靠谱吗、安全吗?如果使用 云渲染会不会被盗作品......Renderbus瑞云渲染作为一个正经的云渲染平台,也时不时会收到这类疑问,首先,瑞云渲染是肯定靠谱的,各位可以放心使用。另外小编也将在本篇教你如何辨别云渲染平台是否安全…

通达信W底形态选股公式,也称双底形态

W底形态,也称双底形态,是一种经典的技术分析形态,代表了跌势的逆转。看起来像字母 "W",描述了一波下跌,反弹,再次下跌到与上一波下跌相同或相近的位置,最后是另一波反弹。W底形态两次…