SigLIP 2 相比 SigLIP 的改进(多 teacher、native 分辨率)?
核心概念
SigLIP 2 是对 SigLIP(Sigmoid Loss for Language-Image Pre-training)模型的重大升级,旨在提升视觉语言预训练的效率和性能。其核心改进主要有两点:1) 多 teacher 知识蒸馏 (Multi-teacher Distillation),即利用多个强大的、预训练好的模型(teachers)产生的文本特征作为监督信号,为图像编码器提供更丰富、更鲁棒的语义目标;2) 原生分辨率训练 (Native Resolution Training),即直接在图像的原始、可变分辨率和长宽比上进行训练,而不是将所有图像强制缩放到固定的正方形,从而保留更多图像细节并减少信息损失。
原理与推导
回顾:SigLIP 原理
SigLIP 的核心是使用 Sigmoid 损失替代了传统 CLIP 模型中的 Softmax 损失。对于一个 batch 内的图像-文本对,Softmax 需要计算所有可能的配对()的相似度,并通过归一化将正样本与其他所有负样本进行对比。而 Sigmoid 损失将问题解耦为一系列独立的二分类问题:判断任意一个图像-文本对是正样本还是负样本。
给定图像特征 和文本特征 ,它们的相似度得分(通常是缩放后的余弦相似度)为 ,其中 是温度超参, 是偏置项。
SigLIP 的损失函数为:
其中 是正样本对集合(匹配的图像和文本), 是负样本对集合, 是 Sigmoid 函数。这本质上是所有样本对的二元交叉熵损失之和。
改进 1: 多 Teacher 知识蒸馏
动机: 网络上的文本描述(Alt-text)质量参差不齐,单个文本描述可能存在噪声、不完整或有偏见。例如,一张“狗在沙滩上”的图片,其文本可能是“一只金毛在玩飞盘”,也可能是“日落下的海滩”。单一的文本无法完全捕捉图像的丰富语义。
原理:
为了提供更强大的监督信号,SigLIP 2 引入了多个(个)预训练好的、强大的“teacher”模型(例如 PaLI-3、LLaVA 等的文本编码器)。对于同一张图片,使用这些 teacher 模型生成 个不同的、高质量的文本嵌入 。
学生模型(即正在训练的 SigLIP 2 模型)的图像编码器产生的图像嵌入 ,其目标不再是仅仅匹配单一的文本嵌入,而是要同时与这 个 teacher 文本嵌入对齐。
数学推导: 损失函数被相应地修改。对于一个正样本图像 ,它现在对应 个 teacher 文本嵌入 。正样本部分的损失被修改为对所有 teacher 的期望或平均:
负样本部分的损失保持不变,即图像 与 batch 内其他所有图像的 teacher 文本嵌入(或原始文本嵌入)构成负样本。
完整的 SigLIP 2 多 teacher 损失(简化形式)如下:
其中 是图像 与其第 个 teacher 文本的相似度,而 是图像 与其他图像 的文本(可以是原始文本或某个 teacher 文本)的相似度。这种方式相当于从多个专家的视角“提炼”知识,让学生模型学习到一个更全面、更泛化的图像表示。
改进 2: 原生分辨率训练
动机:
传统的 ViT (Vision Transformer) 要求输入图像被缩放到固定的正方形分辨率(如 224x224, 336x336)。这种操作会:
- 扭曲长宽比: 对于非正方形的图像(如风景照、人像照),强制缩放会严重扭曲物体和场景的几何形状。
- 丢失信息: 将高分辨率图像下采样会丢失大量细节;将低分辨率图像上采样则会引入不必要的计算和伪影。
- 训练-推理不一致: 推理时,我们往往希望在原生分辨率上获得最佳效果,而训练时的固定尺寸导致了不一致。
原理:
为了处理可变分辨率,SigLIP 2 对 ViT 的架构进行了关键调整:
-
可变数量的图像块 (Patches): 对于一个尺寸为 的图像和大小为 的 patch,
ViT会将其切分为 个图像块。在原生分辨率训练中, 和 是可变的,因此每个图像的 patch 数量 也是可变的。 -
动态位置编码 (Dynamic Positional Embeddings): 标准
ViT使用一个可学习的 1D 位置编码表,其长度与 patch 数量 绑定。当 可变时,这个表就失效了。SigLIP 2 采用 2D 傅里叶特征 (2D Fourier Features) 作为位置编码。对于每个 patch 在其 2D 网格中的坐标 ,其位置编码是基于正弦和余弦函数生成的,可以推广到任意大小的网格。例如,一个简化的傅里叶特征可以表示为:其中 是一组固定的频率。这种方式生成的位置编码不依赖于固定的图像尺寸,可以为任意位置 计算编码。
-
批处理与填充 (Batching & Padding): 为了将不同尺寸的图像(即不同数量的 patch 序列)组合成一个 batch,需要进行填充。具体做法是:在一个 batch 中,找到最大的高度 和宽度 (以 patch 数量计)。然后将所有其他图像的 patch 序列填充到这个最大尺寸,形成一个统一的张量。至关重要的是,必须同时生成一个 注意力掩码 (Attention Mask),以确保在
Transformer的自注意力计算中,模型不会关注到这些填充的 "dummy" patches。
算法复杂度:
- 时间复杂度: 对于单个图像,
ViT的复杂度为 ,其中 是 patch 数量, 是模型维度。在原生分辨率下,,所以复杂度与图像面积的平方成正比。由于批处理需要填充到 batch 内最大尺寸,实际计算开销由该 batch 中最大的图像决定。 - 空间复杂度: 主要由存储注意力矩阵 和激活值 决定,同样受 batch 内最大图像尺寸影响。
代码实现
以下 PyTorch 代码片段示意了 SigLIP 2 两大改进的核心逻辑。这是一个教学性的简化实现,并非完整的模型。
1import torch2import torch.nn as nn3import torch.nn.functional as F4from typing import List56# -----------------------------------------------------------------------------7# 改进 1: 多 Teacher 损失函数的示意实现8# -----------------------------------------------------------------------------9def calculate_siglip_multiteacher_loss(10 image_features: torch.Tensor, # 学生模型产出的图像特征, [B, D]11 teacher_text_features: List[torch.Tensor], # K个teacher产出的文本特征列表, 每个元素是 [B, D]12 temperature: float = 10.0,13 bias: float = -10.014):15 """16 计算 SigLIP 2 的多 teacher 损失。17 为了简化,这里只计算 batch 内的对比损失。18 """19 B, D = image_features.shape20 K = len(teacher_text_features)2122 # 将所有 teacher 的文本特征堆叠起来,方便计算23 # all_teacher_features 的形状为 [K, B, D]24 all_teacher_features = torch.stack(teacher_text_features)2526 # 1. 计算正样本损失27 # image_features [B, D] -> [1, B, D]28 # all_teacher_features [K, B, D]29 # 逐元素相乘后按维度 D 求和,得到 [K, B] 的相似度矩阵30 positive_logits = torch.einsum('bd,kbd->kb', image_features, all_teacher_features) / temperature + bias3132 # 对 K 个 teacher 的结果取平均33 # 这是核心:一个图像特征要同时与 K 个 teacher 文本特征相似34 loss_pos = -F.logsigmoid(positive_logits).mean()3536 # 2. 计算负样本损失37 # 每个图像特征与 batch 内所有其他图像的 teacher 文本特征都应不相似38 # 为了简化,我们只用第一个 teacher 的文本特征来构造负样本39 negative_teacher_features = teacher_text_features[0] # [B, D]4041 # 计算所有图像特征与所有文本特征的相似度矩阵 [B, B]42 all_pairs_logits = torch.einsum('id,jd->ij', image_features, negative_teacher_features) / temperature + bias4344 # 创建一个对角线为 True,其他为 False 的掩码,用于忽略正样本对45 mask = torch.eye(B, device=image_features.device, dtype=torch.bool)4647 # 应用掩码,只保留负样本对的 logits48 negative_logits = all_pairs_logits[~mask]4950 # log(1 - sigmoid(x)) 等价于 logsigmoid(-x)51 loss_neg = -F.logsigmoid(-negative_logits).mean()5253 total_loss = (loss_pos + loss_neg) / 254 return total_loss5556# -----------------------------------------------------------------------------57# 改进 2: 支持原生分辨率的 ViT 模块示意58# -----------------------------------------------------------------------------59class ViTWithNativeResolution(nn.Module):60 def __init__(self, patch_size=16, embed_dim=768, num_heads=12, num_layers=12):61 super().__init__()62 self.patch_size = patch_size63 self.embed_dim = embed_dim6465 # 图像块嵌入层66 self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)6768 # Transformer 编码器层69 encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)70 self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)7172 def generate_fourier_pos_embed(self, h, w, device):73 """动态生成 2D 傅里叶位置编码"""74 # 这是一个简化的实现,实际实现会更复杂75 half_dim = self.embed_dim // 476 freqs = torch.exp(torch.arange(half_dim, device=device) * -(torch.log(torch.tensor(10000.0)) / half_dim))7778 x_coords = torch.arange(w, device=device).unsqueeze(0) * freqs.unsqueeze(1)79 y_coords = torch.arange(h, device=device).unsqueeze(0) * freqs.unsqueeze(1)8081 x_sin = torch.sin(x_coords)82 x_cos = torch.cos(x_coords)83 y_sin = torch.sin(y_coords)84 y_cos = torch.cos(y_coords)8586 pos_embed_x = torch.cat([x_sin, x_cos], dim=0).T.unsqueeze(0).repeat(h, 1, 1) # [h, w, D/2]87 pos_embed_y = torch.cat([y_sin, y_cos], dim=0).T.unsqueeze(1).repeat(1, w, 1) # [h, w, D/2]8889 pos_embed = torch.cat([pos_embed_y, pos_embed_x], dim=2) # [h, w, D]90 return pos_embed.reshape(h * w, self.embed_dim) # [h*w, D]9192 def forward(self, images: List[torch.Tensor]):93 """94 前向传播,输入是一个包含不同尺寸图像张量的列表。95 """96 # 1. 填充与创建掩码97 # 找到 batch 中最大的 H 和 W98 max_h = max(img.shape[1] for img in images)99 max_w = max(img.shape[2] for img in images)100101 # 确保 H 和 W 是 patch_size 的整数倍102 max_h = (max_h // self.patch_size) * self.patch_size103 max_w = (max_w // self.patch_size) * self.patch_size104105 padded_images = []106 masks = []107 pos_embeds = []108109 for img in images:110 # 为什么这样做:将每张图片填充到 batch 内的最大尺寸,形成统一的张量111 c, h, w = img.shape112 pad_h = max_h - h113 pad_w = max_w - w114 padded_img = F.pad(img, (0, pad_w, 0, pad_h))115 padded_images.append(padded_img)116117 # 为什么这样做:创建注意力掩码,让模型忽略填充部分118 num_patches_h = h // self.patch_size119 num_patches_w = w // self.patch_size120 num_patches_total_padded = (max_h // self.patch_size) * (max_w // self.patch_size)121122 mask = torch.zeros(num_patches_total_padded, dtype=torch.bool, device=img.device)123 valid_indices = num_patches_w * (max_h // self.patch_size) + num_patches_h124 # 这是一个简化的掩码逻辑,实际掩码会更精细125 mask[:num_patches_h * num_patches_w] = True # 假设 patch 按行展开126 masks.append(mask)127128 # 为什么这样做:为每个图像动态生成其原始尺寸对应位置编码129 pos_embed = self.generate_fourier_pos_embed(h // self.patch_size, w // self.patch_size, img.device)130 # 填充位置编码以匹配填充后的 patch 序列长度131 padded_pos_embed = F.pad(pos_embed, (0, 0, 0, num_patches_total_padded - pos_embed.shape[0]))132 pos_embeds.append(padded_pos_embed)133134 # 2. 批处理与 patch 嵌入135 batch_images = torch.stack(padded_images) # [B, C, max_H, max_W]136 patch_embeds = self.patch_embed(batch_images) # [B, D, max_H/P, max_W/P]137 patch_embeds = patch_embeds.flatten(2).transpose(1, 2) # [B, N_padded, D]138139 # 3. 添加位置编码140 batch_pos_embeds = torch.stack(pos_embeds) # [B, N_padded, D]141 x = patch_embeds + batch_pos_embeds142143 # 4. 通过 Transformer 编码器144 # 注意力掩码在这里传入,防止在填充区域上计算注意力145 # PyTorch TransformerEncoder 需要一个 (N, N) 的 mask 或 (B*num_heads, N, N) 的 mask146 # 此处 src_key_padding_mask 更合适,形状为 (B, N),标记哪些是 padding147 attention_mask = torch.stack(masks) # [B, N_padded]148149 # TransformerEncoder 期望 padding 位置为 True150 output = self.transformer_encoder(x, src_key_padding_mask=~attention_mask)151152 # 通常使用 [CLS] token 或对所有非填充 token 的输出进行平均池化153 # 此处简化,返回所有 token 的输出154 return output, attention_mask
工程实践
-
使用场景: SigLIP 2 特别适用于需要从大规模、含噪声的网络数据(如网页图文对)进行预训练的场景。其产出的模型对各种图像尺寸和长宽比具有很强的鲁棒性,非常适合作为下游各种视觉任务(如图像分类、目标检测、图像字幕生成)的通用骨干网络。
-
超参数选择:
- Teacher 模型: 选择的 teacher 应该足够强大且具有多样性。例如,可以组合一个基于
CLIP架构的 teacher 和一个基于生成式模型(如 PaLI)的 teacher。teacher 的数量(K)是一个权衡:更多的 teacher 提供更丰富的信号,但也增加了生成训练数据的计算成本。通常 K=2 或 3 是一个不错的起点。 - 最大分辨率: 在原生分辨率训练中,需要设定一个内存可承受的最大分辨率上限。超过此上限的图像仍需被下采样。这个值的设定直接影响 GPU 显存占用。
- 批处理策略: 为了最小化填充带来的计算浪费,可以采用“尺寸分桶”(bucketing)策略。即,将尺寸或长宽比相近的图像分到同一个 batch 中,这样
max_h和max_w与 batch 内其他图像的尺寸差异较小,填充区域也较少。
- Teacher 模型: 选择的 teacher 应该足够强大且具有多样性。例如,可以组合一个基于
-
性能/显存/吞吐权衡:
- 显存: 原生分辨率训练的主要挑战是显存。一个 batch 中若包含一张超高分辨率的图像,会导致整个 batch 的显存占用飙升。分桶策略是缓解此问题的关键。
- 吞吐量: 相比固定尺寸训练,如果数据集中包含大量小尺寸图像,原生分辨率训练反而可能提升吞吐量,因为它避免了将小图上采样到大尺寸所带来的不必要计算。
- 性能: 原生分辨率训练通常能带来显著的性能提升,尤其是在需要细粒度识别或对物体形状敏感的任务上。
-
常见坑和调试技巧:
- 注意力掩码错误: 这是最常见的坑。如果掩码不正确,模型会“关注”到填充的零值区域,导致梯度消失或模型学到错误的模式。调试时,可以从
Transformer层中提取注意力权重矩阵并将其可视化,确保在填充区域的注意力权重接近于零。 - 位置编码问题: 检查动态生成的位置编码是否正确应用到了非填充部分,并且在不同尺寸的图像上表现一致。
- Teacher 嵌入对齐: 在使用多 teacher 时,需要确保所有 teacher 的嵌入空间是归一化或对齐的。如果一个 teacher 的嵌入模长远大于其他 teacher,它将在损失计算中占据主导地位。在使用前对所有 teacher 嵌入进行 L2 归一化是一个好习惯。
- 注意力掩码错误: 这是最常见的坑。如果掩码不正确,模型会“关注”到填充的零值区域,导致梯度消失或模型学到错误的模式。调试时,可以从
常见误区与边界情况
-
误区 1: "多 teacher 就是模型集成(Ensemble)": 这是不准确的。模型集成是在推理时组合多个模型的预测结果。而 SigLIP 2 的多 teacher 是一种知识蒸馏,在训练阶段使用多个 teacher 模型来生成更优质的监督信号,目的是训练出一个单一的、更强大的学生模型。推理时只使用这个学生模型。
-
误区 2: "原生分辨率训练就是不对图像做任何处理": 并非如此。它指的是“尽可能保持原始分辨率和长宽比”,但仍然存在一个由硬件(主要是 GPU 显存)决定的上限。超过这个上限的图像还是需要被下采样。此外,patch化本身也是一种图像处理。
-
边界情况 1: Teacher 之间意见不合: 如果多个 teacher 对同一张图片的语义理解存在巨大分歧(例如,一个认为是“猫”,另一个认为是“狗”),它们的嵌入会指向空间中完全不同的方向。这会给学生模型带来矛盾的梯度信号,可能导致训练不稳定或学习到一个“四不像”的折衷表示。这强调了选择高质量、语义一致的 teacher 的重要性。
-
边界情况 2: 极端长宽比的图像: 对于如全景图(例如 10000x500)或条幅广告(例如 500x10000)这类极端长宽比的图像,即使采用原生分辨率训练,填充所带来的计算浪费依然非常巨大。在实践中,可能需要对这类图像进行特殊的裁剪或处理策略。
-
常见面试追问:
- 问: 除了傅里叶特征,还有哪些方法可以实现动态位置编码?
- 答: 另一个优秀的选择是旋转位置编码 (
RoPE),它通过在注意力计算中旋转 query 和 key 向量来注入相对位置信息,天然支持可变序列长度。此外,还可以使用可学习的、可插值的 2D 位置编码,但傅里叶特征/RoPE因其无需学习和良好的泛化性而更受欢迎。
- 答: 另一个优秀的选择是旋转位置编码 (
- 问: 如果给你无限的计算资源,你会如何选择和使用 teacher 模型?
- 答: 我会选择尽可能多样化的SOTA(State-of-the-Art)模型。多样性体现在:1) 架构多样性(如 Transformer-based, CNN-based, 混合架构);2) 训练数据多样性(在不同数据集上训练的 teacher);3) 模态多样性(不仅用文本 teacher,甚至可以引入能理解布局或分割的视觉 teacher)。我会增加 teacher 的数量,并可能设计一个加权方案,根据学生模型与各 teacher 的一致性动态调整它们在损失函数中的权重。
- 问: 原生分辨率训练相比固定尺寸训练,对模型的泛化性有什么具体影响?
- 答: 显著提升泛化性。模型见过了各种真实世界的长宽比和分辨率,对几何形变不再那么敏感。例如,一个只在正方形图像上训练的模型,在面对一个被拉伸的物体时可能会识别失败。而原生分辨率训练的模型由于在训练中已经适应了各种形状,表现会更鲁棒。这使得模型从“适应训练数据分布”向“理解普适视觉概念”更近了一步。
- 问: 除了傅里叶特征,还有哪些方法可以实现动态位置编码?