§1.3.27

AdamW 与 Adam 的 weight decay 差异?

手写练习
  • 手写 AdamW 一步更新

核心概念

Adam (Adaptive Moment Estimation) 中的 weight_decay 参数实际上实现的是 L2 正则化。这意味着正则化项被加入到损失函数的梯度中,然后才进入动量和自适应学习率的计算流程。这种方式导致权重衰减的效果与梯度的历史幅度相耦合,使得权重较大且梯度也一直较大的参数,其有效的权重衰减会变小。

AdamW (Adam with Decoupled Weight Decay) 修正了这个问题。它将权重衰减步骤与梯度更新步骤解耦。具体来说,AdamW 在计算完基于梯度的更新量之后,独立地对权重进行衰减。这使得权重衰减成为一个与优化器动态无关的、更直接和可控的正则化手段,更符合其最初在 SGD 中提出的“权重衰减”的本意。

原理与推导

为了理解二者的差异,我们首先回顾 L2 正则化和权重衰减在标准 SGD 中的等价性,然后分析在 Adam 中为何它们不再等价。

设损失函数为 L(θ)L(\theta),L2 正则化项为 λ2θ2\frac{\lambda}{2} ||\theta||^2,其中 λ\lambda 是正则化系数。

1. SGD 中的 L2 正则化与权重衰减

  • L2 正则化:总损失为 Ltotal=L(θ)+λ2θt2L_{total} = L(\theta) + \frac{\lambda}{2} ||\theta_t||^2。 梯度为: gt=θtLtotal=θtL(θ)+λθtg_t = \nabla_{\theta_t} L_{total} = \nabla_{\theta_t} L(\theta) + \lambda \theta_t SGD 更新规则为: θt+1=θtηgt=θtη(θtL(θ)+λθt)=(1ηλ)θtηθtL(θ)\theta_{t+1} = \theta_t - \eta g_t = \theta_t - \eta (\nabla_{\theta_t} L(\theta) + \lambda \theta_t) = (1 - \eta \lambda) \theta_t - \eta \nabla_{\theta_t} L(\theta) 这里的 (1ηλ)(1 - \eta \lambda) 项就是权重衰减 (Weight Decay)。它在每次更新时都将权重向量向原点收缩一个固定的比例。

2. Adam 中的 L2 正则化 (标准 Adam 的 weight_decay)

Adam 维护了梯度的一阶矩(动量)mtm_t 和二阶矩(非中心方差)vtv_t

  • 梯度计算:与 SGD 类似,将 L2 正则项加入梯度。 gt=θtL(θ)+λθtg_t = \nabla_{\theta_t} L(\theta) + \lambda \theta_t
  • 动量更新mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2
  • 偏差修正m^t=mt1β1t\hat{m}_t = \frac{m_t}{1 - \beta_1^t} v^t=vt1β2t\hat{v}_t = \frac{v_t}{1 - \beta_2^t}
  • 参数更新θt+1=θtηm^tv^t+ϵ\theta_{t+1} = \theta_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

问题所在:请注意,gtg_t 中包含了正则化项 λθt\lambda \theta_t。这个项会影响 mtm_tvtv_t 的累积。特别是 vtv_t,它累积了梯度的平方。如果某个权重 θt,i\theta_{t,i} 本身很大,或者其数据梯度 θt,iL(θ)\nabla_{\theta_{t,i}} L(\theta) 很大,那么 gt,ig_{t,i} 就会很大,导致 vt,iv_{t,i} 持续增大。在最终的更新步骤中,分母 v^t\sqrt{\hat{v}_t} 会变大,从而使得整个更新步长(包括来自 λθt\lambda \theta_t 的部分)变小。

直观解释:对于那些具有较大历史梯度的权重,Adam 会减小其有效学习率。当 L2 正则化项与梯度耦合时,这些权重的有效权重衰减也被减小了。这与我们的初衷相悖:我们希望权重衰减能有力地惩罚大权重,但在这里,大权重反而可能因为历史梯度大而获得较小的衰减惩罚。

3. AdamW 中的解耦权重衰减

AdamW 的核心思想是,将权重衰减从梯度更新中分离出来。

  • 梯度计算:只使用数据本身的梯度。 gt=θtL(θ)g_t = \nabla_{\theta_t} L(\theta)
  • Adam 核心更新:使用 gtg_t 执行标准的 Adam 更新步骤,计算出纯粹由数据梯度驱动的更新量 Δθt\Delta \theta_tmt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 m^t=mt1β1t,v^t=vt1β2t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} Adam_update=ηm^tv^t+ϵ\text{Adam\_update} = \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
  • 参数更新:分为两步,先进行梯度更新,再进行权重衰减。 θt+1=θtAdam_updateηλθt\theta_{t+1} = \theta_t - \text{Adam\_update} - \eta \lambda' \theta_t 为了与 PyTorch 等框架的实现保持一致,通常将权重衰减写成一个乘法因子,并与学习率 η\eta 分开(尽管在 PyTorch 的 AdamW 实现中,衰减率实际上是 lr * weight_decay)。一个更清晰的表述是: θt+1=(1λdecay)θtηm^tv^t+ϵ\theta_{t+1} = (1 - \lambda_{decay}) \theta_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} 其中 λdecay\lambda_{decay} 是每一步的衰减率。在 PyTorch 的 AdamW 中,这个衰减是在更新步骤之前完成的,等效于 θt+1=θtη(m^tv^t+ϵ+λθt)\theta_{t+1} = \theta_t - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t \right),其中 λ\lambda 是用户传入的 weight_decay。这清晰地展示了衰减项 λθt\lambda \theta_t 被加在了自适应项的外部,因此它的大小只与 θt\theta_t 本身有关,而不再受到 v^t\hat{v}_t 的影响。

复杂度分析:Adam 和 AdamW 的时间复杂度和空间复杂度是相同的。它们都需要为每个参数额外存储一阶矩 mm 和二阶矩 vv。因此空间复杂度为 O(N)O(N),其中 NN 是模型参数量。每一步的计算也都是对参数的逐元素操作,时间复杂度为 O(N)O(N)。AdamW 只是增加了一个额外的逐元素乘法和减法,计算开销的增加可以忽略不计。

代码实现

下面是一个使用 NumPy 手写 AdamW 单步更新的函数,并与 Adam 的行为进行对比。

python
1import numpy as np
2
3def adamw_step(params, grads, states, lr, betas, weight_decay, eps):
4 """
5 执行一步 AdamW 更新。
6
7 Args:
8 params (list of np.ndarray): 模型参数列表
9 grads (list of np.ndarray): 对应参数的梯度列表
10 states (dict): 优化器的状态,包含 't', 'm', 'v'
11 lr (float): 学习率
12 betas (tuple): (beta1, beta2)
13 weight_decay (float): 权重衰减系数
14 eps (float): 防止除以零的小值
15 """
16 beta1, beta2 = betas
17 t = states['t']
18 m_list = states['m']
19 v_list = states['v']
20
21 # 时间步加一
22 t += 1
23 states['t'] = t
24
25 # 计算偏差修正系数
26 beta1_t = beta1 ** t
27 beta2_t = beta2 ** t
28
29 for i, (p, g) in enumerate(zip(params, grads)):
30 m, v = m_list[i], v_list[i]
31
32 # 关键区别 1: AdamW 直接对参数进行衰减
33 # 这步操作等价于在梯度更新后减去 lr * weight_decay * p
34 # 这样做更符合 "decoupled" 的思想,即衰减与梯度更新分离
35 if weight_decay > 0:
36 p *= (1.0 - lr * weight_decay)
37
38 # Adam 的核心更新逻辑
39 # 更新一阶矩(动量)
40 m = beta1 * m + (1 - beta1) * g
41 # 更新二阶矩(非中心方差)
42 v = beta2 * v + (1 - beta2) * (g ** 2)
43
44 # 保存更新后的状态
45 m_list[i], v_list[i] = m, v
46
47 # 计算偏差修正后的一阶和二阶矩
48 m_hat = m / (1 - beta1_t)
49 v_hat = v / (1 - beta2_t)
50
51 # 执行参数更新
52 update = lr * m_hat / (np.sqrt(v_hat) + eps)
53 p -= update
54
55# --- 示例使用 ---
56# 初始化参数、梯度和优化器状态
57np.random.seed(42)
58params = [np.random.randn(10, 5).astype(np.float32)]
59grads = [np.random.randn(10, 5).astype(np.float32)]
60states = {
61 't': 0,
62 'm': [np.zeros_like(p) for p in params],
63 'v': [np.zeros_like(p) for p in params],
64}
65
66# 超参数
67lr = 1e-3
68betas = (0.9, 0.999)
69weight_decay = 0.01
70eps = 1e-8
71
72# 复制一份用于对比
73params_adam = [p.copy() for p in params]
74grads_adam = [g.copy() for g in grads]
75states_adam = {
76 't': 0,
77 'm': [np.zeros_like(p) for p in params],
78 'v': [np.zeros_like(p) for p in params],
79}
80
81print("原始参数 (部分):")
82print(params[0][0, :3])
83print("-" * 20)
84
85# 执行 AdamW 更新
86adamw_step(params, grads, states, lr, betas, weight_decay, eps)
87print("AdamW 更新后参数 (部分):")
88print(params[0][0, :3])
89print("-" * 20)
90
91
92# --- 对比 Adam 的实现 ---
93# 在 Adam 中,weight_decay 是 L2 正则化,需要加到梯度里
94grads_adam[0] += weight_decay * params_adam[0]
95
96# 使用相同的函数,但将 weight_decay 设为 0,因为其效果已经包含在梯度里
97adamw_step(params_adam, grads_adam, states_adam, lr, betas, weight_decay=0, eps=eps)
98print("Adam (L2) 更新后参数 (部分):")
99print(params_adam[0][0, :3])
100print("-" * 20)
101
102# 可以看到,AdamW 和 Adam (L2) 的更新结果是不同的
103assert not np.allclose(params[0], params_adam[0])
104print("验证:AdamW 和 Adam 的更新结果不同。")

