§1.3.15

SigLIP 为何用 sigmoid pairwise loss 替换 softmax?

核心概念

SigLIP (Sigmoid Loss for Language-Image Pre-training) 是一种大规模视觉语言预训练方法。其核心创新在于用成对的 (pairwise) Sigmoid 损失替代了 CLIP 等模型中使用的基于 Softmax 的对比损失。这种改变将原来“从一批候选项中识别唯一正确配对”的多分类问题,解耦为一系列独立的“判断当前图文对是否匹配”的二元分类问题,从而显著降低了对大批量(large batch size)训练的依赖。

原理与推导

为了理解 SigLIP 的动机,我们首先需要回顾其前身 CLIP 所使用的 InfoNCE Loss (a.k.a. Softmax-based contrastive loss)。

假设在一个批次 (batch) 中有 NN 个图文对 (i1,t1),(i2,t2),,(iN,tN)(i_1, t_1), (i_2, t_2), \dots, (i_N, t_N)。我们通过图像编码器和文本编码器得到它们归一化后的嵌入向量。任意图像嵌入 vkv_k 和文本嵌入 uju_j 之间的相似度(通常是点积)记为 skj=vkTujs_{kj} = v_k^T u_j

1. CLIP 的 Softmax 对比损失

CLIP 的目标是对于每个图像 iki_k,在所有 NN 个文本中,正确地将其与文本 tkt_k 匹配。这被构建为一个 NN-分类问题。损失函数(以图像 iki_k 为锚点)如下:

Lkimage=logexp(skk/τ)j=1Nexp(skj/τ)\mathcal{L}_{k}^{\text{image}} = -\log \frac{\exp(s_{kk} / \tau)}{\sum_{j=1}^{N} \exp(s_{kj} / \tau)}

其中:

  • skks_{kk} 是正样本对(匹配的图文对)的相似度。
  • skjs_{kj} (jkj \neq k) 是负样本对(不匹配的图文对)的相似度。
  • τ\tau 是一个可学习的温度超参数,用于缩放相似度的分布。

核心思想

  • 耦合的归一化:分母 j=1Nexp(skj/τ)\sum_{j=1}^{N} \exp(s_{kj} / \tau) 对批次内所有文本的相似度进行了求和。这意味着对一个图文对 (ik,tj)(i_k, t_j) 的损失计算,依赖于所有其他的负样本对 (ik,tm)(i_k, t_m)
  • 竞争机制:Softmax 强制模型在所有候选项中进行“竞争”,将概率“预算”分配给最可能的那个。为了让正样本对的概率 exp(skk/τ)...\frac{\exp(s_{kk} / \tau)}{\sum_{...}} 变大,模型不仅要拉近正样本,还必须同时推开所有负样本。
  • 大批量依赖:为了提供足够多、足够难的负样本,从而让模型学到有意义的表示,这种方法严重依赖于非常大的批次(例如,CLIP 原始论文中使用了 32768 的批次大小)。这带来了巨大的计算和显存开销。

2. SigLIP 的 Sigmoid Pairwise 损失

SigLIP 改变了问题的范式。它不再将问题看作“N选1”的多分类,而是看作 NN 个独立的二分类问题:对于图像 iki_k,文本 tjt_j 是不是它的正确描述?

对于一个给定的图文对 (ik,tj)(i_k, t_j),我们希望在它们匹配时(k=jk=j),模型输出的概率接近1;不匹配时(kjk \neq j),输出的概率接近0。这正是二元逻辑回归的目标。其损失函数(二元交叉熵)如下:

Lkj=pkjlog(σ(zkj))(1pkj)log(1σ(zkj))\mathcal{L}_{kj} = - p_{kj} \log(\sigma(z_{kj})) - (1-p_{kj}) \log(1 - \sigma(z_{kj}))

