§1.3.25

交叉熵 vs MSE,分类为何用 CE?

手写练习
  • 手写带 label smoothing 的 CE

核心概念

交叉熵损失(Cross-Entropy Loss, CE)是衡量两个概率分布之间差异的指标。在机器学习分类任务中,它被用来度量模型预测的概率分布(如 Softmax 输出)与真实的标签分布(通常是 one-hot 编码)之间的“距离”。最小化交叉熵损失,等价于最大化模型预测正确类别的对数似然。

均方误差(Mean Squared Error, MSE)计算的是预测值与真实值之差的平方的均值,它度量的是欧氏空间中的距离。MSE 主要适用于回归问题,即预测一个连续值。虽然理论上可以用于分类,但它存在严重的梯度和优化问题。

原理与推导

核心结论是:对于分类问题,交叉熵(CE)的梯度形式比均方误差(MSE)更优越,能够避免梯度消失,从而实现更高效的训练。

我们以一个多分类问题为例进行推导。假设模型最后一层为线性层,其输出为 logits,记为 z=[z1,z2,...,zC]z = [z_1, z_2, ..., z_C],其中 CC 是类别数。这些 logits 经过 Softmax 函数得到预测概率 y^\hat{y}

y^j=Softmax(z)j=ezjk=1Cezk\hat{y}_j = \text{Softmax}(z)_j = \frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}}

真实标签 yy 是一个 one-hot 向量,例如对于第 cc 类, yc=1y_c=1 且对所有 jcj \neq cyj=0y_j=0


1. 交叉熵损失 (Cross-Entropy Loss)

公式定义: 对于单个样本,交叉熵损失为:

LCE(y^,y)=j=1Cyjlog(y^j)L_{CE}(\hat{y}, y) = - \sum_{j=1}^{C} y_j \log(\hat{y}_j)

由于 yy 是 one-hot 向量,假设真实类别是 cc,则 yc=1y_c=1 其他为0,公式可以简化为:

LCE=log(y^c)L_{CE} = - \log(\hat{y}_c)

这直观地表示:我们希望最大化正确类别 cc 的预测概率 y^c\hat{y}_c,也就是最小化它的负对数。

梯度推导(关键步骤): 我们关心的是损失对 logits zjz_j 的梯度,因为这是反向传播更新网络参数的起点。根据链式法则 Lzj=k=1CLy^ky^kzj\frac{\partial L}{\partial z_j} = \sum_{k=1}^{C} \frac{\partial L}{\partial \hat{y}_k} \frac{\partial \hat{y}_k}{\partial z_j}

首先,计算 LCEy^k\frac{\partial L_{CE}}{\partial \hat{y}_k}:

LCEy^k=y^k(i=1Cyilog(y^i))=yky^k\frac{\partial L_{CE}}{\partial \hat{y}_k} = \frac{\partial}{\partial \hat{y}_k} \left( - \sum_{i=1}^{C} y_i \log(\hat{y}_i) \right) = - \frac{y_k}{\hat{y}_k}

其次,计算 Softmax 的导数 y^kzj\frac{\partial \hat{y}_k}{\partial z_j}。这是一个经典结果:

y^kzj=y^k(δkjy^j)\frac{\partial \hat{y}_k}{\partial z_j} = \hat{y}_k (\delta_{kj} - \hat{y}_j)

其中 δkj\delta_{kj} 是克罗内克德尔塔函数(当 k=jk=j 时为1,否则为0)。

将两者结合:

LCEzj=k=1C(yky^k)(y^k(δkjy^j))=k=1Cyk(δkjy^j)=(yjk=1Cyky^j)=(yjy^jk=1Cyk)\begin{aligned} \frac{\partial L_{CE}}{\partial z_j} &= \sum_{k=1}^{C} \left( - \frac{y_k}{\hat{y}_k} \right) \cdot \left( \hat{y}_k (\delta_{kj} - \hat{y}_j) \right) \\ &= - \sum_{k=1}^{C} y_k (\delta_{kj} - \hat{y}_j) \\ &= - \left( y_j - \sum_{k=1}^{C} y_k \hat{y}_j \right) \\ &= - \left( y_j - \hat{y}_j \sum_{k=1}^{C} y_k \right) \end{aligned}

因为 yy 是一个 one-hot 向量,k=1Cyk=1\sum_{k=1}^{C} y_k = 1。所以,我们得到了一个极其简洁和优美的结果:

LCEzj=y^jyj\frac{\partial L_{CE}}{\partial z_j} = \hat{y}_j - y_j

