ViT 的 patch embedding + [CLS] + 位置编码流程?
- —手写 PatchEmbedding(Conv2d kernel=patch_size,stride=patch_size)
核心概念
Vision Transformer (ViT) 的输入处理流程是将一张 2D 图像转换为一个 1D 的向量序列(sequence of tokens),以便送入标准的 Transformer Encoder。这个过程包含三个关键步骤:
- Patch
Embedding:将图像分割成固定大小的、不重叠的图块(patches),然后将每个图块线性投影(linearly project)成一个向量。这等效于将图像看作一个“句子”,每个图块就是一个“单词”。 - [CLS] Token:在序列化后的图块向量(patch tokens)前,拼接一个可学习的特殊向量,即
[CLS](Classification)token。在经过TransformerEncoder 后,这个[CLS]token 对应的最终输出向量将被用作整个图像的聚合表示,用于下游的分类任务。 - Position
Embedding:由于Transformer的自注意力机制本身不包含位置信息(即置换不变性),必须显式地为每个输入向量(包括[CLS]token 和所有 patch tokens)添加一个可学习的位置编码(position embedding),以告知模型每个图块的原始空间位置。
原理与推导
假设我们有一个输入图像 ,其中 是高度, 是宽度, 是通道数。目标是将其转换为一个向量序列 ,其中 是图块数量, 是 Transformer 的隐藏维度。
1. Patch Embedding
-
概念拆解:
- Reshaping:将图像 分割成 个大小为 的图块(patches)。。
- Flattening:将每个图块 展平为一个向量,其维度为 。
- Linear Projection:使用一个可学习的权重矩阵 ,将每个展平的图块向量线性投影到维度 。
-
数学公式: 对于第 个图块 ,其嵌入向量 的计算如下: 其中 是展平后的向量。所有图块经过投影后,我们得到一个序列 。
-
等效的卷积实现: 上述“分割-展平-投影”操作在数学上等价于一个 2D 卷积。这是一个至关重要的工程洞察,因为它能利用高度优化的
Conv2dGPU/TPU kernel。- 卷积核大小 (kernel_size):
- 步长 (stride):
- 输入通道 (in_channels):
- 输出通道 (out_channels):
一个卷积核在输入张量上以步长 滑动,每次恰好覆盖一个 的图块。卷积操作本身就完成了对图块内所有像素值(跨所有通道 )的加权求和,这与线性投影完全等价。
- 输入张量 (PyTorch 格式):
- 经过
Conv2d(in_channels=C, out_channels=D, kernel_size=P, stride=P)后,输出张量形状为 。 - 最后,将 的后两个维度展平并调换维度,即可得到目标序列形状 ,其中 。
-
复杂度:Patch
Embedding层的计算复杂度主要由卷积操作决定,约为 。由于 ,复杂度也可近似为 。这在计算上是非常高效的。
2. [CLS] Token
- 动机:
TransformerEncoder 对输入序列中的所有 token 进行信息交互和融合。最终,我们需要一个单一的向量来代表整个图像。直接对所有 patch tokens 的输出做平均池化(Global Average Pooling)是一种方法,但ViT采用了BERT中的[CLS]token 策略。 - 原理:在 patch embedding 序列的最前面,插入一个可学习的向量 。这个向量的初始值是随机的,但会随着模型的训练而更新。
- 数学表示: 拼接后,序列长度从 变为 。。
3. Position Embedding
-
动机:自注意力机制是置换不变的(permutation-invariant)。如果打乱 patch tokens 的顺序,输出结果仅仅是对应位置被打乱而已,模型无法感知图块之间的相对空间关系。
-
原理:为序列中的每一个 token(包括
[CLS]token)添加一个独一无二的、可学习的位置编码向量。 -
数学表示: 定义一个可学习的位置编码矩阵 。 的第 行是第 个位置的编码向量。 这里的加法是逐元素相加。 就是最终送入
TransformerEncoder 的输入序列。 -
几何解释:可以想象,Patch
Embedding将每个图块映射到 维空间中的一个点。[CLS]token 是另一个点。PositionEmbedding则为每个点提供一个独特的“偏移量”,将它们移动到能够反映其原始位置的新位置。例如,图像左上角的图块和右上角的图块,它们的 patch embedding 可能很相似(如果内容都是天空),但加上了不同的 position embedding 后,它们在输入空间中的位置就被区分开来。
代码实现
以下代码完整实现了 ViT 的输入处理流程,包括使用 Conv2d 手写 PatchEmbedding 模块。
1import torch2import torch.nn as nn34class PatchEmbedding(nn.Module):5 """6 将图像分割成图块并进行线性投影的核心模块。7 使用 Conv2d 实现,效率高。8 """9 def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):10 super().__init__()11 self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size12 self.patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size13 self.grid_size = (self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1])14 self.num_patches = self.grid_size[0] * self.grid_size[1]1516 # 关键步骤:使用 Conv2d 实现 Patch Embedding17 # 为什么这样做:这是一种高效的实现方式。卷积核在图像上以 patch_size 为步长滑动,18 # 每次操作的感受野恰好是一个 patch。卷积的输出通道数 embed_dim 即为投影后的向量维度。19 # 输入: [B, C, H, W]20 # 输出: [B, D, H/P, W/P]21 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size)2223 def forward(self, x):24 # x: [B, C, H, W]25 B, C, H, W = x.shape26 assert H == self.img_size[0] and W == self.img_size[1], \27 f"输入图像尺寸 ({H}*{W}) 与模型预设尺寸 ({self.img_size[0]}*{self.img_size[1]}) 不匹配."2829 # 1. 线性投影30 x = self.proj(x) # [B, embed_dim, grid_size[0], grid_size[1]]3132 # 2. 展平并调整维度顺序33 # 为什么这样做:Transformer Encoder 需要的输入格式是 [B, N, D],即 (批次, 序列长度, 维度)。34 # Conv2d 输出是 [B, D, H/P, W/P],需要将 H/P 和 W/P 两个维度合并为序列长度 N,35 # 并将维度 D 移到最后。36 # flatten(2) 将索引为2之后的所有维度展平: [B, D, N]37 # transpose(1, 2) 交换维度1和2: [B, N, D]38 x = x.flatten(2).transpose(1, 2)39 return x4041# --- 模拟完整流程 ---4243# 1. 定义超参数44batch_size = 445img_size = 22446patch_size = 1647in_chans = 348embed_dim = 768 # ViT-Base 的隐藏维度4950# 2. 创建一个假的输入图像张量51dummy_images = torch.randn(batch_size, in_chans, img_size, img_size)52print(f"输入图像形状: {dummy_images.shape}\n")5354# --- 步骤一: Patch Embedding ---55print("--- 步骤一: Patch Embedding ---")56patch_embed_layer = PatchEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)57patch_embeddings = patch_embed_layer(dummy_images)58num_patches = patch_embed_layer.num_patches59print(f"Patch 数量 (N): {num_patches}")60print(f"经过 Patch Embedding 后的形状: {patch_embeddings.shape} (格式: [B, N, D])\n")6162# --- 步骤二: 添加 [CLS] Token ---63print("--- 步骤二: 添加 [CLS] Token ---")64# 为什么这样做:创建一个可学习的参数作为 [CLS] token。它的形状是 [1, 1, D],65# 以便可以利用广播机制轻松地与批次中的每个样本拼接。66cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))67print(f"CLS Token 形状: {cls_token.shape}")6869# 为什么这样做:使用 expand 将 CLS token 复制 B 次,以匹配批次大小。70# 然后在序列维度(dim=1)上与 patch_embeddings 拼接。71cls_tokens = cls_token.expand(batch_size, -1, -1)72tokens_with_cls = torch.cat((cls_tokens, patch_embeddings), dim=1)73print(f"拼接 CLS Token 后的形状: {tokens_with_cls.shape} (格式: [B, N+1, D])\n")7475# --- 步骤三: 添加 Position Embedding ---76print("--- 步骤三: 添加 Position Embedding ---")77# 为什么这样做:创建一个可学习的位置编码矩阵。其序列长度为 N+1,以覆盖 CLS token 和所有 patch tokens。78# 形状为 [1, N+1, D],同样是为了利用广播机制与整个批次相加。79pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))80print(f"位置编码形状: {pos_embed.shape}")8182# 为什么这样做:将位置编码直接加到 token 序列上。PyTorch 的广播机制会自动处理83# `pos_embed` 的第一个维度(从1扩展到B),实现对批次中每个样本应用相同的位置编码。84final_embeddings = tokens_with_cls + pos_embed85print(f"添加位置编码后的最终输入形状: {final_embeddings.shape} (格式: [B, N+1, D])")
工程实践
- 使用场景: 这个输入流程是所有基于
ViT的模型的标准起点,包括图像分类(ViT)、目标检测(DETR, ViT-FRCNN)、语义分割(SETR)等。 - 超参数选择:
patch_size: 这是最关键的超参数之一。常见的有 14, 16, 32。- 小 patch_size (e.g., 14, 16): 产生更长的序列
N,计算量(尤其是自注意力的二次方复杂度)和显存消耗更大。但能捕捉更精细的图像特征,通常性能更好。 - 大 patch_size (e.g., 32): 序列更短,模型更轻、更快。但在小物体或细粒度任务上可能表现不佳。
- 小 patch_size (e.g., 14, 16): 产生更长的序列
embed_dim:Transformer的工作维度。ViT-Base 为 768,ViT-Large 为 1024,ViT-Huge 为 1280。更大的embed_dim意味着更强的模型容量,但也会增加计算和显存。
- 性能/显存权衡:
ViT的主要计算瓶颈在于自注意力,其复杂度为 。由于 ,序列长度 对图像尺寸 是二次方关系,对patch_size是平方反比关系。- 在实际项目中,如果需要处理高分辨率图像,直接降低
patch_size会导致序列过长,显存爆炸。常见的策略是:- 保持
patch_size不变,将高分辨率图像降采样到模型预设的img_size(e.g., 224x224, 384x384)。 - 采用 Swin
Transformer等模型,它们使用分层结构和窗口注意力来处理长序列,降低计算复杂度。
- 保持
- 常见坑和调试技巧:
- 图像尺寸不匹配:
ViT的PatchEmbedding通常要求输入图像尺寸能被patch_size整除。在部署或推理时,必须确保输入图像被正确地预处理(裁剪或缩放)到模型训练时使用的尺寸。 - 位置编码插值: 当推理时使用的图像分辨率与训练时不同时,patch 的数量 会改变,导致预训练好的
pos_embed尺寸不匹配。常见的解决办法是对pos_embed进行2D双线性插值,以适应新的序列长度。 - 维度混淆:
[B, N, D]vs[B, D, N]是初学者常见的错误来源。务必时刻清楚当前张量的维度含义,尤其是在transpose,permute,flatten操作之后。
- 图像尺寸不匹配:
常见误区与边界情况
-
误区1: "
ViT完全没有使用卷积"- 辨析: 这是不准确的。
ViT的 PatchEmbedding步骤最有效的实现方式就是卷积。此外,很多ViT的变体(如 HybridViT)会先用一个小的 CNN(如 ResNet 的前几个 stage)来提取低级特征图,再对这个特征图进行 patch 和 embedding。这证明了卷积在视觉任务中的归纳偏置(inductive bias)仍然很有价值。
- 辨析: 这是不准确的。
-
误区2: "[CLS] token 是唯一的全局表示方法"
- 辨析: 不是。另一种常见的方法是,不使用
[CLS]token,在TransformerEncoder 的输出端对所有 patch tokens 的向量表示进行全局平均池化(Global Average Pooling),然后将得到的池化向量用于分类。实验表明两种方法性能相似,但使用[CLS]token 是原始ViT论文的做法,并已成为一种事实标准。
- 辨析: 不是。另一种常见的方法是,不使用
-
误区3: "位置编码必须是可学习的"
- 辨析: 不一定。原始
Transformer论文使用的是固定的正弦/余弦位置编码。ViT论文实验发现,可学习的1D位置编码、2D位置编码和固定的正弦位置编码性能差异不大。但可学习的编码实现简单,且能让模型自行学习最适合的位置表示,因此在ViT及其变体中被广泛采用。
- 辨析: 不一定。原始
-
边界情况:图像尺寸不能被 patch_size 整除
- 标准
nn.Conv2d在步长大于1时,如果输入尺寸不能被整除,会自动丢弃右侧和/或底部的部分像素。在ViT的上下文中,这意味着图像的边缘信息会丢失。 - 标准做法: 在数据预处理阶段,就将所有图像强制缩放和/或填充到一个固定的、可以被
patch_size整除的尺寸(如 224x224)。这保证了所有输入都有相同的序列长度 ,简化了模型设计和批处理。
- 标准
-
常见面试追问:
- 问: "为什么不直接用
torch.nn.Unfold+torch.nn.Linear来实现 PatchEmbedding?" - 答:
Unfold+Linear在功能上是等价的,它显式地提取图块然后进行线性变换。但Conv2d是一个单一的、高度优化的底层操作,在现代深度学习框架和硬件(GPU/TPU)上通常比Unfold+Linear的组合拳执行得更快。因此,Conv2d是更优雅且性能更优的工程选择。 - 问: "这个
[CLS]token 的初始值是什么?它如何学习?" - 答: 它是一个
nn.Parameter,通常用零或小的随机值初始化。它和其他模型参数(如卷积核权重、注意力矩阵)一样,通过反向传播和梯度下降进行端到端的学习。在训练过程中,模型会学会将[CLS]token 的最终输出调整为对整个图像分类任务最有用的表示。
- 问: "为什么不直接用