§1.2.6

ViT 的 patch embedding + [CLS] + 位置编码流程?

手写练习
  • 手写 PatchEmbedding(Conv2d kernel=patch_size,stride=patch_size)

核心概念

Vision Transformer (ViT) 的输入处理流程是将一张 2D 图像转换为一个 1D 的向量序列(sequence of tokens),以便送入标准的 Transformer Encoder。这个过程包含三个关键步骤:

  1. Patch Embedding:将图像分割成固定大小的、不重叠的图块(patches),然后将每个图块线性投影(linearly project)成一个向量。这等效于将图像看作一个“句子”,每个图块就是一个“单词”。
  2. [CLS] Token:在序列化后的图块向量(patch tokens)前,拼接一个可学习的特殊向量,即 [CLS](Classification)token。在经过 Transformer Encoder 后,这个 [CLS] token 对应的最终输出向量将被用作整个图像的聚合表示,用于下游的分类任务。
  3. Position Embedding:由于 Transformer 的自注意力机制本身不包含位置信息(即置换不变性),必须显式地为每个输入向量(包括 [CLS] token 和所有 patch tokens)添加一个可学习的位置编码(position embedding),以告知模型每个图块的原始空间位置。

原理与推导

假设我们有一个输入图像 xRH×W×Cx \in \mathbb{R}^{H \times W \times C},其中 HH 是高度,WW 是宽度,CC 是通道数。目标是将其转换为一个向量序列 z0R(N+1)×Dz_0 \in \mathbb{R}^{(N+1) \times D},其中 NN 是图块数量,DDTransformer 的隐藏维度。

1. Patch Embedding

  • 概念拆解

    1. Reshaping:将图像 xx 分割成 NN 个大小为 P×PP \times P 的图块(patches)。N=HP×WPN = \frac{H}{P} \times \frac{W}{P}
    2. Flattening:将每个图块 xpRP×P×Cx_p \in \mathbb{R}^{P \times P \times C} 展平为一个向量,其维度为 P2CP^2C
    3. Linear Projection:使用一个可学习的权重矩阵 ER(P2C)×DE \in \mathbb{R}^{(P^2C) \times D},将每个展平的图块向量线性投影到维度 DD
  • 数学公式: 对于第 ii 个图块 xpix_p^i,其嵌入向量 eie_i 的计算如下: ei=xpiEe_i = x_p^i \cdot E 其中 xpix_p^i 是展平后的向量。所有图块经过投影后,我们得到一个序列 [e1,e2,,eN][e_1, e_2, \dots, e_N]

  • 等效的卷积实现: 上述“分割-展平-投影”操作在数学上等价于一个 2D 卷积。这是一个至关重要的工程洞察,因为它能利用高度优化的 Conv2d GPU/TPU kernel。

    • 卷积核大小 (kernel_size): P×PP \times P
    • 步长 (stride): PP
    • 输入通道 (in_channels): CC
    • 输出通道 (out_channels): DD

    一个卷积核在输入张量上以步长 PP 滑动,每次恰好覆盖一个 P×PP \times P 的图块。卷积操作本身就完成了对图块内所有像素值(跨所有通道 CC)的加权求和,这与线性投影完全等价。

    • 输入张量 (PyTorch 格式): xRB×C×H×Wx \in \mathbb{R}^{B \times C \times H \times W}
    • 经过 Conv2d(in_channels=C, out_channels=D, kernel_size=P, stride=P) 后,输出张量形状为 xRB×D×(H/P)×(W/P)x' \in \mathbb{R}^{B \times D \times (H/P) \times (W/P)}
    • 最后,将 xx' 的后两个维度展平并调换维度,即可得到目标序列形状 zpatchesRB×N×Dz_{patches} \in \mathbb{R}^{B \times N \times D},其中 N=(H/P)×(W/P)N = (H/P) \times (W/P)
  • 复杂度:Patch Embedding 层的计算复杂度主要由卷积操作决定,约为 O(BDNP2C)O(B \cdot D \cdot N \cdot P^2 \cdot C)。由于 NP2HWN \cdot P^2 \approx H \cdot W,复杂度也可近似为 O(BDHWC)O(B \cdot D \cdot H \cdot W \cdot C)。这在计算上是非常高效的。