梯度解释: 梯度就是 预测值与真实值之差

  • 如果模型对正确类别 cc 的预测概率 y^c\hat{y}_c 接近 0(即 y^cyc=y^c11\hat{y}_c - y_c = \hat{y}_c - 1 \approx -1),梯度很大,模型会快速学习。
  • 如果预测概率 y^c\hat{y}_c 已经接近 1(即 y^cyc=y^c10\hat{y}_c - y_c = \hat{y}_c - 1 \approx 0),梯度很小,学习会放缓。 这种线性的、与误差成正比的梯度信号,使得优化过程非常稳定和高效。

2. 均方误差损失 (Mean Squared Error)

公式定义: 对于单个样本,MSE 损失为:

LMSE(y^,y)=j=1C(y^jyj)2L_{MSE}(\hat{y}, y) = \sum_{j=1}^{C} (\hat{y}_j - y_j)^2

梯度推导: 同样,我们计算对 logits zjz_j 的梯度:

LMSEzj=k=1CLMSEy^ky^kzj=k=1C2(y^kyk)(y^k(δkjy^j))\begin{aligned} \frac{\partial L_{MSE}}{\partial z_j} &= \sum_{k=1}^{C} \frac{\partial L_{MSE}}{\partial \hat{y}_k} \frac{\partial \hat{y}_k}{\partial z_j} \\ &= \sum_{k=1}^{C} 2(\hat{y}_k - y_k) \cdot \left( \hat{y}_k (\delta_{kj} - \hat{y}_j) \right) \end{aligned}

这个表达式非常复杂,但我们可以分析一个关键情况来理解其缺陷。

梯度解释(饱和问题): 考虑模型在一个二分类问题中犯了严重错误:真实标签是 1 (y1=1y_1=1),但模型预测其概率 y^1\hat{y}_1 趋近于 0。此时,我们期望一个很大的梯度来纠正这个错误。

  • 在这种情况下,y^10\hat{y}_1 \approx 0
  • LMSEz1\frac{\partial L_{MSE}}{\partial z_1} 的表达式中会包含 y^1\hat{y}_1(1y^1)(1-\hat{y}_1) 这样的项(来自Softmax的导数)。当 y^10\hat{y}_1 \to 0 时,导数项 y^1(1y^1)\hat{y}_1(1-\hat{y}_1) 也趋近于 0。
  • 这导致了 梯度消失:即使模型犯了天大的错误(预测概率为0),但由于 Softmax 输出的饱和特性,传递给 logits 的梯度也几乎为零。模型无法从中学习,陷入了次优的“平坦”区域。

3. 信息论解释

  • 交叉熵 H(y,y^)H(y, \hat{y}) 可以分解为 H(y)+DKL(yy^)H(y) + D_{KL}(y || \hat{y}),其中 H(y)H(y) 是真实分布的信息熵, DKLD_{KL} 是 KL 散度。由于真实标签 yy 是固定的,其熵 H(y)H(y) 是一个常数。因此,最小化交叉熵等价于最小化真实分布与预测分布之间的 KL 散度,这具有清晰的统计意义。
  • MSE 则是假设了一个高斯分布的误差模型,这与分类问题中离散的、由伯努利或多项式分布描述的标签性质不符。

代码实现

以下是使用 PyTorch 手动实现带标签平滑(Label Smoothing)的交叉熵损失。标签平滑是一种正则化技术,可以防止模型变得过于自信。

