§1.2.11

I-JEPA / V-JEPA 2 的预测式自监督?

好的,我们来深入解析 I-JEPA 与 V-JEPA 2 的预测式自监督学习范式。

核心概念

I-JEPA(Image Joint-Embedding Predictive Architecture)是一种非生成式的自监督学习方法。其核心思想是,在一个高维的表征空间中,根据一个上下文(Context)块的表征,来预测同一个样本中其他目标(Target)块的表征。它不预测像素本身,而是预测由一个动量更新的目标编码器(Target Encoder)生成的抽象特征。这种“在表征空间中填空”的方式,迫使模型学习数据中更具语义的、可预测的底层规律,而不是去拟合高频、低层次的像素细节。V-JEPA 2 将此思想从图像域扩展到了视频和音频的多模态领域。

原理与推导

JEPA 的架构主要由三个部分组成:编码器(Encoder)、目标编码器(Target Encoder)和预测器(Predictor)。

  1. 编码器 (Encoder) fθf_\theta: 这是我们主要训练的模型,通常是一个 Vision Transformer (ViT)。它接收一部分可见的输入块(上下文块),并输出它们的表征。
  2. 目标编码器 (Target Encoder) fθˉf_{\bar{\theta}}: 它的网络结构与编码器完全相同,但不通过反向传播进行训练。其权重 θˉ\bar{\theta} 是编码器权重 θ\theta 的指数移动平均(Exponential Moving Average, EMA)。它接收目标块(未被遮蔽的原始图像块)作为输入,生成稳定、高质量的目标表征。
  3. 预测器 (Predictor) gϕg_\phi: 这是一个相对较小的模型(如一个较浅的 ViT),它接收编码器输出的上下文表征,并结合目标块的位置信息,来预测目标块的表征。

数学推导:

假设输入图像为 xx,我们将其分割成一系列不重叠的图像块(patches)。我们从中采样一个上下文块集合 xCx_C 和多个目标块集合 xT1,xT2,...,xTNx_{T_1}, x_{T_2}, ..., x_{T_N}

  1. 上下文表征: 编码器 fθf_\theta 处理上下文块 xCx_C,输出上下文表征 zCz_C。在 ViT 中,这对应于上下文块的输出 token 序列。 zC=fθ(xC)z_C = f_\theta(x_C)

  2. 目标表征: 目标编码器 fθˉf_{\bar{\theta}} 处理未被遮蔽的原始目标块 xTix_{T_i},输出作为“真值”的目标表征 zTiz_{T_i}。这一步不计算梯度zTi=fθˉ(xTi)(stop-gradient)z_{T_i} = f_{\bar{\theta}}(x_{T_i}) \quad (\text{stop-gradient})

  3. 预测: 预测器 gϕg_\phi 接收上下文表征 zCz_C 和目标块 xTix_{T_i} 对应的位置编码 pTip_{T_i},输出预测的目标表征 z^Ti\hat{z}_{T_i}z^Ti=gϕ(zC,pTi)\hat{z}_{T_i} = g_\phi(z_C, p_{T_i}) 预测器需要知道它要预测哪个位置的表征,因此需要位置编码作为输入。

  4. 损失函数: 损失函数是所有目标块上预测表征与真实表征之间的平均L2距离(MSE)。 L(θ,ϕ)=1Ni=1Nz^TizTi22\mathcal{L}(\theta, \phi) = \frac{1}{N} \sum_{i=1}^{N} || \hat{z}_{T_i} - z_{T_i} ||_2^2 这个损失会同时更新编码器 fθf_\theta 和预测器 gϕg_\phi 的参数。

  5. 目标编码器更新: 在每次梯度更新后,目标编码器的权重 θˉ\bar{\theta} 通过 EMA 进行更新。 θˉmθˉ+(1m)θ\bar{\theta} \leftarrow m \cdot \bar{\theta} + (1-m) \cdot \theta 其中 mm 是动量系数,通常是一个接近 1 的值(如 0.996)。这使得目标编码器缓慢地、稳定地跟踪主编码器的变化,为预测任务提供了一致且高质量的目标。

