§1.3.34

残差连接为什么有效(梯度视角、恒等映射视角)?

核心概念

残差连接(Residual Connection),又称快捷连接(Shortcut Connection),是 ResNet 架构的核心思想。它允许网络的某一层的输入信号可以“跳过”一个或多个中间层,直接与这些中间层的输出相加,然后将结果传递给下一层。其核心表达式为 H(x)=F(x)+xH(x) = F(x) + x,其中 xx 是块的输入, F(x)F(x) 是残差块(通常是几个卷积层和激活函数)学习到的函数,而 H(x)H(x) 是块的最终输出。残差连接旨在解决深度神经网络中的“退化问题”(degradation problem)和梯度消失问题,使得训练非常深的网络成为可能。

原理与推导

残差连接的有效性可以从两个核心视角来理解:恒等映射和梯度传播。

1. 恒等映射视角(解决网络退化问题)

在 ResNet 论文提出之前,一个普遍的观察是,当网络深度增加时,其性能会先提升然后达到饱和,最后迅速下降。反直觉的是,这种性能下降(即“退化”)并不仅仅是过拟合,因为在训练集上的误差也会随深度增加而增加。

理论上的最优情况: 假设我们有一个已经训练好的、性能不错的浅层网络。如果我们想构建一个更深的网络,一个简单的方法是在浅层网络后面堆叠新的层。如果这些新添加的层能够学习成为“恒等映射”(Identity Mapping),即输出完全等于输入(H(x)=xH(x)=x),那么这个更深的网络至少应该能达到与原浅层网络相当的性能,而不应该更差。

没有残差连接的困难: 对于一个标准的前馈网络(如 VGG),一个层或块的输出是 H(x)=ReLU(BN(Conv(x)))H(x) = \text{ReLU}(\text{BN}(\text{Conv}(x)))。让这样一堆非线性的、带权重参数的层去精确拟合一个恒等映射 H(x)=xH(x)=x 是非常困难的。优化器很难找到合适的参数组合来实现这一点。

残差连接的解决方案: 残差连接将学习目标从一个完整的映射 H(x)H(x) 重新定义为一个残差映射 F(x)F(x)H(x)=F(x,{Wi})+xH(x) = F(x, \{W_i\}) + x 其中,F(x,{Wi})F(x, \{W_i\}) 是需要学习的残差函数,通常由两个或三个卷积层组成。

  • 几何/直观解释: 如果恒等映射是当前层的最优选择(即增加该层并无益处),那么网络只需要让 F(x)F(x) 趋近于零即可。这比让整个复杂的 F(x)+xF(x)+x 结构去拟合 xx 要容易得多。优化器可以通过将权重 WiW_i 推向零来轻松实现 F(x)0F(x) \approx 0,从而使该层近似于一个恒等映射。这为网络提供了一条“如果这个模块没用,就跳过它”的捷径,从而避免了深度增加带来的退化问题。

2. 梯度视角(缓解梯度消失)

梯度消失是训练深度网络的另一个主要障碍。在反向传播过程中,梯度需要通过链式法则从最后一层逐层传回第一层。如果每层的梯度雅可比矩阵的范数都小于1,那么经过多层传播后,梯度信号会呈指数级衰减,导致靠近输入的层几乎无法更新权重。

标准网络的反向传播: 考虑损失函数 L\mathcal{L} 对某浅层特征 xlx_l 的梯度: Lxl=LxLxLxL1xL1xL2xl+1xl=LxLi=lL1xi+1xi\frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_L} \frac{\partial x_L}{\partial x_{L-1}} \frac{\partial x_{L-1}}{\partial x_{L-2}} \cdots \frac{\partial x_{l+1}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_L} \prod_{i=l}^{L-1} \frac{\partial x_{i+1}}{\partial x_i} 其中 xLx_L 是深层特征。如果 xi+1xi\frac{\partial x_{i+1}}{\partial x_i} 涉及的权重矩阵和激活函数导数通常小于1,那么连乘项 i=lL1\prod_{i=l}^{L-1} 会迅速趋近于零。

