BN 层做预测的时候, 方差均值怎么算

news2025/4/16 20:51:51

✅ 一、Batch Normalization(BN)回顾

 

BN 层在训练和推理阶段的行为是不一样的,核心区别就在于:

训练时用 mini-batch 里的均值方差,预测时用全局的“滑动平均”均值方差。

🧪 二、训练阶段(Training mode)

• 每个小批量(batch)都会计算:

\mu_{\text{batch}} = \frac{1}{m} \sum_{i=1}^{m} x_i

\sigma^2_{\text{batch}} = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_{\text{batch}})^2

\hat{x}_i = \frac{x_i - \mu_{\text{batch}}}{\sqrt{\sigma^2_{\text{batch}} + \epsilon}}

为了后面预测用得上,训练时还会维护全局“滑动平均”:
\mu_{\text{running}} = \rho \cdot \mu_{\text{running}} + (1 - \rho) \cdot \mu_{\text{batch}} \sigma^2_{\text{running}} = \rho \cdot \sigma^2_{\text{running}} + (1 - \rho) \cdot \sigma^2_{\text{batch}}

其中\rho是动量参数(momentum),通常为 0.9 或 0.99。


🧠 三、推理阶段(Evaluation / Inference)

推理阶段不会再计算当前 batch 的均值和方差。

而是使用训练时积累的滑动平均

\hat{x}i = \frac{x_i - \mu{\text{running}}}{\sqrt{\sigma^2_{\text{running}} + \epsilon}}

这样能保证预测过程中结果稳定、不依赖 batch 大小或数据分布波动


🧰 四、在 PyTorch / TensorFlow 中自动切换

PyTorch:

model.train()   # 启用训练模式,BN 用 batch 均值方差
model.eval()    # 启用评估模式,BN 用滑动均值方差

TensorFlow (Keras):

model.fit(...)       # 自动使用训练模式
model.evaluate(...)  # 自动使用推理模式

📌 总结一句话:

BN 层预测时的均值和方差,来自 训练期间累计的滑动平均值,而不是实时计算。

五、补充知识:Keras 是什么


🧠 一句话定义:

Keras 是一个高级神经网络 API,用来快速搭建、训练和部署深度学习模型,底层运行在 TensorFlow 上。

📦 二、Keras 的定位

特性

说明

高级封装

用几行代码就能搭建复杂模型,适合快速开发

基于 TensorFlow

现在是 TensorFlow 的官方高层 API(tf.keras)

易学易用

类似积木式的拼接方式,语法简洁,初学者友好

灵活性强

同时支持顺序模型(Sequential)和函数式模型(Functional API)

支持多种任务

图像分类、NLP、生成模型、时间序列、强化学习等

支持多平台部署

可以导出为 SavedModel,支持 TensorFlow Serving、TFLite、ONNX、Web 等


⚙️ 三、简单例子(Keras 搭建一个 MLP 分类器)

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

