§1.2.10

DINO / DINOv2 自蒸馏 + 多 crop 策略?

核心概念

DINO (self-DIstillation with NO labels) 是一种自监督学习(SSL)框架,其核心思想是“自蒸馏”。它构建了两个结构相同但权重不同的网络:一个学生网络(student)和一个教师网络(teacher)。教师网络的权重是学生网络权重的指数移动平均(Exponential Moving Average, EMA),因此教师是学生“缓慢”演进的、更稳定的版本。结合多尺度裁剪(Multi-crop)策略,学生网络被训练来预测教师网络在不同(通常是更大)图像视图下的输出分布,从而在没有标签的情况下学习到强大的语义表征。DINOv2 是 DINO 的大规模升级版,通过使用更庞大的精选数据集、更稳定的训练技术和更强的模型架构,成为了一个强大的视觉基础模型。

原理与推导

DINO 的目标是让学生网络 gθsg_{\theta_s} 的输出,在给定不同图像视图(crop)时,能够匹配教师网络 gθtg_{\theta_t} 的输出。这两个网络都由一个主干网络(如 ViT)和一个投影头(projection head)组成。

1. 输出概率分布

对于一个输入图像视图 xx,学生和教师网络分别输出一个 KK 维的特征向量。这些特征向量通过一个带有温度参数 τ\tausoftmax 函数转换为概率分布 PsP_sPtP_t

P(x)(i)=exp(g(x)(i)/τ)k=1Kexp(g(x)(k)/τ)P(x)^{(i)} = \frac{\exp(g(x)^{(i)} / \tau)}{\sum_{k=1}^K \exp(g(x)^{(k)} / \tau)}
  • g(x)(i)g(x)^{(i)} 是网络对输入 xx 输出的第 ii 维 logits。
  • τ\tau 是温度参数。学生网络使用较高的温度 τs\tau_s(如 0.1),教师网络使用较低的温度 τt\tau_t(如 0.04-0.07)。较低的温度会使输出分布变得“尖锐”(sharpening),让教师的预测更具确定性,为学生提供更强的学习信号。

2. 多尺度裁剪策略 (Multi-crop)

这是 DINO 的关键数据增强策略。对于每张输入图片,会生成一个视图集合 VV,包含:

  • 2 个分辨率较高的全局视图(global views),例如 224x224。
  • 多个分辨率较低的局部视图(local views),例如 96x96。

核心思想: 教师网络只看到全局视图,而学生网络需要看到所有视图(全局+局部)。学生的目标是,无论它看到的是全局视图还是局部视图,其输出都应与教师在某个全局视图上的输出保持一致。这迫使模型学习到“部分-整体”的对应关系和尺度不变的特征。

3. 损失函数

损失函数是学生和教师输出分布之间的交叉熵(cross-entropy)。对于一张图片,其损失计算如下:

L=xsVxt{x1g,x2g},xtxsPt(xt)logPs(xs)\mathcal{L} = \sum_{x_s \in V} \sum_{x_t \in \{x_{1}^g, x_{2}^g\}, x_t \neq x_s} - P_t(x_t) \log P_s(x_s)
  • xsx_s 是送入学生网络的任一视图(全局或局部)。
  • xtx_t 是送入教师网络的全局视图之一,且 xtx_txsx_s 不是同一个视图。
  • Pt(xt)P_t(x_t) 是教师的输出分布(作为伪标签),Ps(xs)P_s(x_s) 是学生的输出分布。
  • 动机:这个公式意味着,对于学生看到的每一个 crop(无论是大是小),它都必须预测出教师在另一个“全局” crop 上看到的结果。例如,即使学生只看到一只猫的耳朵(局部视图),它也应该输出与教师看到整只猫(全局视图)时相似的概率分布。

4. 教师网络更新 (EMA)

教师网络的权重 θt\theta_t 不通过反向传播更新。相反,它是学生网络权重 θs\theta_s 的指数移动平均值。在每次训练迭代后,教师权重按以下方式更新:

θtλθt+(1λ)θs\theta_t \leftarrow \lambda \theta_t + (1 - \lambda) \theta_s
  • λ\lambda 是一个动量系数,通常是一个接近 1 的值(例如,从 0.996 逐渐增加到 1)。
  • 动机:这种“慢速”更新机制使得教师网络比学生网络更稳定。学生网络在不断探索和学习,而教师网络则提供了一个稳定、可靠的“平均”目标,有效防止了训练崩溃(即学生和教师输出相同但无意义的恒定值)。

5. 防止模型崩溃 (Collapse) 的额外机制