2. [CLS] Token

  • 动机Transformer Encoder 对输入序列中的所有 token 进行信息交互和融合。最终,我们需要一个单一的向量来代表整个图像。直接对所有 patch tokens 的输出做平均池化(Global Average Pooling)是一种方法,但 ViT 采用了 BERT 中的 [CLS] token 策略。
  • 原理:在 patch embedding 序列的最前面,插入一个可学习的向量 xclsRDx_{cls} \in \mathbb{R}^{D}。这个向量的初始值是随机的,但会随着模型的训练而更新。
  • 数学表示z0=[xcls;e1;e2;;eN]z'_{0} = [x_{cls}; e_1; e_2; \dots; e_N] 拼接后,序列长度从 NN 变为 N+1N+1z0R(N+1)×Dz'_{0} \in \mathbb{R}^{(N+1) \times D}

3. Position Embedding

  • 动机:自注意力机制是置换不变的(permutation-invariant)。如果打乱 patch tokens 的顺序,输出结果仅仅是对应位置被打乱而已,模型无法感知图块之间的相对空间关系。

  • 原理:为序列中的每一个 token(包括 [CLS] token)添加一个独一无二的、可学习的位置编码向量。

  • 数学表示: 定义一个可学习的位置编码矩阵 EposR(N+1)×DE_{pos} \in \mathbb{R}^{(N+1) \times D}EposE_{pos} 的第 ii 行是第 ii 个位置的编码向量。 z0=z0+Epos=[xcls;e1;;eN]+Eposz_0 = z'_{0} + E_{pos} = [x_{cls}; e_1; \dots; e_N] + E_{pos} 这里的加法是逐元素相加。z0z_0 就是最终送入 Transformer Encoder 的输入序列。

  • 几何解释:可以想象,Patch Embedding 将每个图块映射到 DD 维空间中的一个点。[CLS] token 是另一个点。Position Embedding 则为每个点提供一个独特的“偏移量”,将它们移动到能够反映其原始位置的新位置。例如,图像左上角的图块和右上角的图块,它们的 patch embedding 可能很相似(如果内容都是天空),但加上了不同的 position embedding 后,它们在输入空间中的位置就被区分开来。

代码实现

以下代码完整实现了 ViT 的输入处理流程,包括使用 Conv2d 手写 PatchEmbedding 模块。

