§1.2.20

数据不平衡的处理(过采样、欠采样、Focal Loss、类权重)?

手写练习
  • 手写 Focal Loss

核心概念

数据不平衡(Data Imbalance)指在分类任务中,不同类别的样本数量差距悬殊的现象。例如,在欺诈检测中,99.9% 的交易是正常的,只有 0.1% 是欺诈交易。若不加处理,模型会倾向于预测样本为多数类,忽略少数类,导致在少数类(通常是更重要的类)上表现很差,即使总体准确率很高。处理数据不平衡的核心思想是:在训练过程中,以某种方式提升少数类样本的重要性,使得模型能够充分学习其特征。

原理与推导

处理数据不平衡的方法主要分为两大类:数据层面(Data-level)和算法层面(Algorithm-level)。

1. 数据层面方法 (Data-level Methods)

这类方法通过改变训练集的样本分布来解决不平衡问题。

(a) 欠采样 (Undersampling)

原理:随机或有策略地减少多数类样本的数量,使其与少数类样本数量相当。 优点:减少了训练数据量,可以降低训练时间和存储开销。 缺点:可能会丢失多数类样本中包含的重要信息,导致模型泛化能力下降。 复杂度:随机欠采样的时间复杂度主要在于数据筛选,约为 O(Nmajor)O(N_{major}),其中 NmajorN_{major} 是多数类样本数。

(b) 过采样 (Oversampling)

原理:增加少数类样本的数量。最简单的方法是随机复制少数类样本,但这容易导致模型对特定样本过拟合。 SMOTE (Synthetic Minority Over-sampling Technique) 是一种更高级的过采样方法。其核心思想是在特征空间中为少数类样本人工合成新的、相似的样本。

SMOTE 算法推导: 对于少数类集合中的每一个样本 xix_i

  1. 找到它在少数类集合中的 kk 个最近邻(通常 k=5k=5)。
  2. 从这 kk 个近邻中随机选择一个样本 xknnx_{knn}
  3. xix_ixknnx_{knn} 之间的连线上生成一个新的合成样本 xnewx_{new}

数学公式如下: xnew=xi+λ(xknnxi)x_{new} = x_i + \lambda \cdot (x_{knn} - x_i) 其中 λ\lambda 是一个在 [0,1][0, 1] 之间均匀分布的随机数。

直观解释:SMOTE 不是简单地复制样本,而是在少数类样本的“附近”创造新的、合理的样本,从而扩大了少数类的决策区域,有助于模型学习到更鲁棒的决策边界。

复杂度:对于每个少数类样本,需要计算与其他所有少数类样本的距离以找到k近邻,所以基础实现的复杂度约为 O(Nminor2D)O(N_{minor}^2 \cdot D),其中 NminorN_{minor} 是少数类样本数,D是特征维度。使用 KD-Tree 等数据结构可以优化近邻搜索。

2. 算法层面方法 (Algorithm-level Methods)

这类方法不改变数据分布,而是修改学习算法本身,使其对少数类样本更“敏感”。

(a) 类别权重 (Class Weights)

原理:在计算损失函数时,为不同类别的样本赋予不同的权重。通常,为少数类样本赋予更高的权重,为多数类样本赋予较低的权重。 推导:以带权重的交叉熵损失函数(Weighted Cross-Entropy Loss)为例。对于一个多分类问题,其标准交叉熵损失为: LCE=i=1Cyilog(pi)L_{CE} = - \sum_{i=1}^{C} y_i \log(p_i) 其中 CC 是类别数,yiy_i 是一个 one-hot 向量,表示真实标签,pip_i 是模型预测为第 ii 类的概率。

带权重的交叉熵损失为: LCE_weighted=i=1Cwiyilog(pi)L_{CE\_weighted} = - \sum_{i=1}^{C} w_i \cdot y_i \log(p_i) 其中 wiw_i 是第 ii 类的权重。如果第 kk 类是少数类,则可以设置一个较大的 wkw_k。一个常见的设置是 wiw_i 与类别频率成反比。

直观解释:通过加大少数类样本的损失权重,使得模型在这些样本上犯错时会受到更“严厉”的惩罚。这迫使模型更加关注少数类,努力将其正确分类。

(b) Focal Loss