除了 EMA 教师,DINO 还使用了两种关键技术来避免模型崩溃:

  • 中心化 (Centering):教师的输出在送入损失函数前会减去一个中心值 cc。这个中心值 cc 是教师在整个批次(batch)上输出特征的指数移动平均。gt(x)gt(x)cg_t(x) \leftarrow g_t(x) - c。这可以防止某一维度长期占据主导地位,鼓励网络利用所有维度。
  • 锐化 (Sharpening):如前所述,使用较低的教师温度 τt\tau_t 会让教师的输出分布更尖锐,避免其输出均匀分布这种平凡解。

DINOv2 的改进

DINOv2 继承了 DINO 的核心思想,并在以下方面进行了大规模扩展和优化:

  • 大规模精选数据集:构建了一个包含 1.42 亿张图片的 LVD-142M 数据集。
  • 算法增强:结合了 iBOT 的掩码图像建模(MIM)损失,并对 Swin Transformer 结构进行修改以提高大规模训练的稳定性。
  • 工程优化:使用 FusedAdamW 优化器和高效的 FP16/BFloat16 训练,实现了在海量数据上的高效稳定训练。

代码实现

以下 PyTorch 代码片段展示了 DINO 损失函数和教师更新的核心逻辑。这是一个简化的示例,旨在阐明原理,而非一个完整的训练脚本。