残差网络的反向传播: 对于一个残差块,我们有 xl+1=F(xl,Wl)+xlx_{l+1} = F(x_l, W_l) + x_l。根据求导法则,其雅可比矩阵为: xl+1xl=F(xl,Wl)xl+xlxl=F(xl,Wl)xl+1\frac{\partial x_{l+1}}{\partial x_l} = \frac{\partial F(x_l, W_l)}{\partial x_l} + \frac{\partial x_l}{\partial x_l} = \frac{\partial F(x_l, W_l)}{\partial x_l} + 1 现在,我们再看从深层 LL 到浅层 ll 的梯度: Lxl=LxLxLxl\frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_L} \frac{\partial x_L}{\partial x_l} 根据 xL=xL1+F(xL1,WL1)=xL2+F(xL2,WL2)+F(xL1,WL1)==xl+i=lL1F(xi,Wi)x_L = x_{L-1} + F(x_{L-1}, W_{L-1}) = x_{L-2} + F(x_{L-2}, W_{L-2}) + F(x_{L-1}, W_{L-1}) = \dots = x_l + \sum_{i=l}^{L-1} F(x_i, W_i),我们可以推导出: Lxl=LxLxl(xl+i=lL1F(xi,Wi))=LxL(1+xli=lL1F(xi,Wi))\frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_L} \frac{\partial}{\partial x_l} \left( x_l + \sum_{i=l}^{L-1} F(x_i, W_i) \right) = \frac{\partial \mathcal{L}}{\partial x_L} \left( 1 + \frac{\partial}{\partial x_l} \sum_{i=l}^{L-1} F(x_i, W_i) \right)

  • 信息论/物理直观解释: 这个公式揭示了残差连接的强大之处。梯度 LxL\frac{\partial \mathcal{L}}{\partial x_L} 可以通过两条路径回传到 xlx_l
    1. 无衰减的“高速公路”:通过 +1 这一项,来自顶层的梯度可以直接、无任何中间矩阵乘法衰减地传递到浅层。这保证了即使网络很深,浅层也能接收到有效的梯度信号。
    2. 学习调整的“支路”:通过 xli=lL1F(xi,Wi)\frac{\partial}{\partial x_l} \sum_{i=l}^{L-1} F(x_i, W_i) 这一项,梯度流经各个残差块的权重层,进行正常的学习和调整。

这个 +1 的存在,使得即使残差支路 Fxl\frac{\partial F}{\partial x_l} 的梯度很小,总梯度也不会消失。它为梯度提供了一条畅通无阻的传播路径,极大地缓解了梯度消失问题。

代码实现

下面是一个使用 PyTorch 实现的残差块(Residual Block)的例子。它包含了两种情况:

  1. 输入和输出维度相同,直接相加(Identity Shortcut)。
  2. 输入和输出维度不同(例如,通道数或尺寸变化),需要一个“投影快捷连接”(Projection Shortcut)来匹配维度,通常用一个 1x1 卷积实现。