其中:

  • zkj=wskj+bz_{kj} = w \cdot s_{kj} + b 是缩放后的相似度,其中 ww (权重) 和 bb (偏置) 是可学习的参数,类似于温度 τ\tau 但更灵活。
  • σ(x)=11+ex\sigma(x) = \frac{1}{1+e^{-x}} 是 Sigmoid 函数。
  • pkjp_{kj} 是目标标签,当 k=jk=j 时为1,否则为0。

整个批次的总损失是所有 N×NN \times N 个可能配对的二元交叉熵损失之和:

LSigLIP=k=1Nj=1NLkj=k=1N(log(σ(zkk))+jklog(1σ(zkj)))\mathcal{L}_{\text{SigLIP}} = \sum_{k=1}^{N} \sum_{j=1}^{N} \mathcal{L}_{kj} = - \sum_{k=1}^{N} \left( \log(\sigma(z_{kk})) + \sum_{j \neq k} \log(1 - \sigma(z_{kj})) \right)

核心思想

  • 解耦的评估:对图文对 (ik,tj)(i_k, t_j) 的损失计算 Lkj\mathcal{L}_{kj} 只依赖于它们自身的相似度 skjs_{kj},而与其他任何负样本无关。分母中不再有对整个批次的求和项。
  • 独立判断:模型为每个图文对独立地做出“匹配/不匹配”的判断。这更符合现实世界的零样本推理场景,因为推理时通常只有一个图像和一个待选文本,没有一个“批次”可供归一化。
  • 批次大小无关性:由于损失计算的解耦特性,模型不再需要一个巨大的批次来提供负样本上下文。每个负样本都独立地贡献损失。这使得 SigLIP 可以在小得多的批次下有效训练,极大地降低了硬件门槛。

直观解释与复杂度

  • 几何/信息论解释:Softmax 损失试图在相似度空间中,为每个锚点(图像)塑造一个尖锐的概率分布,所有概率之和为1。而 Sigmoid 损失则是在每个点(图文对)上独立地拟合一个概率值,这些概率值之间没有直接的归一化约束。
  • 算法复杂度:对于一个大小为 NN 的批次,两种方法都需要计算一个 N×NN \times N 的相似度矩阵,这部分的计算复杂度是相同的,通常为 O(N2d)O(N^2 d),其中 dd 是嵌入维度。损失计算本身的复杂度远低于此。SigLIP 的主要优势在于,它允许使用更小的 NN 来达到相似或更好的性能,从而降低了总训练成本和显存需求。

代码实现

下面我们用 PyTorch 来实现并对比这两种损失函数。

