I-JEPA / V-JEPA 2 的预测式自监督?
好的,我们来深入解析 I-JEPA 与 V-JEPA 2 的预测式自监督学习范式。
核心概念
I-JEPA(Image Joint-Embedding Predictive Architecture)是一种非生成式的自监督学习方法。其核心思想是,在一个高维的表征空间中,根据一个上下文(Context)块的表征,来预测同一个样本中其他目标(Target)块的表征。它不预测像素本身,而是预测由一个动量更新的目标编码器(Target Encoder)生成的抽象特征。这种“在表征空间中填空”的方式,迫使模型学习数据中更具语义的、可预测的底层规律,而不是去拟合高频、低层次的像素细节。V-JEPA 2 将此思想从图像域扩展到了视频和音频的多模态领域。
原理与推导
JEPA 的架构主要由三个部分组成:编码器(Encoder)、目标编码器(Target Encoder)和预测器(Predictor)。
- 编码器 (Encoder) : 这是我们主要训练的模型,通常是一个 Vision
Transformer(ViT)。它接收一部分可见的输入块(上下文块),并输出它们的表征。 - 目标编码器 (Target Encoder) : 它的网络结构与编码器完全相同,但不通过反向传播进行训练。其权重 是编码器权重 的指数移动平均(Exponential Moving Average, EMA)。它接收目标块(未被遮蔽的原始图像块)作为输入,生成稳定、高质量的目标表征。
- 预测器 (Predictor) : 这是一个相对较小的模型(如一个较浅的
ViT),它接收编码器输出的上下文表征,并结合目标块的位置信息,来预测目标块的表征。
数学推导:
假设输入图像为 ,我们将其分割成一系列不重叠的图像块(patches)。我们从中采样一个上下文块集合 和多个目标块集合 。
-
上下文表征: 编码器 处理上下文块 ,输出上下文表征 。在
ViT中,这对应于上下文块的输出 token 序列。 -
目标表征: 目标编码器 处理未被遮蔽的原始目标块 ,输出作为“真值”的目标表征 。这一步不计算梯度。
-
预测: 预测器 接收上下文表征 和目标块 对应的位置编码 ,输出预测的目标表征 。 预测器需要知道它要预测哪个位置的表征,因此需要位置编码作为输入。
-
损失函数: 损失函数是所有目标块上预测表征与真实表征之间的平均L2距离(MSE)。 这个损失会同时更新编码器 和预测器 的参数。
-
目标编码器更新: 在每次梯度更新后,目标编码器的权重 通过 EMA 进行更新。 其中 是动量系数,通常是一个接近 1 的值(如 0.996)。这使得目标编码器缓慢地、稳定地跟踪主编码器的变化,为预测任务提供了一致且高质量的目标。
算法流程与复杂度:
-
算法流程:
- 从数据集中采样一个样本 。
- 生成一个上下文掩码和多个目标掩码。I-JEPA 使用一种多块掩码策略,即目标块本身可能由多个小图像块组成,以鼓励模型学习更大范围的语义。
- 将上下文块输入编码器 。
- 将(未遮蔽的)目标块输入目标编码器 。
- 将 的输出和目标块位置送入预测器 。
- 计算预测表征和目标表征之间的 MSE 损失。
- 执行反向传播,更新 和 。
- 使用 EMA 更新 。
-
时间/空间复杂度:
- 时间复杂度: 对于
ViT实现,主要开销是Transformer块的自注意力计算,为 ,其中 是序列长度, 是嵌入维度。由于编码器只处理 个上下文块(),其计算量小于处理完整图像。总时间复杂度大致为 分别对应编码器、目标编码器和预测器的前向传播。 - 空间复杂度: 主要由模型参数和激活值决定。由于目标编码器不需要存储梯度,相比于需要完整反向传播的架构,JEPA 在内存上更高效。它也无需像 MAE 那样有一个庞大的像素级解码器。
- 时间复杂度: 对于
直观解释 (信息论角度):
JEPA 试图学习一个“世界模型”。它假设世界是可预测的,一个场景的局部信息(上下文)包含了推断其他部分(目标)的线索。但它不关心像素级的精确复现(高互信息,但可能语义价值低),而是关心概念级的抽象复现。通过在表征空间进行预测,模型被激励去丢弃不相关的、随机的、高频的细节(如草叶的精确纹理),而专注于捕捉可预测的、具有语义的结构(如“这里应该是一片草地”)。动量目标编码器提供了一个稳定的语义“词典”,防止模型在学习过程中“自说自话”导致坍塌。
V-JEPA 2 的扩展:
V-JEPA 2 将此范式应用于视频和音频。其核心思想不变,但预测任务变为时空预测:
- 输入: 一段视频(或音频)的上下文片段。
- 目标: 未来某个时间点的视频(或音频)片段。
- 掩码: 掩码被应用在未来的目标片段上,使得预测任务更具挑战性(例如,根据视频前2秒,预测第4秒被遮挡区域的内容)。
- 多模态: V-JEPA 2 可以实现跨模态预测,例如用音频上下文预测视频目标表征,或反之,从而学习到视听之间的关联。
代码实现
下面是一个简化的 I-JEPA 核心逻辑的 PyTorch 实现。为了教学目的,我们使用简单的随机掩码,并用 nn.TransformerEncoder 模拟 ViT 的核心部分。
1import torch2import torch.nn as nn3import copy45class SimpleViT(nn.Module):6 """一个简化的 Vision Transformer 骨干网络"""7 def __init__(self, dim=256, depth=4, heads=8, patch_size=16, image_size=224):8 super().__init__()9 self.patch_size = patch_size10 num_patches = (image_size // patch_size) ** 21112 # 1. Patch 和位置嵌入13 self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)14 self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))1516 # 2. Transformer 编码器17 encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim * 4, batch_first=True)18 self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)1920 def forward(self, x, patch_indices):21 # 将图像转换为 patch 嵌入22 patches = self.patch_embedding(x).flatten(2).transpose(1, 2) # (B, N, D)2324 # 根据索引选择需要的 patch25 selected_patches = torch.gather(patches, 1, patch_indices.unsqueeze(-1).expand(-1, -1, patches.shape[-1]))26 selected_pos = torch.gather(self.pos_embedding.expand(patches.shape[0], -1, -1), 1, patch_indices.unsqueeze(-1).expand(-1, -1, patches.shape[-1]))2728 # 添加位置编码并输入 Transformer29 return self.transformer_encoder(selected_patches + selected_pos)3031class IJEPA(nn.Module):32 def __init__(self, dim=256, pred_depth=2, enc_depth=4, heads=8, patch_size=16, image_size=224, momentum=0.996):33 super().__init__()34 self.momentum = momentum35 self.patch_size = patch_size36 self.image_size = image_size37 self.num_patches = (image_size // patch_size) ** 23839 # 1. 初始化编码器和预测器40 # 为什么这样做: 这是我们主要训练的模型41 self.encoder = SimpleViT(dim=dim, depth=enc_depth, heads=heads, patch_size=patch_size, image_size=image_size)42 self.predictor = SimpleViT(dim=dim, depth=pred_depth, heads=heads, patch_size=patch_size, image_size=image_size)4344 # 2. 初始化目标编码器45 # 为什么这样做: 目标编码器是编码器的副本,用于生成稳定的目标表征,不参与梯度计算46 self.target_encoder = copy.deepcopy(self.encoder)47 for param in self.target_encoder.parameters():48 param.requires_grad = False4950 self.loss_fn = nn.MSELoss()5152 @torch.no_grad()53 def _update_target_encoder(self):54 """55 使用指数移动平均(EMA)更新目标编码器56 为什么这样做: 这使得目标网络缓慢地跟踪在线网络的权重,提供稳定的学习目标,防止模型坍塌。57 """58 for param_enc, param_target in zip(self.encoder.parameters(), self.target_encoder.parameters()):59 param_target.data = param_target.data * self.momentum + param_enc.data * (1. - self.momentum)6061 def forward(self, x):62 batch_size = x.shape[0]6364 # 3. 生成上下文和目标掩码65 # 为什么这样做: 这是JEPA的核心,将输入划分为可见的上下文和需要预测的目标66 # 为了简化,我们随机选择75%作为上下文,25%作为目标67 all_indices = torch.randperm(self.num_patches, device=x.device).unsqueeze(0).expand(batch_size, -1)68 context_size = int(0.75 * self.num_patches)69 context_indices = all_indices[:, :context_size]70 target_indices = all_indices[:, context_size:]7172 # 4. 计算上下文表征73 # 为什么这样做: 编码器只看到部分图像(上下文),并为其生成表征74 context_representation = self.encoder(x, context_indices)7576 # 5. 计算目标表征 (使用目标编码器)77 # 为什么这样做: 目标编码器看到目标的原始、未遮蔽的图像块,并生成高质量的“真值”表征78 with torch.no_grad():79 self.target_encoder.eval()80 target_representation = self.target_encoder(x, target_indices)8182 # 6. 预测目标表征83 # 为什么这样做: 预测器基于上下文信息,尝试在目标位置上重建出目标表征84 # 预测器的输入是上下文表征序列,加上目标块的位置编码85 # 注意:一个简单的实现是让预测器 attend 到上下文表征上86 # 这里我们简化,将目标位置编码作为 query,上下文表征作为 key/value87 # 为了更简单,我们直接将上下文表征和目标位置编码一起送入预测器88 pred_input_indices = torch.cat([context_indices, target_indices], dim=1)89 # 模拟预测器只使用上下文信息来预测目标90 # 实际实现中,预测器会使用特殊的 maskable attention91 # 这里我们用一个简化逻辑:将上下文表征和目标位置编码送入预测器92 # 提取目标位置的pos embedding93 target_pos_emb = torch.gather(self.predictor.pos_embedding.expand(batch_size, -1, -1), 1, target_indices.unsqueeze(-1).expand(-1, -1, self.predictor.pos_embedding.shape[-1]))9495 # 简单地将上下文表征的平均值与目标位置编码相加作为预测器输入96 # 这是一个非常粗糙的模拟,真实实现会更复杂97 avg_context_rep = context_representation.mean(dim=1, keepdim=True)98 predictor_input = avg_context_rep + target_pos_emb99 predicted_representation = self.predictor.transformer_encoder(predictor_input)100101 # 7. 计算损失102 loss = self.loss_fn(predicted_representation, target_representation)103104 return loss105106# --- 运行示例 ---107if __name__ == '__main__':108 model = IJEPA(109 dim=256,110 enc_depth=4,111 pred_depth=2,112 heads=8,113 patch_size=16,114 image_size=224115 ).cuda()116117 # 创建一个虚拟输入图像118 dummy_image = torch.randn(2, 3, 224, 224).cuda()119120 # 前向传播计算损失121 loss = model(dummy_image)122 print(f"Initial Loss: {loss.item()}")123124 # 模拟一次优化步骤125 # optimizer.zero_grad()126 loss.backward()127 # optimizer.step()128129 # 更新目标编码器130 model._update_target_encoder()131 print("Model forward pass and target encoder update successful.")
工程实践
-
使用场景:
- I-JEPA: 作为强大的视觉预训练模型,其学到的表征在各种下游任务(如图像分类、目标检测、语义分割)上表现出色,通常只需要进行线性探测(Linear Probing)或少量微调。由于其学习的是语义表征,特别适合需要高层理解的任务。
- V-JEPA 2: 在视频理解领域有巨大潜力,如动作识别、视频问答、视频内容生成等。其多模态能力也使其可用于视听联合任务,如根据声音定位视频中的事件。
-
超参数选择的经验法则:
- 掩码策略: 这是效果的关键。I-JEPA 论文中提出了一种多块(multi-block)掩码策略,即生成多个不同尺度和长宽比的目标块。这比简单的随机掩码更有效,因为它创造了从局部纹理到全局结构的多种预测任务。通常会有一个较大的上下文块(保留约75%的图像)和4个尺度不一的目标块(覆盖约15%的图像)。
- 动量
m: 值非常接近1,如0.996到0.999。在训练初期可以使用较低的值,然后随训练进程逐渐退火到接近1。 - 预测器深度/宽度: 预测器应显著小于编码器,例如,编码器12层,预测器可以是2-4层。这构成了一个信息瓶颈,迫使编码器学习更高效、更抽象的表征。
- 学习率与优化器: 通常使用 AdamW 优化器,配合 cosine learning rate schedule 和 warmup。
-
性能 / 显存 / 吞吐 的权衡:
- 显存: JEPA 比 MAE 等像素重建方法更节省显存,因为它没有庞大的像素解码器,且目标表征维度远低于像素数量。目标编码器不存梯度也节省了显存。
- 吞吐: 编码器只处理部分 patch,提升了速度。但额外的预测器和目标编码器前向传播会增加计算开销。总体而言,其训练效率与 MAE 相当或更高。
- 权衡: 增加目标块的数量和大小会增加计算量(主要是目标编码器和损失计算),但可能提升模型性能。需要根据硬件资源进行调整。
-
常见坑和调试技巧:
- 训练不稳定/损失爆炸: 检查梯度裁剪(gradient clipping)是否开启。确保 EMA 更新在优化器步骤之后执行。检查学习率是否过高。
- 模型坍塌 (Loss->0): 尽管动量编码器能有效防止,但如果动量
m设置不当(如太低),或学习率过大,仍可能发生。确保m足够高。 - 性能不佳: 检查掩码策略。过于简单的掩码(如随机单 patch)可能导致任务太简单,模型学不到有用的东西。确保掩码策略能产生有挑战性的预测任务。
常见误区与边界情况
-
初学者容易搞错的点:
- 误区1: JEPA 是在预测被遮住的 patch 的特征。 纠正: 不完全是。预测器确实在预测目标位置的特征,但这个“真值”特征是由目标编码器处理未遮蔽的原始图像块生成的。它不是编码器自己对遮蔽块的“想象”,而是一个更稳定、高质量的外部目标。
- 误区2: 目标编码器和编码器是一起训练的。 纠正: 目标编码器不通过反向传播训练。它的权重是主编码器权重的平滑平均版本。这是一个核心机制,用于防止模型坍塌。
- 误区3: 预测器在下游任务中也有用。 纠正: 预训练结束后,预测器和目标编码器都会被丢弃。只有训练好的主编码器 被用作下游任务的骨干网络。
-
数值稳定性、边界条件、失败模式:
- 数值稳定性: 在混合精度(如
amp)训练下,EMA 更新需要特别注意类型转换,以防精度损失。 - 边界条件: 如果上下文块和目标块有重叠,预测任务会变得过于简单。设计掩码策略时应确保它们在空间上是分离的。
- 失败模式: 如果预测器过于强大(与编码器同等规模),它可能学会一种“捷径”,而不是迫使编码器学习语义。这会削弱预训练的效果。因此,预测器的“不对称性”(更小、更浅)是设计上的一个要点。
- 数值稳定性: 在混合精度(如
-
常见面试追问以及回答要点:
- 问: JEPA 和 MAE (Masked Autoencoders) 的根本区别是什么? 答: 根本区别在于预测目标。MAE 的目标是像素值,它在像素空间进行重建。这使得 MAE 必须关注高频细节和纹理,可能导致模型“浪费”容量在低级信号上。而 JEPA 的目标是表征,它在抽象的语义空间进行预测。这引导模型忽略像素级的噪声和冗余,专注于学习数据中可预测的、更高级的语义结构。
- 问: 为什么需要一个动量更新的目标编码器,而不是直接用主编码器自己来生成目标? 答: 这是为了防止模型坍塌(collapsing)。如果让主编码器自己预测自己的输出(即 , ),模型会找到一个平凡解:为所有输入输出一个常数表征,这样预测误差为零,但模型什么也没学到。引入一个缓慢更新的、独立的动量编码器,打破了这种对称性。目标编码器提供了一个稳定的、非坍塌的回归目标,迫使主编码器去学习有意义的映射。
- 问: JEPA 与对比学习方法(如 SimCLR, MoCo)有何不同? 答: 对比学习通过“拉近正样本,推远负样本”来学习表征。它依赖于构造正负样本对,并且对负样本的数量和质量很敏感。JEPA 是一种预测式/非对比式方法,它不依赖负样本。它的学习信号来自于模型对自身内部结构(一个部分预测另一个部分)的理解能力。这使得 JEPA 在概念上更接近于一个“世界模型”,它通过预测来验证自己对数据生成过程的理解,而不是通过区分不同实例。