§1.2.12

AIM / AIMv2 自回归视觉预训练?

核心概念

AIM (Autoregressive Image Models) 是一种针对视觉 Transformer (ViT) 的自监督预训练框架。其核心思想借鉴了自然语言处理中的自回归模型(如 GPT),将图像预训练任务构建为一个“预测未来”的过程。具体来说,AIM 将图像的一部分(例如,上半部分图像块)作为上下文(“过去”),并训练模型去自回归地预测剩余被遮蔽部分(“未来”)的视觉特征。AIM 不直接预测像素值,而是预测由一个独立的、预训练好的“视觉分词器”(Target Model,如 CLIP 的图像编码器)所提取的特征表示,这使得模型能专注于学习更高层次的语义信息。

原理与推导

AIM 的核心在于其独特的预训练任务设置,它将掩码图像建模(Masked Image Modeling)与自回归预测相结合。

1. 整体流程

给定一张输入图像 xx,其预训练过程如下:

  1. 图像分块 (Patchify): 将图像 xx 分割成 NN 个不重叠的图像块(patches)序列 P={p1,p2,...,pN}P = \{p_1, p_2, ..., p_N\}
  2. 目标生成 (Target Generation): 使用一个冻结的、预训练好的目标编码器 gg (例如 CLIP ViT),对所有图像块 PP 进行编码,得到目标特征序列 T={t1,t2,...,tN}T = \{t_1, t_2, ..., t_N\},其中 ti=g(pi)t_i = g(p_i)。这些特征是模型需要预测的“答案”。
  3. 掩码策略 (Masking): 采用一种特定的策略将图像块序列分为两部分:可见块集合 PvisP_{vis} 和被掩码块集合 PmaskP_{mask}。AIM 采用一种简单的块状掩码(Block-wise Masking),例如,保留前 50% 的块作为可见块,遮蔽后 50% 的块。这种策略天然地定义了“过去”(可见区)和“未来”(掩码区)。
  4. 自回归预测 (Autoregressive Prediction):
    • 将可见块序列 PvisP_{vis} 输入到待训练的 ViT 模型 fθf_\theta 中,得到其上下文表示。
    • 模型 fθf_\theta 的任务是基于 PvisP_{vis} 的信息,预测出被掩码块 PmaskP_{mask} 对应的目标特征 {tipiPmask}\{t_i | p_i \in P_{mask}\}

2. 数学公式与推导

从概率角度看,AIM 的目标是最大化在给定可见块条件下,被掩码块目标特征的对数似然:

maxθxDlogP(TmaskPvis;θ)\max_{\theta} \sum_{x \in \mathcal{D}} \log P(T_{mask} | P_{vis}; \theta)

这里的“自回归”体现在 P(TmaskPvis)P(T_{mask} | P_{vis}) 的分解上。如果严格按照时序自回归,预测顺序很重要:

logP(TmaskPvis)=piPmasklogP(tiPvis,{tj}pjPmask,j<i)\log P(T_{mask} | P_{vis}) = \sum_{p_i \in P_{mask}} \log P(t_i | P_{vis}, \{t_j\}_{p_j \in P_{mask}, j<i})

然而,在 AIM 的原始实现中,为了计算效率,它做了一个简化。它并非逐个生成掩码块的特征,而是并行地预测所有掩码块的特征。其“自回归”性质主要体现在信息流的单向性上:可见区的信息可以流向掩码区,但反之不行。这通过块状掩码策略得以实现。

因此,实际优化的损失函数是一个回归损失,通常是均方误差(MSE Loss),计算模型预测值与真实目标特征之间的差距:

LAIM=piPmaskfθ(Pvis)iti22\mathcal{L}_{AIM} = \sum_{p_i \in P_{mask}} \| f_\theta(P_{vis})_i - t_i \|_2^2

其中 fθ(Pvis)if_\theta(P_{vis})_i 表示模型基于可见块 PvisP_{vis} 对第 ii 个掩码块特征的预测,而 ti=g(pi)t_i = g(p_i) 是由目标编码器生成的真实特征。

AIMv2 的改进

