Transformer 为什么用 LayerNorm 而不用 BatchNorm?
好的,我们来深入剖析 Transformer 中 Layer Normalization 的选择。
核心概念
Batch Normalization (BN) 和 Layer Normalization (LN) 都是为了解决深度神经网络训练过程中的“内部协变量偏移”(Internal Covariate Shift)问题,旨在加速模型收敛并提高其泛化能力。
Batch Normalization (BN):对一个 mini-batch 内的所有样本,在同一个特征维度上进行归一化。它的统计量(均值和方差)是基于当前批次数据计算的,因此对批次大小(batch size)敏感。
Layer Normalization (LN):对单个样本的所有特征维度进行归一化。它的统计量是基于当前单个样本计算的,因此与批次大小无关,非常适合处理变长序列数据。
简单来说,如果数据是 [N, D] (N个样本, D个特征),BN 是在 N 这个维度上计算统计量,而 LN 是在 D 这个维度上计算。
原理与推导
为了清晰地理解两者的区别,我们假设输入数据为一个 mini-batch ,其形状为 ,其中 是批次大小, 是特征维度。 表示第 个样本的第 个特征。
1. Batch Normalization (BN)
BN 的核心思想是在每个特征维度上,将所有样本的激活值调整为均值为0,方差为1的分布。
- 计算均值和方差: BN 对每个特征 (即数据的每一列)计算均值 和方差 。
- 归一化: 对每个样本 的特征 进行归一化。
其中 是一个很小的常数,防止分母为零。
- 仿射变换: 为了保持模型的表达能力,BN 引入了可学习的缩放参数 和平移参数 。
注意, 和 的维度是 ,即每个特征维度有一对独立的参数。
- 推理阶段:在推理时,我们可能只有一个样本,无法计算批次统计量。因此,BN 在训练时会通过移动平均(moving average)估算全局的均值和方差,并在推理时使用这些全局统计量。
2. Layer Normalization (LN)
LN 的核心思想是,对于每一个独立的样本,将其所有特征的激活值调整为均值为0,方差为1的分布。
- 计算均值和方差: LN 对每个样本 (即数据的每一行)计算其所有特征的均值 和方差 。
- 归一化: 对样本 的每个特征 进行归一化。
- 仿射变换: 同样,LN 也引入可学习的缩放参数 和平移参数 。
与 BN 相同, 和 的维度也是 。
为什么 Transformer 选择 LayerNorm?
-
对变长序列的友好性 (核心原因):
- 在 NLP 任务中,一个批次内的句子长度通常是不同的。为了进行批处理,我们会用特殊符号(如
<pad>)将短句补齐到与最长句相同的长度。 - 如果使用 BN,它会在特征维度上(例如,第10个词的 embedding 的第5个维度)对一个批次内的所有样本进行归一化。这意味着,那些由
<pad>填充的无效位置会参与均值和方差的计算,给有效的词元带来噪声,污染统计数据。 - 而 LN 是在每个样本(即每个句子中的每个词元)的特征维度上独立进行归一化。它完全不受其他样本或 padding 的影响,每个词元的归一化只依赖于其自身的 embedding 向量。
- 在 NLP 任务中,一个批次内的句子长度通常是不同的。为了进行批处理,我们会用特殊符号(如
-
对小批量大小的鲁棒性:
- BN 的性能严重依赖于批次大小。当批次较小时,计算出的均值和方差抖动很大,无法准确代表全局数据分布,导致模型性能下降。
Transformer这样的大模型由于显存限制,常常只能使用很小的批次大小(例如 1 或 2)。在这种情况下,BN 的效果会非常差。- LN 的计算完全在单个样本内部完成,与批次大小无关。无论批次大小是 1 还是 128,其计算方式和结果都完全一致,表现稳定。
-
训练与推理的一致性:
- BN 在训练和推理时使用两套不同的逻辑:训练时用当前批次的统计量,推理时用全局的移动平均统计量。这种不一致有时会引入一些细微的性能差异或部署问题。
- LN 在训练和推理时执行完全相同的计算,因为它不需要预先计算或存储任何全局统计信息。这使得模型行为更加一致和可预测。
直观解释
想象一个班级(mini-batch)的学生(样本)参加多门考试(特征)。
BatchNorm:想知道“物理”这门课的平均分和分数标准差。它会拿出所有学生的物理成绩,计算均值和方差,然后对每个学生的物理成绩进行归一化。它关心的是单科在全体学生中的相对位置。LayerNorm:想知道“张三”这位同学的综合表现。它会拿出张三的所有科目成绩,计算他自己所有科目的平均分和方差,然后对他自己的所有成绩进行归一化。它关心的是单人所有科目成绩的分布情况。
对于 Transformer 处理的词元序列,LN 的做法更符合直觉:我们希望对单个词元(比如“apple”)的 embedding 向量(包含其语义、位置等所有信息)进行整体的数值缩放,而不是去比较“apple”的第5维 embedding 和“banana”的第5维 embedding。
代码实现
下面我们用 PyTorch 来展示 BN 和 LN 的计算差异。
1import torch2import torch.nn as nn34def main():5 # 设定随机种子以保证结果可复现6 torch.manual_seed(42)78 # 假设我们有一个 mini-batch9 # N=2: 批次大小为2 (两个句子)10 # C=3: 序列长度为3 (每个句子3个词)11 # D=4: embedding维度为412 # PyTorch的BN和LN通常作用在最后一个维度上,所以我们使用 (N, C, D) 的形状13 # 为了简化,我们先将 N 和 C 合并,模拟一个 (BatchSize, FeatureDim) 的输入14 batch_size = 2 * 315 feature_dim = 41617 # 创建输入数据,形状为 (6, 4)18 # 可以想象成一个批次里有6个词元,每个词元是4维向量19 x = torch.randn(batch_size, feature_dim)20 print("--- 输入数据 (X) ---")21 print(x)22 print("形状:", x.shape)23 print("\n" + "="*50 + "\n")2425 # --- BatchNorm1d ---26 # BN需要知道特征的数量27 bn = nn.BatchNorm1d(num_features=feature_dim)28 bn_output = bn(x)2930 # 手动验证BatchNorm的计算过程31 # 为什么这样做:为了证明BN是沿着批次维度(dim=0)计算统计量的32 mean_bn = x.mean(dim=0, keepdim=True) # 沿着dim=0计算均值,得到 (1, 4)33 var_bn = x.var(dim=0, unbiased=False, keepdim=True) # 沿着dim=0计算方差,得到 (1, 4)34 x_hat_bn = (x - mean_bn) / torch.sqrt(var_bn + bn.eps)35 # bn.weight 和 bn.bias 分别是可学习的 gamma 和 beta36 manual_bn_output = x_hat_bn * bn.weight + bn.bias3738 print("--- BatchNorm1d 输出 ---")39 print("PyTorch BatchNorm1d 输出:\n", bn_output)40 print("手动计算 BatchNorm1d 输出:\n", manual_bn_output)41 print("BN计算的均值 (沿着dim=0):\n", mean_bn)42 print("BN计算的方差 (沿着dim=0):\n", var_bn)43 # 验证误差是否足够小44 assert torch.allclose(bn_output, manual_bn_output, atol=1e-7)45 print("\nBatchNorm1d 验证成功!\n" + "="*50 + "\n")4647 # --- LayerNorm ---48 # LN需要知道需要归一化的特征形状49 ln = nn.LayerNorm(normalized_shape=feature_dim)50 ln_output = ln(x)5152 # 手动验证LayerNorm的计算过程53 # 为什么这样做:为了证明LN是沿着特征维度(dim=1)计算统计量的54 mean_ln = x.mean(dim=1, keepdim=True) # 沿着dim=1计算均值,得到 (6, 1)55 var_ln = x.var(dim=1, unbiased=False, keepdim=True) # 沿着dim=1计算方差,得到 (6, 1)56 x_hat_ln = (x - mean_ln) / torch.sqrt(var_ln + ln.eps)57 # ln.weight 和 ln.bias 分别是可学习的 gamma 和 beta58 manual_ln_output = x_hat_ln * ln.weight + ln.bias5960 print("--- LayerNorm 输出 ---")61 print("PyTorch LayerNorm 输出:\n", ln_output)62 print("手动计算 LayerNorm 输出:\n", manual_ln_output)63 print("LN计算的均值 (沿着dim=1):\n", mean_ln)64 print("LN计算的方差 (沿着dim=1):\n", var_ln)65 # 验证误差是否足够小66 assert torch.allclose(ln_output, manual_ln_output, atol=1e-7)67 print("\nLayerNorm 验证成功!")6869if __name__ == '__main__':70 main()
工程实践
-
使用场景:
LayerNorm:Transformer、RNN、LSTM 等序列模型是其主要阵地。在需要模型对批次大小不敏感的场景(如在线学习、强化学习)也很有用。BatchNorm:CNN(卷积神经网络)是其绝对主场。在图像任务中,一个批次内不同图片在同一空间位置的特征通常具有相似的统计分布,BN能很好地利用这一点。
-
超参数选择:
eps(epsilon):通常使用默认值1e-5即可。这是一个为了数值稳定性而加在分母上的小值。在混合精度训练(FP16)中,如果遇到梯度为NaN的问题,可以适当增大eps到1e-4或1e-3。elementwise_affine:布尔值,决定是否使用可学习的 和 。几乎总是设置为True。如果设为False,层将只进行归一化,不进行仿射变换,这会削弱模型的表达能力,相当于强行让每层的输出都符合标准正态分布。
-
性能/显存/吞吐考量:
- 在现代硬件(如GPU)和深度学习框架(如PyTorch+cuDNN)上,BN和LN的计算都非常高效。对于整个
Transformer模型来说,归一化层的计算开销远小于自注意力机制和前馈网络,通常不是性能瓶颈。 - 显存占用方面,两者都需要存储 和 参数,BN 额外需要存储两个非参数的 moving average buffer。总体而言,它们的显存占用都很小。
- 主要的考量并非计算性能,而是对模型训练动态和最终性能的影响。LN的稳定性、对小批次的友好性是其在
Transformer中胜出的关键。
- 在现代硬件(如GPU)和深度学习框架(如PyTorch+cuDNN)上,BN和LN的计算都非常高效。对于整个
-
调试技巧:
- Pre-LN vs. Post-LN:在
Transformer中,LN层的位置有两种主流设计:放在残差连接之前(Pre-LN)或之后(Post-LN)。原始论文使用的是Post-LN,但后续研究发现Pre-LN能使训练更加稳定,对学习率不那么敏感,且通常不需要warm-up。目前,Pre-LN是更为主流和推荐的做法。 - 梯度消失/爆炸:如果模型训练不稳定,可以检查LN层的输出。在没有仿射变换时,其均值应接近0,标准差应接近1。如果仿射变换后的值域变得极端,可能需要检查权重初始化或学习率。
- Pre-LN vs. Post-LN:在
常见误区与边界情况
-
误区:“LN 就是转置版的 BN”
- 这是一个非常不准确的类比。虽然从计算轴上看有“转置”的感觉,但它们背后的统计假设和工程影响完全不同。核心区别在于BN依赖批次,而LN不依赖。
-
误区:“LN 总是优于 BN”
- 错误。在CNN领域,BN仍然是标准配置且效果优异。图像数据中,跨样本的特征统计是有意义的,BN能有效利用这一信息。选择哪种归一化方法高度依赖于数据模态和模型架构。
-
边界情况:批次大小为 1
- BN:在训练模式下,如果批次大小为1,计算方差时会得到0,导致除零错误。PyTorch的
BatchNorm会直接抛出异常。 - LN:完全不受影响,因为它的计算与批次大小无关。这是它在小批量或在线学习场景下的巨大优势。
- BN:在训练模式下,如果批次大小为1,计算方差时会得到0,导致除零错误。PyTorch的
-
边界情况:特征维度为 1
- LN:如果一个样本只有一个特征维度,计算方差会得到0,同样导致除零错误。但在实践中,尤其是在
Transformer中,embedding维度(特征维度)通常远大于1(例如768),所以这不是一个实际问题。
- LN:如果一个样本只有一个特征维度,计算方差会得到0,同样导致除零错误。但在实践中,尤其是在
-
面试追问:
- 问:除了BN和LN,你还知道哪些归一化方法?它们和LN有什么关系?
- 答:还有 Instance Normalization (IN) 和 Group Normalization (GN)。
- InstanceNorm:常用于风格迁移等图像生成任务。可以看作是应用在图像上的LN,它对每个样本的每个通道(channel)独立进行归一化。如果把一个通道看作一个“层”,它就和LN很像了。
- GroupNorm:介于LN和IN之间。它将特征通道分成若干组(group),在每个组内进行归一化。当group数量为1时,GN等价于LN;当group数量等于通道数时,GN等价于IN。GN同样不依赖批次大小,是对BN在小批次场景下的一种改进。
- 总结:这些方法的根本区别在于选择哪些维度来共同计算统计量。LN选择了单个样本的所有特征,因为它假设这些特征共同构成了一个需要被归一化的“层”。
- 答:还有 Instance Normalization (IN) 和 Group Normalization (GN)。
- 问:除了BN和LN,你还知道哪些归一化方法?它们和LN有什么关系?