python
1import torch
2import torch.nn.functional as F
3
4def label_smoothing_cross_entropy(logits: torch.Tensor,
5 targets: torch.Tensor,
6 smoothing: float = 0.1) -> torch.Tensor:
7 """
8 手写带标签平滑的交叉熵损失函数。
9
10 Args:
11 logits (torch.Tensor): 模型的原始输出,形状为 (N, C),其中 N 是批量大小,C 是类别数。
12 targets (torch.Tensor): 真实标签,形状为 (N,),值为 0 到 C-1 的整数。
13 smoothing (float): 标签平滑系数 epsilon。
14
15 Returns:
16 torch.Tensor: 计算出的平均损失,一个标量。
17 """
18 # 1. 参数检查与准备
19 if logits.dim() != 2:
20 raise ValueError(f"期望 logits 是 2D 张量 (N, C),但得到 {logits.dim()}D")
21 if targets.dim() != 1:
22 raise ValueError(f"期望 targets 是 1D 张量 (N,),但得到 {targets.dim()}D")
23 if not 0.0 <= smoothing < 1.0:
24 raise ValueError(f"smoothing 值必须在 [0, 1) 范围内,但得到 {smoothing}")
25
26 N, C = logits.shape
27
28 # 2. 创建平滑后的目标分布
29 # 为什么这样做? 这是标签平滑的核心。我们不再使用 [0, 0, 1, 0] 这样的 one-hot 硬标签,
30 # 而是使用一个“软”标签。例如,对于正确类别,我们给它 1-smoothing 的概率,
31 # 然后把 smoothing 的概率均匀地分给所有 C 个类别。
32 # true_dist[i] = (1 - smoothing) * one_hot(target[i]) + smoothing / C
33
34 # 创建一个 (N, C) 的张量,用平滑值填充
35 # off-target 的概率值
36 smooth_value = smoothing / C
37 true_dist = torch.full_like(logits, fill_value=smooth_value)
38
39 # on-target 的概率值
40 confidence = 1.0 - smoothing
41
42 # 为什么用 scatter_? 这是一个高效的 in-place 操作,用于根据 targets 中的索引,
43 # 将 confidence 值填充到 true_dist 的正确位置。
44 # dim=1 表示在第1个维度(类别维度)上进行操作。
45 # targets.unsqueeze(1) 将 (N,) 的 targets 变形为 (N, 1) 以匹配 scatter_ 的 index 参数要求。
46 true_dist.scatter_(1, targets.unsqueeze(1), confidence + smooth_value)
47 # 注意:上面直接用 confidence 填充会更直观,但 scatter_ 是累加操作,所以我们填充 confidence + smooth_value
48 # 更标准的做法是:
49 # true_dist.fill_(smoothing / C)
50 # true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing + (smoothing / C)) # 这样做更严谨
51 # 但为了简洁,我们直接用 scatter_ 覆盖
52 # 修正为更标准的写法:
53 true_dist.fill_(smoothing / (C - 1)) # 将 smoothing 分给 C-1 个错误类别
54 true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing)
55
56
57 # 3. 计算交叉熵
58 # 为什么用 log_softmax? 直接计算 log(softmax(x)) 在数值上是不稳定的。
59 # 当 logits 中的某个值非常大时,exp(z) 可能会溢出。
60 # F.log_softmax 内部使用了 log-sum-exp 技巧,可以保证数值稳定性。
61 log_probs = F.log_softmax(logits, dim=-1)
62
63 # 4. 计算最终损失
64 # 为什么是 - (true_dist * log_probs).sum(dim=-1)?
65 # 这就是交叉熵 H(p, q) = - sum(p(x) * log(q(x))) 的直接实现。
66 # p 是我们的平滑后标签 true_dist,q 是模型的预测分布 log_probs。
67 # sum(dim=-1) 对每个样本的类别维度求和,得到每个样本的损失。
68 # .mean() 对一个批次内的所有样本损失求平均。
69 loss = - (true_dist * log_probs).sum(dim=-1)
70
71 return loss.mean()
72
73# --- 代码练习和验证 ---
74if __name__ == '__main__':
75 # 模拟数据
76 N, C = 4, 5 # 4个样本,5个类别
77 dummy_logits = torch.randn(N, C)
78 dummy_targets = torch.randint(0, C, (N,))
79
80 print("--- 自定义标签平滑交叉熵 ---")
81 smoothing_factor = 0.1
82 custom_loss = label_smoothing_cross_entropy(dummy_logits, dummy_targets, smoothing=smoothing_factor)
83 print(f"Logits:\n{dummy_logits}")
84 print(f"Targets: {dummy_targets}")
85 print(f"自定义损失 (smoothing={smoothing_factor}): {custom_loss.item():.4f}")
86
87 print("\n--- 与 PyTorch 内置函数对比 (smoothing=0) ---")
88 # 当 smoothing=0 时,我们的函数应该等价于标准的交叉熵损失
89 custom_loss_no_smoothing = label_smoothing_cross_entropy(dummy_logits, dummy_targets, smoothing=0.0)
90
91 # PyTorch 的 CrossEntropyLoss 包含了 log_softmax
92 pytorch_loss_fn = torch.nn.CrossEntropyLoss()
93 pytorch_loss = pytorch_loss_fn(dummy_logits, dummy_targets)
94
95 print(f"自定义损失 (smoothing=0.0): {custom_loss_no_smoothing.item():.4f}")
96 print(f"PyTorch 内置损失: {pytorch_loss.item():.4f}")
97 assert torch.allclose(custom_loss_no_smoothing, pytorch_loss), "当 smoothing=0 时,结果应与 PyTorch 一致"
98 print("测试通过:smoothing=0 时与 PyTorch 结果一致。")
99
100 # PyTorch 2.0+ 也内置了 label_smoothing
101 if hasattr(torch.nn.CrossEntropyLoss, 'label_smoothing'):
102 pytorch_ls_loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=smoothing_factor)
103 pytorch_ls_loss = pytorch_ls_loss_fn(dummy_logits, dummy_targets)
104 print(f"\nPyTorch 内置标签平滑损失 (smoothing={smoothing_factor}): {pytorch_ls_loss.item():.4f}")
105 assert torch.allclose(custom_loss, pytorch_ls_loss), "自定义标签平滑与 PyTorch 内置不一致"
106 print("测试通过:自定义标签平滑与 PyTorch 内置结果一致。")