算法流程与复杂度:

  • 算法流程:

    1. 从数据集中采样一个样本 xx
    2. 生成一个上下文掩码和多个目标掩码。I-JEPA 使用一种多块掩码策略,即目标块本身可能由多个小图像块组成,以鼓励模型学习更大范围的语义。
    3. 将上下文块输入编码器 fθf_\theta
    4. 将(未遮蔽的)目标块输入目标编码器 fθˉf_{\bar{\theta}}
    5. fθf_\theta 的输出和目标块位置送入预测器 gϕg_\phi
    6. 计算预测表征和目标表征之间的 MSE 损失。
    7. 执行反向传播,更新 θ\thetaϕ\phi
    8. 使用 EMA 更新 θˉ\bar{\theta}
  • 时间/空间复杂度:

    • 时间复杂度: 对于 ViT 实现,主要开销是 Transformer 块的自注意力计算,为 O(L2D)O(L^2 \cdot D),其中 LL 是序列长度, DD 是嵌入维度。由于编码器只处理 LCL_C 个上下文块(LC<LL_C < L),其计算量小于处理完整图像。总时间复杂度大致为 O(LC2Denc)+O(LT2Denc)+O(LC2Dpred)O(L_C^2 \cdot D_{\text{enc}}) + O(L_T^2 \cdot D_{\text{enc}}) + O(L_C^2 \cdot D_{\text{pred}}) 分别对应编码器、目标编码器和预测器的前向传播。
    • 空间复杂度: 主要由模型参数和激活值决定。由于目标编码器不需要存储梯度,相比于需要完整反向传播的架构,JEPA 在内存上更高效。它也无需像 MAE 那样有一个庞大的像素级解码器。

直观解释 (信息论角度):

JEPA 试图学习一个“世界模型”。它假设世界是可预测的,一个场景的局部信息(上下文)包含了推断其他部分(目标)的线索。但它不关心像素级的精确复现(高互信息,但可能语义价值低),而是关心概念级的抽象复现。通过在表征空间进行预测,模型被激励去丢弃不相关的、随机的、高频的细节(如草叶的精确纹理),而专注于捕捉可预测的、具有语义的结构(如“这里应该是一片草地”)。动量目标编码器提供了一个稳定的语义“词典”,防止模型在学习过程中“自说自话”导致坍塌。

V-JEPA 2 的扩展:

V-JEPA 2 将此范式应用于视频和音频。其核心思想不变,但预测任务变为时空预测:

  • 输入: 一段视频(或音频)的上下文片段。
  • 目标: 未来某个时间点的视频(或音频)片段。
  • 掩码: 掩码被应用在未来的目标片段上,使得预测任务更具挑战性(例如,根据视频前2秒,预测第4秒被遮挡区域的内容)。
  • 多模态: V-JEPA 2 可以实现跨模态预测,例如用音频上下文预测视频目标表征,或反之,从而学习到视听之间的关联。

代码实现

下面是一个简化的 I-JEPA 核心逻辑的 PyTorch 实现。为了教学目的,我们使用简单的随机掩码,并用 nn.TransformerEncoder 模拟 ViT 的核心部分。