python
1import torch
2import torch.nn as nn
3
4class ResidualBlock(nn.Module):
5 """
6 一个标准的残差块,包含两个卷积层。
7 支持维度匹配的投影快捷连接。
8 """
9 def __init__(self, in_channels, out_channels, stride=1):
10 super(ResidualBlock, self).__init__()
11
12 # 残差函数 F(x) 部分
13 # 第一个卷积层,可能通过 stride 改变特征图尺寸
14 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
15 self.bn1 = nn.BatchNorm2d(out_channels)
16 self.relu = nn.ReLU(inplace=True)
17
18 # 第二个卷积层,不改变尺寸
19 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
20 self.bn2 = nn.BatchNorm2d(out_channels)
21
22 # 快捷连接(Shortcut Connection)部分
23 self.shortcut = nn.Sequential()
24 # 为什么需要这个判断?
25 # 如果 stride 不为 1(尺寸减半)或输入输出通道数不同,
26 # 那么输入 x 和 F(x) 的维度会不匹配,无法直接相加。
27 # 此时,需要通过一个 1x1 卷积(投影)来调整 x 的维度。
28 if stride != 1 or in_channels != out_channels:
29 self.shortcut = nn.Sequential(
30 nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
31 nn.BatchNorm2d(out_channels)
32 )
33
34 def forward(self, x):
35 # 记录原始输入,用于最后的相加
36 identity = x
37
38 # --- 残差路径 F(x) ---
39 out = self.conv1(x)
40 out = self.bn1(out)
41 out = self.relu(out)
42
43 out = self.conv2(out)
44 out = self.bn2(out)
45
46 # --- 快捷连接路径 ---
47 # 为什么在这里应用 shortcut?
48 # 为了让 identity 的维度与 F(x) 的输出 out 保持一致,以便可以相加。
49 identity = self.shortcut(x)
50
51 # --- 元素级相加 ---
52 # 为什么是相加?
53 # 这就是残差学习的核心:H(x) = F(x) + x
54 out += identity
55
56 # 为什么在相加之后应用 ReLU?
57 # 这是原始 ResNet 论文中的设计(post-activation)。
58 # 也有研究表明 pre-activation (BN-ReLU-Conv) 效果更好。
59 out = self.relu(out)
60
61 return out
62
63# --- code_drills: 演示两种情况 ---
64
65# 1. 维度匹配(Identity Shortcut)
66print("--- 演示维度匹配情况 ---")
67# 输入通道=64, 输出通道=64, stride=1
68block_identity = ResidualBlock(in_channels=64, out_channels=64, stride=1)
69# 创建一个假的输入张量:Batch=4, Channels=64, Height=32, Width=32
70input_tensor_1 = torch.randn(4, 64, 32, 32)
71output_tensor_1 = block_identity(input_tensor_1)
72print(f"输入尺寸: {input_tensor_1.shape}")
73print(f"输出尺寸: {output_tensor_1.shape}")
74# 检查:输出尺寸应与输入尺寸相同
75
76print("\n" + "="*30 + "\n")
77
78# 2. 维度不匹配(Projection Shortcut)
79print("--- 演示维度不匹配情况 ---")
80# 输入通道=64, 输出通道=128, stride=2(尺寸减半)
81block_projection = ResidualBlock(in_channels=64, out_channels=128, stride=2)
82# 创建一个假的输入张量:Batch=4, Channels=64, Height=32, Width=32
83input_tensor_2 = torch.randn(4, 64, 32, 32)
84output_tensor_2 = block_projection(input_tensor_2)
85print(f"输入尺寸: {input_tensor_2.shape}")
86print(f"输出尺寸: {output_tensor_2.shape}")
87# 检查:输出通道数变为128,高和宽变为16

