ResNet 残差思想为什么解决退化问题?
核心概念
ResNet(Residual Network,残差网络)的核心思想是引入“快捷连接”(Shortcut Connection)或称“残差连接”(Residual Connection),让神经网络的一部分层去学习输入与输出之间的“残差”,而不是直接学习从输入到输出的完整映射 。网络的最终输出是 。这种结构使得网络在需要时可以轻易地学习一个恒等映射(即 ),从而解决了深度神经网络中出现的“退化问题”(Degradation Problem)——即网络层数增加,训练集上的准确率反而下降的现象。
原理与推导
退化问题的核心是,当网络变得非常深时,让多层非线性网络去拟合一个恒等映射(Identity Mapping, )都变得非常困难。如果一个更深的模型至少可以和其较浅的版本表现得一样好(通过在多余的层学习恒等映射),那么它的性能就不应该下降。ResNet正是基于这个“恒等映射”假设来设计的。
数学公式与推导
假设一个残差块的输入为 ,输出为 。
- 普通网络层: ,其中 是一个复杂的非线性变换,例如
Conv -> BN -> ReLU。 - 残差网络层: ,其中 是需要学习的残差函数,通常由两到三个卷积层构成。
为什么这能解决退化问题?关键在于反向传播时的梯度流。
考虑损失函数 对某浅层 的梯度 。根据链式法则,它与更深层 的梯度关系如下:
对于一个普通的深度网络, (为简化,此处忽略激活函数,将变换视为矩阵乘法)。在深度网络中,这个连乘项很容易导致梯度消失(乘积项小于1)或梯度爆炸(乘积项大于1)。
现在看 ResNet。由于 ,我们有:
将其代入链式法则的展开式:
这个公式是递归展开后的简化形式,更严谨的表达是:
这个 +1 项至关重要。它创建了一个“梯度高速公路”,使得梯度可以直接从深层 传播到浅层 ,而不会被中间层的权重矩阵完全“吞噬”。即使残差路径上的梯度 非常小(接近于0),梯度仍然可以通过这个 +1 的路径回传。这极大地缓解了梯度消失问题,使得非常深的网络也能得到有效训练。
直观解释
-
恒等映射易学性: 假设对于某些层,最优的函数就是一个恒等映射(即原封不动地传递信息)。对于传统网络,它需要通过调整复杂的卷积核权重来拟合 。这对于非线性激活函数(如ReLU)和权重初始化(通常是均值为0的小随机数)来说非常困难。而对于 ResNet,网络只需要学习让残差函数 的输出为0即可,这比拟合恒等映射容易得多。权重衰减(Weight Decay)等正则化项本身就会鼓励权重趋向于0。
-
信息流视角: 快捷连接提供了一条“干净”的信息通道。输入信号 可以直接流向网络的更深处。网络层 的作用更像是对主干道上的信息进行微调和修正,而不是完全重构。这保证了即使深层网络学得不好(例如 引入了噪声),原始信息 也不会丢失。
算法复杂度
- 时间复杂度: 一个残差块(如基础块,包含两个3x3卷积)的计算量与两个普通卷积层相当。因此,一个 N 层的 ResNet 的时间复杂度与一个 N 层的普通卷积网络在同一数量级,即 。
- 空间复杂度: 主要由特征图和模型参数决定。与同等深度的普通网络相比,ResNet 仅增加了存储快捷连接输出的少量额外显存,空间复杂度也在同一数量级。
代码实现
下面是一个 PyTorch 实现的 ResNet 基础块(BasicBlock)和瓶颈块(Bottleneck),它们是构成 ResNet-18/34 和 ResNet-50/101/152 的基本单元。
1import torch2import torch.nn as nn3import torch.nn.functional as F45class BasicBlock(nn.Module):6 """7 ResNet-18/34 使用的基础残差块8 结构: 3x3 Conv -> BN -> ReLU -> 3x3 Conv -> BN9 """10 expansion = 1 # expansion因子,表示输出通道数相对于输入通道数的倍数1112 def __init__(self, in_channels, out_channels, stride=1):13 super(BasicBlock, self).__init__()1415 # 主路径 (残差函数 F(x))16 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)17 self.bn1 = nn.BatchNorm2d(out_channels)18 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)19 self.bn2 = nn.BatchNorm2d(out_channels)2021 # 快捷连接 (Shortcut Connection)22 self.shortcut = nn.Sequential()23 # 为什么需要这个if判断?24 # 当输入和输出的维度不匹配时(通道数或图像尺寸),快捷连接无法直接相加。25 # 1. stride != 1: 经过conv1后,图像尺寸减半,x的尺寸与F(x)的尺寸不匹配。26 # 2. in_channels != out_channels * self.expansion: 输入通道数与输出通道数不匹配。27 # 此时,需要通过一个1x1卷积对x进行变换,使其维度与F(x)一致。28 if stride != 1 or in_channels != self.expansion * out_channels:29 self.shortcut = nn.Sequential(30 nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),31 nn.BatchNorm2d(self.expansion * out_channels)32 )3334 def forward(self, x):35 # 主路径前向传播36 out = F.relu(self.bn1(self.conv1(x)))37 out = self.bn2(self.conv2(out))3839 # 加上快捷连接的输出40 out += self.shortcut(x)4142 # 最后应用ReLU激活函数43 out = F.relu(out)44 return out4546class Bottleneck(nn.Module):47 """48 ResNet-50/101/152 使用的瓶颈残差块49 结构: 1x1 Conv -> 3x3 Conv -> 1x1 Conv50 """51 expansion = 4 # 瓶颈块的输出通道数是输入的4倍5253 def __init__(self, in_channels, out_channels, stride=1):54 super(Bottleneck, self).__init__()5556 # 为什么叫瓶颈?57 # 1x1卷积先将通道数从 in_channels 降到 out_channels(瓶颈),58 # 然后3x3卷积在更小的通道维度上进行,减少计算量,59 # 最后1x1卷积再将通道数恢复并扩展到 out_channels * self.expansion。60 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)61 self.bn1 = nn.BatchNorm2d(out_channels)62 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)63 self.bn2 = nn.BatchNorm2d(out_channels)64 self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size=1, bias=False)65 self.bn3 = nn.BatchNorm2d(self.expansion * out_channels)6667 self.shortcut = nn.Sequential()68 if stride != 1 or in_channels != self.expansion * out_channels:69 self.shortcut = nn.Sequential(70 nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),71 nn.BatchNorm2d(self.expansion * out_channels)72 )7374 def forward(self, x):75 out = F.relu(self.bn1(self.conv1(x)))76 out = F.relu(self.bn2(self.conv2(out)))77 out = self.bn3(self.conv3(out))78 out += self.shortcut(x)79 out = F.relu(out)80 return out8182# --- 示例使用 ---83if __name__ == '__main__':84 # 模拟一个输入张量 (Batch, Channels, Height, Width)85 input_tensor = torch.randn(64, 64, 56, 56)8687 # 1. BasicBlock 示例88 # 场景: 维度不变89 basic_block_same_dim = BasicBlock(in_channels=64, out_channels=64, stride=1)90 output_tensor_basic_same = basic_block_same_dim(input_tensor)91 print(f"BasicBlock (维度不变) 输入尺寸: {input_tensor.shape}")92 print(f"BasicBlock (维度不变) 输出尺寸: {output_tensor_basic_same.shape}\n")9394 # 场景: 维度变化 (尺寸减半,通道数加倍)95 basic_block_diff_dim = BasicBlock(in_channels=64, out_channels=128, stride=2)96 output_tensor_basic_diff = basic_block_diff_dim(input_tensor)97 print(f"BasicBlock (维度变化) 输入尺寸: {input_tensor.shape}")98 print(f"BasicBlock (维度变化) 输出尺寸: {output_tensor_basic_diff.shape}\n")99100 # 2. Bottleneck 示例101 input_tensor_bottle = torch.randn(64, 256, 56, 56)102 # 场景: 维度变化 (ResNet-50中常见的stage转换)103 # 输入256通道,瓶颈内部64通道,输出256通道,但尺寸减半104 bottleneck_block = Bottleneck(in_channels=256, out_channels=64, stride=2)105 output_tensor_bottle = bottleneck_block(input_tensor_bottle)106 print(f"Bottleneck (维度变化) 输入尺寸: {input_tensor_bottle.shape}")107 print(f"Bottleneck (维度变化) 输出尺寸: {output_tensor_bottle.shape}")108 # 注意输出通道数是 out_channels * expansion = 64 * 4 = 256
工程实践
- 使用场景: ResNet 及其变体(如 ResNeXt, Wide ResNet)是计算机视觉领域应用最广泛的骨干网络。几乎所有的现代图像分类、目标检测(如 Faster R-CNN, YOLOv3+)、语义分割(如 DeepLabV3+)等模型的特征提取部分都基于 ResNet。
- 超参数选择:
- 深度: ResNet-18/34/50/101/152 是标准选项。ResNet-50 在性能和效率之间取得了很好的平衡,是许多研究和项目的首选基线。ResNet-18/34 更快,适用于对速度要求高的场景。ResNet-101/152 提供更高的精度,但计算成本也更高。
- Bottleneck vs. BasicBlock: 对于 ResNet-50 及更深的网络,必须使用 Bottleneck 结构。它通过 1x1 卷积来降低和恢复通道数,极大地减少了参数量和计算量,使得训练深层网络成为可能。
- 预训练权重: 在实际项目中,几乎总是使用在 ImageNet 上预训练的 ResNet 权重进行微调(Fine-tuning),这能显著加快收敛速度并提高模型性能,尤其是在目标数据集较小的情况下。
- 性能/显存/吞吐的权衡:
- 深度 vs. 宽度: 增加深度(如从 ResNet-50 到 ResNet-101)可以提升精度,但收益会递减,且推理速度变慢。增加宽度(增加通道数,即 Wide ResNet)有时比增加深度更有效。
- 输入分辨率: 提高输入图像的分辨率通常能提升小目标的检测精度,但会急剧增加显存占用和计算量。
- Batch Size: 在显存允许的情况下,使用较大的 Batch Size 配合
BatchNorm通常能获得更稳定和快速的训练。
- 常见坑和调试技巧:
- 维度不匹配: 最常见的 bug 是在
out += self.shortcut(x)处发生运行时错误。务必仔细检查stride和通道数变化时,self.shortcut是否被正确定义和使用。 - Pre-activation vs. Post-activation: 原始论文使用的是 post-activation(ReLU 在相加之后)。后续研究《Identity Mappings in Deep Residual Networks》提出了 pre-activation(BN 和 ReLU 移到卷积之前),理论上能提供更“干净”的梯度路径,有时能带来微小的性能提升。在实践中,两者效果相差不大,但 pre-activation 的正则化效果可能更好。
- 冻结
BatchNorm: 在微调时,如果下游任务的数据集很小,冻结预训练 ResNet 的BatchNorm层的统计数据(均值和方差)通常是一个好主意,因为小批量数据可能导致统计数据不稳定。
- 维度不匹配: 最常见的 bug 是在
常见误区与边界情况
-
误区一:ResNet 彻底解决了梯度消失问题。 辨析: ResNet 并非彻底“解决”,而是极大地“缓解”了梯度消失。从公式 可以看出,虽然
+1项保证了梯度流的下限,但如果一系列残差块的 累加起来是一个很大的负数(例如-1),梯度依然可能消失。但这种情况在实践中很少发生,因为权重初始化和BN的存在使得网络倾向于学习平滑的函数。正确的说法是 ResNet 提供了梯度传播的“高速公路”,避免了梯度在深度传播中因连乘效应而快速衰减。 -
误区二:退化问题就是过拟合。 辨析: 这是两个完全不同的概念。过拟合指模型在训练集上表现好,但在测试集上表现差。退化指模型在训练集上的性能随着网络深度的增加而下降。退化问题表明,更深的模型甚至无法学习到与较浅模型相当的解,说明优化过程本身遇到了困难。
-
误区三:快捷连接必须是恒等映射。 辨析: 只有在输入和输出维度完全相同时,快捷连接才是恒等映射。当维度发生变化时(如通道数增加或空间尺寸减小),必须使用一个线性投影(通常是 1x1 卷积)来匹配维度,如代码实现中所示。这个投影也是网络可学习的一部分。
-
边界情况与面试追问:
- 问: 如果把 ResNet 的激活函数 ReLU 换成 Sigmoid 或 Tanh 会怎么样?
答: Sigmoid 和 Tanh 存在饱和区,它们的导数在输入绝对值较大时会趋近于0。这会使得残差路径 的梯度 更容易消失,从而削弱了残差学习的优势。网络将更加依赖于
+1的恒等路径,学习能力会受限。ReLU 在正区间的导数恒为1,不存在饱和问题,因此与残差结构配合得更好。 - 问: ResNet 和 Highway Network 有什么区别? 答: Highway Network 提出了一个更通用的门控机制:,其中 和 是“变换门”和“携带门”,它们的和为1。ResNet 可以看作是 Highway Network 的一个简化特例,其中 恒为1, 恒为1(在维度匹配时)。ResNet 的设计更简洁,且实践证明非常有效。
- 问: 为什么 ResNet-50 要用 Bottleneck 结构? 答: 为了计算效率。一个标准的基础块(两个3x3卷积),从256通道到256通道,参数量约为 。而一个 Bottleneck 块,从256通道 -> 64通道 -> 256通道,参数量为 。参数量和计算量都大大减少,使得构建更深的网络成为可能。
- 问: 如果把 ResNet 的激活函数 ReLU 换成 Sigmoid 或 Tanh 会怎么样?
答: Sigmoid 和 Tanh 存在饱和区,它们的导数在输入绝对值较大时会趋近于0。这会使得残差路径 的梯度 更容易消失,从而削弱了残差学习的优势。网络将更加依赖于