AIMv2 引入了多尺度和分层预测的思想,使其学习的表示更具层次性。

  • 多尺度目标: 目标编码器 gg 不再只提供最后一层的特征,而是提供多个中间层的特征。例如,浅层特征捕捉局部纹理,深层特征捕捉全局语义。
  • 分层预测: 待训练的模型 fθf_\theta 被设计成一个分层结构,其浅层被训练去预测目标编码器的浅层特征,深层去预测深层特征。这强制模型在不同层级学习不同尺度的信息,提升了表示的质量和泛化能力。

算法复杂度

  • 时间复杂度: 主要由 ViT 中的自注意力机制决定。对于 NvisN_{vis} 个可见块和维度为 DD 的模型,编码器的时间复杂度为 O(Nvis2D)O(N_{vis}^2 D)。预测阶段的复杂度取决于具体实现,但总体上与 ViT 的标准计算量级相当。
  • 空间复杂度: O(Nvis2+NvisD)O(N_{vis}^2 + N_{vis}D)

直观解释 可以把 AIM 想象成一个“看图填空”的画家。面试官先给画家看一幅画的上半部分(PvisP_{vis}),然后要求他补全下半部分(PmaskP_{mask})。但不同于要求他画出每个像素(像 MAE),面试官给了他一个更高维度的要求:你补全的画的每个局部,其“艺术风格”和“语义内容”(由 CLIP 模型定义的特征 tit_i)必须和原作一模一样。这种方式迫使画家不仅要学习模仿,更要理解图像的结构、内容和语义。

代码实现

下面是一个使用 PyTorch 实现的 AIM 核心逻辑的简化示例。我们将使用 timm 库来获取一个 ViT 模型作为主干,并创建一个 mock 的目标模型。