python
1import torch
2import torch.nn as nn
3
4class PatchEmbedding(nn.Module):
5 """
6 将图像分割成图块并进行线性投影的核心模块。
7 使用 Conv2d 实现,效率高。
8 """
9 def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
10 super().__init__()
11 self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size
12 self.patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
13 self.grid_size = (self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1])
14 self.num_patches = self.grid_size[0] * self.grid_size[1]
15
16 # 关键步骤:使用 Conv2d 实现 Patch Embedding
17 # 为什么这样做:这是一种高效的实现方式。卷积核在图像上以 patch_size 为步长滑动,
18 # 每次操作的感受野恰好是一个 patch。卷积的输出通道数 embed_dim 即为投影后的向量维度。
19 # 输入: [B, C, H, W]
20 # 输出: [B, D, H/P, W/P]
21 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
22
23 def forward(self, x):
24 # x: [B, C, H, W]
25 B, C, H, W = x.shape
26 assert H == self.img_size[0] and W == self.img_size[1], \
27 f"输入图像尺寸 ({H}*{W}) 与模型预设尺寸 ({self.img_size[0]}*{self.img_size[1]}) 不匹配."
28
29 # 1. 线性投影
30 x = self.proj(x) # [B, embed_dim, grid_size[0], grid_size[1]]
31
32 # 2. 展平并调整维度顺序
33 # 为什么这样做:Transformer Encoder 需要的输入格式是 [B, N, D],即 (批次, 序列长度, 维度)。
34 # Conv2d 输出是 [B, D, H/P, W/P],需要将 H/P 和 W/P 两个维度合并为序列长度 N,
35 # 并将维度 D 移到最后。
36 # flatten(2) 将索引为2之后的所有维度展平: [B, D, N]
37 # transpose(1, 2) 交换维度1和2: [B, N, D]
38 x = x.flatten(2).transpose(1, 2)
39 return x
40
41# --- 模拟完整流程 ---
42
43# 1. 定义超参数
44batch_size = 4
45img_size = 224
46patch_size = 16
47in_chans = 3
48embed_dim = 768 # ViT-Base 的隐藏维度
49
50# 2. 创建一个假的输入图像张量
51dummy_images = torch.randn(batch_size, in_chans, img_size, img_size)
52print(f"输入图像形状: {dummy_images.shape}\n")
53
54# --- 步骤一: Patch Embedding ---
55print("--- 步骤一: Patch Embedding ---")
56patch_embed_layer = PatchEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
57patch_embeddings = patch_embed_layer(dummy_images)
58num_patches = patch_embed_layer.num_patches
59print(f"Patch 数量 (N): {num_patches}")
60print(f"经过 Patch Embedding 后的形状: {patch_embeddings.shape} (格式: [B, N, D])\n")
61
62# --- 步骤二: 添加 [CLS] Token ---
63print("--- 步骤二: 添加 [CLS] Token ---")
64# 为什么这样做:创建一个可学习的参数作为 [CLS] token。它的形状是 [1, 1, D],
65# 以便可以利用广播机制轻松地与批次中的每个样本拼接。
66cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
67print(f"CLS Token 形状: {cls_token.shape}")
68
69# 为什么这样做:使用 expand 将 CLS token 复制 B 次,以匹配批次大小。
70# 然后在序列维度(dim=1)上与 patch_embeddings 拼接。
71cls_tokens = cls_token.expand(batch_size, -1, -1)
72tokens_with_cls = torch.cat((cls_tokens, patch_embeddings), dim=1)
73print(f"拼接 CLS Token 后的形状: {tokens_with_cls.shape} (格式: [B, N+1, D])\n")
74
75# --- 步骤三: 添加 Position Embedding ---
76print("--- 步骤三: 添加 Position Embedding ---")
77# 为什么这样做:创建一个可学习的位置编码矩阵。其序列长度为 N+1,以覆盖 CLS token 和所有 patch tokens。
78# 形状为 [1, N+1, D],同样是为了利用广播机制与整个批次相加。
79pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
80print(f"位置编码形状: {pos_embed.shape}")
81
82# 为什么这样做:将位置编码直接加到 token 序列上。PyTorch 的广播机制会自动处理
83# `pos_embed` 的第一个维度(从1扩展到B),实现对批次中每个样本应用相同的位置编码。
84final_embeddings = tokens_with_cls + pos_embed
85print(f"添加位置编码后的最终输入形状: {final_embeddings.shape} (格式: [B, N+1, D])")

