SigLIP 为何用 sigmoid pairwise loss 替换 softmax?
核心概念
SigLIP (Sigmoid Loss for Language-Image Pre-training) 是一种大规模视觉语言预训练方法。其核心创新在于用成对的 (pairwise) Sigmoid 损失替代了 CLIP 等模型中使用的基于 Softmax 的对比损失。这种改变将原来“从一批候选项中识别唯一正确配对”的多分类问题,解耦为一系列独立的“判断当前图文对是否匹配”的二元分类问题,从而显著降低了对大批量(large batch size)训练的依赖。
原理与推导
为了理解 SigLIP 的动机,我们首先需要回顾其前身 CLIP 所使用的 InfoNCE Loss (a.k.a. Softmax-based contrastive loss)。
假设在一个批次 (batch) 中有 个图文对 。我们通过图像编码器和文本编码器得到它们归一化后的嵌入向量。任意图像嵌入 和文本嵌入 之间的相似度(通常是点积)记为 。
1. CLIP 的 Softmax 对比损失
CLIP 的目标是对于每个图像 ,在所有 个文本中,正确地将其与文本 匹配。这被构建为一个 -分类问题。损失函数(以图像 为锚点)如下:
其中:
- 是正样本对(匹配的图文对)的相似度。
- () 是负样本对(不匹配的图文对)的相似度。
- 是一个可学习的温度超参数,用于缩放相似度的分布。
核心思想:
- 耦合的归一化:分母 对批次内所有文本的相似度进行了求和。这意味着对一个图文对 的损失计算,依赖于所有其他的负样本对 。
- 竞争机制:Softmax 强制模型在所有候选项中进行“竞争”,将概率“预算”分配给最可能的那个。为了让正样本对的概率 变大,模型不仅要拉近正样本,还必须同时推开所有负样本。
- 大批量依赖:为了提供足够多、足够难的负样本,从而让模型学到有意义的表示,这种方法严重依赖于非常大的批次(例如,
CLIP原始论文中使用了 32768 的批次大小)。这带来了巨大的计算和显存开销。
2. SigLIP 的 Sigmoid Pairwise 损失
SigLIP 改变了问题的范式。它不再将问题看作“N选1”的多分类,而是看作 个独立的二分类问题:对于图像 ,文本 是不是它的正确描述?
对于一个给定的图文对 ,我们希望在它们匹配时(),模型输出的概率接近1;不匹配时(),输出的概率接近0。这正是二元逻辑回归的目标。其损失函数(二元交叉熵)如下:
其中:
- 是缩放后的相似度,其中 (权重) 和 (偏置) 是可学习的参数,类似于温度 但更灵活。
- 是 Sigmoid 函数。
- 是目标标签,当 时为1,否则为0。
整个批次的总损失是所有 个可能配对的二元交叉熵损失之和:
核心思想:
- 解耦的评估:对图文对 的损失计算 只依赖于它们自身的相似度 ,而与其他任何负样本无关。分母中不再有对整个批次的求和项。
- 独立判断:模型为每个图文对独立地做出“匹配/不匹配”的判断。这更符合现实世界的零样本推理场景,因为推理时通常只有一个图像和一个待选文本,没有一个“批次”可供归一化。
- 批次大小无关性:由于损失计算的解耦特性,模型不再需要一个巨大的批次来提供负样本上下文。每个负样本都独立地贡献损失。这使得 SigLIP 可以在小得多的批次下有效训练,极大地降低了硬件门槛。
直观解释与复杂度
- 几何/信息论解释:Softmax 损失试图在相似度空间中,为每个锚点(图像)塑造一个尖锐的概率分布,所有概率之和为1。而 Sigmoid 损失则是在每个点(图文对)上独立地拟合一个概率值,这些概率值之间没有直接的归一化约束。
- 算法复杂度:对于一个大小为 的批次,两种方法都需要计算一个 的相似度矩阵,这部分的计算复杂度是相同的,通常为 ,其中 是嵌入维度。损失计算本身的复杂度远低于此。SigLIP 的主要优势在于,它允许使用更小的 来达到相似或更好的性能,从而降低了总训练成本和显存需求。
代码实现
下面我们用 PyTorch 来实现并对比这两种损失函数。
1import torch2import torch.nn.functional as F34def clip_loss(logits: torch.Tensor) -> torch.Tensor:5 """6 计算基于Softmax的CLIP InfoNCE损失。78 Args:9 logits (torch.Tensor): 相似度矩阵,形状为 [N, N]。对角线为正样本对。1011 Returns:12 torch.Tensor: 计算出的损失值。13 """14 # 为什么这样做:CLIP的目标是N选1的多分类问题。15 # 对于第i个图像,它的正确文本是第i个。16 # 因此,标签是一个从0到N-1的序列,代表每个样本的正确类别索引。17 batch_size = logits.shape[0]18 labels = torch.arange(batch_size, device=logits.device)1920 # 为什么这样做:F.cross_entropy内部集成了log_softmax和nll_loss。21 # 它期望的输入是未经softmax的原始logits和类别索引。22 # 我们分别计算image-to-text和text-to-image的损失,然后取平均。23 loss_i = F.cross_entropy(logits, labels)24 loss_t = F.cross_entropy(logits.t(), labels)2526 return (loss_i + loss_t) / 22728def siglip_loss(logits: torch.Tensor) -> torch.Tensor:29 """30 计算基于Sigmoid的SigLIP Pairwise损失。3132 Args:33 logits (torch.Tensor): 相似度矩阵,形状为 [N, N]。3435 Returns:36 torch.Tensor: 计算出的损失值。37 """38 # 为什么这样做:SigLIP将问题视为N*N个独立的二分类问题。39 # 标签是一个N*N的矩阵,对角线上为1(正样本),其余为0(负样本)。40 # 使用torch.eye创建单位矩阵作为标签。41 batch_size = logits.shape[0]42 labels = torch.eye(batch_size, device=logits.device)4344 # 为什么这样做:F.binary_cross_entropy_with_logits是数值上更稳定的选择。45 # 它将sigmoid和binary_cross_entropy合并在一个函数里,避免了由于sigmoid输出接近0或1时可能出现的数值问题。46 # 它期望的输入是未经sigmoid的原始logits和0/1的标签。47 # 我们对N*N个所有可能配对的损失求和,然后归一化。48 loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='sum')4950 # 为什么这样做:为了让损失值不随batch_size变化而剧烈波动,通常会进行归一化。51 # SigLIP论文中将总损失除以batch_size。52 return loss / batch_size5354# --- 示例 ---55if __name__ == '__main__':56 # 假设的参数57 batch_size = 4 # 使用一个小batch size来演示58 embedding_dim = 1285960 # 随机生成归一化的图像和文本嵌入61 image_embeddings = F.normalize(torch.randn(batch_size, embedding_dim))62 text_embeddings = F.normalize(torch.randn(batch_size, embedding_dim))6364 # 可学习的温度/缩放参数65 # CLIP使用log_t,SigLIP使用w和b,这里为简化统一用一个参数66 logit_scale = torch.nn.Parameter(torch.ones([]) * 2.659)6768 # 计算相似度矩阵 (logits)69 # 为什么这样做:这是计算图文相似度的标准方法,通过矩阵乘法一次性得到所有配对的相似度。70 # 乘以logit_scale是为了在反向传播中学习合适的温度。71 logits = torch.matmul(image_embeddings, text_embeddings.t()) * logit_scale.exp()7273 print("相似度矩阵 (Logits):")74 print(logits)75 print("-" * 30)7677 # 计算并打印两种损失78 loss_clip = clip_loss(logits)79 print(f"CLIP Loss: {loss_clip.item():.4f}")8081 loss_siglip = siglip_loss(logits)82 print(f"SigLIP Loss: {loss_siglip.item():.4f}")
工程实践
- 使用场景:SigLIP 主要用于需要从海量(通常是带噪声的)网络图文数据中进行预训练的视觉语言模型。其训练出的模型可直接用于零样本图像分类、图文检索、图像字幕等任务。
- 超参数选择:
- 批次大小 (Batch Size):SigLIP 的最大优势就是对批次大小不敏感。实验表明,即使在 2048 或 4096 这样相对较小的批次下,SigLIP 也能达到或超过在 32768 批次下训练的
CLIP的性能。这使得在有限的硬件(例如 8-16 个 GPU)上进行有效训练成为可能。 - 缩放权重
w和偏置b:SigLIP 引入了可学习的偏置项b。这个偏置项非常重要,它帮助模型校准正负样本的先验概率。由于一个批次中负样本远多于正样本( vs ),模型容易倾向于预测“不匹配”。偏置b可以补偿这种不平衡,让模型在相似度为0附近时,也能做出合理的初始判断。通常w(logit_scale) 初始化为可学习的参数,b(logit_bias) 也初始化为可学习的参数(例如0)。
- 批次大小 (Batch Size):SigLIP 的最大优势就是对批次大小不敏感。实验表明,即使在 2048 或 4096 这样相对较小的批次下,SigLIP 也能达到或超过在 32768 批次下训练的
- 性能 / 显存 / 吞吐 的权衡:
- 显存:由于可以使用更小的批次,SigLIP 显著降低了单次迭代的显存峰值。这允许研究者在同等硬件上使用更大的模型,或者在更便宜的硬件上进行训练。
- 吞吐:虽然单步计算复杂度相似,但小批次意味着更频繁的梯度更新和同步,可能会稍微降低理论上的硬件利用率。然而,在分布式训练中,小批次也减少了节点间的通信开销(例如 All-Gather 操作的张量尺寸更小),总体上训练效率更高。
- 常见坑和调试技巧:
- 忘记偏置项
b:只使用w * s_ij而不加b会损害性能。偏置项对于处理数据不平衡至关重要。 - 数值稳定性:务必使用
torch.nn.functional.binary_cross_entropy_with_logits而不是手动计算sigmoid再传入binary_cross_entropy,前者在数值上稳定得多。 - 监控
w和b:在训练过程中,应该监控w和b的值。w(logit_scale) 不应发散,b(logit_bias) 通常会收敛到一个负值,以反映负样本占多数的先验。
- 忘记偏置项
常见误区与边界情况
-
误区1:SigLIP 不需要负样本。
- 错误。SigLIP 仍然严重依赖负样本。它的损失函数 明确包含了对所有负样本对 for 的惩罚项。区别在于它处理负样本的方式是“一对一”的,而不是像 Softmax 那样“一对多”地打包处理。
-
误区2:Sigmoid 损失总是优于 Softmax 损失。
- 不一定。这取决于问题的性质。对于互斥的多分类任务(例如,一张图片只能是“猫”、“狗”、“鸟”中的一种),Softmax 是自然且理论上更合适的选择。而对于多标签分类(一张图片可以同时包含“猫”和“狗”)或像 SigLIP 这样的解耦对比学习场景,Sigmoid/二元交叉熵是更优的选择。
-
面试追问1:既然 SigLIP 这么好,为什么不直接用它替代所有对比学习中的 Softmax 损失?
- 回答要点:
- 历史与惯性:InfoNCE 和 Softmax 损失在自监督和对比学习领域有深厚的历史根基,并且在很多场景下工作得很好。
- 任务匹配度:如上所述,对于某些强调“从一堆候选中挑出唯一最佳”的任务,Softmax 的竞争机制可能更有效。
- 负样本挖掘:Softmax 损失对“难负样本”(hard negatives)更敏感,因为一个高相似度的负样本会极大地影响分母,从而产生大的梯度信号。这有时被认为是一种隐式的难负样本挖掘。Sigmoid 损失对所有负样本一视同仁(尽管相似度高的负样本仍会产生更大的损失),可能需要更精巧的负样本采样策略来达到同样的效果(尽管 SigLIP 的实践表明简单的批内负采样已经足够好)。
- SigLIP 的成功关键:SigLIP 的成功不仅仅是换了个损失函数,也与 Google 使用的庞大、高质量的内部数据集和高效的训练基础设施有关。损失函数的改进是关键,但不能孤立地看待。
- 回答要点:
-
面试追问2:SigLIP 论文中提到,它的性能对批次大小不敏感。但代码实现中仍然计算了 N x N 的矩阵,这难道不意味着计算量和显存还是和 N^2 相关吗?
- 回答要点:
- 区分“性能”与“单步计算”:这里的“性能”指的是模型最终的准确率等评估指标,而不是指单步训练的计算速度。SigLIP 的核心发现在于,即使
N很小(如4096),模型也能达到高“性能”。而CLIP必须用巨大的N(如32768)才能达到同等“性能”。 - 总训练成本降低:虽然单步计算中
N x N矩阵的开销确实存在,但因为 SigLIP 可以在小得多的N下训练,所以每次迭代的实际显存占用 () 大大降低。这意味着可以用更少的硬件完成训练,或者在同样硬件上训练更大的模型,最终总训练成本(金钱和时间)被显著削减。 - 优化空间:理论上,由于损失是可分解的,我们甚至不需要显式计算和存储整个
N x N矩阵。可以只计算正样本对和一部分采样出的负样本对的损失。然而,在现代硬件(如 TPU/GPU)上,执行一次大的矩阵乘法通常比执行多次小的、零散的计算更高效。因此,为了实现简单和硬件友好,保留N x N矩阵计算是当前实际的最优策略。
- 区分“性能”与“单步计算”:这里的“性能”指的是模型最终的准确率等评估指标,而不是指单步训练的计算速度。SigLIP 的核心发现在于,即使
- 回答要点: