AdamW 与 Adam 的 weight decay 差异?
- —手写 AdamW 一步更新
核心概念
Adam (Adaptive Moment Estimation) 中的 weight_decay 参数实际上实现的是 L2 正则化。这意味着正则化项被加入到损失函数的梯度中,然后才进入动量和自适应学习率的计算流程。这种方式导致权重衰减的效果与梯度的历史幅度相耦合,使得权重较大且梯度也一直较大的参数,其有效的权重衰减会变小。
AdamW (Adam with Decoupled Weight Decay) 修正了这个问题。它将权重衰减步骤与梯度更新步骤解耦。具体来说,AdamW 在计算完基于梯度的更新量之后,独立地对权重进行衰减。这使得权重衰减成为一个与优化器动态无关的、更直接和可控的正则化手段,更符合其最初在 SGD 中提出的“权重衰减”的本意。
原理与推导
为了理解二者的差异,我们首先回顾 L2 正则化和权重衰减在标准 SGD 中的等价性,然后分析在 Adam 中为何它们不再等价。
设损失函数为 ,L2 正则化项为 ,其中 是正则化系数。
1. SGD 中的 L2 正则化与权重衰减
- L2 正则化:总损失为 。 梯度为: SGD 更新规则为: 这里的 项就是权重衰减 (Weight Decay)。它在每次更新时都将权重向量向原点收缩一个固定的比例。
2. Adam 中的 L2 正则化 (标准 Adam 的 weight_decay)
Adam 维护了梯度的一阶矩(动量) 和二阶矩(非中心方差)。
- 梯度计算:与 SGD 类似,将 L2 正则项加入梯度。
- 动量更新:
- 偏差修正:
- 参数更新:
问题所在:请注意, 中包含了正则化项 。这个项会影响 和 的累积。特别是 ,它累积了梯度的平方。如果某个权重 本身很大,或者其数据梯度 很大,那么 就会很大,导致 持续增大。在最终的更新步骤中,分母 会变大,从而使得整个更新步长(包括来自 的部分)变小。
直观解释:对于那些具有较大历史梯度的权重,Adam 会减小其有效学习率。当 L2 正则化项与梯度耦合时,这些权重的有效权重衰减也被减小了。这与我们的初衷相悖:我们希望权重衰减能有力地惩罚大权重,但在这里,大权重反而可能因为历史梯度大而获得较小的衰减惩罚。
3. AdamW 中的解耦权重衰减
AdamW 的核心思想是,将权重衰减从梯度更新中分离出来。
- 梯度计算:只使用数据本身的梯度。
- Adam 核心更新:使用 执行标准的 Adam 更新步骤,计算出纯粹由数据梯度驱动的更新量 。
- 参数更新:分为两步,先进行梯度更新,再进行权重衰减。
为了与 PyTorch 等框架的实现保持一致,通常将权重衰减写成一个乘法因子,并与学习率 分开(尽管在 PyTorch 的
AdamW实现中,衰减率实际上是lr * weight_decay)。一个更清晰的表述是: 其中 是每一步的衰减率。在 PyTorch 的AdamW中,这个衰减是在更新步骤之前完成的,等效于 ,其中 是用户传入的weight_decay。这清晰地展示了衰减项 被加在了自适应项的外部,因此它的大小只与 本身有关,而不再受到 的影响。
复杂度分析:Adam 和 AdamW 的时间复杂度和空间复杂度是相同的。它们都需要为每个参数额外存储一阶矩 和二阶矩 。因此空间复杂度为 ,其中 是模型参数量。每一步的计算也都是对参数的逐元素操作,时间复杂度为 。AdamW 只是增加了一个额外的逐元素乘法和减法,计算开销的增加可以忽略不计。
代码实现
下面是一个使用 NumPy 手写 AdamW 单步更新的函数,并与 Adam 的行为进行对比。
1import numpy as np23def adamw_step(params, grads, states, lr, betas, weight_decay, eps):4 """5 执行一步 AdamW 更新。67 Args:8 params (list of np.ndarray): 模型参数列表9 grads (list of np.ndarray): 对应参数的梯度列表10 states (dict): 优化器的状态,包含 't', 'm', 'v'11 lr (float): 学习率12 betas (tuple): (beta1, beta2)13 weight_decay (float): 权重衰减系数14 eps (float): 防止除以零的小值15 """16 beta1, beta2 = betas17 t = states['t']18 m_list = states['m']19 v_list = states['v']2021 # 时间步加一22 t += 123 states['t'] = t2425 # 计算偏差修正系数26 beta1_t = beta1 ** t27 beta2_t = beta2 ** t2829 for i, (p, g) in enumerate(zip(params, grads)):30 m, v = m_list[i], v_list[i]3132 # 关键区别 1: AdamW 直接对参数进行衰减33 # 这步操作等价于在梯度更新后减去 lr * weight_decay * p34 # 这样做更符合 "decoupled" 的思想,即衰减与梯度更新分离35 if weight_decay > 0:36 p *= (1.0 - lr * weight_decay)3738 # Adam 的核心更新逻辑39 # 更新一阶矩(动量)40 m = beta1 * m + (1 - beta1) * g41 # 更新二阶矩(非中心方差)42 v = beta2 * v + (1 - beta2) * (g ** 2)4344 # 保存更新后的状态45 m_list[i], v_list[i] = m, v4647 # 计算偏差修正后的一阶和二阶矩48 m_hat = m / (1 - beta1_t)49 v_hat = v / (1 - beta2_t)5051 # 执行参数更新52 update = lr * m_hat / (np.sqrt(v_hat) + eps)53 p -= update5455# --- 示例使用 ---56# 初始化参数、梯度和优化器状态57np.random.seed(42)58params = [np.random.randn(10, 5).astype(np.float32)]59grads = [np.random.randn(10, 5).astype(np.float32)]60states = {61 't': 0,62 'm': [np.zeros_like(p) for p in params],63 'v': [np.zeros_like(p) for p in params],64}6566# 超参数67lr = 1e-368betas = (0.9, 0.999)69weight_decay = 0.0170eps = 1e-87172# 复制一份用于对比73params_adam = [p.copy() for p in params]74grads_adam = [g.copy() for g in grads]75states_adam = {76 't': 0,77 'm': [np.zeros_like(p) for p in params],78 'v': [np.zeros_like(p) for p in params],79}8081print("原始参数 (部分):")82print(params[0][0, :3])83print("-" * 20)8485# 执行 AdamW 更新86adamw_step(params, grads, states, lr, betas, weight_decay, eps)87print("AdamW 更新后参数 (部分):")88print(params[0][0, :3])89print("-" * 20)909192# --- 对比 Adam 的实现 ---93# 在 Adam 中,weight_decay 是 L2 正则化,需要加到梯度里94grads_adam[0] += weight_decay * params_adam[0]9596# 使用相同的函数,但将 weight_decay 设为 0,因为其效果已经包含在梯度里97adamw_step(params_adam, grads_adam, states_adam, lr, betas, weight_decay=0, eps=eps)98print("Adam (L2) 更新后参数 (部分):")99print(params_adam[0][0, :3])100print("-" * 20)101102# 可以看到,AdamW 和 Adam (L2) 的更新结果是不同的103assert not np.allclose(params[0], params_adam[0])104print("验证:AdamW 和 Adam 的更新结果不同。")
工程实践
-
使用场景:AdamW 已成为现代深度学习,特别是自然语言处理领域(如 Transformers 模型:
BERT, GPT 等)的事实标准优化器。它几乎在所有场景下都优于或等于标准 Adam。对于计算机视觉任务,AdamW 也被广泛采用并取得了 SOTA 效果。除非有特殊理由,否则在需要使用 Adam 的地方,都应优先考虑 AdamW。 -
超参数选择:
weight_decay: 这是 AdamW 最需要仔细调整的超参数。由于其效果是解耦且稳定的,通常可以设置比 Adam 中大得多的值。例如,在 Adam 中weight_decay(L2 正则) 常见值为1e-5到1e-4,而在 AdamW 中,0.01到0.1都是常见且有效的值。learning_rate: 最佳学习率可能与 Adam 不同。从 Adam 切换到 AdamW 时,通常需要重新搜索最佳学习率和weight_decay的组合。betas,eps: 通常使用默认值(0.9, 0.999)和1e-8即可,很少需要调整。
-
性能/显存/吞吐:AdamW 相对于 Adam 的计算和显存开销增加微乎其微,可以忽略不计。它不会成为训练瓶颈。
-
常见坑和调试技巧:
- 从 Adam 迁移:直接将
torch.optim.Adam替换为torch.optim.AdamW时,不要沿用旧的weight_decay值。旧的值(如1e-5)对于 AdamW 来说太小,可能起不到有效的正则化作用。需要根据模型和任务重新调整weight_decay。 - 调试正则化效果:如果模型过拟合,应首先考虑增大
weight_decay。如果模型欠拟合,可以尝试减小weight_decay。由于 AdamW 的解耦特性,weight_decay成为了一个非常可靠和直观的正则化调节旋钮。 - 与学习率调度器结合:权重衰减的效果与学习率有关(在 PyTorch 实现中)。当使用学习率预热(warmup)时,在预热阶段,小的学习率会导致小的权重衰减效应。一些研究和实现会考虑在预热阶段固定或独立调度权重衰减,但这在标准库中不常见。
- 从 Adam 迁移:直接将
常见误区与边界情况
-
误区一:
Adam(weight_decay=...)就是带权重衰减的 Adam 这是最核心的误区。如上文推导,Adam优化器中的weight_decay参数实现的是 L2 正则化,其效果与梯度耦合,不等同于真正的权重衰减。而AdamW才实现了正确的解耦权重衰减。 -
误区二:AdamW 是一个全新的、复杂的优化器 不是。AdamW 只是对 Adam 的一个简单但关键的修正。其核心的自适应动量更新机制与 Adam 完全相同。可以将其视为 Adam 的一个 "bug fix" 版本。
-
边界情况:
weight_decay = 0当weight_decay设置为 0 时,Adam 和 AdamW 的行为是完全一样的。 -
面试追问 1:为什么解耦权重衰减对
Transformer这类模型特别重要? 回答要点:Transformer模型参数的梯度具有多样性和稀疏性。例如,嵌入层(Embedding)的参数可能只在特定 token 出现时才获得梯度,而自注意力(Self-Attention)和前馈网络(FFN)的权重则可能在每一步都更新。在 Adam 中,对于那些不经常更新但可能需要强正则化的参数(如某些 embedding 权重),它们的二阶矩 会很小,导致有效学习率很大。而对于频繁更新的参数(如 FFN), 会很大,导致有效学习率变小。耦合的 L2 正则化会使得前者的正则化效果被不恰当地放大,后者的正则化效果被不恰当地缩小。AdamW 的解耦衰减对所有参数施加了统一的衰减率(乘以当前权重值),与它们的梯度历史无关,这更加公平和稳定,从而获得了更好的泛化性能。 -
面试追问 2:在 PyTorch 中,
Adam和AdamW的weight_decay参数是如何工作的? 回答要点:torch.optim.Adam(params, weight_decay=wd):在内部计算梯度后,会执行grad = grad + p.data * wd,将 L2 正则化项加到梯度上。torch.optim.AdamW(params, weight_decay=wd):在内部,参数更新前会执行一步p.data.mul_(1.0 - lr * wd)(伪代码,实际实现更复杂以处理动量),将权重衰减直接作用于参数。这清晰地体现了二者的实现差异。面试时能说出这一点会非常加分。