§1.3.32

Xavier/Kaiming/Truncated-Normal 初始化与激活的匹配?

核心概念

权重初始化(Weight Initialization)是深度学习模型训练前设置网络参数初始值的过程。一个好的初始化策略旨在解决梯度消失/爆炸问题,确保信号(前向传播中的激活值)和梯度(反向传播中的梯度值)在网络层间传递时,其方差能够保持在稳定范围内。Xavier、Kaiming 和 Truncated-Normal 都是具体的初始化方法,它们通过数学推导,将权重的初始方差与网络层的输入/输出维度(以及激活函数特性)关联起来,以维持信号传播的稳定性。

原理与推导

核心思想是分析并控制每一层输出的方差,使其与输入方差保持一致。

基本假设:

  1. 输入 xx 的每个分量 xkx_k 和权重 WW 的每个分量 WikW_{ik} 均独立同分布。
  2. 输入和权重均值为零,即 E[xk]=0E[x_k] = 0, E[Wik]=0E[W_{ik}] = 0
  3. 权重 WW 和输入 xx 相互独立。

考虑一个全连接层,其输出 yiy_i 的计算如下(暂不考虑偏置和激活函数): yi=k=1ninWikxky_i = \sum_{k=1}^{n_{in}} W_{ik} x_k 其中 ninn_{in} 是输入神经元的数量。

方差推导: 我们来计算 yiy_i 的方差 Var(yi)Var(y_i)Var(yi)=Var(k=1ninWikxk)Var(y_i) = Var(\sum_{k=1}^{n_{in}} W_{ik} x_k) 由于 WikW_{ik}xkx_k 相互独立,且各项之间也独立,方差和等于和的方差: Var(yi)=k=1ninVar(Wikxk)Var(y_i) = \sum_{k=1}^{n_{in}} Var(W_{ik} x_k) 利用方差性质 Var(AB)=E[A2B2](E[AB])2Var(AB) = E[A^2B^2] - (E[AB])^2。因为 A,BA, B 独立且均值为0,有 E[AB]=E[A]E[B]=0E[AB] = E[A]E[B] = 0E[A2B2]=E[A2]E[B2]E[A^2B^2] = E[A^2]E[B^2]。又因 Var(A)=E[A2](E[A])2=E[A2]Var(A) = E[A^2] - (E[A])^2 = E[A^2],所以 Var(AB)=Var(A)Var(B)+Var(A)(E[B])2+Var(B)(E[A])2=Var(A)Var(B)Var(AB) = Var(A)Var(B) + Var(A)(E[B])^2 + Var(B)(E[A])^2 = Var(A)Var(B)。 因此: Var(yi)=k=1ninVar(Wik)Var(xk)Var(y_i) = \sum_{k=1}^{n_{in}} Var(W_{ik}) Var(x_k) 假设所有输入 xkx_k 方差相同为 Var(x)Var(x),所有权重 WikW_{ik} 方差相同为 Var(W)Var(W),则: Var(y)=ninVar(W)Var(x)Var(y) = n_{in} \cdot Var(W) \cdot Var(x)

1. Xavier (Glorot) 初始化

  • 目标:维持前向传播时激活值的方差不变,即 Var(y)=Var(x)Var(y) = Var(x)。同时,在反向传播时,梯度的方差也不变。
  • 推导
    • 前向传播:为了使 Var(y)=Var(x)Var(y) = Var(x),我们需要 ninVar(W)=1n_{in} \cdot Var(W) = 1,即 Var(W)=1ninVar(W) = \frac{1}{n_{in}}
    • 反向传播:梯度的反向传播与前向传播类似,但输入输出维度角色互换。为保持梯度方差,需要 noutVar(W)=1n_{out} \cdot Var(W) = 1,即 Var(W)=1noutVar(W) = \frac{1}{n_{out}}
    • 折衷方案:Xavier 初始化采用了一个折衷,取两者的调和平均数: Var(W)=2nin+noutVar(W) = \frac{2}{n_{in} + n_{out}}
  • 激活函数匹配:Xavier 的推导基于一个重要的线性假设:激活函数在其原点附近是线性的,且导数为1。这使得 Tanh 和 Sigmoid 函数在输入接近0时表现良好。对于这些函数,输入输出方差近似保持不变。
  • 具体分布
    • 均匀分布 U[a,a]U[-a, a]:方差为 a23\frac{a^2}{3}。令 a23=2nin+nout\frac{a^2}{3} = \frac{2}{n_{in} + n_{out}},解得 a=6nin+nouta = \sqrt{\frac{6}{n_{in} + n_{out}}}
    • 正态分布 N(0,σ2)N(0, \sigma^2):方差为 σ2\sigma^2。令 σ2=2nin+nout\sigma^2 = \frac{2}{n_{in} + n_{out}},即 σ=2nin+nout\sigma = \sqrt{\frac{2}{n_{in} + n_{out}}}