原理:Focal Loss 是对标准交叉熵损失的改进,它旨在解决“难易样本”不平衡的问题。在类别不平衡的场景下,多数类样本通常是“容易”分类的样本,它们数量庞大,即使每个样本的损失很小,累加起来也会主导总损失和梯度,从而掩盖了少数类(通常是“难”样本)的信号。

推导:

  1. 标准二分类交叉熵 (CE): 首先定义 ptp_t

    pt={pif y=11pif y=0p_t = \begin{cases} p & \text{if } y=1 \\ 1-p & \text{if } y=0 \end{cases}

    其中 y{0,1}y \in \{0, 1\} 是真实标签,pp 是模型预测 y=1y=1 的概率。则 CE(pt)=log(pt)CE(p_t) = -\log(p_t)

  2. 引入平衡因子 αt\alpha_t: 为了处理类别不平衡(正负样本不均),可以引入一个权重因子 αt\alpha_t,类似于类别权重。

    αt={αif y=11αif y=0\alpha_t = \begin{cases} \alpha & \text{if } y=1 \\ 1-\alpha & \text{if } y=0 \end{cases}

    加权的交叉熵为 LCE_weighted=αtlog(pt)L_{CE\_weighted} = -\alpha_t \log(p_t)

  3. 引入调制因子 (Modulating Factor): 这是 Focal Loss 的核心。为了降低易分样本的权重,引入了调制因子 (1pt)γ(1-p_t)^\gamma

    • 对于一个易分样本(well-classified example),pt1p_t \to 1,此时调制因子 (1pt)γ0(1-p_t)^\gamma \to 0。这使得该样本的损失贡献变得非常小。
    • 对于一个难分样本(hard example),pt0p_t \to 0,此时调制因子 (1pt)γ1(1-p_t)^\gamma \to 1。该样本的损失基本不受影响。
  4. Focal Loss 最终形式: 结合平衡因子和调制因子,得到 Focal Loss 的完整公式: FL(pt)=αt(1pt)γlog(pt)FL(p_t) = -\alpha_t (1-p_t)^\gamma \log(p_t) 其中 γ0\gamma \ge 0 称为聚焦参数 (focusing parameter)γ\gamma 越大,对易分样本的抑制作用越强,使得训练更加聚焦于难分的样本。当 γ=0\gamma=0 时,Focal Loss 退化为带权重的交叉熵损失。

信息论解释:Focal Loss 可以看作是根据样本的“意外程度”(surprisal, 即 log(pt)-\log(p_t))和“分类置信度”(ptp_t)来动态调整其对总损失的贡献。高置信度的正确分类(低意外)被大幅降权,而低置信度的分类(高意外)则保留其原始权重,从而使模型专注于学习那些不确定的、困难的边界情况。

代码实现

下面是 Focal Loss 的一个手写 PyTorch 实现,它包含了处理多分类问题的逻辑,并且数值稳定。

python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import numpy as np
5
6class FocalLoss(nn.Module):
7 """
8 手写的、数值稳定的多分类 Focal Loss
9 """
10 def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
11 """
12 :param alpha: (Tensor) 类别权重的列表,shape 为 [C]。如果为 None,则不使用 alpha 权重。
13 例如,对于一个三分类问题,可以是 torch.tensor([0.25, 0.5, 0.25])
14 :param gamma: (float) 聚焦参数,用于调节难易样本的权重。
15 :param reduction: (str) 指定对输出的缩减方式: 'mean', 'sum', 'none'
16 """
17 super(FocalLoss, self).__init__()
18 self.gamma = gamma
19 self.alpha = alpha
20 self.reduction = reduction
21
22 # 为什么这样做:如果 alpha 是列表,需要转换为 tensor
23 if isinstance(alpha, (list, np.ndarray)):
24 self.alpha = torch.tensor(alpha, dtype=torch.float32)
25
26 def forward(self, inputs, targets):
27 """
28 :param inputs: (Tensor) 模型的原始输出 (logits),shape 为 [N, C]
29 :param targets: (Tensor) 真实标签,shape 为 [N],值在 [0, C-1] 之间
30 """
31 # 为什么这样做:使用 log_softmax 和 nll_loss 组合比直接计算 log(softmax(x)) 更数值稳定
32 # inputs 是 logits,形状 [N, C]
33 # targets 是标签,形状 [N]
34 log_pt = F.log_softmax(inputs, dim=1)
35
36 # 为什么这样做:使用 gather 从 log_pt 中根据 targets 提取对应类别的 log 概率
37 # 这等价于 one-hot 编码后做点积,但效率更高
38 # log_pt.gather(1, targets.view(-1, 1)) 的结果是每个样本对应正确类别的 log 概率
39 log_pt = log_pt.gather(1, targets.view(-1, 1)).squeeze(1)
40
41 # 为什么这样做:从 log 概率计算标准概率 pt
42 pt = torch.exp(log_pt)
43
44 # 为什么这样做:这是 Focal Loss 的核心,计算调制因子 (1-pt)^gamma
45 modulating_factor = (1 - pt) ** self.gamma
46
47 # 计算最终的 loss
48 loss = -1 * modulating_factor * log_pt
49
50 # 为什么这样做:如果提供了 alpha,需要为每个样本应用对应的类别权重
51 if self.alpha is not None:
52 # 确保 alpha 和 loss 在同一个设备上
53 if self.alpha.device != loss.device:
54 self.alpha = self.alpha.to(loss.device)
55
56 # 为什么这样做:使用 gather 根据 targets 从 alpha tensor 中选取每个样本对应的权重
57 alpha_t = self.alpha.gather(0, targets)
58 loss = alpha_t * loss
59
60 # 为什么这样做:根据 reduction 参数对 loss 进行聚合
61 if self.reduction == 'mean':
62 return loss.mean()
63 elif self.reduction == 'sum':
64 return loss.sum()
65 else: # 'none'
66 return loss
67
68# --- 代码练习:使用示例 ---
69if __name__ == '__main__':
70 # 模拟一个5个样本,4个类别的场景
71 N = 5
72 C = 4
73
74 # 模拟模型输出的 logits
75 # 假设类别3是少数类,模型对其预测置信度较低
76 logits = torch.randn(N, C)
77 print("模拟模型输出 (logits):\n", logits)
78
79 # 模拟真实标签,其中类别3是少数类 (只出现1次)
80 labels = torch.tensor([0, 1, 1, 2, 3])
81 print("\n真实标签:\n", labels)
82
83 # --- 场景1: 不带 alpha 的 Focal Loss ---
84 print("\n--- 1. 使用 Focal Loss (gamma=2, 无 alpha) ---")
85 focal_loss_func = FocalLoss(gamma=2.0)
86 loss1 = focal_loss_func(logits, labels)
87 print(f"Focal Loss: {loss1.item()}")
88
89 # --- 场景2: 带 alpha 的 Focal Loss ---
90 # 假设类别3是少数类,给它更高的权重
91 # 权重通常与类别频率成反比
92 class_weights = torch.tensor([0.1, 0.2, 0.3, 0.4])
93 print("\n--- 2. 使用 Focal Loss (gamma=2, 带 alpha) ---")
94 print("类别权重 (alpha):", class_weights)
95 focal_loss_func_alpha = FocalLoss(alpha=class_weights, gamma=2.0)
96 loss2 = focal_loss_func_alpha(logits, labels)
97 print(f"带 alpha 的 Focal Loss: {loss2.item()}")
98
99 # --- 对比标准交叉熵 ---
100 print("\n--- 3. 对比标准交叉熵损失 ---")
101 ce_loss_func = nn.CrossEntropyLoss()
102 loss_ce = ce_loss_func(logits, labels)
103 print(f"标准交叉熵损失: {loss_ce.item()}")

工程实践

  • 使用场景:数据不平衡广泛存在于欺诈检测、医疗诊断(如癌症识别)、工业品控(次品检测)、广告点击率(CTR)预估等领域。在这些场景,少数类往往是业务上最关心的。

  • 超参数选择

    • 采样策略:不建议将数据平衡到 1:1。可以先尝试对少数类进行过采样,同时对多数类进行轻微的欠采样,例如调整到 1:3 或 1:5 的比例,然后通过交叉验证观察效果。
    • SMOTE的k值:通常取5,但如果少数类样本非常稀疏,可以适当减小k值。
    • 类别权重:一个简单有效的策略是设置权重为类别频率的倒数:wj=NtotalCNjw_j = \frac{N_{total}}{C \cdot N_j},其中 NtotalN_{total} 是总样本数,CC 是类别数,NjN_j 是第 jj 类的样本数。许多框架(如Scikit-learn, PyTorch)支持自动计算这种权重。
    • Focal Loss的 γ\gammaα\alpha:原论文推荐 γ=2,α=0.25\gamma=2, \alpha=0.25(对于正样本)。在实践中,γ\gamma 是最关键的参数,通常在 [1, 5] 范围内调优。γ\gamma 越大,模型越关注难分样本。α\alpha 可以根据类别比例设置,或也作为超参数进行搜索。
  • 性能/显存/吞吐的权衡

    • 过采样:会增加数据集大小,导致每个 epoch 的训练时间变长,消耗更多内存。
    • 欠采样:会减少数据集大小,训练速度快,内存占用小,但有丢失信息的风险。
    • 类权重/Focal Loss:在训练时仅增加极小的计算开销,不影响数据加载和显存占用,是最高效、最受欢迎的方法之一,尤其在深度学习中。
  • 常见坑和调试技巧

    • 数据泄露:在使用SMOTE或任何其他采样方法时,必须先划分训练集、验证集和测试集,然后只对训练集进行采样。如果在划分前采样,会导致合成的样本(源于训练数据)泄露到验证/测试集中,造成评估指标虚高。
    • 评估指标:对于不平衡问题,准确率(Accuracy)是极具误导性的指标。应使用 Precision (精确率), Recall (召回率), F1-Score, AUC-ROC (ROC曲线下面积), 以及 AUC-PR (精确率-召回率曲线下面积)。对于极端不平衡的数据集,AUC-PR 通常比 AUC-ROC 更能反映模型的真实性能。
    • 组合使用:可以组合使用多种策略。例如,先使用 SMOTE 对少数类进行轻微的过采样,然后再用 Focal Loss 进行训练,可能会取得更好的效果。

常见误区与边界情况

  • 误区1:只要数据不平衡就一定要处理。

    • 纠正:不一定。如果少数类的样本本身特征非常清晰,与多数类有明显区别,即使不平衡,模型也可能学得很好。处理前应先用基线模型评估,确认不平衡确实是性能瓶颈。
  • 误区2:SMOTE 总是优于随机过采样。

    • 纠正:不总是。如果少数类样本本身存在大量噪声,或者类别之间边界高度重叠,SMOTE 可能会在噪声点和边界模糊区域生成更多“坏”样本,反而降低模型性能。此时,更复杂的变体如 Borderline-SMOTE 或 ADASYN 可能更合适。
  • 误区3:Focal Loss 和类别权重是互斥的。

    • 纠正:不是。Focal Loss 公式本身就包含了平衡因子 αt\alpha_t,它扮演的角色就和类别权重一样。Focal Loss 是对类别权重的进一步增强,它不仅解决了类间不平衡(通过 αt\alpha_t),更重要的是解决了难易样本不平衡(通过 (1pt)γ(1-p_t)^\gamma)。
  • 边界情况与面试追问

    • 问:当少数类样本极少(比如只有几个)时,SMOTE 合适吗?
      • :不合适。当样本极少时,k近邻会非常受限,SMOTE 只能在极小的特征空间区域内生成高度相似的样本,几乎等同于随机复制,且有过拟合风险。此时,可以考虑数据增强(如图像旋转、裁剪)、迁移学习或异常检测的方法(将少数类视为异常点)。
    • 问:Focal Loss 中的 γ\gamma 参数设为0会怎样?设得非常大会怎样?
        • γ=0\gamma=0 时,(1pt)0=1(1-p_t)^0 = 1,Focal Loss 退化为带 αt\alpha_t 权重的标准交叉熵损失。
        • γ\gamma 非常大时,模型会极度关注那些 ptp_t 稍小于1的样本(即,几乎分对但置信度不是100%的样本),可能会导致训练不稳定,或者对噪声和标注错误的样本过于敏感。
    • 问:在多标签分类任务中,如何处理不平衡问题?
      • :多标签分类中的不平衡更复杂,可能存在标签共现不平衡、每个标签本身正负样本不平衡等。Focal Loss 同样适用,可以对其进行修改,对每个标签独立计算 Focal Loss 然后求和或求平均。此外,还有一些专门为多标签设计的损失函数,如 Asymmetric Loss (ASL)。
相关题目