工程实践

  • 使用场景:残差连接是现代深度学习模型的“标配”,几乎所有SOTA的计算机视觉模型(如 EfficientNet, Vision Transformer)、自然语言处理模型(如 Transformer 中的 FeedForward 和 Attention 之后)以及语音识别模型都深度整合了残差连接的思想。只要你需要构建一个深度超过20层的网络,就应该默认使用残差连接。

  • 超参数选择

    • Block 类型:对于较深的网络(如 ResNet-50/101/152),通常使用“瓶颈”结构(Bottleneck Block),即 1x1 Conv -> 3x3 Conv -> 1x1 Conv。这种结构先用 1x1 卷积降维,然后用 3x3 卷积提取特征,最后用 1x1 卷积升维。相比于使用两个 3x3 卷积的 Basic Block,瓶颈结构在增加网络深度的同时,能更有效地控制参数量和计算量。
    • 激活函数位置:后续研究(Identity Mappings in Deep Residual Networks)表明,采用“预激活”(Pre-activation)的方式,即将 BN 和 ReLU 放在卷积层之前(BN -> ReLU -> Conv),可以提供更好的正则化效果,使优化更容易,性能通常也更好。
    • 初始化:在残差块的最后一个 Batch Normalization 层(bn2)之后,可以将其权重初始化为零。这样,在训练初期,整个残差块的输出 F(x)F(x) 接近于零,使得整个模块近似于一个恒等映射,有助于稳定训练初期的学习过程。
  • 性能/显存/吞吐的权衡

    • 瓶颈结构 vs. 基础结构:瓶颈结构是典型的用计算换参数和显存的例子。虽然它有三层卷积,但由于中间的 3x3 卷积是在降维后的特征图上操作的,总计算量(FLOPs)和参数量都比两个直接在原维度上操作的 3x3 卷积要少。
    • 分组卷积/深度可分离卷积:在移动端或对效率要求极高的场景,可以将残差块中的标准卷积替换为分组卷积或深度可分离卷积(如 MobileNet),以进一步降低计算和参数。
  • 常见坑和调试技巧

    • 维度不匹配:这是最常见的错误。在实现 out += identity 时,一定要确保 outidentity 的形状完全一致。使用 print(out.shape, identity.shape) 是调试此问题的金标准。
    • 梯度爆炸:虽然残差连接缓解了梯度消失,但在极深的网络或使用不当的学习率时,梯度也可能爆炸。确保使用 Batch Normalization,并选择合适的学习率和权重衰减(weight decay)。
    • 训练不稳定:如果模型训练发散,检查学习率、数据预处理和网络初始化。尝试使用预激活结构,或者在训练初期使用较小的学习率进行 warm-up。

常见误区与边界情况

  • 误区1:残差连接就是为了解决梯度消失

    • 辨析:这只说对了一半。残差连接更直接的目的是解决“网络退化”问题,即让深度网络的优化变得更容易。缓解梯度消失是其带来的一个重要“副作用”和结果,但其设计的初衷是为了让网络更容易学习恒等映射。面试时能清晰区分这一点是加分项。
  • 误区2:任何跳层连接都是残差连接

    • 辨析:不完全是。残差连接特指 H(x) = F(x) + x 这种通过相加(element-wise addition)合并的方式。像 DenseNet 那样通过拼接(concatenation)的方式连接,虽然也是跳层连接,但其机制和效果有所不同(促进特征复用,但显存消耗大)。
  • 误区3:有了残差连接,网络就可以无限深

    • 辨析:理论上梯度可以传播,但实践中,过深的网络(如超过1000层)会带来收益递减、训练时间过长、过拟合等问题。模型的性能提升并非与深度呈线性关系。
  • 边界情况与失败模式

    • 数值精度问题:在半精度训练(FP16)中,如果 xx 的数值很大而 F(x)F(x) 的数值很小,x + F(x) 的加法可能会导致 F(x)F(x) 的信息被“冲刷掉”(numerical underflow),使得残差块的梯度无法有效回传。
    • 不当的激活函数:如果在残差连接的求和之后使用 tanhsigmoid 等饱和激活函数,会再次限制梯度的传播,削弱残差连接的优势。这也是为什么 ReLU 在深层网络中更受欢迎的原因之一。
  • 常见面试追问

    • “ResNet 和 DenseNet 的 skip connection 有什么区别?”
      • 回答要点:ResNet 是相加,保留了信息流的主干道,侧重于让网络学习“残差”;DenseNet 是拼接,每一层都连接到后续所有层,侧重于“特征复用”,但显存占用随深度线性增长,而 ResNet 的显存占用是常数。
    • Transformer 里的残差连接用在哪里?为什么?”
      • 回答要点:在 Transformer 的每个子层(多头自注意力、前馈网络)之后都使用了残差连接和层归一化(LayerNorm)。即 output = LayerNorm(x + Sublayer(x))。其作用与在 CNN 中类似:允许梯度无障碍传播,使得可以堆叠非常多的 Transformer Encoder/Decoder 层(例如 BERT-large 有24层),从而构建强大的模型。
相关题目