2. Kaiming (He) 初始化

  • 问题:Xavier 初始化在 ReLU 激活函数上表现不佳。ReLU (f(x)=max(0,x)f(x) = \max(0, x)) 会将所有负值置为0,这破坏了 Xavier 的零均值和线性假设。
  • 推导
    • 假设输入 xx 是一个关于0对称的分布(例如,前一层输出已经过BN或来自一个合理的初始化)。经过一个线性层 z=Wxz = Wxzz 也是关于0对称的。
    • 应用 ReLU 后,y=ReLU(z)y = ReLU(z)yy 的一半值被置为0。这导致其方差大约是原始方差的一半。
    • 形式上,对于均值为0的 ziz_i,有 Var(yi)=E[yi2](E[yi])2Var(y_i) = E[y_i^2] - (E[y_i])^2。可以证明 E[yi2]=12E[zi2]=12Var(zi)E[y_i^2] = \frac{1}{2}E[z_i^2] = \frac{1}{2}Var(z_i)。由于 E[yi]E[y_i] 不为0,计算会更复杂,但一个关键的近似结果是: Var(y)12Var(z)=12ninVar(W)Var(x)Var(y) \approx \frac{1}{2} Var(z) = \frac{1}{2} n_{in} \cdot Var(W) \cdot Var(x)
    • 目标:为了维持方差不变,即 Var(y)=Var(x)Var(y) = Var(x),我们需要: 12ninVar(W)=1    Var(W)=2nin\frac{1}{2} n_{in} \cdot Var(W) = 1 \implies Var(W) = \frac{2}{n_{in}}
  • 激活函数匹配:Kaiming 初始化专为 ReLU 及其变体(如 Leaky ReLU, PReLU)设计。它补偿了 ReLU 导致的一半信息丢失。
  • 具体分布
    • 正态分布 N(0,σ2)N(0, \sigma^2)σ2=2nin\sigma^2 = \frac{2}{n_{in}},即 σ=2nin\sigma = \sqrt{\frac{2}{n_{in}}}
    • 均匀分布 U[a,a]U[-a, a]a23=2nin\frac{a^2}{3} = \frac{2}{n_{in}},解得 a=6nina = \sqrt{\frac{6}{n_{in}}}

3. Truncated-Normal (截断正态) 初始化

  • 核心概念:这不是一种像 Xavier 或 Kaiming 那样的方差缩放规则,而是一种采样策略。它从一个正态分布 N(μ,σ2)N(\mu, \sigma^2) 中采样,但如果采样值落在某个区间(通常是 [μ2σ,μ+2σ][\mu-2\sigma, \mu+2\sigma])之外,就会被丢弃并重新采样,直到采样值落在该区间内。
  • 动机:标准的正态分布理论上可以产生绝对值非常大的数。在初始化阶段,一个极大的权重值可能导致神经元立即饱和(对于 Sigmoid/Tanh)或产生巨大的激活值,从而引起梯度爆炸和训练不稳定。截断正态通过移除这些“离群值”来提高初始化的稳定性。
  • 与 Xavier/Kaiming 的关系:Truncated-Normal 通常与 Xavier 或 Kaiming 的方差计算结合使用。例如,可以先用 Kaiming 规则计算出标准差 σ=2nin\sigma = \sqrt{\frac{2}{n_{in}}},然后从截断正态分布 TruncatedNormal(0,σ2)TruncatedNormal(0, \sigma^2) 中采样权重。

代码实现

下面是一个 PyTorch 示例,演示如何根据激活函数选择并应用不同的初始化方法。

