交叉熵 vs MSE,分类为何用 CE?
- —手写带 label smoothing 的 CE
核心概念
交叉熵损失(Cross-Entropy Loss, CE)是衡量两个概率分布之间差异的指标。在机器学习分类任务中,它被用来度量模型预测的概率分布(如 Softmax 输出)与真实的标签分布(通常是 one-hot 编码)之间的“距离”。最小化交叉熵损失,等价于最大化模型预测正确类别的对数似然。
均方误差(Mean Squared Error, MSE)计算的是预测值与真实值之差的平方的均值,它度量的是欧氏空间中的距离。MSE 主要适用于回归问题,即预测一个连续值。虽然理论上可以用于分类,但它存在严重的梯度和优化问题。
原理与推导
核心结论是:对于分类问题,交叉熵(CE)的梯度形式比均方误差(MSE)更优越,能够避免梯度消失,从而实现更高效的训练。
我们以一个多分类问题为例进行推导。假设模型最后一层为线性层,其输出为 logits,记为 ,其中 是类别数。这些 logits 经过 Softmax 函数得到预测概率 :
真实标签 是一个 one-hot 向量,例如对于第 类, 且对所有 有 。
1. 交叉熵损失 (Cross-Entropy Loss)
公式定义: 对于单个样本,交叉熵损失为:
由于 是 one-hot 向量,假设真实类别是 ,则 其他为0,公式可以简化为:
这直观地表示:我们希望最大化正确类别 的预测概率 ,也就是最小化它的负对数。
梯度推导(关键步骤): 我们关心的是损失对 logits 的梯度,因为这是反向传播更新网络参数的起点。根据链式法则 。
首先,计算 :
其次,计算 Softmax 的导数 。这是一个经典结果:
其中 是克罗内克德尔塔函数(当 时为1,否则为0)。
将两者结合:
因为 是一个 one-hot 向量,。所以,我们得到了一个极其简洁和优美的结果:
梯度解释: 梯度就是 预测值与真实值之差。
- 如果模型对正确类别 的预测概率 接近 0(即 ),梯度很大,模型会快速学习。
- 如果预测概率 已经接近 1(即 ),梯度很小,学习会放缓。 这种线性的、与误差成正比的梯度信号,使得优化过程非常稳定和高效。
2. 均方误差损失 (Mean Squared Error)
公式定义: 对于单个样本,MSE 损失为:
梯度推导: 同样,我们计算对 logits 的梯度:
这个表达式非常复杂,但我们可以分析一个关键情况来理解其缺陷。
梯度解释(饱和问题): 考虑模型在一个二分类问题中犯了严重错误:真实标签是 1 (),但模型预测其概率 趋近于 0。此时,我们期望一个很大的梯度来纠正这个错误。
- 在这种情况下,。
- 的表达式中会包含 和 这样的项(来自Softmax的导数)。当 时,导数项 也趋近于 0。
- 这导致了 梯度消失:即使模型犯了天大的错误(预测概率为0),但由于 Softmax 输出的饱和特性,传递给 logits 的梯度也几乎为零。模型无法从中学习,陷入了次优的“平坦”区域。
3. 信息论解释
- 交叉熵 可以分解为 ,其中 是真实分布的信息熵, 是 KL 散度。由于真实标签 是固定的,其熵 是一个常数。因此,最小化交叉熵等价于最小化真实分布与预测分布之间的 KL 散度,这具有清晰的统计意义。
- MSE 则是假设了一个高斯分布的误差模型,这与分类问题中离散的、由伯努利或多项式分布描述的标签性质不符。
代码实现
以下是使用 PyTorch 手动实现带标签平滑(Label Smoothing)的交叉熵损失。标签平滑是一种正则化技术,可以防止模型变得过于自信。
1import torch2import torch.nn.functional as F34def label_smoothing_cross_entropy(logits: torch.Tensor,5 targets: torch.Tensor,6 smoothing: float = 0.1) -> torch.Tensor:7 """8 手写带标签平滑的交叉熵损失函数。910 Args:11 logits (torch.Tensor): 模型的原始输出,形状为 (N, C),其中 N 是批量大小,C 是类别数。12 targets (torch.Tensor): 真实标签,形状为 (N,),值为 0 到 C-1 的整数。13 smoothing (float): 标签平滑系数 epsilon。1415 Returns:16 torch.Tensor: 计算出的平均损失,一个标量。17 """18 # 1. 参数检查与准备19 if logits.dim() != 2:20 raise ValueError(f"期望 logits 是 2D 张量 (N, C),但得到 {logits.dim()}D")21 if targets.dim() != 1:22 raise ValueError(f"期望 targets 是 1D 张量 (N,),但得到 {targets.dim()}D")23 if not 0.0 <= smoothing < 1.0:24 raise ValueError(f"smoothing 值必须在 [0, 1) 范围内,但得到 {smoothing}")2526 N, C = logits.shape2728 # 2. 创建平滑后的目标分布29 # 为什么这样做? 这是标签平滑的核心。我们不再使用 [0, 0, 1, 0] 这样的 one-hot 硬标签,30 # 而是使用一个“软”标签。例如,对于正确类别,我们给它 1-smoothing 的概率,31 # 然后把 smoothing 的概率均匀地分给所有 C 个类别。32 # true_dist[i] = (1 - smoothing) * one_hot(target[i]) + smoothing / C3334 # 创建一个 (N, C) 的张量,用平滑值填充35 # off-target 的概率值36 smooth_value = smoothing / C37 true_dist = torch.full_like(logits, fill_value=smooth_value)3839 # on-target 的概率值40 confidence = 1.0 - smoothing4142 # 为什么用 scatter_? 这是一个高效的 in-place 操作,用于根据 targets 中的索引,43 # 将 confidence 值填充到 true_dist 的正确位置。44 # dim=1 表示在第1个维度(类别维度)上进行操作。45 # targets.unsqueeze(1) 将 (N,) 的 targets 变形为 (N, 1) 以匹配 scatter_ 的 index 参数要求。46 true_dist.scatter_(1, targets.unsqueeze(1), confidence + smooth_value)47 # 注意:上面直接用 confidence 填充会更直观,但 scatter_ 是累加操作,所以我们填充 confidence + smooth_value48 # 更标准的做法是:49 # true_dist.fill_(smoothing / C)50 # true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing + (smoothing / C)) # 这样做更严谨51 # 但为了简洁,我们直接用 scatter_ 覆盖52 # 修正为更标准的写法:53 true_dist.fill_(smoothing / (C - 1)) # 将 smoothing 分给 C-1 个错误类别54 true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing)555657 # 3. 计算交叉熵58 # 为什么用 log_softmax? 直接计算 log(softmax(x)) 在数值上是不稳定的。59 # 当 logits 中的某个值非常大时,exp(z) 可能会溢出。60 # F.log_softmax 内部使用了 log-sum-exp 技巧,可以保证数值稳定性。61 log_probs = F.log_softmax(logits, dim=-1)6263 # 4. 计算最终损失64 # 为什么是 - (true_dist * log_probs).sum(dim=-1)?65 # 这就是交叉熵 H(p, q) = - sum(p(x) * log(q(x))) 的直接实现。66 # p 是我们的平滑后标签 true_dist,q 是模型的预测分布 log_probs。67 # sum(dim=-1) 对每个样本的类别维度求和,得到每个样本的损失。68 # .mean() 对一个批次内的所有样本损失求平均。69 loss = - (true_dist * log_probs).sum(dim=-1)7071 return loss.mean()7273# --- 代码练习和验证 ---74if __name__ == '__main__':75 # 模拟数据76 N, C = 4, 5 # 4个样本,5个类别77 dummy_logits = torch.randn(N, C)78 dummy_targets = torch.randint(0, C, (N,))7980 print("--- 自定义标签平滑交叉熵 ---")81 smoothing_factor = 0.182 custom_loss = label_smoothing_cross_entropy(dummy_logits, dummy_targets, smoothing=smoothing_factor)83 print(f"Logits:\n{dummy_logits}")84 print(f"Targets: {dummy_targets}")85 print(f"自定义损失 (smoothing={smoothing_factor}): {custom_loss.item():.4f}")8687 print("\n--- 与 PyTorch 内置函数对比 (smoothing=0) ---")88 # 当 smoothing=0 时,我们的函数应该等价于标准的交叉熵损失89 custom_loss_no_smoothing = label_smoothing_cross_entropy(dummy_logits, dummy_targets, smoothing=0.0)9091 # PyTorch 的 CrossEntropyLoss 包含了 log_softmax92 pytorch_loss_fn = torch.nn.CrossEntropyLoss()93 pytorch_loss = pytorch_loss_fn(dummy_logits, dummy_targets)9495 print(f"自定义损失 (smoothing=0.0): {custom_loss_no_smoothing.item():.4f}")96 print(f"PyTorch 内置损失: {pytorch_loss.item():.4f}")97 assert torch.allclose(custom_loss_no_smoothing, pytorch_loss), "当 smoothing=0 时,结果应与 PyTorch 一致"98 print("测试通过:smoothing=0 时与 PyTorch 结果一致。")99100 # PyTorch 2.0+ 也内置了 label_smoothing101 if hasattr(torch.nn.CrossEntropyLoss, 'label_smoothing'):102 pytorch_ls_loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=smoothing_factor)103 pytorch_ls_loss = pytorch_ls_loss_fn(dummy_logits, dummy_targets)104 print(f"\nPyTorch 内置标签平滑损失 (smoothing={smoothing_factor}): {pytorch_ls_loss.item():.4f}")105 assert torch.allclose(custom_loss, pytorch_ls_loss), "自定义标签平滑与 PyTorch 内置不一致"106 print("测试通过:自定义标签平滑与 PyTorch 内置结果一致。")
工程实践
- 使用场景:交叉熵是几乎所有分类问题的默认和首选损失函数,包括图像分类(ImageNet)、自然语言处理(文本分类、机器翻译)、推荐系统(预测用户点击的下一个物品)等。
- 超参数选择:
- 标签平滑
smoothing:通常取0.1左右。它是一种正则化手段,可以防止模型对训练数据过拟合,提高泛化能力。对于噪声较大的数据集尤其有效。 - 类别权重
weight:在处理类别不平衡问题时,torch.nn.CrossEntropyLoss的weight参数非常重要。可以为样本数少的类别赋予更高的权重,以平衡它们对总损失的贡献。
- 标签平滑
- 性能/显存/吞吐权衡:
- CE 本身计算非常高效。
- 对于拥有海量类别(如词汇表大小为几十万的语言模型)的场景,计算完整的 Softmax 和 CE 会非常耗时。此时会采用**负采样(Negative Sampling)**或 **噪声对比估计(Noise Contrastive Estimation, NCE)**等方法来近似计算损失,或者使用 Hierarchical Softmax。
- 常见坑和调试技巧:
- 输入类型错误:
torch.nn.CrossEntropyLoss期望的输入是 原始 logits,而不是已经经过 Softmax 的概率。它内部会高效且稳定地执行log_softmax。将概率传入会等同于对概率又做了一次 Softmax,导致错误。 - Loss 为 NaN:如果手动实现
log(softmax(x)),当 logits 过大或过小,可能导致数值溢出或log(0),产生NaN。始终使用F.log_softmax或nn.CrossEntropyLoss来保证数值稳定性。 - 检查标签范围:确保
targets的值在[0, C-1]范围内。如果标签越界,会导致CUDA错误或 index out of bounds 错误。
- 输入类型错误:
常见误区与边界情况
-
误区一:CE 和 MSE 只是选择不同,效果差不多。
- 纠正:这是最根本的误区。如原理部分所述,MSE 在分类问题中存在梯度饱和问题,会导致训练极其缓慢或停滞。CE 的梯度形式则完美匹配分类任务的优化需求。
-
误区二:CE 必须搭配 one-hot 标签。
- 纠正:不完全是。标准的 CE 确实是基于 one-hot 假设,但其思想可以扩展。标签平滑就是一种使用“软”标签的例子。在知识蒸馏中,教师模型的软目标(概率分布)也可以用来训练学生模型,同样使用交叉熵(或KL散度)。
-
误区三:
nn.CrossEntropyLoss和nn.NLLLoss没区别。- 纠正:
nn.CrossEntropyLoss(logits, target)等价于nn.NLLLoss(F.log_softmax(logits), target)。NLLLoss(Negative Log Likelihood Loss) 期望的输入是对数概率。CrossEntropyLoss是一个更方便的封装,直接接受原始 logits。
- 纠正:
-
边界情况:二分类问题
- 对于二分类问题,交叉熵退化为二元交叉熵(Binary Cross-Entropy, BCE)。
- 此时,模型最后一层通常是一个神经元,输出一个 logit,然后通过 Sigmoid 函数得到概率 。
- 在 PyTorch 中,
nn.BCELoss对应此场景,但同样推荐使用更稳定的nn.BCEWithLogitsLoss,它将 Sigmoid 和 BCE 合并在一起。
-
面试追问:
- 问:“既然 CE 这么好,MSE 在分类中就一无是处吗?”
- 答:基本是的。在标准的监督分类范式下,几乎没有理由选择 MSE。它的统计假设不匹配,优化特性也差。但在一些特殊领域,如某些自编码器或度量学习的变体中,可能会在特征空间中使用 MSE,但这已经脱离了“预测类别”的范畴。
- 问:“标签平滑为什么能起作用?”
- 答:1) 正则化:它向真实标签中注入了噪声,降低了模型对训练标签的置信度,防止其在训练集上过拟合,鼓励模型学习更鲁棒的特征。2) 改善 logits 分布:它惩罚了那些使得类间 logit 差异过大的模型,使得最终的 logits 更加紧凑,这被认为有助于提高模型的泛化和校准能力。
- 问:“交叉熵和KL散度是什么关系?”
- 答:。在机器学习分类中,真实分布 是固定的(one-hot),所以它的熵 是一个常数(对于 one-hot,H(p)=0)。因此,最小化交叉熵 就等价于最小化 和 之间的 KL 散度 。KL散度是衡量两个分布差异的“标准”方法。
- 问:“既然 CE 这么好,MSE 在分类中就一无是处吗?”