KL、JS、交叉熵、Wasserstein 的定义与差异?
- —用 torch 实现 KL(P||Q) 并验证非对称性
核心概念
这四个概念都是用来衡量两个概率分布之间差异的数学工具,但在定义、性质和应用上各有侧重。
- 交叉熵 (Cross-Entropy):主要源于信息论,衡量当使用基于分布 Q 的编码方式去编码来自真实分布 P 的样本时,所需要的平均比特数。在机器学习中,它常作为分类问题的损失函数,衡量模型预测分布 Q 与真实标签分布 P 之间的差异。
- KL 散度 (Kullback-Leibler Divergence):也称为相对熵,衡量用分布 Q 来近似分布 P 时所造成的信息损失。它等于交叉熵 H(P, Q) 减去 P 的信息熵 H(P)。KL 散度是非对称的,不满足距离度量的要求。
- JS 散度 (Jensen-Shannon Divergence):是 KL 散度的一个对称、平滑的版本。它通过引入一个中间分布来计算 P 和 Q 分别与这个中间分布的 KL 散度,从而解决了 KL 散度的非对称性问题。JS 散度有界,且当 P=Q 时为 0。
- Wasserstein 距离 (Wasserstein Distance):也叫“推土机距离”(Earth Mover's Distance),它衡量将一个分布 P “搬运”成另一个分布 Q 所需的最小“成本”。与前三者不同,即使两个分布的支撑集(support)完全不重叠,Wasserstein 距离仍然能提供一个有意义的、平滑的度量,这在生成模型(如 WGAN)中至关重要。
原理与推导
假设我们有两个离散概率分布 和 。
1. 交叉熵 (Cross-Entropy)
交叉熵的定义源于信息论,其数学公式为:
- 推导与解释:
- 根据香农信息论,一个事件 的信息量定义为 。概率越小,信息量越大。
- 分布 P 的信息熵(自信息期望)为 ,表示编码 P 中事件所需的平均最优比特数。
- 当我们用一个错误的分布 Q 来设计编码时(即认为事件 的概率是 ),编码一个来自真实分布 P 的事件 所需的比特数就是 。
- 因此,编码来自真实分布 P 的所有事件,平均所需的比特数就是这些事件的期望,即交叉熵 。
2. KL 散度 (Kullback-Leibler Divergence)
KL 散度衡量用 Q 近似 P 时的信息损失。
- 推导与解释:
- KL 散度可以由交叉熵和信息熵导出:
- 信息论解释: 表示使用基于 Q 的编码方案来编码来自 P 的样本,相比于使用基于 P 的最优编码方案,平均每个样本需要多付出的比特数。
- 非负性:根据吉布斯不等式 (Gibbs' inequality),,当且仅当 时等号成立。
- 非对称性:。
- 关注的是当 时,我们不希望 太小。如果 而 ,则 。这在机器学习中意味着模型(Q)对于一个真实(P)存在的样本给出了极低的概率,应该给予巨大的惩罚。
- 关注的是当 时,我们不希望 太小。如果 而 ,则 。
3. JS 散度 (Jensen-Shannon Divergence)
JS 散度是 KL 散度的对称化版本。
其中 是 P 和 Q 的平均分布。
- 推导与解释:
- 通过引入一个中点分布 M,JS 散度分别衡量 P 和 Q 到这个中点的“距离”,然后取平均。
- 对称性:从公式可以看出 。
- 有界性:JS 散度的值域是 (使用以 2 为底的对数时),或 (使用自然对数时)。这使得它比 KL 散度更稳定。
- 主要缺陷:当两个分布 P 和 Q 的支撑集完全不重叠时,KL 散度可能无定义,而 JS 散度会是一个常数()。这意味着梯度为 0,导致基于梯度的优化方法失效。
4. Wasserstein 距离
Wasserstein 距离衡量将分布 P 变换为 Q 的最小成本。
- 推导与解释:
- 几何解释 (推土机距离):想象 P 和 Q 是两堆形状不同的土堆。 就是将土堆 P 搬运并堆成土堆 Q 形状所需要付出的最小“功”。这里的“功”=“土的质量”ד搬运距离”。
- 是 P 和 Q 所有可能的联合分布 的集合,其边缘分布分别为 P 和 Q。 可以理解为一个搬运方案,表示从位置 搬运多少“土”到位置 。
- 就是在所有可能的搬运方案 中,寻找一个使得总搬运成本 最小的方案。
- Kantorovich-Rubinstein Duality:直接计算上式非常困难。在实践中(如 WGAN),通常使用其对偶形式来计算 1-阶 Wasserstein 距离:
代码实现
以下代码使用 PyTorch 实现 KL 散度,并验证其非对称性。
1import torch2import torch.nn.functional as F34def calculate_kl_divergence(p, q, eps=1e-9):5 """6 直接根据公式计算两个离散概率分布的KL散度 D_KL(P || Q)。78 Args:9 p (torch.Tensor): 第一个概率分布 (P)。10 q (torch.Tensor): 第二个概率分布 (Q)。11 eps (float): 为防止log(0)出现,加入的极小值。1213 Returns:14 torch.Tensor: KL散度的标量值。15 """16 # 为什么这样做:确保输入是合法的概率分布(和为1)17 # 在实践中,softmax的输出已经满足此条件,这里为了严谨性进行检查18 assert torch.isclose(p.sum(), torch.tensor(1.0)), "p必须是概率分布"19 assert torch.isclose(q.sum(), torch.tensor(1.0)), "q必须是概率分布"2021 # 为什么这样做:KL散度的公式是 sum(p * log(p/q))22 # p * (p / q).log() 等价于 p * (p.log() - q.log())23 # 直接计算 p/q 可以更直观地体现公式,但需要处理 q=0 的情况24 # F.kl_div 使用的是 D_KL(p || q) = sum(p * (log_p - log_q)) 的形式,但其输入要求是 log-probabilities25 # 这里我们直接用公式实现,更具教学意义2627 # 为什么这样做:在q中加入eps是为了数值稳定性,防止q中某些元素为0导致log(0) -> -inf28 # 这在工程实践中非常重要29 log_ratio = (p / (q + eps)).log()3031 # 为什么这样做:根据公式 D_KL(P || Q) = Σ p_i * log(p_i / q_i)32 # 只有当 p_i > 0 时,该项才对结果有贡献。如果 p_i = 0,则 p_i * log(...) = 033 # 因此,我们只对 p > 0 的项进行求和,这可以避免 0 * log(0/q) 这种 NaN 情况34 kl_div = (p * log_ratio).sum()3536 return kl_div3738# --- code_drills: 用 torch 实现 KL(P||Q) 并验证非对称性 ---3940# 1. 定义两个不同的概率分布 P 和 Q41# P: 一个偏向左侧的分布42p_dist = torch.tensor([0.7, 0.2, 0.1])43# Q: 一个偏向右侧的分布44q_dist = torch.tensor([0.1, 0.2, 0.7])4546print(f"P 分布: {p_dist}")47print(f"Q 分布: {q_dist}\n")4849# 2. 计算 D_KL(P || Q)50# 衡量用 Q 来近似 P 的信息损失51kl_pq = calculate_kl_divergence(p_dist, q_dist)52print(f"D_KL(P || Q) = {kl_pq:.4f}")53# 解释:P的重心在0.7,而Q只给了0.1的概率,因此用Q来描述P的第一个事件时,"惊讶程度"很高,导致KL散度较大。5455# 3. 计算 D_KL(Q || P)56# 衡量用 P 来近似 Q 的信息损失57kl_qp = calculate_kl_divergence(q_dist, p_dist)58print(f"D_KL(Q || P) = {kl_qp:.4f}")59# 解释:Q的重心在0.7,而P只给了0.1的概率,同样地,用P来描述Q的第三个事件时,"惊讶程度"也很高。6061# 4. 验证非对称性62print(f"\nKL散度是否对称? {'是' if torch.isclose(kl_pq, kl_qp) else '否'}")63assert not torch.isclose(kl_pq, kl_qp), "KL散度应该是非对称的"6465# 5. 使用 PyTorch 内置函数验证66# 注意:F.kl_div 的输入是 (log_q, p),且默认 reduction='sum'67# F.kl_div(q.log(), p) 计算的是 D_KL(P || Q)68kl_pq_torch = F.kl_div(q_dist.log(), p_dist, reduction='sum')69kl_qp_torch = F.kl_div(p_dist.log(), q_dist, reduction='sum')70print(f"PyTorch 内置函数计算 D_KL(P || Q) = {kl_pq_torch:.4f}")71print(f"PyTorch 内置函数计算 D_KL(Q || P) = {kl_qp_torch:.4f}")7273# 验证我们的实现与PyTorch内置函数结果一致74assert torch.isclose(kl_pq, kl_pq_torch)
工程实践
| 度量 | 使用场景 | 超参数/实现细节 | 性能/权衡 |
| -------------- | ----------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| 交叉熵 | 分类问题损失函数:几乎是所有分类任务(二分类、多分类)的默认选择。 | 在 PyTorch 中使用 nn.CrossEntropyLoss,它内部集成了 LogSoftmax 和 NLLLoss,可以提高数值稳定性并避免手动计算 log。 | 计算简单高效。对预测概率的绝对值敏感,当模型对错误类别给出高置信度时,会产生巨大损失,有助于快速修正。 |
| KL 散度 | 1. 变分自编码器(VAE):作为正则项,约束后验分布 逼近先验分布 (通常是标准正态分布)。<br>2. 强化学习(RL):在策略优化中(如 TRPO, PPO),限制新旧策略的差异,保证更新步长不会过大。 | 在 VAE 中,如果先验是标准正态分布,KL 项有解析解,无需数值计算。<br>在 RL 中,KL 散度作为惩罚项或约束项,其系数是一个重要的超参数。 | 计算相对高效。主要问题是非对称性,以及在分布支撑集不重叠时值为无穷大,可能导致梯度爆炸。 |
| JS 散度 | 生成对抗网络(GAN):原始 GAN 论文中,生成器的目标是最小化生成分布与真实分布之间的 JS 散度。 | 实践中很少直接计算,而是通过 GAN 的对抗训练过程隐式地优化。 | 解决了 KL 的非对称问题且有界。但其主要缺点是在分布不重叠时,JS 散度为常数 ,导致梯度消失,使得 GAN 训练困难,这也是 WGAN 等后续工作要解决的核心问题。 |
| Wasserstein | WGAN, WGAN-GP:作为 GAN 的损失函数,用于稳定训练过程,缓解模式崩溃 (mode collapse)。 | 实现其对偶形式需要一个额外的神经网络(称为 Critic),并对 Critic 的权重施加约束(权重裁剪或梯度惩罚)来近似满足 1-Lipschitz 条件。 | 优点:即使分布不重叠也能提供平滑梯度,训练更稳定,损失值与生成图片质量有更好的相关性。<br>缺点:计算成本更高,因为需要训练一个 Critic 网络直到接近最优,并且每次更新生成器前要多次更新 Critic。 |
常见误区与边界情况
-
误区:KL 散度是距离度量。
- 纠正:不是。真正的距离度量必须满足对称性()和三角不等式。KL 散度两者都不满足,它是一种“散度”或“相对熵”,衡量的是信息损失而非空间距离。
-
误区:在分类任务中,最小化交叉熵和最小化 KL 散度是两件不同的事。
- 纠正:在机器学习优化中,它们是等价的。因为 ,而真实分布 P 是固定的(来自数据集的标签),所以其熵 是一个常数。因此,优化 等价于优化 ,梯度是完全一样的。选择交叉熵只是因为它形式更简洁。
-
边界情况:分布支撑集不重叠
- KL/JS:如果存在一点 使得 但 ,那么 会是无穷大。如果两个分布的支撑集完全不重叠,JS 散度会退化成一个常数,导致梯度为零,模型无法学习。这是它们在某些 GAN 场景下失效的根本原因。
- Wasserstein:这是它的优势所在。即使分布完全不重叠,它依然能衡量它们之间的“距离”(比如把数轴上
[0,1]区间的均匀分布搬到[10,11]区间,成本是 10),并提供平滑的梯度信号。
-
面试追问:为什么 WGAN 比原始 GAN 训练更稳定?
- 回答要点:核心在于损失函数的选择。原始 GAN 优化 JS 散度,当生成器分布和真实分布几乎没有重叠时(这在训练初期很常见),JS 散度是常数,判别器梯度消失,生成器无法得到有效信息进行更新。WGAN 采用 Wasserstein 距离,它对不重叠的分布依然能提供平滑的梯度,使得判别器(Critic)能持续为生成器提供有用的、非饱和的梯度信号,从而让训练过程更加稳定,有效避免了模式崩溃。
-
面试追问:在 WGAN 中,梯度惩罚(Gradient Penalty)相比权重裁剪(Weight Clipping)有什么优势?
- 回答要点:两者都是为了让 Critic 满足 1-Lipschitz 约束。权重裁剪(WGAN)方法过于粗暴,它将 Critic 的权重强制限制在一个小范围内(如
[-0.01, 0.01])。这会导致两个问题:1) Critic 的表达能力受限,可能无法学到最优的函数;2) 权重倾向于集中在边界值,导致梯度信息利用不充分。而梯度惩罚(WGAN-GP)通过在损失函数中加入一个惩罚项,鼓励 Critic 在生成样本和真实样本的插值点上的梯度范数接近 1。这是一种更“软”的约束,既能有效实施 Lipschitz 约束,又不会过度限制 Critic 的学习能力,从而获得更好的性能和更稳定的训练。
- 回答要点:两者都是为了让 Critic 满足 1-Lipschitz 约束。权重裁剪(WGAN)方法过于粗暴,它将 Critic 的权重强制限制在一个小范围内(如