python
1import torch
2import torch.nn as nn
3import math
4
5class MyModel(nn.Module):
6 def __init__(self):
7 super().__init__()
8 # 场景1: Tanh 激活,使用 Xavier 初始化
9 self.layer1 = nn.Linear(128, 256)
10 self.act1 = nn.Tanh()
11
12 # 场景2: ReLU 激活,使用 Kaiming 初始化
13 self.layer2 = nn.Linear(256, 512)
14 self.act2 = nn.ReLU()
15
16 # 场景3: 自定义截断正态初始化 (与 Kaiming 结合)
17 self.layer3 = nn.Linear(512, 10)
18
19 self.initialize_weights()
20
21 def initialize_weights(self):
22 # 遍历所有模块
23 for m in self.modules():
24 if isinstance(m, nn.Linear):
25 # 根据激活函数选择初始化
26 # PyTorch 的 Kaiming 和 Xavier 初始化函数已经考虑了激活函数的特性
27 # 这里为了教学目的,我们手动判断并调用
28
29 if m == self.layer1:
30 print("Initializing layer1 (for Tanh) with Xavier Uniform.")
31 # 为什么这样做: Xavier 假设激活函数在0点附近是线性的,Tanh符合此特性。
32 # 'gain' 参数可以调整方差,对于 Tanh,PyTorch 推荐的 gain 是 5/3,但默认的 1.0 也很常用。
33 nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('tanh'))
34 if m.bias is not None:
35 nn.init.constant_(m.bias, 0)
36
37 elif m == self.layer2:
38 print("Initializing layer2 (for ReLU) with Kaiming Normal.")
39 # 为什么这样做: Kaiming 专为 ReLU 及其变体设计,补偿了 ReLU 导致的一半方差损失。
40 # a=0 表示 ReLU, mode='fan_in' 表示方差基于输入维度计算 (Var=2/n_in)
41 nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
42 if m.bias is not None:
43 nn.init.constant_(m.bias, 0)
44
45 elif m == self.layer3:
46 print("Initializing layer3 with Truncated Normal (Kaiming variance).")
47 # 为什么这样做: 结合 Kaiming 的理论优势和 Truncated Normal 的稳定性优势,
48 # 防止初始权重过大。
49 self.truncated_normal_init(m.weight, mean=0.0, std_dev_scale=2.0/m.in_features, truncation_limit=2.0)
50 if m.bias is not None:
51 nn.init.constant_(m.bias, 0)
52
53 def truncated_normal_init(self, tensor, mean=0.0, std_dev_scale=1.0, truncation_limit=2.0):
54 """
55 手动实现截断正态分布初始化
56 """
57 # 为什么这样做: PyTorch 核心库没有内置的截断正态初始化,需要手动实现。
58 # 这是 TensorFlow 中常见的默认初始化方式。
59
60 # 1. 计算标准差
61 # std_dev_scale 对应 Kaiming/Xavier 计算出的方差 Var(W),所以标准差是 sqrt(Var(W))
62 std = math.sqrt(std_dev_scale)
63
64 # 2. 创建一个截断分布对象
65 # a, b 是截断区间的下界和上界,相对于 (x-mean)/std
66 lower_bound = -truncation_limit
67 upper_bound = truncation_limit
68 # scipy.stats.truncnorm 是一个更方便的工具,但这里为了展示原理使用 torch
69
70 # 3. 用正态分布填充,然后对超出范围的值进行重采样
71 with torch.no_grad():
72 # 使用 torch.fmod 来循环地将值限制在范围内,这是一种高效的近似
73 # 更严格的实现是循环检查和重采样,但效率较低
74 # tensor.normal_() 会用 N(0,1) 填充,然后我们乘以标准差并加上均值
75 tensor.normal_(mean=0, std=1.0)
76
77 # 计算截断后的标准差,以保持整体方差不变
78 # 这是一个修正因子,确保截断后的分布方差仍接近期望值
79 import scipy.stats as stats
80 l, u = lower_bound, upper_bound
81 correction = stats.truncnorm.std(a=l, b=u, loc=0, scale=1)
82
83 tensor.mul_(std / correction)
84
85 # 严格的重采样方法(教学用,效率低)
86 # while True:
87 # out_of_bounds = torch.abs(tensor) > truncation_limit * std
88 # if not torch.any(out_of_bounds):
89 # break
90 # num_resample = torch.sum(out_of_bounds)
91 # resamples = torch.normal(mean, std, size=(num_resample,), device=tensor.device)
92 # tensor[out_of_bounds] = resamples
93
94 def forward(self, x):
95 x = self.act1(self.layer1(x))
96 x = self.act2(self.layer2(x))
97 x = self.layer3(x)
98 return x
99
100# 运行示例
101model = MyModel()
102dummy_input = torch.randn(10, 128)
103output = model(dummy_input)
104print("\nModel initialized. Output shape:", output.shape)