python
1import torch
2import torch.nn as nn
3import copy
4
5class SimpleViT(nn.Module):
6 """一个简化的 Vision Transformer 骨干网络"""
7 def __init__(self, dim=256, depth=4, heads=8, patch_size=16, image_size=224):
8 super().__init__()
9 self.patch_size = patch_size
10 num_patches = (image_size // patch_size) ** 2
11
12 # 1. Patch 和位置嵌入
13 self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
14 self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
15
16 # 2. Transformer 编码器
17 encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim * 4, batch_first=True)
18 self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
19
20 def forward(self, x, patch_indices):
21 # 将图像转换为 patch 嵌入
22 patches = self.patch_embedding(x).flatten(2).transpose(1, 2) # (B, N, D)
23
24 # 根据索引选择需要的 patch
25 selected_patches = torch.gather(patches, 1, patch_indices.unsqueeze(-1).expand(-1, -1, patches.shape[-1]))
26 selected_pos = torch.gather(self.pos_embedding.expand(patches.shape[0], -1, -1), 1, patch_indices.unsqueeze(-1).expand(-1, -1, patches.shape[-1]))
27
28 # 添加位置编码并输入 Transformer
29 return self.transformer_encoder(selected_patches + selected_pos)
30
31class IJEPA(nn.Module):
32 def __init__(self, dim=256, pred_depth=2, enc_depth=4, heads=8, patch_size=16, image_size=224, momentum=0.996):
33 super().__init__()
34 self.momentum = momentum
35 self.patch_size = patch_size
36 self.image_size = image_size
37 self.num_patches = (image_size // patch_size) ** 2
38
39 # 1. 初始化编码器和预测器
40 # 为什么这样做: 这是我们主要训练的模型
41 self.encoder = SimpleViT(dim=dim, depth=enc_depth, heads=heads, patch_size=patch_size, image_size=image_size)
42 self.predictor = SimpleViT(dim=dim, depth=pred_depth, heads=heads, patch_size=patch_size, image_size=image_size)
43
44 # 2. 初始化目标编码器
45 # 为什么这样做: 目标编码器是编码器的副本,用于生成稳定的目标表征,不参与梯度计算
46 self.target_encoder = copy.deepcopy(self.encoder)
47 for param in self.target_encoder.parameters():
48 param.requires_grad = False
49
50 self.loss_fn = nn.MSELoss()
51
52 @torch.no_grad()
53 def _update_target_encoder(self):
54 """
55 使用指数移动平均(EMA)更新目标编码器
56 为什么这样做: 这使得目标网络缓慢地跟踪在线网络的权重,提供稳定的学习目标,防止模型坍塌。
57 """
58 for param_enc, param_target in zip(self.encoder.parameters(), self.target_encoder.parameters()):
59 param_target.data = param_target.data * self.momentum + param_enc.data * (1. - self.momentum)
60
61 def forward(self, x):
62 batch_size = x.shape[0]
63
64 # 3. 生成上下文和目标掩码
65 # 为什么这样做: 这是JEPA的核心,将输入划分为可见的上下文和需要预测的目标
66 # 为了简化,我们随机选择75%作为上下文,25%作为目标
67 all_indices = torch.randperm(self.num_patches, device=x.device).unsqueeze(0).expand(batch_size, -1)
68 context_size = int(0.75 * self.num_patches)
69 context_indices = all_indices[:, :context_size]
70 target_indices = all_indices[:, context_size:]
71
72 # 4. 计算上下文表征
73 # 为什么这样做: 编码器只看到部分图像(上下文),并为其生成表征
74 context_representation = self.encoder(x, context_indices)
75
76 # 5. 计算目标表征 (使用目标编码器)
77 # 为什么这样做: 目标编码器看到目标的原始、未遮蔽的图像块,并生成高质量的“真值”表征
78 with torch.no_grad():
79 self.target_encoder.eval()
80 target_representation = self.target_encoder(x, target_indices)
81
82 # 6. 预测目标表征
83 # 为什么这样做: 预测器基于上下文信息,尝试在目标位置上重建出目标表征
84 # 预测器的输入是上下文表征序列,加上目标块的位置编码
85 # 注意:一个简单的实现是让预测器 attend 到上下文表征上
86 # 这里我们简化,将目标位置编码作为 query,上下文表征作为 key/value
87 # 为了更简单,我们直接将上下文表征和目标位置编码一起送入预测器
88 pred_input_indices = torch.cat([context_indices, target_indices], dim=1)
89 # 模拟预测器只使用上下文信息来预测目标
90 # 实际实现中,预测器会使用特殊的 maskable attention
91 # 这里我们用一个简化逻辑:将上下文表征和目标位置编码送入预测器
92 # 提取目标位置的pos embedding
93 target_pos_emb = torch.gather(self.predictor.pos_embedding.expand(batch_size, -1, -1), 1, target_indices.unsqueeze(-1).expand(-1, -1, self.predictor.pos_embedding.shape[-1]))
94
95 # 简单地将上下文表征的平均值与目标位置编码相加作为预测器输入
96 # 这是一个非常粗糙的模拟,真实实现会更复杂
97 avg_context_rep = context_representation.mean(dim=1, keepdim=True)
98 predictor_input = avg_context_rep + target_pos_emb
99 predicted_representation = self.predictor.transformer_encoder(predictor_input)
100
101 # 7. 计算损失
102 loss = self.loss_fn(predicted_representation, target_representation)
103
104 return loss
105
106# --- 运行示例 ---
107if __name__ == '__main__':
108 model = IJEPA(
109 dim=256,
110 enc_depth=4,
111 pred_depth=2,
112 heads=8,
113 patch_size=16,
114 image_size=224
115 ).cuda()
116
117 # 创建一个虚拟输入图像
118 dummy_image = torch.randn(2, 3, 224, 224).cuda()
119
120 # 前向传播计算损失
121 loss = model(dummy_image)
122 print(f"Initial Loss: {loss.item()}")
123
124 # 模拟一次优化步骤
125 # optimizer.zero_grad()
126 loss.backward()
127 # optimizer.step()
128
129 # 更新目标编码器
130 model._update_target_encoder()
131 print("Model forward pass and target encoder update successful.")

工程实践

  • 使用场景:

    • I-JEPA: 作为强大的视觉预训练模型,其学到的表征在各种下游任务(如图像分类、目标检测、语义分割)上表现出色,通常只需要进行线性探测(Linear Probing)或少量微调。由于其学习的是语义表征,特别适合需要高层理解的任务。
    • V-JEPA 2: 在视频理解领域有巨大潜力,如动作识别、视频问答、视频内容生成等。其多模态能力也使其可用于视听联合任务,如根据声音定位视频中的事件。
  • 超参数选择的经验法则:

    • 掩码策略: 这是效果的关键。I-JEPA 论文中提出了一种多块(multi-block)掩码策略,即生成多个不同尺度和长宽比的目标块。这比简单的随机掩码更有效,因为它创造了从局部纹理到全局结构的多种预测任务。通常会有一个较大的上下文块(保留约75%的图像)和4个尺度不一的目标块(覆盖约15%的图像)。
    • 动量 m: 值非常接近1,如0.996到0.999。在训练初期可以使用较低的值,然后随训练进程逐渐退火到接近1。
    • 预测器深度/宽度: 预测器应显著小于编码器,例如,编码器12层,预测器可以是2-4层。这构成了一个信息瓶颈,迫使编码器学习更高效、更抽象的表征。
    • 学习率与优化器: 通常使用 AdamW 优化器,配合 cosine learning rate schedule 和 warmup。
  • 性能 / 显存 / 吞吐 的权衡:

    • 显存: JEPA 比 MAE 等像素重建方法更节省显存,因为它没有庞大的像素解码器,且目标表征维度远低于像素数量。目标编码器不存梯度也节省了显存。
    • 吞吐: 编码器只处理部分 patch,提升了速度。但额外的预测器和目标编码器前向传播会增加计算开销。总体而言,其训练效率与 MAE 相当或更高。
    • 权衡: 增加目标块的数量和大小会增加计算量(主要是目标编码器和损失计算),但可能提升模型性能。需要根据硬件资源进行调整。
  • 常见坑和调试技巧:

    • 训练不稳定/损失爆炸: 检查梯度裁剪(gradient clipping)是否开启。确保 EMA 更新在优化器步骤之后执行。检查学习率是否过高。
    • 模型坍塌 (Loss->0): 尽管动量编码器能有效防止,但如果动量 m 设置不当(如太低),或学习率过大,仍可能发生。确保 m 足够高。
    • 性能不佳: 检查掩码策略。过于简单的掩码(如随机单 patch)可能导致任务太简单,模型学不到有用的东西。确保掩码策略能产生有挑战性的预测任务。

常见误区与边界情况

  • 初学者容易搞错的点:

    • 误区1: JEPA 是在预测被遮住的 patch 的特征纠正: 不完全是。预测器确实在预测目标位置的特征,但这个“真值”特征是由目标编码器处理未遮蔽的原始图像块生成的。它不是编码器自己对遮蔽块的“想象”,而是一个更稳定、高质量的外部目标。
    • 误区2: 目标编码器和编码器是一起训练的纠正: 目标编码器不通过反向传播训练。它的权重是主编码器权重的平滑平均版本。这是一个核心机制,用于防止模型坍塌。
    • 误区3: 预测器在下游任务中也有用纠正: 预训练结束后,预测器和目标编码器都会被丢弃。只有训练好的主编码器 fθf_\theta 被用作下游任务的骨干网络。
  • 数值稳定性、边界条件、失败模式:

    • 数值稳定性: 在混合精度(如 amp)训练下,EMA 更新需要特别注意类型转换,以防精度损失。
    • 边界条件: 如果上下文块和目标块有重叠,预测任务会变得过于简单。设计掩码策略时应确保它们在空间上是分离的。
    • 失败模式: 如果预测器过于强大(与编码器同等规模),它可能学会一种“捷径”,而不是迫使编码器学习语义。这会削弱预训练的效果。因此,预测器的“不对称性”(更小、更浅)是设计上的一个要点。
  • 常见面试追问以及回答要点:

    • 问: JEPA 和 MAE (Masked Autoencoders) 的根本区别是什么? : 根本区别在于预测目标。MAE 的目标是像素值,它在像素空间进行重建。这使得 MAE 必须关注高频细节和纹理,可能导致模型“浪费”容量在低级信号上。而 JEPA 的目标是表征,它在抽象的语义空间进行预测。这引导模型忽略像素级的噪声和冗余,专注于学习数据中可预测的、更高级的语义结构。
    • 问: 为什么需要一个动量更新的目标编码器,而不是直接用主编码器自己来生成目标? : 这是为了防止模型坍塌(collapsing)。如果让主编码器自己预测自己的输出(即 z^=fθ(xC)\hat{z} = f_\theta(x_C), z=fθ(xT)z = f_\theta(x_T)),模型会找到一个平凡解:为所有输入输出一个常数表征,这样预测误差为零,但模型什么也没学到。引入一个缓慢更新的、独立的动量编码器,打破了这种对称性。目标编码器提供了一个稳定的、非坍塌的回归目标,迫使主编码器去学习有意义的映射。
    • 问: JEPA 与对比学习方法(如 SimCLR, MoCo)有何不同? : 对比学习通过“拉近正样本,推远负样本”来学习表征。它依赖于构造正负样本对,并且对负样本的数量和质量很敏感。JEPA 是一种预测式/非对比式方法,它不依赖负样本。它的学习信号来自于模型对自身内部结构(一个部分预测另一个部分)的理解能力。这使得 JEPA 在概念上更接近于一个“世界模型”,它通过预测来验证自己对数据生成过程的理解,而不是通过区分不同实例。
相关题目