工程实践

  • 使用场景:交叉熵是几乎所有分类问题的默认和首选损失函数,包括图像分类(ImageNet)、自然语言处理(文本分类、机器翻译)、推荐系统(预测用户点击的下一个物品)等。
  • 超参数选择
    • 标签平滑 smoothing:通常取 0.1 左右。它是一种正则化手段,可以防止模型对训练数据过拟合,提高泛化能力。对于噪声较大的数据集尤其有效。
    • 类别权重 weight:在处理类别不平衡问题时,torch.nn.CrossEntropyLossweight 参数非常重要。可以为样本数少的类别赋予更高的权重,以平衡它们对总损失的贡献。
  • 性能/显存/吞吐权衡
    • CE 本身计算非常高效。
    • 对于拥有海量类别(如词汇表大小为几十万的语言模型)的场景,计算完整的 Softmax 和 CE 会非常耗时。此时会采用**负采样(Negative Sampling)**或 **噪声对比估计(Noise Contrastive Estimation, NCE)**等方法来近似计算损失,或者使用 Hierarchical Softmax
  • 常见坑和调试技巧
    • 输入类型错误torch.nn.CrossEntropyLoss 期望的输入是 原始 logits,而不是已经经过 Softmax 的概率。它内部会高效且稳定地执行 log_softmax。将概率传入会等同于对概率又做了一次 Softmax,导致错误。
    • Loss 为 NaN:如果手动实现 log(softmax(x)),当 logits 过大或过小,可能导致数值溢出或 log(0),产生 NaN。始终使用 F.log_softmaxnn.CrossEntropyLoss 来保证数值稳定性。
    • 检查标签范围:确保 targets 的值在 [0, C-1] 范围内。如果标签越界,会导致 CUDA 错误或 index out of bounds 错误。

常见误区与边界情况

  • 误区一:CE 和 MSE 只是选择不同,效果差不多。

    • 纠正:这是最根本的误区。如原理部分所述,MSE 在分类问题中存在梯度饱和问题,会导致训练极其缓慢或停滞。CE 的梯度形式则完美匹配分类任务的优化需求。
  • 误区二:CE 必须搭配 one-hot 标签。

    • 纠正:不完全是。标准的 CE 确实是基于 one-hot 假设,但其思想可以扩展。标签平滑就是一种使用“软”标签的例子。在知识蒸馏中,教师模型的软目标(概率分布)也可以用来训练学生模型,同样使用交叉熵(或KL散度)。
  • 误区三:nn.CrossEntropyLossnn.NLLLoss 没区别。

    • 纠正nn.CrossEntropyLoss(logits, target) 等价于 nn.NLLLoss(F.log_softmax(logits), target)NLLLoss (Negative Log Likelihood Loss) 期望的输入是对数概率。CrossEntropyLoss 是一个更方便的封装,直接接受原始 logits。
  • 边界情况:二分类问题

    • 对于二分类问题,交叉熵退化为二元交叉熵(Binary Cross-Entropy, BCE)
    • LBCE=[ylog(y^)+(1y)log(1y^)]L_{BCE} = -[y \log(\hat{y}) + (1-y) \log(1-\hat{y})]
    • 此时,模型最后一层通常是一个神经元,输出一个 logit,然后通过 Sigmoid 函数得到概率 y^\hat{y}
    • 在 PyTorch 中,nn.BCELoss 对应此场景,但同样推荐使用更稳定的 nn.BCEWithLogitsLoss,它将 Sigmoid 和 BCE 合并在一起。
  • 面试追问

    • :“既然 CE 这么好,MSE 在分类中就一无是处吗?”
      • :基本是的。在标准的监督分类范式下,几乎没有理由选择 MSE。它的统计假设不匹配,优化特性也差。但在一些特殊领域,如某些自编码器或度量学习的变体中,可能会在特征空间中使用 MSE,但这已经脱离了“预测类别”的范畴。
    • :“标签平滑为什么能起作用?”
      • :1) 正则化:它向真实标签中注入了噪声,降低了模型对训练标签的置信度,防止其在训练集上过拟合,鼓励模型学习更鲁棒的特征。2) 改善 logits 分布:它惩罚了那些使得类间 logit 差异过大的模型,使得最终的 logits 更加紧凑,这被认为有助于提高模型的泛化和校准能力。
    • :“交叉熵和KL散度是什么关系?”
      • H(p,q)=H(p)+DKL(pq)H(p, q) = H(p) + D_{KL}(p || q)。在机器学习分类中,真实分布 pp 是固定的(one-hot),所以它的熵 H(p)H(p) 是一个常数(对于 one-hot,H(p)=0)。因此,最小化交叉熵 H(p,q)H(p, q) 就等价于最小化 ppqq 之间的 KL 散度 DKL(pq)D_{KL}(p || q)。KL散度是衡量两个分布差异的“标准”方法。
相关题目