DeiT 蒸馏 token 的训练技巧?
核心概念
DeiT (Data-efficient Image Transformer) 蒸馏是一种专门为 Vision Transformer (ViT) 设计的知识蒸馏训练策略。其核心思想是引入一个额外的、可学习的蒸馏令牌 (distillation token)。这个令牌与标准的 [CLS] 令牌和图像块令牌 (patch tokens) 一同进入 Transformer 网络进行学习。[CLS] 令牌用于拟合真实的硬标签 (ground truth labels),而蒸馏令牌则专门用于拟合教师模型的预测输出,从而将教师模型的“知识”高效地迁移到学生模型中。
原理与推导
DeiT 蒸馏的核心在于其独特的网络结构和损失函数设计,旨在让学生模型同时从真实数据标签和强大的教师模型中学习。
1. 架构修改:引入蒸馏令牌
标准的 ViT 模型在输入序列前加入一个 [CLS] 令牌,其在 Transformer Encoder 最后的输出被用于分类。DeiT 在此基础上,额外加入一个 distill 令牌。
- 输入序列构建:
假设图像被切分成 个 patch,其线性投影后的 embedding 为 。
- 标准
ViT输入: - DeiT 输入: 其中 是可学习的 embedding,而 是位置编码,其长度也相应地从 增加到 。
- 标准
这两个特殊令牌都会通过所有的 Transformer 层,与其他所有令牌进行自注意力计算,从而在全局感受野上汇聚信息。
2. 损失函数设计
DeiT 的总损失函数是两个部分的加权和:一部分来自 [CLS] 令牌与真实标签的交叉熵,另一部分来自 distill 令牌与教师模型预测的蒸馏损失。
-
学生模型输出:
[CLS]令牌通过最终的分类头得到 logits: .distill令牌通过另一个独立的分类头得到 logits: . 其中 代表TransformerEncoder 的输出。
-
教师模型输出: 对于同一个输入图像,教师模型输出其预测的 logits 。
-
总损失函数:
其中:
- 是真实标签的 one-hot 向量。
- 是标准的交叉熵损失。
- 是一个超参数,用于平衡两个损失项,通常设为 。
- 是蒸馏损失,DeiT 主要探讨了两种形式:
a) 硬蒸馏 (Hard Distillation) 这是 DeiT 论文中采用的主要方法。教师模型的预测被视为一个“硬”的伪标签。
- 教师模型的预测标签:
- 蒸馏损失:
这种方法的直观解释是:让
distill令牌学会模仿教师模型的最终决策。b) 软蒸馏 (Soft Distillation) 这是知识蒸馏的经典形式,使用 KL 散度来匹配学生和教师的输出概率分布。
- 蒸馏损失:
其中 是蒸馏温度,较高的 会产生更软的概率分布,从而让学生模型学习到教师模型关于“类间相似性”的暗知识。
DeiT 的成功主要归功于硬蒸馏,它实现起来更简单,且效果非常出色,证明了让学生模型直接学习教师模型的决策结果是一种高效的策略。
3. 推理阶段
在推理时,DeiT 结合了 [CLS] 令牌和 distill 令牌的预测结果,以获得更鲁棒的性能。
最终的预测概率是两个头部输出的 Softmax 结果的平均值:
最终预测类别为 。
4. 复杂度分析
- 时间复杂度: 增加一个令牌对
Transformer的计算复杂度影响很小。自注意力机制的复杂度为 ,其中 是序列长度。序列长度从 变为 ,对于典型的 (224x224 图像,16x16 patch),这个变化是微不足道的。 - 空间复杂度: 仅增加了
distill令牌、其对应的位置编码和一个分类头的参数,参数量增加极少。
代码实现
下面是一个简化的 PyTorch 实现,展示了如何在 ViT 中加入蒸馏令牌以及如何计算相应的损失。
1import torch2import torch.nn as nn3import torch.nn.functional as F45class DistillableViT(nn.Module):6 def __init__(self, *, image_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12, mlp_dim=3072):7 super().__init__()8 num_patches = (image_size // patch_size) ** 29 patch_dim = 3 * patch_size ** 2 # 3 channels1011 self.patch_size = patch_size12 self.patch_to_embedding = nn.Linear(patch_dim, dim)1314 # 核心修改:位置编码的长度为 N+215 self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 2, dim))1617 # 核心修改:同时定义 cls 和 distill 令牌18 self.cls_token = nn.Parameter(torch.randn(1, 1, dim))19 self.dist_token = nn.Parameter(torch.randn(1, 1, dim))2021 # 标准的 Transformer Encoder22 encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim, batch_first=True)23 self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)2425 self.to_latent = nn.Identity()2627 # 核心修改:为 cls 和 distill 令牌分别设置分类头28 self.head = nn.Linear(dim, num_classes)29 self.head_dist = nn.Linear(dim, num_classes)3031 def forward(self, img):32 # 1. 将图像切分为 patches 并进行线性投影33 p = self.patch_size34 x = F.unfold(img, kernel_size=p, stride=p).transpose(1, 2)35 x = self.patch_to_embedding(x)36 b, n, _ = x.shape3738 # 2. 准备 cls 和 distill 令牌39 cls_tokens = self.cls_token.expand(b, -1, -1)40 dist_tokens = self.dist_token.expand(b, -1, -1)4142 # 3. 核心修改:在序列最前面拼接 cls 和 distill 令牌43 # 输入序列: [CLS, DISTILL, PATCH_1, PATCH_2, ...]44 x = torch.cat((cls_tokens, dist_tokens, x), dim=1)4546 # 4. 添加位置编码47 x += self.pos_embedding[:, :(n + 2)]4849 # 5. 通过 Transformer Encoder50 x = self.transformer_encoder(x)5152 # 6. 核心修改:分离 cls 和 distill 令牌的输出53 # x[:, 0] 是 cls 令牌的输出, x[:, 1] 是 distill 令牌的输出54 x_cls = self.to_latent(x[:, 0])55 x_dist = self.to_latent(x[:, 1])5657 # 7. 通过各自的分类头得到 logits58 logits_cls = self.head(x_cls)59 logits_dist = self.head_dist(x_dist)6061 # 在训练时,返回两个 logits62 # 在推理时,可以根据需要返回两者或其组合63 return logits_cls, logits_dist6465# --- 模拟训练过程 ---66if __name__ == '__main__':67 # 1. 初始化模型和数据68 # 学生模型69 student_model = DistillableViT(num_classes=10, image_size=32, patch_size=4, dim=128, depth=4, heads=4, mlp_dim=256)7071 # 模拟一个教师模型 (在实际应用中,这是一个预训练好的、更强大的模型)72 # 这里用一个函数简单模拟73 def teacher_model_mock(images):74 # 教师模型通常更强大,这里仅为演示返回一个随机的 logits75 return torch.randn(images.shape[0], 10)7677 # 模拟输入数据和真实标签78 dummy_images = torch.randn(4, 3, 32, 32) # batch_size=479 true_labels = torch.randint(0, 10, (4,))8081 # 2. 前向传播82 # 学生模型输出两个 logits83 logits_cls, logits_dist = student_model(dummy_images)8485 # 获取教师模型的输出86 with torch.no_grad(): # 教师模型不参与梯度更新87 logits_teacher = teacher_model_mock(dummy_images)8889 # 3. 计算损失 (硬蒸馏)90 loss_ce = F.cross_entropy(logits_cls, true_labels)9192 # 教师模型的硬标签93 teacher_labels = torch.argmax(logits_teacher, dim=1)94 loss_distill = F.cross_entropy(logits_dist, teacher_labels)9596 # 加权组合损失97 lambda_param = 0.598 total_loss = (1 - lambda_param) * loss_ce + lambda_param * loss_distill99100 print(f"真实标签: {true_labels}")101 print(f"教师预测标签: {teacher_labels}")102 print(f"CLS Token Loss (L_CE): {loss_ce.item():.4f}")103 print(f"Distill Token Loss (L_distill): {loss_distill.item():.4f}")104 print(f"Total Loss: {total_loss.item():.4f}")105106 # 4. 反向传播和优化 (标准流程)107 # total_loss.backward()108 # optimizer.step()109 # ...110111 # --- 模拟推理过程 ---112 student_model.eval()113 with torch.no_grad():114 logits_cls_eval, logits_dist_eval = student_model(dummy_images)115116 # 结合两个头的预测117 final_probs = 0.5 * (F.softmax(logits_cls_eval, dim=1) + F.softmax(logits_dist_eval, dim=1))118 final_predictions = torch.argmax(final_probs, dim=1)119120 print("\n--- 推理阶段 ---")121 print(f"最终预测结果: {final_predictions}")
工程实践
-
使用场景:
- 数据量不足时: DeiT 的核心价值在于 "Data-efficient"。当训练数据有限时,从一个在大规模数据集(如 JFT-300M 或 ImageNet-21k)上预训练的教师模型中蒸馏知识,能显著提升
ViT在中小型数据集(如 ImageNet-1k)上的性能,甚至超越在同样数据上训练的 CNN。 - 模型压缩与加速: 当需要部署一个轻量级的
ViT模型时,可以先训练一个更大、更强的ViT或 CNN 作为教师,然后通过蒸馏将其知识迁移到小模型上,使小模型达到远超其独立训练所能达到的性能。 - 利用无标签数据: 如果有大量无标签数据,可以先用教师模型为这些数据打上伪标签,然后用
(无标签数据, 伪标签)对学生模型进行纯蒸馏训练(即设置 )。
- 数据量不足时: DeiT 的核心价值在于 "Data-efficient"。当训练数据有限时,从一个在大规模数据集(如 JFT-300M 或 ImageNet-21k)上预训练的教师模型中蒸馏知识,能显著提升
-
超参数选择:
lambda: 论文中0.5是一个鲁棒的选择,意味着同等重视真实标签和教师指导。可以根据教师模型的准确度和数据标签的噪声程度进行微调。如果教师模型极强,可以适当增加lambda。tau(软蒸馏): 如果选择软蒸馏,tau通常取1到10之间的值。较高的tau关注更广泛的类间关系,较低的tau则接近硬蒸馏。硬蒸馏在 DeiT 中被证明非常有效且无需调参,是首选。- 教师模型选择: 教师模型应显著优于学生模型。DeiT 论文成功地使用了 CNN (如 RegNet) 作为教师来教
ViT学生,证明了跨架构蒸馏的可行性。
-
性能权衡:
- 训练开销: 蒸馏训练需要额外运行教师模型的前向传播,会增加训练时间和计算成本。一个常见的优化是:提前将教师模型对整个训练集的预测结果缓存下来,训练学生时直接读取,避免重复计算。
- 推理开销: 推理时,学生模型增加的蒸馏令牌和第二个分类头带来的开销极小,几乎不影响吞吐量和延迟。最终的 Softmax 平均操作计算量也可以忽略不计。
-
调试技巧:
- 监控两个损失: 训练时应同时监控
L_CE和L_distill。如果L_distill下降缓慢或不下降,可能意味着教师模型质量不高,或者学生模型没有足够的容量学习教师的行为。 - 检查教师伪标签: 随机抽查一些样本,看看教师模型的预测
y_teacher是否合理。如果教师模型在很多简单样本上都出错,那么蒸馏的效果会大打折扣。 - 确认推理逻辑: 确保推理时正确地组合了两个头的输出。一个常见的错误是只用了
[CLS]头,这会浪费蒸馏带来的收益。
- 监控两个损失: 训练时应同时监控
常见误区与边界情况
-
误区一:蒸馏令牌只是另一个
[CLS]令牌 不完全是。虽然它们在结构上类似(都是可学习的向量),但它们在功能上是解耦的。[CLS]令牌的梯度完全来自与真实标签的比较,而distill令牌的梯度完全来自与教师模型输出的比较。这种分离避免了单个令牌接收冲突的监督信号(例如,当真实标签与教师预测不一致时),被认为是 DeiT 成功的关键之一。 -
误区二:硬蒸馏信息量比软蒸馏少,效果一定更差 直觉上软蒸馏传递了更丰富的“暗知识”,但 DeiT 的实验表明,对于分类任务,一个高度准确的教师模型提供的硬标签本身就是非常强的正则化信号。它迫使学生模型在决策边界上与教师模型保持一致,这种方法简单、高效且效果惊人。在实践中,硬蒸馏往往是更优先尝试的选项。
-
边界情况与面试追问:
- 问:如果教师模型和真实标签冲突怎么办?
答:这是蒸馏中常见的情况。DeiT 的双令牌设计优雅地处理了这一点。
[CLS]令牌学习拟合真实标签,distill令牌学习拟合教师标签。通过损失函数的加权和,优化器会找到一个平衡点。最终模型在推理时融合两者的预测,相当于一个微型的模型集成,能够综合两方面的信息,通常会比只学习任何一方更鲁棒。 - 问:为什么不直接将蒸馏损失应用在
[CLS]令牌上? 答:可以这样做,这属于传统的知识蒸馏。但如上所述,这会导致[CLS]令牌同时接收两个可能冲突的监督信号(来自 的梯度和来自 的梯度),可能会干扰学习过程。DeiT 引入distill令牌,创建了一个专门的“通道”来接收教师的知识,使得梯度路径更清晰,优化目标更明确。 - 问:当 时,这个方法用在什么场景? 答:当 时,模型完全忽略真实标签,只学习教师的输出。这在半监督或自监督学习中非常有用。例如,你有一个在海量数据上训练的昂贵教师模型,和大量无标签数据。你可以用教师模型给这些无标签数据打上伪标签,然后用纯蒸馏()的方式训练一个轻量级的学生模型。
- 问:如果教师模型和真实标签冲突怎么办?
答:这是蒸馏中常见的情况。DeiT 的双令牌设计优雅地处理了这一点。