熵、条件熵、互信息、Perplexity 的公式与直觉?
- —给定 logits 计算 PPL
核心概念
信息熵(Entropy)是信息论中的一个核心概念,用于度量一个随机变量的不确定性。一个随机变量的熵越大,表示它的不确定性越大,或者说,要搞清楚它的具体取值需要越多的信息。条件熵(Conditional Entropy)则是在已知另一个随机变量的条件下,某个随机变量剩余的不确定性。互信息(Mutual Information)衡量的是两个随机变量之间共享的信息量,即知道一个变量后,另一个变量不确定性减少的程度。困惑度(Perplexity, PPL)是衡量概率模型预测样本好坏的指标,尤其在自然语言处理中用于评估语言模型,它可以被理解为模型在预测下一个词时,平均“困惑”的选项数量。
原理与推导
1. 熵 (Entropy)
熵衡量的是一个随机变量 的平均不确定性。其核心思想是,发生概率越低的事件,其包含的信息量越大。一个事件 的信息量(或称自信息)定义为:
其中 是事件 发生的概率,对数底 通常取 2(单位为比特, bits)、(单位为奈特, nats)或 10。在理论分析中常用 2,在深度学习中常用 。
熵 就是该随机变量所有可能取值的信息量的期望:
对于连续随机变量,熵被称为微分熵:
直观解释: 想象一个系统,如果它只有一种状态(例如,一枚两面都是正面的硬币),那么它的结果是确定的,熵为 ,没有任何不确定性。如果它有两种等概率的状态(一枚均匀的硬币),熵为 bit,不确定性最大。你需要 1 bit 的信息(正面还是反面)来消除这种不确定性。
2. 条件熵 (Conditional Entropy)
条件熵 衡量在已知随机变量 的情况下,随机变量 的剩余不确定性。它被定义为在给定 的各个取值 下 的熵的期望:
其中 是在 这个具体条件下 的熵:
将两者结合,得到完整公式:
直观解释: 假设 是“天气”(晴、雨), 是“天空是否有云”(有、无)。 是天气本身的不确定性。 是指,当你抬头看到天空有云(或无云)后,对于天气是晴是雨这件事,还剩下多少不确定性。通常,知道 会减少 的不确定性,所以 。
3. 互信息 (Mutual Information)
互信息 衡量了两个随机变量 和 之间的相互依赖程度。它可以从三个等价的角度来理解:
角度一:熵的减少量 知道 后, 的不确定性从 减少到了 。这个减少的量就是互信息。
由于对称性,它也等于:
角度二:与联合熵的关系 利用链式法则 ,我们可以推导出:
角度三:KL 散度 互信息衡量了联合分布 与边缘分布乘积 之间的差异程度,即变量 和 离“独立”有多远。这个差异可以用 KL 散度来度量:
几何解释 (Venn 图): 把 和 看作两个信息集合的面积, 是它们的并集面积, 则是它们的交集面积。 是 中独有的部分, 是 中独有的部分。
4. 困惑度 (Perplexity, PPL)
困惑度是交叉熵(Cross-Entropy)的指数形式,用于评估语言模型。首先,定义真实分布 和模型预测分布 之间的交叉熵:
在语言模型中,对于一个语料库(测试集) ,我们希望最大化其概率 。为了计算方便和数值稳定性,我们通常处理其对数形式,并按词数归一化,这正是交叉熵:
这里, 是经验分布,即在位置 处,真实下一个词 的概率为1,其他词为0。 是模型的预测概率分布。
困惑度 PPL 定义为:
在深度学习框架中,通常使用自然对数(),所以公式为 。
直观解释: PPL 可以被看作是模型在预测下一个词时,平均面临的有效选项数。如果 PPL 为 10,意味着模型在预测每个词时的不确定性,等价于从 10 个等概率的选项中进行猜测。PPL 越低,说明模型的预测越准确,不确定性越小,语言模型性能越好。
算法复杂度: 对于一个长度为 的序列和大小为 的词汇表,计算交叉熵和 PPL 的时间复杂度是 ,因为在每个时间步,我们只需要查找真实目标词的预测概率,而不需要遍历整个词汇表。
代码实现
以下代码演示了在 PyTorch 中如何从模型输出的 logits 计算困惑度(PPL)。
1import torch2import torch.nn as nn3import numpy as np45def calculate_ppl_from_logits(logits, labels):6 """7 给定模型的 logits 和真实标签,计算困惑度 (Perplexity)。89 Args:10 logits (torch.Tensor): 模型的原始输出,形状为 (batch_size, sequence_length, vocab_size)。11 labels (torch.Tensor): 真实标签,形状为 (batch_size, sequence_length)。1213 Returns:14 float: 计算出的困惑度。15 """16 # 为什么这样做:PyTorch的CrossEntropyLoss期望的输入形状是:17 # Input: (N, C) 其中 C = number of classes18 # Target: (N)19 # 因此,我们需要将 logits 和 labels 的形状进行调整。20 batch_size, sequence_length, vocab_size = logits.shape2122 # 将 logits 从 (batch, seq_len, vocab) 变形为 (batch * seq_len, vocab)23 logits_reshaped = logits.view(-1, vocab_size)2425 # 将 labels 从 (batch, seq_len) 变形为 (batch * seq_len)26 labels_reshaped = labels.view(-1)2728 # 为什么这样做:使用 torch.nn.CrossEntropyLoss 计算交叉熵。29 # 这个函数内部整合了 LogSoftmax 和 NLLLoss,因此可以直接接受原始的 logits。30 # 它还自动处理了 one-hot 编码的逻辑,我们只需要传入类别索引即可。31 # `ignore_index` 参数用于忽略填充(padding)的 token,在计算损失时不考虑它们。32 # 假设 -100 是用于 padding 的标签索引。33 loss_fn = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')3435 # 计算平均交叉熵损失。36 # 注意:PyTorch 中的交叉熵默认使用自然对数 (base e)。37 cross_entropy_loss = loss_fn(logits_reshaped, labels_reshaped)3839 # 为什么这样做:困惑度的公式是 exp(cross_entropy_loss)。40 # 因为交叉熵使用的是自然对数,所以这里用 torch.exp()。41 # 如果交叉熵使用 log2,则应为 2**cross_entropy_loss。42 perplexity = torch.exp(cross_entropy_loss)4344 return perplexity.item()4546# --- 代码练习:给定 logits 计算 PPL ---4748# 1. 模拟数据49batch_size = 250sequence_length = 551vocab_size = 105253# 模拟模型输出的 logits (随机值)54# logits 是未经 softmax 的原始分数55np.random.seed(42)56mock_logits = torch.from_numpy(np.random.randn(batch_size, sequence_length, vocab_size)).float()5758# 模拟真实标签59# 标签是词汇表中的索引。-100 通常用作 padding 标记,在计算损失时会被忽略。60mock_labels = torch.tensor([61 [1, 5, 2, 7, -100], # 第一个序列,最后一个是 padding62 [3, 8, 4, 1, 9] # 第二个序列63])6465print("模拟 Logits 形状:", mock_logits.shape)66print("模拟 Labels 形状:", mock_labels.shape)67print("-" * 30)6869# 2. 调用函数计算 PPL70ppl = calculate_ppl_from_logits(mock_logits, mock_labels)7172print(f"计算得到的交叉熵损失 (base e): {np.log(ppl):.4f}")73print(f"计算得到的困惑度 (PPL): {ppl:.4f}")7475# --- 手动验证 ---76# 为了教学目的,我们手动分解计算过程77loss_fn_manual = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')78logits_flat = mock_logits.view(-1, vocab_size)79labels_flat = mock_labels.view(-1)80individual_losses = loss_fn_manual(logits_flat, labels_flat)8182# 筛选出非-100的有效损失83valid_losses = individual_losses[labels_flat != -100]84mean_loss = valid_losses.mean()85manual_ppl = torch.exp(mean_loss)8687print("-" * 30)88print("手动计算的平均交叉熵:", mean_loss.item())89print("手动计算的 PPL:", manual_ppl.item())90print("两种计算方式结果是否一致:", np.isclose(ppl, manual_ppl.item()))
工程实践
-
使用场景:
- 熵: 在决策树算法(如 C4.5, ID3)中,用信息增益(基于熵)来选择最佳分裂特征。在主动学习中,可以选择模型最不确定(熵最大)的样本进行标注。
- 互信息: 用于特征选择,选择与目标变量互信息最高的特征。在表示学习中,对比学习的目标(如 InfoNCE loss)可以看作是最大化同一样本不同增广视图之间的互信息。
- Perplexity: 语言模型(LMs)最核心和最常用的评估指标。无论是学术研究还是工业界模型调优(如 GPT 系列、
BERT的MLM任务),PPL 都是衡量模型基础性能的关键。PPL 越低,模型在给定上文时,对下一个词的预测就越准确。
-
超参数选择:
- PPL 本身是一个评估指标,不是超参数。但它受模型架构、训练数据、词汇表大小、优化器等所有超参数的影响。
- 在评估 PPL 时,要确保测试集与训练集没有重叠,并且与验证集分开。
- PPL 对词汇表大小很敏感。使用子词(subword)切分(如 BPE, WordPiece)比使用词(word)切分通常会得到更低的 PPL,因为子词能有效处理未登录词(OOV)问题。
-
性能 / 显存 / 吞吐 的权衡:
- 计算 PPL 的开销相对较小,主要是在推理阶段。其计算复杂度与序列长度和批次大小成正比。
- 在大规模评估时,可以在一部分有代表性的测试集上计算 PPL,而不是全量数据,以节省时间。
- 在训练过程中,监控验证集上的 PPL 是判断模型是否过拟合、决定何时停止训练的关键。
-
常见坑和调试技巧:
- PPL 为
inf或NaN: 这通常意味着模型为测试集中出现的某个词分配了 0 概率。- 原因: 词汇表不匹配、未登录词(OOV)问题。
- 解决: 使用平滑技术(如拉普拉斯平滑,但在深度学习中不常用),或更根本地使用子词切分(BPE/SentencePiece)来避免严格的 OOV 问题。
- PPL 极高: 模型性能差,预测接近于均匀分布。检查训练过程是否有问题,如学习率过高、梯度爆炸、数据加载错误等。
- 不同框架/实现的 PPL 不可比: 一定要确认计算 PPL 时使用的对数底是 还是 2。PyTorch/TensorFlow 默认是 。此外,是否正确处理了 padding、句首/句尾符(BOS/EOS)都会影响最终结果。
- PPL 为
常见误区与边界情况
-
误区一:熵 vs 交叉熵:
- 熵 是一个概率分布 内在不确定性的度量。
- 交叉熵 是度量使用基于分布 的编码方式去编码来自真实分布 的样本所需的平均比特数。它涉及两个分布。当 时,。
-
误区二:PPL 的对数底:
- 理论上,信息论常用
base 2,PPL 的直观解释(如“等价于在 N 个选项中猜测”)也基于此。 - 工程上,深度学习框架的损失函数普遍使用自然对数
base e,因此 PPL 公式为 。面试时能清晰说明这一点是加分项。
- 理论上,信息论常用
-
误区三:PPL < 1?:
- 不可能。交叉熵 。因此 。PPL 的理论最小值为 1,此时模型对测试集每个词的预测概率都为 1,即完美预测。
-
边界情况:确定性分布:
- 如果一个随机变量是确定的(例如 ),它的熵为 0。
- 如果模型对某个词的预测概率为 1,而真实情况也是这个词,那么该词的交叉熵贡献为 。
-
常见面试追问:
- 问: 互信息和相关系数(Correlation)有什么区别?
- 答: 相关系数只能衡量线性关系,而互信息可以捕捉任何类型的统计依赖关系(线性和非线性)。如果两个变量互信息为 0,它们一定是统计独立的;但如果它们相关系数为 0,它们可能仍然存在非线性关系。
- 问: 为什么在语言模型评估中,我们用 PPL 而不是准确率(Accuracy)?
- 答: 语言模型是一个生成任务,词汇表非常大(数万个),在每个位置预测完全正确的词非常困难,导致准确率会非常低,无法有效区分模型好坏。PPL 是一个更平滑的指标,它奖励那些给正确词更高概率的模型,即使该概率不是最高的。它衡量的是模型预测分布与真实分布的接近程度,比非黑即白的准确率信息量更丰富。
- 问: 如何比较两个在不同词汇表上训练的模型的 PPL?
- 答: 直接比较是无意义的。PPL 对词汇表大小敏感。一个拥有更大词汇表的模型天然面临更难的预测任务,其 PPL 可能会更高,但这不代表模型更差。公平比较需要在完全相同的词汇表和切分方式下进行。