BatchNorm/LayerNorm/RMSNorm/GroupNorm/DeepNorm 的公式与差异?
- —手写 RMSNorm
核心概念
归一化层(Normalization Layers)是深度学习模型中的一类关键组件,其核心思想是通过对层输入或激活值进行重新中心化(re-centering)和重新缩放(re-scaling)来规范化其数据分布。这有助于解决深度网络训练过程中的“内部协变量偏移”(Internal Covariate Shift)问题,从而稳定训练过程、加快收敛速度、支持使用更高的学习率,并降低模型对初始化参数的敏感度。不同的归一化方法,如 BatchNorm、LayerNorm 等,其主要区别在于计算均值和方差的维度集合不同。
原理与推导
所有主流的归一化技术都可以被抽象为一个两步过程:
-
标准化 (Standardization): 将输入
x的某个集合S内的元素,调整为均值为 0,方差为 1 的分布。 其中 是集合S的均值, 是集合S的方差, 是一个很小的正数(如1e-5)以防止除以零。 -
仿射变换 (Affine Transformation): 引入两个可学习的参数,缩放因子 和平移因子 ,来恢复网络的表达能力。 如果网络认为不需要归一化,它可以学习到 和 ,从而将输入恒等映射回来。
下面我们详细分析各种归一化方法在计算均值和方差的集合 S 上的差异。假设我们有一个 mini-batch 的输入张量,其形状对于 CNN 是 (N, C, H, W),对于 Transformer 是 (N, L, D),其中 N 是批大小,C 是通道数,H, W 是高和宽,L 是序列长度,D 是特征维度。
1. Batch Normalization (BN)
- 归一化维度: 对每个特征通道(channel),在整个批次(batch)内进行归一化。
- 计算集合
S: 对于(N, C, H, W)的输入,BN 为C个通道中的每一个通道,独立计算其均值和方差。计算S的范围是(N, H, W)维度。这意味着均值和方差与样本有关,但与特征图的空间位置无关。 - 数学公式:
- 注意: 和 的维度是
(C,),即每个通道有一对独立的参数。 - 几何解释: 将不同样本在同一特征通道上的值拉到相似的分布范围内,减少了样本间的差异性对模型的影响。
- 复杂度: 训练时需要计算批次统计量,推理时使用训练期间累积的滑动平均均值和方差。时间/空间复杂度与输入大小成正比。
2. Layer Normalization (LN)
- 归一化维度: 对每个样本,在所有特征通道(或嵌入维度)上进行归一化。
- 计算集合
S: 对于(N, L, D)的输入(如Transformer),LN 为N个样本中的每一个样本,独立计算其均值和方差。计算S的范围是(L, D)维度(或仅D维度,取决于具体实现)。它与批次大小完全无关。 - 数学公式 (以
Transformer中常见形式为例,对最后一个维度D归一化): - 注意: 和 的维度是
(D,),即每个特征维度有一对独立的参数。 - 几何解释: 将单个样本的所有特征拉到同一个固定的分布上,使得模型对特征的整体缩放不敏感。
- 复杂度: 复杂度与输入大小成正比,但计算完全在单个样本内完成,与批次大小无关。
3. Group Normalization (GN)
- 归一化维度: 对每个样本,将通道分成若干组(Group),在每个组内进行归一化。
- 计算集合
S: GN 是 BN 和 LN 的折衷。对于(N, C, H, W)的输入,它首先将C个通道分成G组,每组有C/G个通道。然后对每个样本的每个组内进行归一化。计算S的范围是(C/G, H, W)维度。 - 数学公式: 假设通道
c属于第g组, - 注意: 当
G=1时,GN 约等于 LN;当G=C时,GN 变为 Instance Normalization。 - 几何解释: 假设在通道维度上存在局部的相关性,因此将相关的通道分为一组共同归一化是合理的。
- 复杂度: 与 LN 类似,与批次大小无关。
4. RMS Normalization (RMSNorm)
- 归一化维度: 与 LN 类似,但移除了重新中心化(re-centering)步骤,即不减去均值。
- 核心思想: 论文作者认为,对于稳定训练而言,重新缩放(re-scaling)比重新中心化更重要。移除均值计算可以简化计算,提升速度。
- 数学公式:
- 注意:
RMSNorm通常不使用平移参数 。 - 几何解释: 它只调整向量的幅度(L2 范数),而不改变其方向。它将输入向量投影到单位球上,然后通过 进行缩放。
- 复杂度: 比 LN 更快,因为它省去了均值计算。
5. Deep Normalization (DeepNorm)
- 核心思想: DeepNorm 不是一个独立的归一化层,而是一套用于训练极深
Transformer模型的方法论。它通过理论推导,将模型更新的幅度与模型深度解耦,从而实现稳定训练。 - 关键组成:
- 特定的残差连接: 对于一个子层(如 Attention 或 FFN)函数 ,传统的 Post-LN 是
LayerNorm(x + f(x))。DeepNorm 修改为LayerNorm(α * x + f(x))。 - 特定的参数初始化: 子层 的参数需要被初始化,使其输出的方差很小。
- 特定的残差连接: 对于一个子层(如 Attention 或 FFN)函数 ,传统的 Post-LN 是
- 数学公式:
- 参数 : 这是一个与模型深度
N相关的常数,用于平衡残差连接x和子层输出f(x)的贡献。对于 Encoder,通常设为 ;对于 Decoder,设为 。 - 动机:
- 标准的 Post-LN 结构在深层网络中,梯度会逐层累积,导致梯度爆炸。
- 标准的 Pre-LN 结构 (
x + f(LayerNorm(x))) 虽然梯度稳定,但主干分支的输出会逐层累积,导致数值爆炸。 - DeepNorm 通过理论分析,给出了一个可以使模型更新幅度保持有界的
α值和初始化方案,从而能够稳定训练上千层的Transformer。
代码实现
这里提供一个手动实现的 RMSNorm,并与 PyTorch 内置的 LayerNorm 进行对比。
1import torch2import torch.nn as nn34class RMSNorm(nn.Module):5 """6 手写实现 RMSNorm。78 参数:9 dims (int): 输入张量最后一个维度的尺寸。10 eps (float): 防止除以零的小数。11 """12 def __init__(self, dims: int, eps: float = 1e-6):13 super().__init__()14 self.eps = eps15 # gamma 参数,在 RMSNorm 中通常称为 weight16 self.weight = nn.Parameter(torch.ones(dims))1718 def _norm(self, x):19 # 核心公式: x / sqrt(mean(x^2) + eps)20 # 为什么这样做:21 # 1. x.pow(2): 计算 x 的平方22 # 2. .mean(-1, keepdim=True): 沿着最后一个维度计算均值。keepdim=True 是为了保持维度,方便后续的广播操作。23 # 3. torch.rsqrt: 计算平方根的倒数 (1/sqrt),比直接计算 sqrt 再做除法效率更高。24 variance = x.pow(2).mean(-1, keepdim=True)25 return x * torch.rsqrt(variance + self.eps)2627 def forward(self, x):28 # 为什么这样做:29 # 1. 将输入转换为 float32 进行计算,保证数值精度。30 # 2. 调用 _norm 函数进行归一化。31 # 3. 乘以可学习的 gamma (self.weight) 参数进行缩放。32 # 4. 将输出类型转换回原始输入类型。33 output = self._norm(x.float()).type_as(x)34 return output * self.weight3536# --- 使用示例 ---3738# 设置参数39batch_size = 440seq_len = 1041model_dim = 1284243# 创建随机输入张量44input_tensor = torch.randn(batch_size, seq_len, model_dim)4546# 1. 使用手写的 RMSNorm47print("--- 手写 RMSNorm ---")48my_rmsnorm = RMSNorm(dims=model_dim)49output_my_rmsnorm = my_rmsnorm(input_tensor)50print("输入形状:", input_tensor.shape)51print("输出形状:", output_my_rmsnorm.shape)52# 检查输出的 L2 范数(在乘以 weight 之前,应该接近 sqrt(model_dim))53with torch.no_grad():54 normalized_output = my_rmsnorm._norm(input_tensor)55 l2_norm = torch.linalg.norm(normalized_output[0, 0, :])56 print(f"一个样本归一化后的 L2 范数 (期望约 {model_dim**0.5:.2f}): {l2_norm.item():.2f}")57print("可学习的 weight (gamma) 参数形状:", my_rmsnorm.weight.shape)585960# 2. 对比 PyTorch 的 LayerNorm61print("\n--- PyTorch LayerNorm ---")62# elementwise_affine=True 表示使用可学习的 gamma 和 beta63# LayerNorm 默认会减去均值,为了模拟 RMSNorm 的行为,我们只关注其缩放部分64pytorch_layernorm = nn.LayerNorm(model_dim, elementwise_affine=True)65output_pytorch_ln = pytorch_layernorm(input_tensor)66print("输出形状:", output_pytorch_ln.shape)67print("可学习的 weight (gamma) 参数形状:", pytorch_layernorm.weight.shape)68print("可学习的 bias (beta) 参数形状:", pytorch_layernorm.bias.shape)6970# 总结:RMSNorm 是 LayerNorm 的一个简化版本,它移除了均值中心化和 bias (beta) 参数,71# 从而在保持性能的同时提高了计算效率。
工程实践
BatchNorm:- 场景: CNNs (计算机视觉任务如图像分类、分割)。当批次大小(Batch Size)较大且稳定时(例如 > 16),效果非常好。
- 超参: 无需太多调整,
epsilon和momentum通常使用默认值。 - 权衡: 训练和推理的行为不一致,需要
model.train()和model.eval()切换。小批量(< 4-8)会导致统计量噪声大,性能急剧下降。占用额外内存存储滑动平均统计量。
LayerNorm:- 场景: Transformers (NLP,
ViT) 和 RNNs。由于其计算与批次大小无关,是这些序列模型的标配。 - 超参:
epsilon使用默认值即可。 - 权衡: 在某些 CNN 任务上可能不如 BN。在
Transformer中,Pre-LN (LN -> Sublayer -> +) 比 Post-LN (Sublayer -> + -> LN) 训练更稳定,是目前主流。
- 场景: Transformers (NLP,
- GroupNorm:
- 场景: 当 CNN 任务的批次大小受限时(如高分辨率图像的分割、检测),GN 是 BN 的优秀替代品。
- 超参:
num_groups是关键超参,通常设为 32 或 16。它需要在通道数C和性能之间做权衡。 - 权衡: 性能几乎与批次大小无关,但引入了
num_groups这个需要调整的超参数。
RMSNorm:- 场景: 现代大型语言模型(LLMs)的首选,如 Llama 系列。因其计算效率高且效果与 LN 相当或更好。
- 超参:
epsilon使用默认值。 - 权衡: 相比 LN,计算速度快约 7%-15%。在大多数
Transformer场景下,可以作为 LN 的直接替代品,以获取性能提升。
- DeepNorm:
- 场景: 专门用于训练非常深(> 48层,甚至上千层)的
Transformer模型,当标准 Pre-LN 结构也无法稳定训练时使用。 - 超参:
alpha的设置和参数的特殊初始化是关键,需要严格遵循论文指导。 - 权衡: 这是一个专家级的工具,使用复杂。对于常规深度的模型(如 12-24 层),Pre-LN 通常足够,不需要使用 DeepNorm。
- 场景: 专门用于训练非常深(> 48层,甚至上千层)的
常见误区与边界情况
-
误区1: BN 在推理时也使用当前批次的统计量。
- 正解: 这是最常见的错误。BN 在推理时(
model.eval()模式)必须使用训练阶段学习到的全局滑动平均均值和方差(running_mean和running_var),否则会导致结果不稳定且依赖于推理时的批次数据。忘记切换模式是经典的 bug 来源。
- 正解: 这是最常见的错误。BN 在推理时(
-
误区2: 归一化层会限制模型的表达能力。
- 正解: 单纯的标准化(减均值除方差)确实会限制。但所有主流的归一化层都引入了可学习的仿射变换参数 和 。这使得网络可以自适应地缩放和偏移归一化后的输出,在最坏情况下,网络可以学习到一组参数来完全抵消归一化操作,从而恢复原始的激活值。因此,表达能力并未丢失,而是被置于网络的可控范围内。
-
误区3:
LayerNorm和RMSNorm在任何场景下都优于BatchNorm。- 正解: 没有“银弹”。在经典的 CV 任务和中等以上批次大小的场景下,
BatchNorm及其变种(如 SyncBatchNorm)通常仍然是性能最好的选择。LN/RMSNorm的优势在于其对批次大小不敏感,因此在Transformer和小批量场景下胜出。
- 正解: 没有“银弹”。在经典的 CV 任务和中等以上批次大小的场景下,
-
面试追问: 为什么 Pre-LN 比 Post-LN 更稳定?
- 回答要点:
- Post-LN (原始
Transformer):output = LayerNorm(x + Sublayer(x))。残差连接的输出在进入下一个模块前才被归一化。在深层网络中,x的数值会逐层累加,可能导致数值范围过大,训练不稳定。同时,梯度在反向传播时会流经LayerNorm,导致梯度在深层网络中可能消失或爆炸(所谓的“梯度上坡”问题)。 - Pre-LN:
output = x + Sublayer(LayerNorm(x))。输入在进入子层(Sublayer)前就被归一化。这保证了每个子层的输入分布都是稳定的,从而使得训练过程更加平滑。残差主干 (x) 不经过LayerNorm,梯度可以直接流过,有效缓解了梯度消失问题。这使得 Pre-LN 能够支持更深的模型和更宽松的超参(如更高的学习率)。
- Post-LN (原始
- 回答要点:
-
边界情况: 输入方差为零
- 处理: 如果一个归一化集合
S内的所有元素都相同,其方差为零。此时,epsilon() 参数就起到了关键作用,它能防止分母为零,保证数值计算的稳定性。这是所有归一化层都必须包含epsilon的原因。
- 处理: 如果一个归一化集合