python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from copy import deepcopy
5
6class DINOHead(nn.Module):
7 """
8 DINO 的投影头,一个简单的多层感知机 (MLP)。
9 """
10 def __init__(self, in_dim, out_dim, hidden_dim=2048, n_layers=3):
11 super().__init__()
12 layers = []
13 # 输入层
14 layers.append(nn.Linear(in_dim, hidden_dim))
15 layers.append(nn.GELU())
16 # 隐藏层
17 for _ in range(n_layers - 2):
18 layers.append(nn.Linear(hidden_dim, hidden_dim))
19 layers.append(nn.GELU())
20 # 输出层
21 layers.append(nn.Linear(hidden_dim, out_dim))
22 self.mlp = nn.Sequential(*layers)
23
24 def forward(self, x):
25 return self.mlp(x)
26
27class DINOLoss(nn.Module):
28 """
29 DINO 损失函数的核心实现。
30 """
31 def __init__(self, out_dim, n_global_crops, n_local_crops, student_temp, teacher_temp, center_momentum=0.9):
32 super().__init__()
33 self.student_temp = student_temp
34 self.teacher_temp = teacher_temp
35 self.n_crops = n_global_crops + n_local_crops
36 self.n_global_crops = n_global_crops
37 self.center_momentum = center_momentum
38 # 注册一个持久化的 buffer `center`,它不是模型参数,但会随模型状态一起保存
39 self.register_buffer("center", torch.zeros(1, out_dim))
40
41 def forward(self, student_output, teacher_output):
42 """
43 计算 DINO 损失。
44 student_output: (batch_size * n_crops, out_dim)
45 teacher_output: (batch_size * n_global_crops, out_dim)
46 """
47 # 1. 对学生和教师的输出应用温度 softmax
48 student_out = student_output / self.student_temp
49
50 # 2. 对教师的输出进行锐化和中心化
51 # 教师不参与反向传播,所以使用 .detach()
52 teacher_out = F.softmax((teacher_output - self.center) / self.teacher_temp, dim=-1)
53 teacher_out = teacher_out.detach()
54
55 total_loss = 0
56 n_loss_terms = 0
57
58 # 3. 计算交叉熵损失
59 # 将学生输出按 crop 分割
60 student_out_chunks = student_out.chunk(self.n_crops, dim=0)
61 # 将教师输出按 global crop 分割
62 teacher_out_chunks = teacher_out.chunk(self.n_global_crops, dim=0)
63
64 for i, teacher_chunk in enumerate(teacher_out_chunks):
65 for j, student_chunk in enumerate(student_out_chunks):
66 # 教师和学生不能来自同一个原始 crop
67 if i == j:
68 continue
69
70 # 计算学生在某个 crop 上的输出与教师在另一个 global crop 上的输出的交叉熵
71 loss = torch.sum(-teacher_chunk * F.log_softmax(student_chunk, dim=-1), dim=-1)
72 total_loss += loss.mean()
73 n_loss_terms += 1
74
75 total_loss /= n_loss_terms
76
77 # 4. 更新中心值 center
78 self.update_center(teacher_output)
79
80 return total_loss
81
82 @torch.no_grad()
83 def update_center(self, teacher_output):
84 """
85 使用 EMA 更新中心值。
86 """
87 batch_center = torch.mean(teacher_output, dim=0, keepdim=True)
88 self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
89
90@torch.no_grad()
91def update_teacher_ema(student, teacher, momentum):
92 """
93 使用 EMA 更新教师网络的权重。
94 """
95 for param_s, param_t in zip(student.parameters(), teacher.parameters()):
96 param_t.data.mul_(momentum).add_(param_s.data, alpha=1 - momentum)
97
98# --- 示例运行 ---
99if __name__ == '__main__':
100 # --- 参数设置 ---
101 BATCH_SIZE = 4
102 IN_DIM = 768 # ViT-Base 的特征维度
103 OUT_DIM = 65536 # DINO 论文中的投影维度
104 N_GLOBAL_CROPS = 2
105 N_LOCAL_CROPS = 6
106 N_CROPS = N_GLOBAL_CROPS + N_LOCAL_CROPS
107
108 # --- 模型初始化 ---
109 # 假设 backbone 是一个 ViT 模型
110 student_backbone = nn.Linear(512, IN_DIM) # 伪 ViT
111 teacher_backbone = deepcopy(student_backbone)
112
113 # DINO 投影头
114 student_head = DINOHead(IN_DIM, OUT_DIM)
115 teacher_head = DINOHead(IN_DIM, OUT_DIM)
116
117 # 教师网络与学生网络结构完全相同,但初始权重也相同
118 teacher_head.load_state_dict(student_head.state_dict())
119
120 # 冻结教师网络的梯度,因为它通过 EMA 更新
121 for p in teacher_backbone.parameters():
122 p.requires_grad = False
123 for p in teacher_head.parameters():
124 p.requires_grad = False
125
126 # --- 损失函数和优化器 ---
127 dino_loss_fn = DINOLoss(
128 out_dim=OUT_DIM,
129 n_global_crops=N_GLOBAL_CROPS,
130 n_local_crops=N_LOCAL_CROPS,
131 student_temp=0.1,
132 teacher_temp=0.05
133 )
134 # 优化器只更新学生网络的参数
135 params = list(student_backbone.parameters()) + list(student_head.parameters())
136 optimizer = torch.optim.AdamW(params, lr=0.0005)
137
138 # --- 模拟一次训练迭代 ---
139 # 模拟多 crop 输入
140 # 真实场景中,这是由数据加载器生成的
141 # (batch_size * n_crops, channels, height, width)
142 dummy_input = torch.randn(BATCH_SIZE * N_CROPS, 3, 224, 224)
143 # 伪 backbone 的输入
144 dummy_features = torch.randn(BATCH_SIZE * N_CROPS, 512)
145
146 # 学生网络前向传播 (所有 crops)
147 student_feats = student_backbone(dummy_features)
148 student_output = student_head(student_feats)
149
150 # 教师网络前向传播 (仅 global crops)
151 with torch.no_grad(): # 明确告知 PyTorch 此处无需计算梯度
152 teacher_feats = teacher_backbone(dummy_features[:BATCH_SIZE * N_GLOBAL_CROPS])
153 teacher_output = teacher_head(teacher_feats)
154
155 # 计算损失
156 loss = dino_loss_fn(student_output, teacher_output)
157
158 # 反向传播和优化
159 optimizer.zero_grad()
160 loss.backward()
161 optimizer.step()
162
163 # 更新教师网络权重
164 update_teacher_ema(student_backbone, teacher_backbone, momentum=0.996)
165 update_teacher_ema(student_head, teacher_head, momentum=0.996)
166
167 print(f"单次迭代完成。计算出的损失为: {loss.item():.4f}")
168 print("教师网络权重已通过 EMA 更新。")

