✅ 一、Batch Normalization(BN)回顾
BN 层在训练和推理阶段的行为是不一样的,核心区别就在于:
训练时用 mini-batch 里的均值方差,预测时用全局的“滑动平均”均值方差。
🧪 二、训练阶段(Training mode)
• 每个小批量(batch)都会计算:
为了后面预测用得上,训练时还会维护全局“滑动平均”:
其中是动量参数(momentum),通常为 0.9 或 0.99。
🧠 三、推理阶段(Evaluation / Inference)
推理阶段不会再计算当前 batch 的均值和方差。
而是使用训练时积累的滑动平均:
这样能保证预测过程中结果稳定、不依赖 batch 大小或数据分布波动。
🧰 四、在 PyTorch / TensorFlow 中自动切换
PyTorch:
model.train() # 启用训练模式,BN 用 batch 均值方差
model.eval() # 启用评估模式,BN 用滑动均值方差
TensorFlow (Keras):
model.fit(...) # 自动使用训练模式
model.evaluate(...) # 自动使用推理模式
📌 总结一句话:
BN 层预测时的均值和方差,来自 训练期间累计的滑动平均值,而不是实时计算。
五、补充知识:Keras 是什么
🧠 一句话定义:
Keras 是一个高级神经网络 API,用来快速搭建、训练和部署深度学习模型,底层运行在 TensorFlow 上。
📦 二、Keras 的定位
特性 | 说明 |
---|---|
高级封装 | 用几行代码就能搭建复杂模型,适合快速开发 |
基于 TensorFlow | 现在是 TensorFlow 的官方高层 API(tf.keras) |
易学易用 | 类似积木式的拼接方式,语法简洁,初学者友好 |
灵活性强 | 同时支持顺序模型(Sequential)和函数式模型(Functional API) |
支持多种任务 | 图像分类、NLP、生成模型、时间序列、强化学习等 |
支持多平台部署 | 可以导出为 SavedModel,支持 TensorFlow Serving、TFLite、ONNX、Web 等 |
⚙️ 三、简单例子(Keras 搭建一个 MLP 分类器)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
model = Sequential([
Dense(64, activation='relu', input_shape=(100,)),
Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 假设 x_train.shape = (1000, 100),y_train 是 one-hot 标签
model.fit(x_train, y_train, epochs=10, batch_size=32)
🏗 四、Keras 模型的两种构建方式
1. Sequential(顺序模型)
• 一层接一层,简单好用
model = Sequential([...])
2. Functional API(函数式模型)
• 灵活连接,适合多输入/多输出、残差连接等复杂结构
from tensorflow.keras import Model, Input
x = Input(shape=(100,))
h = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(h)
model = Model(inputs=x, outputs=y)
🔥 五、Keras 常见模块
模块 | 作用 |
---|---|
tf.keras.models | 创建模型(Sequential、Model) |
tf.keras.layers | 各种神经网络层(Dense、Conv2D、LSTM 等) |
tf.keras.optimizers | 优化器(SGD、Adam、RMSprop 等) |
tf.keras.losses | 损失函数(MSE、CrossEntropy 等) |
tf.keras.metrics | 评价指标(Accuracy、Precision 等) |
tf.keras.callbacks | 回调函数(EarlyStopping、ModelCheckpoint 等) |
📌 总结一句话:
Keras = 深度学习“乐高”,用来快速搭建模型,适合初学者,也支持复杂自定义模型,是 TensorFlow 的核心部分。