§1.1.10

凸函数判定与凸优化?深度学习为什么是非凸?

核心概念

凸函数(Convex Function)是指其函数图像上任意两点的连线段,都位于函数图像的上方或与之重合的函数。严格凸函数则要求连线段(除端点外)严格位于函数图像上方。从几何上看,凸函数的图形呈现“碗状”。

凸优化(Convex Optimization)是在约束为凸集、目标函数为凸函数的前提下求解最优化问题。凸优化问题最显著的特点是:任何局部最优解都是全局最优解。这使得其在理论分析和数值求解上都比非凸优化问题简单得多。

深度学习的损失函数之所以是非凸的,是因为它通常是关于模型权重的高度非线性、高维度的复合函数。尽管损失函数本身(如交叉熵)对于模型的直接输出可能是凸的,但由于神经网络的非线性激活函数和多层结构,导致从模型权重到最终损失的映射关系极其复杂,形成了包含大量局部极小值、鞍点和宽阔平坦区域的非凸景观。

原理与推导

判定一个函数 ff 是否为凸函数,主要有三种方法,其严格程度和适用性逐层递增。

1. 零阶条件:定义法 (Jensen's Inequality)

对于函数定义域内的任意两点 x1,x2x_1, x_2 和任意 θ[0,1]\theta \in [0, 1],如果满足:

f(θx1+(1θ)x2)θf(x1)+(1θ)f(x2)f(\theta x_1 + (1-\theta) x_2) \le \theta f(x_1) + (1-\theta) f(x_2)

则函数 ff 是凸函数。

直观解释:这个不等式描述了“弦在图上”的几何性质。函数在两点之间线性插值(弦)的值,不小于这两点函数值的线性插值。

2. 一阶条件:梯度/切线法

如果函数 ff 可微,则 ff 是凸函数的充要条件是,对于其定义域内任意两点 x,yx, y,满足:

f(y)f(x)+f(x)T(yx)f(y) \ge f(x) + \nabla f(x)^T (y-x)

直观解释:函数图像始终位于其任意一点的切线(或切超平面)的上方。该切线是函数的一个全局下界支撑。

3. 二阶条件:Hessian 矩阵法

如果函数 ff 二阶可微,则 ff 是凸函数的充要条件是,其 Hessian 矩阵 2f(x)\nabla^2 f(x) 在整个定义域内是半正定的(Positive Semidefinite, PSD)。

2f(x)0\nabla^2 f(x) \succeq 0

一个矩阵是半正定的,意味着对于任意非零向量 zz,都有 zT(2f(x))z0z^T (\nabla^2 f(x)) z \ge 0。对于一维函数,这退化为二阶导数 f(x)0f''(x) \ge 0

算法复杂度:对于一个 nn 维变量的函数,Hessian 矩阵的大小为 n×nn \times n。计算 Hessian 矩阵需要 O(n2)O(n^2) 的代价。判断一个矩阵是否为半正定(例如,通过计算所有特征值)通常需要 O(n3)O(n^3) 的代价。对于深度学习模型,参数量 nn 可达数百万甚至数十亿,计算和存储 Hessian 矩阵是完全不可行的。

深度学习为什么是非凸的?

深度学习的损失函数 L(W)L(W) 是关于模型权重 WW 的函数。我们可以将其看作一个复合函数: L(W)=Loss(NN(X;W),Y)L(W) = \text{Loss}(\text{NN}(X; W), Y)

  1. 非线性激活函数:像 ReLU, Sigmoid, Tanh 等激活函数是解决线性模型表达能力不足的关键,但它们也引入了非线性。一个简单的带 ReLU 的单神经元网络 f(w,x)=ReLU(wx)f(w, x) = \text{ReLU}(wx) 的损失函数(如MSE)关于权重 ww 就已经是非凸的了。

  2. 网络层级复合:深度网络是多层非线性函数的复合,即 fL(f2(f1(X;W1);W2);WL)f_L(\dots f_2(f_1(X; W_1); W_2) \dots; W_L)。即使每一层都是一个简单的函数,它们的深度复合也会产生极其复杂的函数地形。

  3. 对称性导致的非凸性(权重空间对称性):这是最直观的解释。考虑一个简单的多层感知机(MLP)。如果你训练好了一个网络,得到一组权重 WW^*。现在,将第一个隐藏层中的任意两个神经元(连同它们的所有输入和输出权重)交换位置,你会得到一组新的权重 WW'。显然 WWW^* \neq W',但这个新网络与原网络计算的函数完全相同,因此 L(W)=L(W)L(W^*) = L(W')。 既然存在多个不同的权重配置可以得到相同的最优损失值,那么损失函数必然至少存在多个全局最小值,它不可能是凸的。(一个严格凸函数只有一个全局最小值)。连接 WW^*WW' 的路径上的损失值 L(0.5W+0.5W)L(0.5W^* + 0.5W') 几乎肯定会比 L(W)L(W^*) 大得多,直接违反了凸函数的定义。

代码实现

下面的 PyTorch 代码将演示如何使用一阶和二阶条件来判定函数的凸性。

python
1import torch
2import numpy as np
3import matplotlib.pyplot as plt
4
5# 设置 Matplotlib 支持中文显示
6plt.rcParams['font.sans-serif'] = ['SimHei']
7plt.rcParams['axes.unicode_minus'] = False
8
9# --- 1. 一维函数凸性判定 ---
10print("--- 1. 一维函数凸性判定 ---")
11
12# 定义一个凸函数 f(x) = x^2
13def convex_func(x):
14 return x**2
15
16# 定义一个非凸函数 g(x) = -x^2 * cos(2*pi*x)
17def non_convex_func(x):
18 return -x**2 * torch.cos(2 * np.pi * x)
19
20# a) 一阶条件可视化:函数位于切线上方
21x0 = torch.tensor([-0.5], requires_grad=True)
22y0 = convex_func(x0)
23y0.backward() # 自动计算梯度
24
25x_range = torch.linspace(-2, 2, 100)
26y_range = convex_func(x_range)
27tangent_line = y0.item() + x0.grad.item() * (x_range - x0.item())
28
29plt.figure(figsize=(12, 5))
30plt.subplot(1, 2, 1)
31plt.plot(x_range.numpy(), y_range.numpy(), label='$f(x) = x^2$ (凸函数)')
32plt.plot(x_range.numpy(), tangent_line.numpy(), 'r--', label=f'在 x={x0.item()} 处的切线')
33plt.scatter(x0.item(), y0.item(), c='r')
34plt.title('一阶条件:凸函数在其切线上方')
35plt.legend()
36plt.grid(True)
37
38# b) 二阶条件:Hessian (对于一维即二阶导数)
39x = torch.tensor([2.0], requires_grad=True)
40y = convex_func(x)
41# 计算一阶导数
42grad_x, = torch.autograd.grad(y, x, create_graph=True)
43# 计算二阶导数 (Hessian)
44hessian, = torch.autograd.grad(grad_x, x)
45print(f"凸函数 f(x)=x^2 在 x=2 处的二阶导数: {hessian.item()}")
46# 为什么这样做:二阶导数大于等于0是凸函数的二阶条件。对于 f(x)=x^2, f''(x)=2 >= 0,所以是凸的。
47
48x_non_convex = torch.tensor([0.8], requires_grad=True)
49y_non_convex = non_convex_func(x_non_convex)
50grad_non_convex, = torch.autograd.grad(y_non_convex, x_non_convex, create_graph=True)
51hessian_non_convex, = torch.autograd.grad(grad_non_convex, x_non_convex)
52print(f"非凸函数 g(x) 在 x=0.8 处的二阶导数: {hessian_non_convex.item():.2f}")
53# 为什么这样做:如果能找到某点二阶导数小于0,则函数是非凸的。
54
55plt.subplot(1, 2, 2)
56x_range_non_convex = torch.linspace(-1.5, 1.5, 200)
57plt.plot(x_range_non_convex.numpy(), non_convex_func(x_range_non_convex).detach().numpy(), label='$g(x) = -x^2 \cos(2\pi x)$ (非凸函数)')
58plt.title('一个非凸函数示例')
59plt.legend()
60plt.grid(True)
61plt.show()
62
63
64# --- 2. 二维函数凸性判定 (Hessian矩阵) ---
65print("\n--- 2. 二维函数凸性判定 ---")
66
67def func_2d(x):
68 # 一个凸函数: f(x1, x2) = x1^2 + 2*x2^2
69 return x[0]**2 + 2 * x[1]**2
70
71x = torch.tensor([1.0, 2.0], requires_grad=True)
72
73# 使用 torch.autograd.functional.hessian 计算 Hessian 矩阵
74# 为什么这样做:这是 PyTorch 提供的标准方法,用于精确计算任意可微函数的 Hessian 矩阵。
75hessian_matrix = torch.autograd.functional.hessian(func_2d, x)
76print(f"二维凸函数在点 {x.detach().numpy()} 处的 Hessian 矩阵:\n{hessian_matrix.numpy()}")
77
78# 检查 Hessian 矩阵是否为半正定
79# 为什么这样做:判断 Hessian 矩阵是否半正定是二阶条件的核心。一个实对称矩阵是半正定的,当且仅当它的所有特征值都非负。
80eigenvalues = torch.linalg.eigvalsh(hessian_matrix)
81print(f"Hessian 矩阵的特征值: {eigenvalues.numpy()}")
82if torch.all(eigenvalues >= 0):
83 print("所有特征值非负,Hessian 矩阵是半正定的,函数在该点附近是凸的。")
84else:
85 print("存在负特征值,Hessian 矩阵不是半正定的,函数是非凸的。")

工程实践

  1. 利用凸性:在机器学习领域,一些经典模型是凸优化问题,例如:

    • 线性回归(使用均方误差损失函数)
    • 逻辑回归(使用对数损失函数)
    • 支持向量机 (SVM) 对于这些问题,我们可以使用高效的凸优化算法(如梯度下降、牛顿法、内点法)找到唯一的全局最优解,结果稳定且可复现。
  2. 应对非凸性:深度学习是典型的非凸优化问题。工程实践中,我们不追求找到全局最优解,而是寻找一个“足够好”的局部最优解。

    • 优化器选择:我们不使用需要计算 Hessian 的二阶方法(如牛顿法),而是采用一阶方法,如 SGD (随机梯度下降) 及其变体 Adam, RMSprop, AdaGrad。这些算法计算开销小,并且通过引入随机性(来自 mini-batch)和动量等机制,能够帮助模型跳出不良的局部极小值和逃离鞍点。
    • 超参数调整学习率(Learning Rate) 是最重要的超参数。合适的学习率(以及学习率衰减策略)对于在非凸的损失平面上导航至关重要。太大的学习率可能导致在“山谷”两侧来回震荡无法收敛,太小则可能收敛过慢或陷入离初始点很近的次优解。
    • 初始化策略:权重的初始值决定了优化过程的起点。好的初始化(如 Xavier/Glorot, He 初始化)能将模型置于一个更容易优化的区域,避免梯度消失或爆炸,从而加速收敛。
    • 正则化:L1/L2 正则化、Dropout 等技术会改变损失函数的形状,使其更平滑,有助于泛化,并可能引导优化过程找到更优的解。
  3. 性能权衡

    • 一阶方法 vs. 二阶方法:一阶方法(如 Adam)每次迭代快(只需计算梯度),但可能需要更多次迭代。二阶方法(如牛ton法)收敛快(迭代次数少),但每次迭代极慢(需要计算、存储和求逆 Hessian 矩阵),在深度学习中完全不可行。
    • Batch Size:大的 Batch Size 提供的梯度估计更准,但可能收敛到“尖锐”的最小值,泛化能力可能较差。小的 Batch Size 引入更多噪声,有助于跳出局部最小值,可能收敛到“宽阔”的最小值,泛化能力更好,但训练过程更不稳定。

常见误区与边界情况

  1. 误区:“非凸优化没有理论保证,所以是玄学。” 纠正:虽然非凸优化无法保证找到全局最优,但近年来大量研究表明,高维空间中的非凸优化问题(如深度学习)具有特殊结构。损失平面上的大多数局部最小值都具有相似的、接近全局最优的性能。真正的挑战更多来自于大量的鞍点(Saddle Points),而非“坏”的局部最小值。现代优化器(如 Adam)被设计为能有效逃离鞍点。

  2. 误区:“损失函数(如交叉熵)是凸的,所以深度学习是凸优化。” 纠正:这是一个非常常见的混淆。交叉熵损失函数 L(p,q)=pilog(qi)L(p, q) = -\sum p_i \log(q_i) 对于其直接输入(模型的预测概率分布 qq)是凸的。然而,在神经网络训练中,我们优化的变量是权重 WW,而不是直接优化 qqqq 是一个通过 WW 经过多层非线性变换得到的复杂函数,即 q=softmax(NN(X;W))q = \text{softmax}(\text{NN}(X; W))。一个凸函数和一个非线性函数的复合通常是非凸的。因此,损失 LL 作为权重 WW 的函数是非凸的。

  3. 误区:“只要损失在下降,就说明优化方向是正确的。” 纠正:损失下降只说明你正在沿着梯度的反方向移动,走向一个更低的点。但这并不能保证你走向的是全局最优。你可能正走向一个性能很差的局部最小值,或者在鞍点附近徘徊。

  4. 面试追问:“既然深度学习是非凸的,为什么我们用梯度下降法还能取得这么好的效果?” 回答要点

    • 过参数化(Over-parameterization):现代神经网络通常是过参数化的,即参数数量远大于训练样本数量。这使得损失平面变得相对“友好”,存在大量性能几乎一样好的“好”的局部最小值。
    • 高维空间的祝福:在低维空间,我们容易被局部最小值困住。但在高维空间,一个点要成为局部最小值,需要在所有维度上都是最小值,这是非常苛刻的条件。而鞍点(在某些维度是最小,在其他维度是最大)则变得非常普遍。
    • SGD的随机性:使用 Mini-batch 的 SGD 引入了噪声,这种噪声有助于算法“抖动”,从而有机会越过小的势垒,逃离尖锐的局部最小值和鞍点。
    • 优秀的优化算法:像 Adam 这样的自适应优化算法,结合了动量(帮助冲过鞍点)和自适应学习率(为不同参数调整步长),在实践中被证明能非常有效地在复杂的非凸地形中寻找高质量的解。
  5. 边界情况:ReLU 的不可微点 ReLU 函数在 x=0x=0 处不可微。在实践中,这几乎不成问题。在 x=0x=0 处,我们可以取其次梯度(subgradient)为 [0, 1] 区间内的任意值(通常实现为 0 或 1)。由于浮点数计算的精度问题,输入严格等于 0 的情况极少发生,因此优化算法可以顺利运行。

相关题目