工程实践

  • 使用场景: 这个输入流程是所有基于 ViT 的模型的标准起点,包括图像分类(ViT)、目标检测(DETR, ViT-FRCNN)、语义分割(SETR)等。
  • 超参数选择:
    • patch_size: 这是最关键的超参数之一。常见的有 14, 16, 32。
      • 小 patch_size (e.g., 14, 16): 产生更长的序列 N,计算量(尤其是自注意力的二次方复杂度)和显存消耗更大。但能捕捉更精细的图像特征,通常性能更好。
      • 大 patch_size (e.g., 32): 序列更短,模型更轻、更快。但在小物体或细粒度任务上可能表现不佳。
    • embed_dim: Transformer 的工作维度。ViT-Base 为 768,ViT-Large 为 1024,ViT-Huge 为 1280。更大的 embed_dim 意味着更强的模型容量,但也会增加计算和显存。
  • 性能/显存权衡:
    • ViT 的主要计算瓶颈在于自注意力,其复杂度为 O(N2D)O(N^2 D)。由于 N=(HW)/P2N = (H \cdot W) / P^2,序列长度 NN 对图像尺寸 H,WH, W 是二次方关系,对 patch_size 是平方反比关系。
    • 在实际项目中,如果需要处理高分辨率图像,直接降低 patch_size 会导致序列过长,显存爆炸。常见的策略是:
      1. 保持 patch_size 不变,将高分辨率图像降采样到模型预设的 img_size (e.g., 224x224, 384x384)。
      2. 采用 Swin Transformer 等模型,它们使用分层结构和窗口注意力来处理长序列,降低计算复杂度。
  • 常见坑和调试技巧:
    1. 图像尺寸不匹配: ViTPatchEmbedding 通常要求输入图像尺寸能被 patch_size 整除。在部署或推理时,必须确保输入图像被正确地预处理(裁剪或缩放)到模型训练时使用的尺寸。
    2. 位置编码插值: 当推理时使用的图像分辨率与训练时不同时,patch 的数量 NN 会改变,导致预训练好的 pos_embed 尺寸不匹配。常见的解决办法是对 pos_embed 进行2D双线性插值,以适应新的序列长度。
    3. 维度混淆: [B, N, D] vs [B, D, N] 是初学者常见的错误来源。务必时刻清楚当前张量的维度含义,尤其是在 transpose, permute, flatten 操作之后。

常见误区与边界情况

  • 误区1: "ViT 完全没有使用卷积"

    • 辨析: 这是不准确的。ViT 的 Patch Embedding 步骤最有效的实现方式就是卷积。此外,很多 ViT 的变体(如 Hybrid ViT)会先用一个小的 CNN(如 ResNet 的前几个 stage)来提取低级特征图,再对这个特征图进行 patch 和 embedding。这证明了卷积在视觉任务中的归纳偏置(inductive bias)仍然很有价值。
  • 误区2: "[CLS] token 是唯一的全局表示方法"

    • 辨析: 不是。另一种常见的方法是,不使用 [CLS] token,在 Transformer Encoder 的输出端对所有 patch tokens 的向量表示进行全局平均池化(Global Average Pooling),然后将得到的池化向量用于分类。实验表明两种方法性能相似,但使用 [CLS] token 是原始 ViT 论文的做法,并已成为一种事实标准。
  • 误区3: "位置编码必须是可学习的"

    • 辨析: 不一定。原始 Transformer 论文使用的是固定的正弦/余弦位置编码。ViT 论文实验发现,可学习的1D位置编码、2D位置编码和固定的正弦位置编码性能差异不大。但可学习的编码实现简单,且能让模型自行学习最适合的位置表示,因此在 ViT 及其变体中被广泛采用。
  • 边界情况:图像尺寸不能被 patch_size 整除

    • 标准 nn.Conv2d 在步长大于1时,如果输入尺寸不能被整除,会自动丢弃右侧和/或底部的部分像素。在 ViT 的上下文中,这意味着图像的边缘信息会丢失。
    • 标准做法: 在数据预处理阶段,就将所有图像强制缩放和/或填充到一个固定的、可以被 patch_size 整除的尺寸(如 224x224)。这保证了所有输入都有相同的序列长度 NN,简化了模型设计和批处理。
  • 常见面试追问:

    • : "为什么不直接用 torch.nn.Unfold + torch.nn.Linear 来实现 Patch Embedding?"
    • : Unfold + Linear 在功能上是等价的,它显式地提取图块然后进行线性变换。但 Conv2d 是一个单一的、高度优化的底层操作,在现代深度学习框架和硬件(GPU/TPU)上通常比 Unfold + Linear 的组合拳执行得更快。因此,Conv2d 是更优雅且性能更优的工程选择。
    • : "这个 [CLS] token 的初始值是什么?它如何学习?"
    • : 它是一个 nn.Parameter,通常用零或小的随机值初始化。它和其他模型参数(如卷积核权重、注意力矩阵)一样,通过反向传播和梯度下降进行端到端的学习。在训练过程中,模型会学会将 [CLS] token 的最终输出调整为对整个图像分类任务最有用的表示。
相关题目