python
1import torch
2import torch.nn.functional as F
3
4def clip_loss(logits: torch.Tensor) -> torch.Tensor:
5 """
6 计算基于Softmax的CLIP InfoNCE损失。
7
8 Args:
9 logits (torch.Tensor): 相似度矩阵,形状为 [N, N]。对角线为正样本对。
10
11 Returns:
12 torch.Tensor: 计算出的损失值。
13 """
14 # 为什么这样做:CLIP的目标是N选1的多分类问题。
15 # 对于第i个图像,它的正确文本是第i个。
16 # 因此,标签是一个从0到N-1的序列,代表每个样本的正确类别索引。
17 batch_size = logits.shape[0]
18 labels = torch.arange(batch_size, device=logits.device)
19
20 # 为什么这样做:F.cross_entropy内部集成了log_softmax和nll_loss。
21 # 它期望的输入是未经softmax的原始logits和类别索引。
22 # 我们分别计算image-to-text和text-to-image的损失,然后取平均。
23 loss_i = F.cross_entropy(logits, labels)
24 loss_t = F.cross_entropy(logits.t(), labels)
25
26 return (loss_i + loss_t) / 2
27
28def siglip_loss(logits: torch.Tensor) -> torch.Tensor:
29 """
30 计算基于Sigmoid的SigLIP Pairwise损失。
31
32 Args:
33 logits (torch.Tensor): 相似度矩阵,形状为 [N, N]。
34
35 Returns:
36 torch.Tensor: 计算出的损失值。
37 """
38 # 为什么这样做:SigLIP将问题视为N*N个独立的二分类问题。
39 # 标签是一个N*N的矩阵,对角线上为1(正样本),其余为0(负样本)。
40 # 使用torch.eye创建单位矩阵作为标签。
41 batch_size = logits.shape[0]
42 labels = torch.eye(batch_size, device=logits.device)
43
44 # 为什么这样做:F.binary_cross_entropy_with_logits是数值上更稳定的选择。
45 # 它将sigmoid和binary_cross_entropy合并在一个函数里,避免了由于sigmoid输出接近0或1时可能出现的数值问题。
46 # 它期望的输入是未经sigmoid的原始logits和0/1的标签。
47 # 我们对N*N个所有可能配对的损失求和,然后归一化。
48 loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='sum')
49
50 # 为什么这样做:为了让损失值不随batch_size变化而剧烈波动,通常会进行归一化。
51 # SigLIP论文中将总损失除以batch_size。
52 return loss / batch_size
53
54# --- 示例 ---
55if __name__ == '__main__':
56 # 假设的参数
57 batch_size = 4 # 使用一个小batch size来演示
58 embedding_dim = 128
59
60 # 随机生成归一化的图像和文本嵌入
61 image_embeddings = F.normalize(torch.randn(batch_size, embedding_dim))
62 text_embeddings = F.normalize(torch.randn(batch_size, embedding_dim))
63
64 # 可学习的温度/缩放参数
65 # CLIP使用log_t,SigLIP使用w和b,这里为简化统一用一个参数
66 logit_scale = torch.nn.Parameter(torch.ones([]) * 2.659)
67
68 # 计算相似度矩阵 (logits)
69 # 为什么这样做:这是计算图文相似度的标准方法,通过矩阵乘法一次性得到所有配对的相似度。
70 # 乘以logit_scale是为了在反向传播中学习合适的温度。
71 logits = torch.matmul(image_embeddings, text_embeddings.t()) * logit_scale.exp()
72
73 print("相似度矩阵 (Logits):")
74 print(logits)
75 print("-" * 30)
76
77 # 计算并打印两种损失
78 loss_clip = clip_loss(logits)
79 print(f"CLIP Loss: {loss_clip.item():.4f}")
80
81 loss_siglip = siglip_loss(logits)
82 print(f"SigLIP Loss: {loss_siglip.item():.4f}")

工程实践

  • 使用场景:SigLIP 主要用于需要从海量(通常是带噪声的)网络图文数据中进行预训练的视觉语言模型。其训练出的模型可直接用于零样本图像分类、图文检索、图像字幕等任务。
  • 超参数选择
    • 批次大小 (Batch Size):SigLIP 的最大优势就是对批次大小不敏感。实验表明,即使在 2048 或 4096 这样相对较小的批次下,SigLIP 也能达到或超过在 32768 批次下训练的 CLIP 的性能。这使得在有限的硬件(例如 8-16 个 GPU)上进行有效训练成为可能。
    • 缩放权重 w 和偏置 b:SigLIP 引入了可学习的偏置项 b。这个偏置项非常重要,它帮助模型校准正负样本的先验概率。由于一个批次中负样本远多于正样本(N2NN^2-N vs NN),模型容易倾向于预测“不匹配”。偏置 b 可以补偿这种不平衡,让模型在相似度为0附近时,也能做出合理的初始判断。通常 w (logit_scale) 初始化为可学习的参数,b (logit_bias) 也初始化为可学习的参数(例如0)。
  • 性能 / 显存 / 吞吐 的权衡
    • 显存:由于可以使用更小的批次,SigLIP 显著降低了单次迭代的显存峰值。这允许研究者在同等硬件上使用更大的模型,或者在更便宜的硬件上进行训练。
    • 吞吐:虽然单步计算复杂度相似,但小批次意味着更频繁的梯度更新和同步,可能会稍微降低理论上的硬件利用率。然而,在分布式训练中,小批次也减少了节点间的通信开销(例如 All-Gather 操作的张量尺寸更小),总体上训练效率更高。
  • 常见坑和调试技巧
    • 忘记偏置项 b:只使用 w * s_ij 而不加 b 会损害性能。偏置项对于处理数据不平衡至关重要。
    • 数值稳定性:务必使用 torch.nn.functional.binary_cross_entropy_with_logits 而不是手动计算 sigmoid 再传入 binary_cross_entropy,前者在数值上稳定得多。
    • 监控 wb:在训练过程中,应该监控 wb 的值。w (logit_scale) 不应发散,b (logit_bias) 通常会收敛到一个负值,以反映负样本占多数的先验。

