§1.3.29

BatchNorm/LayerNorm/RMSNorm/GroupNorm/DeepNorm 的公式与差异?

手写练习
  • 手写 RMSNorm

核心概念

归一化层(Normalization Layers)是深度学习模型中的一类关键组件,其核心思想是通过对层输入或激活值进行重新中心化(re-centering)和重新缩放(re-scaling)来规范化其数据分布。这有助于解决深度网络训练过程中的“内部协变量偏移”(Internal Covariate Shift)问题,从而稳定训练过程、加快收敛速度、支持使用更高的学习率,并降低模型对初始化参数的敏感度。不同的归一化方法,如 BatchNormLayerNorm 等,其主要区别在于计算均值和方差的维度集合不同。

原理与推导

所有主流的归一化技术都可以被抽象为一个两步过程:

  1. 标准化 (Standardization): 将输入 x 的某个集合 S 内的元素,调整为均值为 0,方差为 1 的分布。 x^i=xiμSσS2+ϵ\hat{x}_i = \frac{x_i - \mu_S}{\sqrt{\sigma_S^2 + \epsilon}} 其中 μS=1SiSxi\mu_S = \frac{1}{|S|} \sum_{i \in S} x_i 是集合 S 的均值,σS2=1SiS(xiμS)2\sigma_S^2 = \frac{1}{|S|} \sum_{i \in S} (x_i - \mu_S)^2 是集合 S 的方差,ϵ\epsilon 是一个很小的正数(如 1e-5)以防止除以零。

  2. 仿射变换 (Affine Transformation): 引入两个可学习的参数,缩放因子 γ\gamma 和平移因子 β\beta,来恢复网络的表达能力。 yi=γx^i+βy_i = \gamma \hat{x}_i + \beta 如果网络认为不需要归一化,它可以学习到 γ=σS2+ϵ\gamma = \sqrt{\sigma_S^2 + \epsilon}β=μS\beta = \mu_S,从而将输入恒等映射回来。

下面我们详细分析各种归一化方法在计算均值和方差的集合 S 上的差异。假设我们有一个 mini-batch 的输入张量,其形状对于 CNN 是 (N, C, H, W),对于 Transformer(N, L, D),其中 N 是批大小,C 是通道数,H, W 是高和宽,L 是序列长度,D 是特征维度。


1. Batch Normalization (BN)

  • 归一化维度: 对每个特征通道(channel),在整个批次(batch)内进行归一化。
  • 计算集合 S: 对于 (N, C, H, W) 的输入,BN 为 C 个通道中的每一个通道,独立计算其均值和方差。计算 S 的范围是 (N, H, W) 维度。这意味着均值和方差与样本有关,但与特征图的空间位置无关。
  • 数学公式: μc=1NHWn=1Nh=1Hw=1Wxn,c,h,w\mu_c = \frac{1}{N \cdot H \cdot W} \sum_{n=1}^N \sum_{h=1}^H \sum_{w=1}^W x_{n,c,h,w} σc2=1NHWn=1Nh=1Hw=1W(xn,c,h,wμc)2\sigma_c^2 = \frac{1}{N \cdot H \cdot W} \sum_{n=1}^N \sum_{h=1}^H \sum_{w=1}^W (x_{n,c,h,w} - \mu_c)^2 x^n,c,h,w=xn,c,h,wμcσc2+ϵ\hat{x}_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}} yn,c,h,w=γcx^n,c,h,w+βcy_{n,c,h,w} = \gamma_c \hat{x}_{n,c,h,w} + \beta_c
  • 注意: γ\gammaβ\beta 的维度是 (C,),即每个通道有一对独立的参数。
  • 几何解释: 将不同样本在同一特征通道上的值拉到相似的分布范围内,减少了样本间的差异性对模型的影响。
  • 复杂度: 训练时需要计算批次统计量,推理时使用训练期间累积的滑动平均均值和方差。时间/空间复杂度与输入大小成正比。

