ConvNeXt 如何用卷积追平 Transformer?
核心概念
ConvNeXt 并非一个全新的架构,而是一个通过借鉴 Vision Transformer(特别是 Swin Transformer)的设计思想和现代训练方法,对标准卷积网络(如 ResNet)进行系统性“现代化改造”后的产物。其核心思想在于证明,当配备了与 Transformer 相当的宏观/微观设计以及训练策略后,纯粹的卷积网络在性能上完全可以与顶尖的 Transformer 模型相媲美,甚至超越。ConvNeXt 的贡献不在于发明新的模块,而在于其严谨的、对照实验驱动的“进化”过程,揭示了网络架构、训练策略和最终性能之间的深刻联系。
原理与推导
ConvNeXt 的诞生过程是一场从 ResNet-50 到 ConvNeXt 的“现代化之旅”。作者系统地、一步步地将 Swin Transformer 的设计元素融入到 ResNet 中,并评估每一步带来的影响。
起点:一个标准的 ResNet-50 模型。
1. 宏观架构设计 (Macro Design)
-
阶段计算比例 (Stage Compute Ratio):
- ResNet-50 中 4 个 stage 的 block 数量比例为 (3, 4, 6, 3)。
- Swin-T (Tiny) 的比例为 (1, 1, 3, 1),每个 stage 的 block 数量更接近,且在 stage3 分配了更多计算。
- 改造: 将 ResNet-50 的 block 数量调整为 (3, 3, 9, 3),模仿 Swin-T 将更多计算量移到靠后的阶段。这使得模型能更好地学习高级语义特征。
-
Stem 层 "Patchify":
- ResNet 的 Stem 是一个 7x7、步长为 2 的卷积层,后跟一个最大池化层,共同实现 4 倍下采样。
ViT/Swin 使用一个非重叠的 "Patchify" 层,即一个大卷积核、大步长的卷积层来直接将输入图像切分为块 (patch)。- 改造: 将 ResNet 的 Stem 替换为一个 4x4、步长为 4 的卷积层。这在功能上与
Transformer的 patch embedding 完全等价,实现了更直接、简单的 4 倍下采样。
2. 借鉴 ResNeXt (ResNeXt-ify)
- 分组卷积 (Grouped Convolution): ResNeXt 引入了分组卷积来提高“基数”(cardinality),在不显著增加参数量和计算量的情况下提升模型性能。
- 深度可分离卷积 (Depthwise Separable Convolution): 这是分组卷积的一个极端情况,其中分组数等于通道数 ()。Swin
Transformer的多头自注意力 (MHSA) 也可以被看作是一种分组操作,每个头在通道的一个子集上独立工作。 - 改造: 采用深度可分离卷积。这不仅借鉴了 ResNeXt 的思想,也使得卷积块的结构更接近
Transformerblock 中的 MHSA 模块。
3. 倒置瓶颈设计 (Inverted Bottleneck)
- ResNet Bottleneck: "胖-瘦-胖" 结构。例如,一个 256 维的输入,先通过 1x1 卷积压缩到 64 维,然后 3x3 卷积,最后 1x1 卷积扩张回 256 维。
TransformerFFN Block: "瘦-胖-瘦" 结构。输入 维,通过第一个 MLP 扩展到 维,然后通过第二个 MLP 压缩回 维。MobileNetV2 也采用了这种倒置瓶颈。- 改造: 将 ResNet 的瓶颈结构颠倒过来,采用 "瘦-胖-瘦" 的倒置瓶颈。即 1x1 卷积将通道数从 扩展到 ,深度卷积处理 维特征,最后 1x1 卷积再压缩回 维。
4. 大卷积核 (Large Kernel Sizes)
- 动机:
ViT的一个核心优势是其自注意力机制具有全局感受野。SwinTransformer通过窗口移位机制实现了跨窗口的信息交互,也获得了比传统小核卷积(如 3x3)更大的有效感受野。 - 挑战: 在标准卷积中,增大卷积核尺寸会急剧增加参数量 () 和计算量。
- 解决方案: 由于上一步已经采用了深度可分离卷积,参数量仅为 ,与通道数呈线性关系。这使得使用大卷积核变得可行。
- 改造: 将深度卷积的核尺寸从 3x3 逐步上移。实验发现,将核尺寸从 3x3 提升到 7x7 能带来显著的性能增益,且性能超过了 Swin-T。这证明了大的局部感受野对于视觉任务至关重要。
5. 微观架构设计 (Micro-level Design)
在完成了上述主要修改后,作者进一步对 block 内的细节进行了调整,使其更像 Transformer block。
- 激活函数: 将 ReLU 替换为 GELU。GELU 是
BERT, GPT,ViT等主流Transformer模型使用的激活函数,其非线性和平滑性被认为更优。 - 归一化层: 将
BatchNorm(BN) 全部替换为LayerNorm(LN)。Transformer完全使用 LN。BN 依赖于 batch 的统计量,对 batch size 敏感,且在训练和推理时行为不一;而 LN 在单个样本的通道维度上进行归一化,与 batch size 无关,行为更稳定。 - 减少激活函数和归一化层: ResNet 的一个 block 中包含多个激活函数和 BN 层。
Transformerblock 中通常只有一个或两个归一化层,且激活函数只在 FFN 中使用。 - 改造: 精简 block 结构,每个 ConvNeXt block 只保留一个 GELU 和一个 LN 层,结构更简洁高效。
最终的 ConvNeXt Block 结构:
一个输入张量 经过 ConvNeXt block 的计算流程如下:
其中:
- 是一个 7x7 的深度卷积。
- 是
LayerNorm。 - 第一个 是逐点卷积,用于将通道数扩展 4 倍(倒置瓶颈)。
- 第二个 是逐点卷积,用于将通道数压缩回原始维度。
- 是一种正则化技术,随机“丢弃”整个残差分支,是
Transformer中常用的 Stochastic Depth。
复杂度分析: 对于一个 ConvNeXt block,假设输入特征图尺寸为 ,通道数为 :
- 时间复杂度: 主要由卷积贡献。。对于 ,约为 。
- 空间复杂度: 主要是存储特征图和参数。参数量为 (DW-Conv) + (PW-Conv1) + (PW-Conv2) 。
代码实现
下面是一个 PyTorch 实现的 ConvNeXt Block 和一个简化的 ConvNeXt 模型。
1import torch2import torch.nn as nn3import torch.nn.functional as F45def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):6 """7 Drop paths (Stochastic Depth) per sample.8 这是 Transformer 和 ConvNeXt 中常用的正则化方法,可以看作是 Dropout 的一种结构化形式。9 它会随机地将整个残差分支的输出置为零,而不是像 Dropout 那样随机置零单个元素。10 """11 if drop_prob == 0. or not training:12 return x13 keep_prob = 1 - drop_prob14 shape = (x.shape[0],) + (1,) * (x.ndim - 1) # (B, 1, 1, 1)15 random_tensor = x.new_empty(shape).bernoulli_(keep_prob)16 if keep_prob > 0.0 and scale_by_keep:17 random_tensor.div_(keep_prob)18 return x * random_tensor1920class DropPath(nn.Module):21 """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""22 def __init__(self, drop_prob=None, scale_by_keep=True):23 super(DropPath, self).__init__()24 self.drop_prob = drop_prob25 self.scale_by_keep = scale_by_keep2627 def forward(self, x):28 return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)2930class LayerNorm(nn.Module):31 """32 自定义 LayerNorm 以支持 (N, C, H, W) 格式的输入。33 PyTorch 原生的 nn.LayerNorm 期望通道在最后一个维度。34 这里通过 permute 操作来适配,或者直接在指定维度上计算均值和方差。35 """36 def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):37 super().__init__()38 self.weight = nn.Parameter(torch.ones(normalized_shape))39 self.bias = nn.Parameter(torch.zeros(normalized_shape))40 self.eps = eps41 self.data_format = data_format42 if self.data_format not in ["channels_last", "channels_first"]:43 raise NotImplementedError44 self.normalized_shape = (normalized_shape, )4546 def forward(self, x):47 if self.data_format == "channels_last":48 # (N, H, W, C)49 return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)50 elif self.data_format == "channels_first":51 # (N, C, H, W)52 u = x.mean(1, keepdim=True)53 s = (x - u).pow(2).mean(1, keepdim=True)54 x = (x - u) / torch.sqrt(s + self.eps)55 x = self.weight[:, None, None] * x + self.bias[:, None, None]56 return x5758class ConvNeXtBlock(nn.Module):59 """60 ConvNeXt Block.6162 结构:63 x -> 7x7 DW-Conv -> LayerNorm -> 1x1 PW-Conv -> GELU -> 1x1 PW-Conv -> DropPath -> + -> x_out64 | |65 +-----------------------(残差连接)----------------------+66 """67 def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):68 super().__init__()69 # 深度卷积 (Depthwise Conv), 模仿自注意力的局部性70 self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)7172 # LayerNorm, Transformer 的标配73 self.norm = LayerNorm(dim, eps=1e-6, data_format="channels_first")7475 # 两个逐点卷积 (Pointwise Conv) 构成的倒置瓶颈, 模仿 Transformer 的 FFN76 self.pwconv1 = nn.Conv2d(dim, 4 * dim, kernel_size=1)77 self.act = nn.GELU() # GELU 激活函数78 self.pwconv2 = nn.Conv2d(4 * dim, dim, kernel_size=1)7980 # Layer Scale, 一种额外的正则化/稳定训练技巧81 self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),82 requires_grad=True) if layer_scale_init_value > 0 else None8384 # Stochastic Depth85 self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()8687 def forward(self, x):88 input = x89 x = self.dwconv(x)90 x = self.norm(x)91 x = self.pwconv1(x)92 x = self.act(x)93 x = self.pwconv2(x)9495 if self.gamma is not None:96 # Layer Scale: 按通道对残差分支的输出进行缩放97 x = self.gamma.view(1, -1, 1, 1) * x9899 # 残差连接 + DropPath100 x = input + self.drop_path(x)101 return x102103# 示例:创建一个 ConvNeXt-Tiny 规模的模型104class ConvNeXt(nn.Module):105 def __init__(self, in_chans=3, num_classes=1000,106 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],107 drop_path_rate=0., layer_scale_init_value=1e-6):108 super().__init__()109110 self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers111 # Stem: "Patchify" 层112 stem = nn.Sequential(113 nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),114 LayerNorm(dims[0], eps=1e-6, data_format="channels_first")115 )116 self.downsample_layers.append(stem)117118 # 下采样层119 for i in range(3):120 downsample_layer = nn.Sequential(121 LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),122 nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),123 )124 self.downsample_layers.append(downsample_layer)125126 self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple blocks127 dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]128 cur = 0129 for i in range(4):130 stage = nn.Sequential(131 *[ConvNeXtBlock(dim=dims[i], drop_path=dp_rates[cur + j],132 layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]133 )134 self.stages.append(stage)135 cur += depths[i]136137 self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer138 self.head = nn.Linear(dims[-1], num_classes)139140 def forward(self, x):141 for i in range(4):142 x = self.downsample_layers[i](x)143 x = self.stages[i](x)144145 # 全局平均池化 + 分类头146 x = self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)147 x = self.head(x)148 return x149150# 创建一个 ConvNeXt-Tiny 模型实例151model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])152dummy_input = torch.randn(2, 3, 224, 224) # (Batch, Channels, Height, Width)153output = model(dummy_input)154print(f"输入尺寸: {dummy_input.shape}")155print(f"输出尺寸: {output.shape}")156# 打印模型参数量157num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)158print(f"模型总参数量: {num_params / 1e6:.2f} M")
工程实践
-
使用场景: ConvNeXt 是一个高性能、通用的视觉骨干网络。它可以作为 ResNet、EfficientNet 或 Swin
Transformer的直接替代品,用于图像分类、目标检测、语义分割等各种下游任务。由于其纯卷积的特性,在某些硬件(如标准 GPU、移动端 DSP)上可能比Transformer有更好的部署和推理优化潜力。 -
超参数选择:
- 模型规模: 论文提供了从
Tiny(28M 参数) 到XLarge(350M 参数) 的多个版本。选择哪个版本取决于算力预算和性能要求。ConvNeXt-T/S是性能和效率的甜点。 - 训练策略是关键: 直接将 ConvNeXt 代码套用到旧的训练脚本(如 ResNet 的标准 SGD 训练)上,性能会大打折扣。必须使用现代化的训练策略,包括:
- 优化器: AdamW,配合较高的学习率(如 4e-3)和较长的 warmup。
- 数据增强: Mixup, CutMix, RandAugment, Random Erasing 等强数据增强是必须的。
- 正则化: Label Smoothing, Stochastic Depth (DropPath), Layer Scale。
- 训练周期: 通常需要更长的训练周期(例如 300 epochs on ImageNet-1k)。
- 模型规模: 论文提供了从
-
性能 / 显存 / 吞吐 的权衡:
- 吞吐量: 在同等性能水平下,ConvNeXt 的推理吞吐量通常高于 Swin
Transformer,因为卷积操作在 cuDNN 等库中有极高的优化。 - 显存: 显存占用与 Swin
Transformer相当。大卷积核的深度卷积虽然参数少,但中间激活图尺寸较大,会消耗一定显存。 - 灵活性: 纯卷积结构使其更容易集成到现有的 CNN-based 框架中,无需处理
Transformer特有的 token 操作和位置编码。
- 吞吐量: 在同等性能水平下,ConvNeXt 的推理吞吐量通常高于 Swin
-
常见坑和调试技巧:
- 性能不达标: 最常见的原因是没有使用配套的训练策略。请务必检查优化器、学习率调度、数据增强和正则化设置是否与官方实现一致。
LayerNorm的data_format: 在 PyTorch 中使用nn.LayerNorm时,需要注意它默认处理channels_last的数据。对于 CNN 常用的channels_first(N, C, H, W),需要进行permute操作或使用自定义的 LN 实现(如代码所示),否则会出错。- 迁移学习: 在下游任务(如检测、分割)上微调时,由于 ConvNeXt 预训练时使用了强正则化,其特征可能比传统 ResNet 更“平滑”。微调时可能需要调整学习率或解冻策略。
常见误区与边界情况
-
误区一:ConvNeXt 是一个全新的发明。
- 纠正: ConvNeXt 的核心价值在于其“方法论”——即通过严谨的实验证明了,将
Transformer的设计哲学应用于 CNN,可以使后者达到 SOTA 性能。它的模块(深度卷积、倒置瓶颈)本身都是已知的。
- 纠正: ConvNeXt 的核心价值在于其“方法论”——即通过严谨的实验证明了,将
-
误区二:7x7 是最优的卷积核尺寸。
- 纠正: 7x7 是在 ImageNet 分类任务上找到的一个很好的平衡点。论文中也实验了更大的核(如 11x11),性能有微小提升但延迟增加。在其他任务或不同输入分辨率下,最优核尺寸可能不同。关键思想是“使用比 3x3 大得多的核”。
-
误区三:ConvNeXt 证明了卷积优于自注意力。
- 纠正: ConvNeXt 证明了卷积的“归纳偏置”(局部性、平移等变性)在经过精心设计和训练后,依然极具竞争力。它并没有否定自注意力的价值(如动态权重、全局感受野)。它表明,架构设计的优劣是一个系统工程,单纯比较一个操作(conv vs. attention)是片面的。
-
边界情况与面试追问:
- 追问:为什么将 BN 换成 LN 如此重要?
- 回答要点: 1) 对齐
Transformer设计:这是模仿Transformer的关键一步。2) 小批量问题:在目标检测、分割等任务中,每个 GPU 上的 batch size 可能很小(1-2),BN 的统计量会非常不稳定,而 LN 不受影响。3) 训练/推理不一致:BN 在训练时使用 batch 统计,推理时使用全局移动平均统计,行为不一。LN 在任何时候都只对当前样本操作,行为一致。
- 回答要点: 1) 对齐
- 追问:深度卷积+大核是如何模拟自注意力的?
- 回答要点: 自注意力根据输入动态计算权重,实现大感受野内的加权求和。深度卷积+大核则是在一个大的局部窗口内进行加权求和,但其权重是静态的(由卷积核参数决定)、与输入内容无关的。ConvNeXt 的成功表明,一个足够大的、静态的感受野,配合强大的特征学习能力(倒置瓶颈+GELU),在很多视觉任务上已经足够强大。它用一个强化的“归纳偏置”达到了类似的效果。
- 追问:如果让你继续改进 ConvNeXt,你会从哪个方向入手?
- 回答要点: 1) 动态性:引入某种形式的动态或内容自适应机制,例如动态卷积核,让模型能根据输入调整卷积核权重,进一步弥合与自注意力的差距。2) 多尺度融合:在 block 内部融合不同大小的卷积核,类似 Inception 的思想,以更高效地捕捉多尺度特征。3) 更优的训练策略:探索自监督或半监督预训练方法,减少对大规模标注数据的依赖。
- 追问:为什么将 BN 换成 LN 如此重要?