model = Sequential([
    Dense(64, activation='relu', input_shape=(100,)),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 假设 x_train.shape = (1000, 100),y_train 是 one-hot 标签
model.fit(x_train, y_train, epochs=10, batch_size=32)

🏗 四、Keras 模型的两种构建方式

1. Sequential(顺序模型)

• 一层接一层,简单好用

model = Sequential([...])

2. Functional API(函数式模型)

• 灵活连接,适合多输入/多输出、残差连接等复杂结构

from tensorflow.keras import Model, Input
x = Input(shape=(100,))
h = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(h)
model = Model(inputs=x, outputs=y)

🔥 五、Keras 常见模块

模块

作用

tf.keras.models

创建模型(Sequential、Model)

tf.keras.layers

各种神经网络层(Dense、Conv2D、LSTM 等)

tf.keras.optimizers

优化器(SGD、Adam、RMSprop 等)

tf.keras.losses

损失函数(MSE、CrossEntropy 等)

tf.keras.metrics

评价指标(Accuracy、Precision 等)

tf.keras.callbacks

回调函数(EarlyStopping、ModelCheckpoint 等)


📌 总结一句话:

Keras = 深度学习“乐高”,用来快速搭建模型,适合初学者,也支持复杂自定义模型,是 TensorFlow 的核心部分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2329740.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JS 其他事件类型

页面加载 事件 window.addEvent() window.addEventListener(load,function(){const btn document.querySelector(button)btn.addEventListener(click,function(){alert(按钮)})})也可以给其他标签加该事件 HTML加载事件 找html标签 也可以给页面直接赋值

AI Agent设计模式五:Orchestrator

概念 :中央任务调度中枢 ✅ 优点:全局资源协调,确保任务执行顺序❌ 缺点:单点故障风险,可能成为性能瓶颈 import operator import osfrom langchain.schema import SystemMessage, HumanMessage from langchain_opena…

MySQL基础 [三] - 数据类型

目录 数据类型分类 ​编辑 数值类型 tinyint bit 浮点类型 float decimal 字符串类型 char varchar varchar和char的比较和选择 日期和时间类型 enum和set enum类型 set类型 enum和set的类型查找 数据类型分类 数值类型 tinyint TINYINT[(M)] [UNSIGNED]是 …

不用训练,集成多个大模型产生更优秀的输出

论文标题 Collab: Controlled Decoding using Mixture of Agents for LLM Alignment 论文地址 https://arxiv.org/pdf/2503.21720 作者背景 JP摩根,马里兰大学帕克分校,普林斯顿大学 动机 大模型对齐(alignment)的主要目的…

随笔1 认识编译命令

1.认识编译命令 1.1 解释gcc编译命令: gcc test1.cpp -o test1 pkg-config --cflags --libs opencv 命令解析: gcc:GNU C/C 编译器,用于编译C/C代码。 test1.cpp:源代码文件。 -o test1:指定输出的可执行文件名为t…

Hyperlane 框架路由功能详解:静态与动态路由全掌握

Hyperlane 框架路由功能详解:静态与动态路由全掌握 Hyperlane 框架提供了强大而灵活的路由功能,支持静态路由和动态路由两种模式,让开发者能够轻松构建各种复杂的 Web 应用。本文将详细介绍这两种路由的使用方法。 静态路由:简单…

铰链损失函数 Hinge Loss和Keras 实现

一、说明 在为了了解 Keras 深度学习框架的来龙去脉,本文介绍铰链损失函数,然后使用 Keras 实现它们以进行练习并了解它们的行为方式。在这篇博客中,您将首先找到两个损失函数的简要介绍,以确保您在我们继续实现它们之前直观地理解…

瑞数信息发布《BOTS自动化威胁报告》,揭示AI时代网络安全新挑战

近日,瑞数信息正式发布《BOTS自动化威胁报告》,力求通过全景式观察和安全威胁的深度分析,为企业在AI时代下抵御自动化攻击提供安全防护策略,从而降低网络安全事件带来的影响,进一步增强业务韧性和可持续性。 威胁一&am…

FLV格式:流媒体视频的经典选择

FLV格式:流媒体视频的经典选择 FLV(Flash Video)格式曾经是流媒体视频的主力军,在互联网视频的早期时代广泛应用于视频网站和多媒体平台。凭借其高效的压缩和较小的文件体积,FLV成为了许多视频内容创作者和平台的首选…

需求分析-用例图绘制、流程图绘制

第一,引论 需求分析是开发的第一步,也是我个人认为最重要的一步。 技术难题的克服,甚至在我心里,还要排在需求分析后面。 如果需求分析做好了,数据库就更容易建立,数据库建好了,业务逻辑写起…

Windows安装 PHP 8 和mysql9,win下使用phpcustom安装php8.4.5和mysql9

百度搜索官网并下载phpcustom,然后启动环境,点击网站管理 里面就有php8最新版,可以点mysql设置切mysql9最新版,如果你用最新版无法使用,说明你的php程序不支持最新版的mysql MySQL 9.0 引入了一些新的 SQL 模式和语法变…

http://noi.openjudge.cn/_2.5基本算法之搜索_1804:小游戏

文章目录 题目深搜代码宽搜代码深搜数据演示图总结 题目 1804:小游戏 总时间限制: 1000ms 内存限制: 65536kB 描述 一天早上,你起床的时候想:“我编程序这么牛,为什么不能靠这个赚点小钱呢?”因此你决定编写一个小游戏。 游戏在一…

手写JSX实现虚拟DOM

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》、《前端求职突破计划》 🍚 蓝桥云课签约作者、…

Spring Boot 中的 Bean

2025/4/6 向全栈工程师迈进&#xff01; 一、Bean的扫描 在之前&#xff0c;对于Bean的扫描&#xff0c;我们可以在XML文件中书写标签&#xff0c;来指定要扫描的包路径&#xff0c;如下所示,可以实通过如下标签的方式&#xff1a; <context:component-scan base-package&…

ST 芯片架构全景速览:MCU、无线 SoC、BLE 模块、MPU 差异详解

在嵌入式开发中,ST 是一个非常常见的芯片厂商,其产品线覆盖了 MCU、无线芯片、BLE 模块以及运行 Linux 的 MPU 等多个领域。很多开发者初次接触 ST 时会对这些产品之间的关系感到困惑。 本文从分类视角出发,带你快速了解 ST 芯片家族的核心架构和主要用途。 🧭 ST 芯片四…

AtCoder Beginner Contest 400(ABCDE)

A - ABC400 Party 翻译&#xff1a; 在 ABC400 的纪念仪式上&#xff0c;我们想把 400 人排成 A 行 B 列的长方形&#xff0c;且不留任何空隙。 给你一个正整数 A&#xff0c;请打印可以这样排列的正整数 B 的值。如果没有这样的正整数 B&#xff0c;则打印-1。 思路&#xff…

Flask+Vue构建图书管理系统及Echarts组件的使用

教程视频链接从零开始FlaskVue前后端分离图书管理系统 后端 项目下载地址 其中venv为该项目的虚拟环境&#xff0c;已安装所有依赖 使用方法&#xff1a; 在pycharm终端中flask create一下&#xff08;因为写了一个自定义命令的代码&#xff09;&#xff0c;初始化books数据…

【项目管理】第2章 信息技术发展 --知识点整理

Oracle相关文档,希望互相学习,共同进步 风123456789~-CSDN博客 (一)知识总览 对应:第1章-第5章 (二)知识笔记 二、信息技术的发展 1. 信息技术及其发展 1)计算机软硬件 计算机硬件由电子机械、光电元件等组成的物理装置,提供物质基础给计算机软件运行。软件包括程…

4-c语言中的数据类型

一.C 语⾔中的常量 1.生活中的数据 整数&#xff1a; 100,200,300,400,500 小数: 11.11 22.22 33.33 字母&#xff1a; a&#xff0c;b&#xff0c;c&#xff0c;d A&#xff0c;B&#xff0c;C&#xff0c;D 在 C 语⾔中我们把字⺟叫做字符. 字符⽤单引号引⽤。例如A’ 单词…

LORA+llama模型微调全流程

LORAllama.cpp模型微调全流程 准备阶段 1.下载基础大模型 新建一个download.py脚本 from modelscope import snapshot_download#模型存放路径 model_path /root/autodl-tmp #模型名字 name itpossible/Chinese-Mistral-7B-Instruct-v0.1 model_dir snapshot_download(na…