工程实践

  • 使用场景:AdamW 已成为现代深度学习,特别是自然语言处理领域(如 Transformers 模型:BERT, GPT 等)的事实标准优化器。它几乎在所有场景下都优于或等于标准 Adam。对于计算机视觉任务,AdamW 也被广泛采用并取得了 SOTA 效果。除非有特殊理由,否则在需要使用 Adam 的地方,都应优先考虑 AdamW。

  • 超参数选择

    • weight_decay: 这是 AdamW 最需要仔细调整的超参数。由于其效果是解耦且稳定的,通常可以设置比 Adam 中大得多的值。例如,在 Adam 中 weight_decay (L2 正则) 常见值为 1e-51e-4,而在 AdamW 中,0.010.1 都是常见且有效的值。
    • learning_rate: 最佳学习率可能与 Adam 不同。从 Adam 切换到 AdamW 时,通常需要重新搜索最佳学习率和 weight_decay 的组合。
    • betas, eps: 通常使用默认值 (0.9, 0.999)1e-8 即可,很少需要调整。
  • 性能/显存/吞吐:AdamW 相对于 Adam 的计算和显存开销增加微乎其微,可以忽略不计。它不会成为训练瓶颈。

  • 常见坑和调试技巧

    • 从 Adam 迁移:直接将 torch.optim.Adam 替换为 torch.optim.AdamW 时,不要沿用旧的 weight_decay 值。旧的值(如 1e-5)对于 AdamW 来说太小,可能起不到有效的正则化作用。需要根据模型和任务重新调整 weight_decay
    • 调试正则化效果:如果模型过拟合,应首先考虑增大 weight_decay。如果模型欠拟合,可以尝试减小 weight_decay。由于 AdamW 的解耦特性,weight_decay 成为了一个非常可靠和直观的正则化调节旋钮。
    • 与学习率调度器结合:权重衰减的效果与学习率有关(在 PyTorch 实现中)。当使用学习率预热(warmup)时,在预热阶段,小的学习率会导致小的权重衰减效应。一些研究和实现会考虑在预热阶段固定或独立调度权重衰减,但这在标准库中不常见。

常见误区与边界情况

  • 误区一:Adam(weight_decay=...) 就是带权重衰减的 Adam 这是最核心的误区。如上文推导,Adam 优化器中的 weight_decay 参数实现的是 L2 正则化,其效果与梯度耦合,不等同于真正的权重衰减。而 AdamW 才实现了正确的解耦权重衰减。

  • 误区二:AdamW 是一个全新的、复杂的优化器 不是。AdamW 只是对 Adam 的一个简单但关键的修正。其核心的自适应动量更新机制与 Adam 完全相同。可以将其视为 Adam 的一个 "bug fix" 版本。

  • 边界情况:weight_decay = 0weight_decay 设置为 0 时,Adam 和 AdamW 的行为是完全一样的。

  • 面试追问 1:为什么解耦权重衰减对 Transformer 这类模型特别重要? 回答要点Transformer 模型参数的梯度具有多样性和稀疏性。例如,嵌入层(Embedding)的参数可能只在特定 token 出现时才获得梯度,而自注意力(Self-Attention)和前馈网络(FFN)的权重则可能在每一步都更新。在 Adam 中,对于那些不经常更新但可能需要强正则化的参数(如某些 embedding 权重),它们的二阶矩 vtv_t 会很小,导致有效学习率很大。而对于频繁更新的参数(如 FFN),vtv_t 会很大,导致有效学习率变小。耦合的 L2 正则化会使得前者的正则化效果被不恰当地放大,后者的正则化效果被不恰当地缩小。AdamW 的解耦衰减对所有参数施加了统一的衰减率(乘以当前权重值),与它们的梯度历史无关,这更加公平和稳定,从而获得了更好的泛化性能。

  • 面试追问 2:在 PyTorch 中,AdamAdamWweight_decay 参数是如何工作的? 回答要点

    • torch.optim.Adam(params, weight_decay=wd):在内部计算梯度后,会执行 grad = grad + p.data * wd,将 L2 正则化项加到梯度上。
    • torch.optim.AdamW(params, weight_decay=wd):在内部,参数更新前会执行一步 p.data.mul_(1.0 - lr * wd)(伪代码,实际实现更复杂以处理动量),将权重衰减直接作用于参数。这清晰地体现了二者的实现差异。面试时能说出这一点会非常加分。
相关题目