工程实践

  • 使用场景:

    • 特征提取器 (Feature Extractor): DINO/DINOv2 最强大的用途是作为通用的、无需微调的特征提取器。预训练好的 ViT 主干网络可以直接用于下游任务,如图像分类、语义分割、目标检测、图像检索等,其 [CLS] token 或 patch tokens 具有丰富的语义信息,性能媲美甚至超越有监督预训练模型。
    • 模型初始化: 使用 DINO/DINOv2 的权重来初始化模型,再在特定任务上进行微调(fine-tuning),通常能比从零开始或用 ImageNet 监督预训练的模型取得更好的性能和更快的收敛速度。
    • 细粒度任务: DINO 学到的特征对物体的局部细节有很好的感知,因此在细粒度识别、实例分割等任务上表现优异。
  • 超参数选择:

    • EMA 动量 λ\lambda: 这是最敏感的超参数之一。通常从 0.996 开始,在训练过程中通过 cosine schedule 逐渐增加到 1.0。动量太低会导致训练不稳定,太高则教师更新过慢,无法跟上学生的学习步伐。
    • 温度 τt,τs\tau_t, \tau_s: 教师温度 τt\tau_t 需足够低(如 0.04-0.07)以产生尖锐的分布。学生温度 τs\tau_s 相对较高(如 0.1)。两者差距不宜过大。
    • Multi-crop 配置: 全局视图通常为 224x224,局部视图为 96x96。局部视图的数量是计算成本和性能的权衡,通常在 4-10 个之间。更多的局部视图能提供更丰富的学习信号,但会显著增加计算量和显存占用。
    • 优化器: AdamW 是标准选择。DINOv2 的实践表明,对于超大规模训练,使用 NVIDIA Apex 提供的 FusedAdamW 可以提升训练速度。
  • 性能 / 显存 / 吞吐 的权衡:

    • 多 crop vs. 吞吐量: Multi-crop 是 DINO 性能的关键,但也是计算瓶颈。每个额外的 crop 都会增加一次前向传播的计算量。在资源有限时,可以适当减少局部视图的数量。
    • 模型大小: DINO 对大模型(如 ViT-L, ViT-G)的扩展性很好。模型越大,从自监督学习中获益越多。但训练和推理成本也随之急剧上升。DINOv2 表明 ViT-L/14 是一个性能和效率的甜点。
    • 混合精度训练: 使用 FP16 或 BFloat16 是训练大模型的标配,可以节省近一半的显存并加速计算。但需要注意数值稳定性,DINOv2 为此特别调整了网络层(如 LayerScale)和优化器。

常见误区与边界情况

  • 误区1:DINO 是普通的知识蒸馏

    • 辨析: DINO 是蒸馏。在传统知识蒸馏中,教师是一个固定的、预先训练好的强大模型。而在 DINO 中,教师是学生自身的 EMA,它与学生共同进化。这不是单向的知识传递,而是一个动态的自学习过程。
  • 误区2:教师网络也通过梯度下降进行训练

    • 辨析: 这是一个核心误解。教师网络的权重完全不通过反向传播更新。它的唯一更新来源是学生权重的 EMA。损失函数的梯度只流向学生网络。
  • 失败模式:模型崩溃 (Collapse)

    • 现象: 网络的输出变成一个常数向量,或者所有输出维度都均等,导致损失为零或一个固定值,模型学不到任何有效信息。这是所有非对比(non-contrastive)自监督学习方法的主要挑战。
    • DINO 的对策:
      1. EMA 教师: 提供稳定目标,是防止崩溃的第一道防线。
      2. 中心化 (Centering): 防止输出的某一维度“饱和”或“死亡”,强制网络利用所有输出维度。
      3. 锐化 (Sharpening): 避免教师输出平凡的均匀分布。
    • 边界情况: 在极小的 batch size 下,中心化的估计会非常不稳定,可能损害训练。因此 DINO 通常需要较大的 batch size。
  • 常见面试追问:

    • 问:为什么不直接用上一个 epoch 的学生作为教师,而要用 EMA?
      • : 上一个 epoch 的学生权重变化可能非常剧烈(noisy),将其作为目标会导致训练不稳定。EMA 像一个低通滤波器,平滑了学生权重的更新,提供了一个更稳定、可靠的教学信号,这是防止模型崩溃的关键。
    • 问:DINO 与 BYOL、MoCo 等其他 SSL 方法有何异同?
      • :
        • 与 MoCo (对比学习): MoCo 依赖大量的负样本对进行对比学习。DINO 属于非对比方法,避免了负样本采样带来的复杂性和潜在偏差。
        • 与 BYOL (非对比学习): 两者都使用 EMA 教师。但 BYOL 在学生网络上增加了一个额外的预测头(predictor),并只在这个预测头上计算损失,结构不对称。DINO 的学生和教师网络结构对称(除了不共享权重),通过中心化和锐化来防止崩溃,机制更简洁。
    • 问:DINOv2 为什么能在下游任务上实现强大的零样本(zero-shot)或少样本(few-shot)性能?
      • : 这归功于三个因素的结合:1) 海量且高质量的数据 (LVD-142M) 提供了足够的多样性;2) 强大的学习范式 (DINO + iBOT) 迫使模型学习到底层的、可组合的视觉概念,而不仅仅是分类边界;3) 大容量模型 (ViT-L/G) 有能力吸收并存储这些复杂的知识。最终得到的特征具有极强的泛化能力,无需微调就能适应新任务。
相关题目