python
1import torch
2import torch.nn as nn
3from timm.models.vision_transformer import VisionTransformer
4
5# 1. 定义一个 Mock 的目标模型 (Target Model)
6# 在真实场景中,这会是一个预训练好的、冻结的强大模型,如 CLIP ViT-L/14
7class MockTargetModel(nn.Module):
8 def __init__(self, patch_size=16, embed_dim=768):
9 super().__init__()
10 # 目标模型也使用一个ViT结构,但它是被冻结的
11 self.model = VisionTransformer(patch_size=patch_size, embed_dim=embed_dim, depth=12, num_heads=12)
12 # 假设这个模型已经预训练好了,设为评估模式并冻结参数
13 self.eval()
14 for param in self.parameters():
15 param.requires_grad = False
16
17 def forward(self, x):
18 # 为什么这样做:目标模型的作用是为每个图像块生成一个高质量的特征表示。
19 # 我们只取其输出的特征,不关心分类头。
20 # forward_features方法通常返回[B, N, D]的块特征
21 return self.model.forward_features(x)
22
23# 2. 定义 AIM 模型
24class AIM(nn.Module):
25 def __init__(self, patch_size=16, embed_dim=768, depth=12, num_heads=12):
26 super().__init__()
27 # 为什么这样做:这是我们要训练的主干模型。
28 self.encoder = VisionTransformer(patch_size=patch_size, embed_dim=embed_dim, depth=depth, num_heads=num_heads)
29
30 # 为什么这样做:预测头用于将ViT输出的维度映射到目标特征的维度。
31 # 在这个例子中,维度相同,但保留这个结构更具通用性。
32 self.predictor = nn.Linear(embed_dim, embed_dim)
33
34 # 为什么这样做:损失函数使用MSE来计算预测特征和目标特征之间的L2距离。
35 self.loss_fn = nn.MSELoss()
36
37 def forward(self, x, target_model, mask_ratio=0.5):
38 # x: 输入图像, [B, C, H, W]
39
40 # --- 步骤 1: 掩码策略 ---
41 # 为什么这样做:这是AIM的核心机制之一,通过块状掩码定义“过去”和“未来”。
42 # ViT内部会将[B, C, H, W]的图像转换为[B, N, D]的块嵌入。N = (H/patch_size) * (W/patch_size)。
43 # 我们在这里模拟这个过程来确定要掩码的块。
44 num_patches = self.encoder.patch_embed.num_patches
45 num_masked = int(num_patches * mask_ratio)
46 num_visible = num_patches - num_masked
47
48 # 随机生成掩码索引,但为了简单和复现AIM的块状掩码,我们使用固定的块状掩码
49 # 假设块是按光栅顺序排列的
50 visible_indices = torch.arange(0, num_visible)
51 masked_indices = torch.arange(num_visible, num_patches)
52
53 # --- 步骤 2: 生成目标特征 ---
54 with torch.no_grad():
55 # 为什么这样做:目标特征是预先计算好的“真值”,整个过程不应有梯度流向目标模型。
56 all_target_features = target_model(x) # [B, N, D]
57 target_features_for_masked = all_target_features[:, masked_indices, :] # [B, num_masked, D]
58
59 # --- 步骤 3: 前向传播 ---
60 # 为什么这样做:ViT编码器只处理可见的图像块,这是自监督学习中提高效率的关键。
61 # timm的ViT没有直接提供仅处理部分块的接口,我们通过手动选择块嵌入来模拟
62 x_patches = self.encoder.patch_embed(x) # [B, N, D]
63 cls_token = self.encoder.cls_token.expand(x.shape[0], -1, -1)
64 x_patches = torch.cat((cls_token, x_patches), dim=1)
65
66 # 添加位置编码
67 x_patches = x_patches + self.encoder.pos_embed
68
69 # 只选择可见块送入Transformer编码器
70 # 注意:timm ViT的pos_embed包含了cls_token的位置,所以索引要+1
71 visible_patches = x_patches[:, [0] + (visible_indices + 1).tolist(), :]
72
73 # 编码器处理可见块
74 encoded_features = self.encoder.blocks(visible_patches)
75 encoded_features = self.encoder.norm(encoded_features) # [B, 1+num_visible, D]
76
77 # --- 步骤 4: 预测掩码块特征 ---
78 # 为什么这样做:模型需要从可见块的上下文中推断出被遮蔽块的信息。
79 # 这里简化处理,使用[CLS] token的输出来预测所有被遮蔽的块。
80 # 更复杂的实现可能会使用专门的decoder或query token。
81 cls_output = encoded_features[:, 0, :] # [B, D]
82
83 # 为了预测多个掩码块,我们可以重复CLS token或者使用更复杂的机制
84 # 这里我们简单地让一个线性层从CLS token预测所有掩码块
85 predicted_features = self.predictor(cls_output).unsqueeze(1).repeat(1, num_masked, 1) # [B, num_masked, D]
86
87 # --- 步骤 5: 计算损失 ---
88 loss = self.loss_fn(predicted_features, target_features_for_masked)
89
90 return loss, predicted_features, target_features_for_masked
91
92# --- 运行示例 ---
93if __name__ == '__main__':
94 # 初始化模型
95 target_model = MockTargetModel()
96 aim_model = AIM()
97
98 # 创建伪数据
99 dummy_images = torch.randn(2, 3, 224, 224) # Batch size 2
100
101 # 执行前向传播和损失计算
102 loss, _, _ = aim_model(dummy_images, target_model, mask_ratio=0.5)
103
104 print(f"AIM Pre-training Loss: {loss.item()}")
105
106 # 验证梯度只在AIM模型中
107 print("\nAIM Model Parameters Requiring Gradients:")
108 for name, param in aim_model.named_parameters():
109 if param.requires_grad:
110 print(name)
111
112 print("\nTarget Model Parameters Requiring Gradients:")
113 has_grad = False
114 for name, param in target_model.named_parameters():
115 if param.requires_grad:
116 print(name)
117 has_grad = True
118 if not has_grad:
119 print("None (as expected)")