2. Layer Normalization (LN)

  • 归一化维度: 对每个样本,在所有特征通道(或嵌入维度)上进行归一化。
  • 计算集合 S: 对于 (N, L, D) 的输入(如 Transformer),LN 为 N 个样本中的每一个样本,独立计算其均值和方差。计算 S 的范围是 (L, D) 维度(或仅 D 维度,取决于具体实现)。它与批次大小完全无关。
  • 数学公式 (以 Transformer 中常见形式为例,对最后一个维度 D 归一化): μn,l=1Dd=1Dxn,l,d\mu_{n,l} = \frac{1}{D} \sum_{d=1}^D x_{n,l,d} σn,l2=1Dd=1D(xn,l,dμn,l)2\sigma_{n,l}^2 = \frac{1}{D} \sum_{d=1}^D (x_{n,l,d} - \mu_{n,l})^2 x^n,l,d=xn,l,dμn,lσn,l2+ϵ\hat{x}_{n,l,d} = \frac{x_{n,l,d} - \mu_{n,l}}{\sqrt{\sigma_{n,l}^2 + \epsilon}} yn,l,d=γdx^n,l,d+βdy_{n,l,d} = \gamma_d \hat{x}_{n,l,d} + \beta_d
  • 注意: γ\gammaβ\beta 的维度是 (D,),即每个特征维度有一对独立的参数。
  • 几何解释: 将单个样本的所有特征拉到同一个固定的分布上,使得模型对特征的整体缩放不敏感。
  • 复杂度: 复杂度与输入大小成正比,但计算完全在单个样本内完成,与批次大小无关。

3. Group Normalization (GN)

  • 归一化维度: 对每个样本,将通道分成若干组(Group),在每个组内进行归一化。
  • 计算集合 S: GN 是 BN 和 LN 的折衷。对于 (N, C, H, W) 的输入,它首先将 C 个通道分成 G 组,每组有 C/G 个通道。然后对每个样本的每个组内进行归一化。计算 S 的范围是 (C/G, H, W) 维度。
  • 数学公式: 假设通道 c 属于第 g 组, μn,g=1(C/G)HWcgroupgh=1Hw=1Wxn,c,h,w\mu_{n,g} = \frac{1}{(C/G) \cdot H \cdot W} \sum_{c \in \text{group}_g} \sum_{h=1}^H \sum_{w=1}^W x_{n,c,h,w} σn,g2=1(C/G)HWcgroupgh=1Hw=1W(xn,c,h,wμn,g)2\sigma_{n,g}^2 = \frac{1}{(C/G) \cdot H \cdot W} \sum_{c \in \text{group}_g} \sum_{h=1}^H \sum_{w=1}^W (x_{n,c,h,w} - \mu_{n,g})^2 x^n,c,h,w=xn,c,h,wμn,gσn,g2+ϵ\hat{x}_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_{n,g}}{\sqrt{\sigma_{n,g}^2 + \epsilon}} yn,c,h,w=γcx^n,c,h,w+βcy_{n,c,h,w} = \gamma_c \hat{x}_{n,c,h,w} + \beta_c
  • 注意: 当 G=1 时,GN 约等于 LN;当 G=C 时,GN 变为 Instance Normalization。
  • 几何解释: 假设在通道维度上存在局部的相关性,因此将相关的通道分为一组共同归一化是合理的。
  • 复杂度: 与 LN 类似,与批次大小无关。

4. RMS Normalization (RMSNorm)

  • 归一化维度: 与 LN 类似,但移除了重新中心化(re-centering)步骤,即不减去均值。
  • 核心思想: 论文作者认为,对于稳定训练而言,重新缩放(re-scaling)比重新中心化更重要。移除均值计算可以简化计算,提升速度。
  • 数学公式: RMS(x)=1Dd=1Dxd2\text{RMS}(x) = \sqrt{\frac{1}{D} \sum_{d=1}^D x_d^2} x^n,l,d=xn,l,d1Di=1Dxn,l,i2+ϵ\hat{x}_{n,l,d} = \frac{x_{n,l,d}}{\sqrt{\frac{1}{D} \sum_{i=1}^D x_{n,l,i}^2 + \epsilon}} yn,l,d=γdx^n,l,dy_{n,l,d} = \gamma_d \hat{x}_{n,l,d}
  • 注意: RMSNorm 通常不使用平移参数 β\beta
  • 几何解释: 它只调整向量的幅度(L2 范数),而不改变其方向。它将输入向量投影到单位球上,然后通过 γ\gamma 进行缩放。
  • 复杂度: 比 LN 更快,因为它省去了均值计算。