工程实践

  • 默认选择:在现代深度学习框架(如 PyTorch、TensorFlow)中,nn.Linearnn.Conv2d 的默认初始化通常就是 Kaiming 初始化。这是因为 ReLU 是迄今为止最常用的激活函数。因此,大多数情况下,你甚至不需要手动初始化。
  • 超参数 gain:PyTorch 的初始化函数中有一个 gain 参数。它是一个乘法因子,用于调整计算出的标准差。nn.init.calculate_gain(nonlinearity, param=None) 函数可以帮你计算常用激活函数的推荐 gain 值。例如,对于 Leaky ReLU,其负斜率会影响方差,需要通过 gain 来补偿。
  • 与归一化层的关系:批量归一化(Batch Normalization)和层归一化(Layer Normalization)会重新缩放和中心化激活值,这在很大程度上缓解了不当初始化带来的问题。然而,一个好的初始化仍然至关重要,因为:
    1. 它能加速训练初期的收敛。
    2. 它能帮助归一化层更快地学习到合适的缩放和平移参数。
    3. 在没有使用归一化层的模型中,初始化是决定成败的关键。
  • 性能权衡
    • Xavier/Kaiming:计算开销极小,在模型定义时一次性完成。对训练吞吐量没有影响。
    • Truncated Normal:严格的重采样实现可能比标准采样慢一点,但由于只在训练开始前执行一次,对总体性能的影响可以忽略不计。
  • 调试技巧:如果怀疑初始化有问题(例如,模型不收敛,梯度为 NaN),一个有效的调试方法是:可视化激活值和梯度的统计数据。逐层打印激活值和梯度的均值和标准差。如果标准差在前向传播中急剧缩小(梯度消失)或增大(梯度爆炸),或者在反向传播中梯度的范数变化剧烈,那么很可能是初始化策略不当。

常见误区与边界情况

  • 误区1:初始化方法和分布类型混淆 Xavier 和 Kaiming 是方差缩放规则,它们告诉你权重的方差应该是多少。而 Uniform、Normal、Truncated-Normal 是概率分布,它们描述了如何从具有该方差的分布中采样。你可以有 Kaiming Normal 初始化,也可以有 Kaiming Uniform 初始化。

  • 误区2:对所有层使用相同的初始化 错误地将为 Tanh 设计的 Xavier 初始化用于 ReLU 层,可能会导致训练速度变慢。反之亦然。必须根据层后面的激活函数来选择匹配的初始化方法。

  • 误区3:认为有了 Batch Norm 就不用管初始化了 如上所述,BN 虽强,但好的初始化依然是“锦上添花”,能提供更稳定、更快速的收敛起点。在非常深或者结构复杂的网络中,这种“助推”作用可能非常关键。

  • 边界情况:偏置(Bias)的初始化

    • 通常,偏置被初始化为0。
    • 一个特例是,在 ReLU 网络中,有时会将偏置初始化为一个小的正数(如0.01或0.1)。这被称为 "ReLU bias trick"。其动机是确保所有 ReLU 单元在训练开始时都能获得正输入,从而激活并参与学习,避免所谓的“死亡 ReLU”问题。然而,在实践中,配合 Kaiming 初始化和现代优化器(如 Adam),将偏置初始化为0通常已经足够好。
  • 常见面试追问

    • :“为什么 Kaiming 初始化的方差是 2/nin2/n_{in} 而不是 1/nin1/n_{in}?”
      • :核心在于 ReLU 激活函数会使大约一半的输入变为0,这导致输出的方差减半。为了补偿这个效应,我们需要将权重的初始方差加倍,即从 Xavier 的 1/nin1/n_{in} 变为 2/nin2/n_{in},以确保通过 ReLU 后的激活值方差能恢复到和输入一致的水平。
    • :“这些初始化方法对 Transformer 模型适用吗?”
      • :适用,但有细微差别。Transformer 中的全连接层(FFN部分)通常使用 ReLU 或其变体(如 GeLU),因此 Kaiming 初始化是合适的。对于 Attention 中的 Q, K, V 投影矩阵,它们后面没有直接的非线性激活(Softmax 是在 Attention 分数上做的),但实践中 Kaiming 或 Xavier 仍然是常见的、稳健的选择。原始 Transformer 论文使用的是 Xavier 初始化。
    • :“如果我用 Leaky ReLU,Kaiming 初始化公式需要调整吗?”
      • :是的。Leaky ReLU 的负半轴斜率为 aa(一个小的正数)。这会改变方差的传播。修正后的方差传播变为 Var(y)1+a22ninVar(W)Var(x)Var(y) \approx \frac{1+a^2}{2} n_{in} Var(W) Var(x)。因此,为了保持方差,需要 Var(W)=2(1+a2)ninVar(W) = \frac{2}{(1+a^2)n_{in}}。PyTorch 的 kaiming_normal_ 函数通过 nonlinearity='leaky_relu'a=... 参数自动处理了这种情况。
相关题目