DINO / DINOv2 自蒸馏 + 多 crop 策略?
核心概念
DINO (self-DIstillation with NO labels) 是一种自监督学习(SSL)框架,其核心思想是“自蒸馏”。它构建了两个结构相同但权重不同的网络:一个学生网络(student)和一个教师网络(teacher)。教师网络的权重是学生网络权重的指数移动平均(Exponential Moving Average, EMA),因此教师是学生“缓慢”演进的、更稳定的版本。结合多尺度裁剪(Multi-crop)策略,学生网络被训练来预测教师网络在不同(通常是更大)图像视图下的输出分布,从而在没有标签的情况下学习到强大的语义表征。DINOv2 是 DINO 的大规模升级版,通过使用更庞大的精选数据集、更稳定的训练技术和更强的模型架构,成为了一个强大的视觉基础模型。
原理与推导
DINO 的目标是让学生网络 的输出,在给定不同图像视图(crop)时,能够匹配教师网络 的输出。这两个网络都由一个主干网络(如 ViT)和一个投影头(projection head)组成。
1. 输出概率分布
对于一个输入图像视图 ,学生和教师网络分别输出一个 维的特征向量。这些特征向量通过一个带有温度参数 的 softmax 函数转换为概率分布 和 。
- 是网络对输入 输出的第 维 logits。
- 是温度参数。学生网络使用较高的温度 (如 0.1),教师网络使用较低的温度 (如 0.04-0.07)。较低的温度会使输出分布变得“尖锐”(sharpening),让教师的预测更具确定性,为学生提供更强的学习信号。
2. 多尺度裁剪策略 (Multi-crop)
这是 DINO 的关键数据增强策略。对于每张输入图片,会生成一个视图集合 ,包含:
- 2 个分辨率较高的全局视图(global views),例如 224x224。
- 多个分辨率较低的局部视图(local views),例如 96x96。
核心思想: 教师网络只看到全局视图,而学生网络需要看到所有视图(全局+局部)。学生的目标是,无论它看到的是全局视图还是局部视图,其输出都应与教师在某个全局视图上的输出保持一致。这迫使模型学习到“部分-整体”的对应关系和尺度不变的特征。
3. 损失函数
损失函数是学生和教师输出分布之间的交叉熵(cross-entropy)。对于一张图片,其损失计算如下:
- 是送入学生网络的任一视图(全局或局部)。
- 是送入教师网络的全局视图之一,且 与 不是同一个视图。
- 是教师的输出分布(作为伪标签), 是学生的输出分布。
- 动机:这个公式意味着,对于学生看到的每一个 crop(无论是大是小),它都必须预测出教师在另一个“全局” crop 上看到的结果。例如,即使学生只看到一只猫的耳朵(局部视图),它也应该输出与教师看到整只猫(全局视图)时相似的概率分布。
4. 教师网络更新 (EMA)
教师网络的权重 不通过反向传播更新。相反,它是学生网络权重 的指数移动平均值。在每次训练迭代后,教师权重按以下方式更新:
- 是一个动量系数,通常是一个接近 1 的值(例如,从 0.996 逐渐增加到 1)。
- 动机:这种“慢速”更新机制使得教师网络比学生网络更稳定。学生网络在不断探索和学习,而教师网络则提供了一个稳定、可靠的“平均”目标,有效防止了训练崩溃(即学生和教师输出相同但无意义的恒定值)。
5. 防止模型崩溃 (Collapse) 的额外机制
除了 EMA 教师,DINO 还使用了两种关键技术来避免模型崩溃:
- 中心化 (Centering):教师的输出在送入损失函数前会减去一个中心值 。这个中心值 是教师在整个批次(batch)上输出特征的指数移动平均。。这可以防止某一维度长期占据主导地位,鼓励网络利用所有维度。
- 锐化 (Sharpening):如前所述,使用较低的教师温度 会让教师的输出分布更尖锐,避免其输出均匀分布这种平凡解。
DINOv2 的改进
DINOv2 继承了 DINO 的核心思想,并在以下方面进行了大规模扩展和优化:
- 大规模精选数据集:构建了一个包含 1.42 亿张图片的 LVD-142M 数据集。
- 算法增强:结合了 iBOT 的掩码图像建模(MIM)损失,并对 Swin
Transformer结构进行修改以提高大规模训练的稳定性。 - 工程优化:使用
FusedAdamW优化器和高效的 FP16/BFloat16 训练,实现了在海量数据上的高效稳定训练。
代码实现
以下 PyTorch 代码片段展示了 DINO 损失函数和教师更新的核心逻辑。这是一个简化的示例,旨在阐明原理,而非一个完整的训练脚本。
1import torch2import torch.nn as nn3import torch.nn.functional as F4from copy import deepcopy56class DINOHead(nn.Module):7 """8 DINO 的投影头,一个简单的多层感知机 (MLP)。9 """10 def __init__(self, in_dim, out_dim, hidden_dim=2048, n_layers=3):11 super().__init__()12 layers = []13 # 输入层14 layers.append(nn.Linear(in_dim, hidden_dim))15 layers.append(nn.GELU())16 # 隐藏层17 for _ in range(n_layers - 2):18 layers.append(nn.Linear(hidden_dim, hidden_dim))19 layers.append(nn.GELU())20 # 输出层21 layers.append(nn.Linear(hidden_dim, out_dim))22 self.mlp = nn.Sequential(*layers)2324 def forward(self, x):25 return self.mlp(x)2627class DINOLoss(nn.Module):28 """29 DINO 损失函数的核心实现。30 """31 def __init__(self, out_dim, n_global_crops, n_local_crops, student_temp, teacher_temp, center_momentum=0.9):32 super().__init__()33 self.student_temp = student_temp34 self.teacher_temp = teacher_temp35 self.n_crops = n_global_crops + n_local_crops36 self.n_global_crops = n_global_crops37 self.center_momentum = center_momentum38 # 注册一个持久化的 buffer `center`,它不是模型参数,但会随模型状态一起保存39 self.register_buffer("center", torch.zeros(1, out_dim))4041 def forward(self, student_output, teacher_output):42 """43 计算 DINO 损失。44 student_output: (batch_size * n_crops, out_dim)45 teacher_output: (batch_size * n_global_crops, out_dim)46 """47 # 1. 对学生和教师的输出应用温度 softmax48 student_out = student_output / self.student_temp4950 # 2. 对教师的输出进行锐化和中心化51 # 教师不参与反向传播,所以使用 .detach()52 teacher_out = F.softmax((teacher_output - self.center) / self.teacher_temp, dim=-1)53 teacher_out = teacher_out.detach()5455 total_loss = 056 n_loss_terms = 05758 # 3. 计算交叉熵损失59 # 将学生输出按 crop 分割60 student_out_chunks = student_out.chunk(self.n_crops, dim=0)61 # 将教师输出按 global crop 分割62 teacher_out_chunks = teacher_out.chunk(self.n_global_crops, dim=0)6364 for i, teacher_chunk in enumerate(teacher_out_chunks):65 for j, student_chunk in enumerate(student_out_chunks):66 # 教师和学生不能来自同一个原始 crop67 if i == j:68 continue6970 # 计算学生在某个 crop 上的输出与教师在另一个 global crop 上的输出的交叉熵71 loss = torch.sum(-teacher_chunk * F.log_softmax(student_chunk, dim=-1), dim=-1)72 total_loss += loss.mean()73 n_loss_terms += 17475 total_loss /= n_loss_terms7677 # 4. 更新中心值 center78 self.update_center(teacher_output)7980 return total_loss8182 @torch.no_grad()83 def update_center(self, teacher_output):84 """85 使用 EMA 更新中心值。86 """87 batch_center = torch.mean(teacher_output, dim=0, keepdim=True)88 self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)8990@torch.no_grad()91def update_teacher_ema(student, teacher, momentum):92 """93 使用 EMA 更新教师网络的权重。94 """95 for param_s, param_t in zip(student.parameters(), teacher.parameters()):96 param_t.data.mul_(momentum).add_(param_s.data, alpha=1 - momentum)9798# --- 示例运行 ---99if __name__ == '__main__':100 # --- 参数设置 ---101 BATCH_SIZE = 4102 IN_DIM = 768 # ViT-Base 的特征维度103 OUT_DIM = 65536 # DINO 论文中的投影维度104 N_GLOBAL_CROPS = 2105 N_LOCAL_CROPS = 6106 N_CROPS = N_GLOBAL_CROPS + N_LOCAL_CROPS107108 # --- 模型初始化 ---109 # 假设 backbone 是一个 ViT 模型110 student_backbone = nn.Linear(512, IN_DIM) # 伪 ViT111 teacher_backbone = deepcopy(student_backbone)112113 # DINO 投影头114 student_head = DINOHead(IN_DIM, OUT_DIM)115 teacher_head = DINOHead(IN_DIM, OUT_DIM)116117 # 教师网络与学生网络结构完全相同,但初始权重也相同118 teacher_head.load_state_dict(student_head.state_dict())119120 # 冻结教师网络的梯度,因为它通过 EMA 更新121 for p in teacher_backbone.parameters():122 p.requires_grad = False123 for p in teacher_head.parameters():124 p.requires_grad = False125126 # --- 损失函数和优化器 ---127 dino_loss_fn = DINOLoss(128 out_dim=OUT_DIM,129 n_global_crops=N_GLOBAL_CROPS,130 n_local_crops=N_LOCAL_CROPS,131 student_temp=0.1,132 teacher_temp=0.05133 )134 # 优化器只更新学生网络的参数135 params = list(student_backbone.parameters()) + list(student_head.parameters())136 optimizer = torch.optim.AdamW(params, lr=0.0005)137138 # --- 模拟一次训练迭代 ---139 # 模拟多 crop 输入140 # 真实场景中,这是由数据加载器生成的141 # (batch_size * n_crops, channels, height, width)142 dummy_input = torch.randn(BATCH_SIZE * N_CROPS, 3, 224, 224)143 # 伪 backbone 的输入144 dummy_features = torch.randn(BATCH_SIZE * N_CROPS, 512)145146 # 学生网络前向传播 (所有 crops)147 student_feats = student_backbone(dummy_features)148 student_output = student_head(student_feats)149150 # 教师网络前向传播 (仅 global crops)151 with torch.no_grad(): # 明确告知 PyTorch 此处无需计算梯度152 teacher_feats = teacher_backbone(dummy_features[:BATCH_SIZE * N_GLOBAL_CROPS])153 teacher_output = teacher_head(teacher_feats)154155 # 计算损失156 loss = dino_loss_fn(student_output, teacher_output)157158 # 反向传播和优化159 optimizer.zero_grad()160 loss.backward()161 optimizer.step()162163 # 更新教师网络权重164 update_teacher_ema(student_backbone, teacher_backbone, momentum=0.996)165 update_teacher_ema(student_head, teacher_head, momentum=0.996)166167 print(f"单次迭代完成。计算出的损失为: {loss.item():.4f}")168 print("教师网络权重已通过 EMA 更新。")
工程实践
-
使用场景:
- 特征提取器 (Feature Extractor): DINO/DINOv2 最强大的用途是作为通用的、无需微调的特征提取器。预训练好的
ViT主干网络可以直接用于下游任务,如图像分类、语义分割、目标检测、图像检索等,其 [CLS] token 或 patch tokens 具有丰富的语义信息,性能媲美甚至超越有监督预训练模型。 - 模型初始化: 使用 DINO/DINOv2 的权重来初始化模型,再在特定任务上进行微调(fine-tuning),通常能比从零开始或用 ImageNet 监督预训练的模型取得更好的性能和更快的收敛速度。
- 细粒度任务: DINO 学到的特征对物体的局部细节有很好的感知,因此在细粒度识别、实例分割等任务上表现优异。
- 特征提取器 (Feature Extractor): DINO/DINOv2 最强大的用途是作为通用的、无需微调的特征提取器。预训练好的
-
超参数选择:
- EMA 动量 : 这是最敏感的超参数之一。通常从 0.996 开始,在训练过程中通过 cosine schedule 逐渐增加到 1.0。动量太低会导致训练不稳定,太高则教师更新过慢,无法跟上学生的学习步伐。
- 温度 : 教师温度 需足够低(如 0.04-0.07)以产生尖锐的分布。学生温度 相对较高(如 0.1)。两者差距不宜过大。
- Multi-crop 配置: 全局视图通常为 224x224,局部视图为 96x96。局部视图的数量是计算成本和性能的权衡,通常在 4-10 个之间。更多的局部视图能提供更丰富的学习信号,但会显著增加计算量和显存占用。
- 优化器: AdamW 是标准选择。DINOv2 的实践表明,对于超大规模训练,使用 NVIDIA Apex 提供的
FusedAdamW可以提升训练速度。
-
性能 / 显存 / 吞吐 的权衡:
- 多 crop vs. 吞吐量: Multi-crop 是 DINO 性能的关键,但也是计算瓶颈。每个额外的 crop 都会增加一次前向传播的计算量。在资源有限时,可以适当减少局部视图的数量。
- 模型大小: DINO 对大模型(如 ViT-L, ViT-G)的扩展性很好。模型越大,从自监督学习中获益越多。但训练和推理成本也随之急剧上升。DINOv2 表明 ViT-L/14 是一个性能和效率的甜点。
- 混合精度训练: 使用 FP16 或 BFloat16 是训练大模型的标配,可以节省近一半的显存并加速计算。但需要注意数值稳定性,DINOv2 为此特别调整了网络层(如 LayerScale)和优化器。
常见误区与边界情况
-
误区1:DINO 是普通的知识蒸馏
- 辨析: DINO 是自蒸馏。在传统知识蒸馏中,教师是一个固定的、预先训练好的强大模型。而在 DINO 中,教师是学生自身的 EMA,它与学生共同进化。这不是单向的知识传递,而是一个动态的自学习过程。
-
误区2:教师网络也通过梯度下降进行训练
- 辨析: 这是一个核心误解。教师网络的权重完全不通过反向传播更新。它的唯一更新来源是学生权重的 EMA。损失函数的梯度只流向学生网络。
-
失败模式:模型崩溃 (Collapse)
- 现象: 网络的输出变成一个常数向量,或者所有输出维度都均等,导致损失为零或一个固定值,模型学不到任何有效信息。这是所有非对比(non-contrastive)自监督学习方法的主要挑战。
- DINO 的对策:
- EMA 教师: 提供稳定目标,是防止崩溃的第一道防线。
- 中心化 (Centering): 防止输出的某一维度“饱和”或“死亡”,强制网络利用所有输出维度。
- 锐化 (Sharpening): 避免教师输出平凡的均匀分布。
- 边界情况: 在极小的 batch size 下,中心化的估计会非常不稳定,可能损害训练。因此 DINO 通常需要较大的 batch size。
-
常见面试追问:
- 问:为什么不直接用上一个 epoch 的学生作为教师,而要用 EMA?
- 答: 上一个 epoch 的学生权重变化可能非常剧烈(noisy),将其作为目标会导致训练不稳定。EMA 像一个低通滤波器,平滑了学生权重的更新,提供了一个更稳定、可靠的教学信号,这是防止模型崩溃的关键。
- 问:DINO 与 BYOL、MoCo 等其他 SSL 方法有何异同?
- 答:
- 与 MoCo (对比学习): MoCo 依赖大量的负样本对进行对比学习。DINO 属于非对比方法,避免了负样本采样带来的复杂性和潜在偏差。
- 与 BYOL (非对比学习): 两者都使用 EMA 教师。但 BYOL 在学生网络上增加了一个额外的预测头(predictor),并只在这个预测头上计算损失,结构不对称。DINO 的学生和教师网络结构对称(除了不共享权重),通过中心化和锐化来防止崩溃,机制更简洁。
- 答:
- 问:DINOv2 为什么能在下游任务上实现强大的零样本(zero-shot)或少样本(few-shot)性能?
- 答: 这归功于三个因素的结合:1) 海量且高质量的数据 (LVD-142M) 提供了足够的多样性;2) 强大的学习范式 (DINO + iBOT) 迫使模型学习到底层的、可组合的视觉概念,而不仅仅是分类边界;3) 大容量模型 (ViT-L/G) 有能力吸收并存储这些复杂的知识。最终得到的特征具有极强的泛化能力,无需微调就能适应新任务。
- 问:为什么不直接用上一个 epoch 的学生作为教师,而要用 EMA?