5. Deep Normalization (DeepNorm)

  • 核心思想: DeepNorm 不是一个独立的归一化层,而是一套用于训练极深 Transformer 模型的方法论。它通过理论推导,将模型更新的幅度与模型深度解耦,从而实现稳定训练。
  • 关键组成:
    1. 特定的残差连接: 对于一个子层(如 Attention 或 FFN)函数 f(x)f(x),传统的 Post-LN 是 LayerNorm(x + f(x))。DeepNorm 修改为 LayerNorm(α * x + f(x))
    2. 特定的参数初始化: 子层 f(x)f(x) 的参数需要被初始化,使其输出的方差很小。
  • 数学公式: xl+1=LayerNorm(αxl+fl(xl))x_{l+1} = \text{LayerNorm}(\alpha \cdot x_l + f_l(x_l))
  • 参数 α\alpha: 这是一个与模型深度 N 相关的常数,用于平衡残差连接 x 和子层输出 f(x) 的贡献。对于 Encoder,通常设为 α=(2N)1/4\alpha = (2N)^{1/4};对于 Decoder,设为 α=(8N)1/4\alpha = (8N)^{1/4}
  • 动机:
    • 标准的 Post-LN 结构在深层网络中,梯度会逐层累积,导致梯度爆炸。
    • 标准的 Pre-LN 结构 (x + f(LayerNorm(x))) 虽然梯度稳定,但主干分支的输出会逐层累积,导致数值爆炸。
    • DeepNorm 通过理论分析,给出了一个可以使模型更新幅度保持有界的 α 值和初始化方案,从而能够稳定训练上千层的 Transformer

代码实现

这里提供一个手动实现的 RMSNorm,并与 PyTorch 内置的 LayerNorm 进行对比。

python
1import torch
2import torch.nn as nn
3
4class RMSNorm(nn.Module):
5 """
6 手写实现 RMSNorm。
7
8 参数:
9 dims (int): 输入张量最后一个维度的尺寸。
10 eps (float): 防止除以零的小数。
11 """
12 def __init__(self, dims: int, eps: float = 1e-6):
13 super().__init__()
14 self.eps = eps
15 # gamma 参数,在 RMSNorm 中通常称为 weight
16 self.weight = nn.Parameter(torch.ones(dims))
17
18 def _norm(self, x):
19 # 核心公式: x / sqrt(mean(x^2) + eps)
20 # 为什么这样做:
21 # 1. x.pow(2): 计算 x 的平方
22 # 2. .mean(-1, keepdim=True): 沿着最后一个维度计算均值。keepdim=True 是为了保持维度,方便后续的广播操作。
23 # 3. torch.rsqrt: 计算平方根的倒数 (1/sqrt),比直接计算 sqrt 再做除法效率更高。
24 variance = x.pow(2).mean(-1, keepdim=True)
25 return x * torch.rsqrt(variance + self.eps)
26
27 def forward(self, x):
28 # 为什么这样做:
29 # 1. 将输入转换为 float32 进行计算,保证数值精度。
30 # 2. 调用 _norm 函数进行归一化。
31 # 3. 乘以可学习的 gamma (self.weight) 参数进行缩放。
32 # 4. 将输出类型转换回原始输入类型。
33 output = self._norm(x.float()).type_as(x)
34 return output * self.weight
35
36# --- 使用示例 ---
37
38# 设置参数
39batch_size = 4
40seq_len = 10
41model_dim = 128
42
43# 创建随机输入张量
44input_tensor = torch.randn(batch_size, seq_len, model_dim)
45
46# 1. 使用手写的 RMSNorm
47print("--- 手写 RMSNorm ---")
48my_rmsnorm = RMSNorm(dims=model_dim)
49output_my_rmsnorm = my_rmsnorm(input_tensor)
50print("输入形状:", input_tensor.shape)
51print("输出形状:", output_my_rmsnorm.shape)
52# 检查输出的 L2 范数(在乘以 weight 之前,应该接近 sqrt(model_dim))
53with torch.no_grad():
54 normalized_output = my_rmsnorm._norm(input_tensor)
55 l2_norm = torch.linalg.norm(normalized_output[0, 0, :])
56 print(f"一个样本归一化后的 L2 范数 (期望约 {model_dim**0.5:.2f}): {l2_norm.item():.2f}")
57print("可学习的 weight (gamma) 参数形状:", my_rmsnorm.weight.shape)
58
59
60# 2. 对比 PyTorch 的 LayerNorm
61print("\n--- PyTorch LayerNorm ---")
62# elementwise_affine=True 表示使用可学习的 gamma 和 beta
63# LayerNorm 默认会减去均值,为了模拟 RMSNorm 的行为,我们只关注其缩放部分
64pytorch_layernorm = nn.LayerNorm(model_dim, elementwise_affine=True)
65output_pytorch_ln = pytorch_layernorm(input_tensor)
66print("输出形状:", output_pytorch_ln.shape)
67print("可学习的 weight (gamma) 参数形状:", pytorch_layernorm.weight.shape)
68print("可学习的 bias (beta) 参数形状:", pytorch_layernorm.bias.shape)
69
70# 总结:RMSNorm 是 LayerNorm 的一个简化版本,它移除了均值中心化和 bias (beta) 参数,
71# 从而在保持性能的同时提高了计算效率。