常见误区与边界情况

  • 误区1:SigLIP 不需要负样本。

    • 错误。SigLIP 仍然严重依赖负样本。它的损失函数 LSigLIP\mathcal{L}_{\text{SigLIP}} 明确包含了对所有负样本对 (ik,tj)(i_k, t_j) for jkj \neq k 的惩罚项。区别在于它处理负样本的方式是“一对一”的,而不是像 Softmax 那样“一对多”地打包处理。
  • 误区2:Sigmoid 损失总是优于 Softmax 损失。

    • 不一定。这取决于问题的性质。对于互斥的多分类任务(例如,一张图片只能是“猫”、“狗”、“鸟”中的一种),Softmax 是自然且理论上更合适的选择。而对于多标签分类(一张图片可以同时包含“猫”和“狗”)或像 SigLIP 这样的解耦对比学习场景,Sigmoid/二元交叉熵是更优的选择。
  • 面试追问1:既然 SigLIP 这么好,为什么不直接用它替代所有对比学习中的 Softmax 损失?

    • 回答要点
      1. 历史与惯性:InfoNCE 和 Softmax 损失在自监督和对比学习领域有深厚的历史根基,并且在很多场景下工作得很好。
      2. 任务匹配度:如上所述,对于某些强调“从一堆候选中挑出唯一最佳”的任务,Softmax 的竞争机制可能更有效。
      3. 负样本挖掘:Softmax 损失对“难负样本”(hard negatives)更敏感,因为一个高相似度的负样本会极大地影响分母,从而产生大的梯度信号。这有时被认为是一种隐式的难负样本挖掘。Sigmoid 损失对所有负样本一视同仁(尽管相似度高的负样本仍会产生更大的损失),可能需要更精巧的负样本采样策略来达到同样的效果(尽管 SigLIP 的实践表明简单的批内负采样已经足够好)。
      4. SigLIP 的成功关键:SigLIP 的成功不仅仅是换了个损失函数,也与 Google 使用的庞大、高质量的内部数据集和高效的训练基础设施有关。损失函数的改进是关键,但不能孤立地看待。
  • 面试追问2:SigLIP 论文中提到,它的性能对批次大小不敏感。但代码实现中仍然计算了 N x N 的矩阵,这难道不意味着计算量和显存还是和 N^2 相关吗?

    • 回答要点
      1. 区分“性能”与“单步计算”:这里的“性能”指的是模型最终的准确率等评估指标,而不是指单步训练的计算速度。SigLIP 的核心发现在于,即使 N 很小(如4096),模型也能达到高“性能”。而 CLIP 必须用巨大的 N(如32768)才能达到同等“性能”。
      2. 总训练成本降低:虽然单步计算中 N x N 矩阵的开销确实存在,但因为 SigLIP 可以在小得多的 N 下训练,所以每次迭代的实际显存占用 (O(Nd+N2)O(N d + N^2)) 大大降低。这意味着可以用更少的硬件完成训练,或者在同样硬件上训练更大的模型,最终总训练成本(金钱和时间)被显著削减。
      3. 优化空间:理论上,由于损失是可分解的,我们甚至不需要显式计算和存储整个 N x N 矩阵。可以只计算正样本对和一部分采样出的负样本对的损失。然而,在现代硬件(如 TPU/GPU)上,执行一次大的矩阵乘法通常比执行多次小的、零散的计算更高效。因此,为了实现简单和硬件友好,保留 N x N 矩阵计算是当前实际的最优策略。
相关题目