工程实践

  • 使用场景: AIM预训练出的模型是强大的视觉特征提取器。它们在各种下游任务中表现出色,尤其是需要理解图像空间关系和密集预测的任务,如:
    • 图像分类 (Image Classification): 在ImageNet等数据集上进行微调。
    • 目标检测 (Object Detection): 作为Faster R-CNN, Mask R-CNN等检测器的骨干网络。
    • 语义分割 (Semantic Segmentation): 作为U-Net, DeepLab等分割模型的编码器。
  • 超参数选择:
    • 目标模型: 这是最重要的选择之一。通常,目标模型越强大(如CLIP ViT-L/14),提供的监督信号质量越高,预训练效果越好。但这也意味着更大的计算开销来生成目标。
    • 掩码率 (Mask Ratio): AIM论文中发现50%到75%的掩码率效果较好。较高的掩码率使任务更难,迫使模型学习更鲁棒的表示。
    • 目标特征归一化: 目标模型(如CLIP)输出的特征向量可能没有固定的范围。在计算损失之前,对目标特征和预测特征进行L2归一化,可以使训练过程更稳定。
    • 学习率与优化器: 通常使用AdamW优化器,配合cosine学习率衰减策略。由于是自监督学习,预训练过程通常需要较长的训练周期(数百个epochs)。
  • 性能/显存/吞吐的权衡:
    • 预训练成本: AIM的预训练成本非常高,需要大规模无标签数据集(如ImageNet-22K, JFT-300M)和大量的TPU/GPU小时。
    • 目标生成: 如果在数据加载时动态生成目标特征,会增加数据预处理的CPU/GPU负担。一种常见的优化是提前计算好所有图像的目标特征,并存储起来,但这需要巨大的磁盘空间。
    • 模型大小: 更大的ViT模型(如ViT-L, ViT-H)通常能获得更好的性能,但训练和推理的成本也更高。
  • 常见坑和调试技巧:
    • 损失不下降: 检查目标特征是否正确归一化。检查学习率是否过高或过低。
    • 可视化是关键: 尽管AIM预测的是特征,但可以通过一个“解码器”(例如,找到特征空间中最近邻的真实图像块)来可视化模型的预测结果。如果模型能重建出有意义的结构,说明训练在正确的轨道上。
    • 模型实现细节: ViT的实现细节(如位置编码、LayerNorm的位置)对最终性能有影响。尽量与经过验证的开源实现(如timm)保持一致。

常见误区与边界情况

  • 误区1: AIM预测像素值
    • 错误点: 认为AIM像MAE(Masked Autoencoders)一样,直接重建图像的原始像素。
    • 纠正: AIM预测的是由一个强大教师模型(如CLIP)提取的抽象特征。这使得模型不必浪费容量去拟合高频、低语义的像素细节,而是专注于学习与教师模型对齐的语义概念。
  • 误区2: AIM与BEiT/MAE没有区别
    • 错误点: 将所有掩码图像建模方法混为一谈。
    • 纠正:
      • AIM vs. MAE: AIM预测特征,MAE预测像素。AIM是自回归框架,MAE是并行解码。
      • AIM vs. BEiT: AIM的目标是连续特征,BEiT的目标是离散的视觉词元(token)(来自dVAE)。AIM的框架类似GPT(自回归),BEiT的框架类似BERT(掩码语言模型)。
  • 误区3: “自回归”意味着逐像素/逐块的缓慢生成
    • 错误点: 将其与文本生成中的严格串行过程等同。
    • 纠正: AIM的“自回归”更多是概念上的,体现在信息流的单向性(从可见区到掩码区)。其实际实现(特别是块状掩码)允许对所有被掩码块进行并行预测,从而保持了计算效率。它是一种区域级别(region-level)的自回归。
  • 边界情况与面试追问:
    • 问: 如果目标模型很差会怎么样?
      • : 预训练的效果会大打折扣。AIM的性能上限在很大程度上受限于教师模型的“知识水平”。如果教师模型本身不能很好地理解图像,它提供的目标特征就是“垃圾”,学生模型自然也学不到有用的东西(Garbage in, garbage out)。
    • 问: 为什么不直接用一个随机初始化的模型当目标编码器?
      • : 这会导致“盲人引路”的问题。一个未经训练的模型输出的特征是随机且无意义的,无法提供有效的学习信号。整个训练过程可能会崩溃或学到平凡解。
    • 问: AIMv2的多尺度设计有什么好处?
      • : 它迫使模型在不同深度学习不同层次的抽象。浅层负责重建纹理细节(对应教师模型的浅层特征),深层负责理解全局结构(对应教师模型的深层特征)。这使得最终学到的特征表示更加丰富和全面,对下游任务的适应性更强。
相关题目