工程实践

  • BatchNorm:
    • 场景: CNNs (计算机视觉任务如图像分类、分割)。当批次大小(Batch Size)较大且稳定时(例如 > 16),效果非常好。
    • 超参: 无需太多调整,epsilonmomentum 通常使用默认值。
    • 权衡: 训练和推理的行为不一致,需要 model.train()model.eval() 切换。小批量(< 4-8)会导致统计量噪声大,性能急剧下降。占用额外内存存储滑动平均统计量。
  • LayerNorm:
    • 场景: Transformers (NLP, ViT) 和 RNNs。由于其计算与批次大小无关,是这些序列模型的标配。
    • 超参: epsilon 使用默认值即可。
    • 权衡: 在某些 CNN 任务上可能不如 BN。在 Transformer 中,Pre-LN (LN -> Sublayer -> +) 比 Post-LN (Sublayer -> + -> LN) 训练更稳定,是目前主流。
  • GroupNorm:
    • 场景: 当 CNN 任务的批次大小受限时(如高分辨率图像的分割、检测),GN 是 BN 的优秀替代品。
    • 超参: num_groups 是关键超参,通常设为 32 或 16。它需要在通道数 C 和性能之间做权衡。
    • 权衡: 性能几乎与批次大小无关,但引入了 num_groups 这个需要调整的超参数。
  • RMSNorm:
    • 场景: 现代大型语言模型(LLMs)的首选,如 Llama 系列。因其计算效率高且效果与 LN 相当或更好。
    • 超参: epsilon 使用默认值。
    • 权衡: 相比 LN,计算速度快约 7%-15%。在大多数 Transformer 场景下,可以作为 LN 的直接替代品,以获取性能提升。
  • DeepNorm:
    • 场景: 专门用于训练非常深(> 48层,甚至上千层)的 Transformer 模型,当标准 Pre-LN 结构也无法稳定训练时使用。
    • 超参: alpha 的设置和参数的特殊初始化是关键,需要严格遵循论文指导。
    • 权衡: 这是一个专家级的工具,使用复杂。对于常规深度的模型(如 12-24 层),Pre-LN 通常足够,不需要使用 DeepNorm。

常见误区与边界情况

  • 误区1: BN 在推理时也使用当前批次的统计量。

    • 正解: 这是最常见的错误。BN 在推理时(model.eval() 模式)必须使用训练阶段学习到的全局滑动平均均值和方差(running_meanrunning_var),否则会导致结果不稳定且依赖于推理时的批次数据。忘记切换模式是经典的 bug 来源。
  • 误区2: 归一化层会限制模型的表达能力。

    • 正解: 单纯的标准化(减均值除方差)确实会限制。但所有主流的归一化层都引入了可学习的仿射变换参数 γ\gammaβ\beta。这使得网络可以自适应地缩放和偏移归一化后的输出,在最坏情况下,网络可以学习到一组参数来完全抵消归一化操作,从而恢复原始的激活值。因此,表达能力并未丢失,而是被置于网络的可控范围内。
  • 误区3: LayerNormRMSNorm 在任何场景下都优于 BatchNorm

    • 正解: 没有“银弹”。在经典的 CV 任务和中等以上批次大小的场景下,BatchNorm 及其变种(如 SyncBatchNorm)通常仍然是性能最好的选择。LN/RMSNorm 的优势在于其对批次大小不敏感,因此在 Transformer 和小批量场景下胜出。
  • 面试追问: 为什么 Pre-LN 比 Post-LN 更稳定?

    • 回答要点:
      1. Post-LN (原始 Transformer): output = LayerNorm(x + Sublayer(x))。残差连接的输出在进入下一个模块前才被归一化。在深层网络中,x 的数值会逐层累加,可能导致数值范围过大,训练不稳定。同时,梯度在反向传播时会流经 LayerNorm,导致梯度在深层网络中可能消失或爆炸(所谓的“梯度上坡”问题)。
      2. Pre-LN: output = x + Sublayer(LayerNorm(x))。输入在进入子层(Sublayer)前就被归一化。这保证了每个子层的输入分布都是稳定的,从而使得训练过程更加平滑。残差主干 (x) 不经过 LayerNorm,梯度可以直接流过,有效缓解了梯度消失问题。这使得 Pre-LN 能够支持更深的模型和更宽松的超参(如更高的学习率)。
  • 边界情况: 输入方差为零

    • 处理: 如果一个归一化集合 S 内的所有元素都相同,其方差为零。此时,epsilon (ϵ\epsilon) 参数就起到了关键作用,它能防止分母为零,保证数值计算的稳定性。这是所有归一化层都必须包含 epsilon 的原因。
相关题目