§1.3.17

InternViT-6B 的参数放大策略与训练数据?

核心概念

InternViT-6B 是一个拥有 60 亿参数的巨型视觉 Transformer(Vision Transformer, ViT)模型。其核心创新在于采用了**混合专家模型(Mixture-of-Experts, MoE)**的参数放大策略,从而在不显著增加推理计算量的前提下,极大地扩展了模型的参数容量。与传统地通过增加网络深度或宽度来放大模型不同,MoE 允许模型在每次前向传播时,只激活总参数的一小部分。为了有效训练这个巨型模型,研究团队还构建并使用了一个名为 InternData 的大规模、多源、经过精细清洗的训练数据集。

原理与推导

1. 参数放大策略:混合专家模型 (MoE)

传统的 ViT 模型通过增加层数(Depth)、隐藏层维度(Width)和注意力头数来放大。然而,这种“密集(Dense)”放大方式意味着每次推理都需要动用全部参数,导致计算成本随参数量线性增长,很快变得不可行。

InternViT-6B 采用的 MoE 策略解决了这个问题。其核心思想是将 Transformer Block 中的前馈网络(Feed-Forward Network, FFN)层替换为 MoE 层。

一个 MoE 层包含两部分:

  1. N 个专家网络(Experts):通常是 N 个结构相同但参数独立的 FFN,记为 E1,E2,...,ENE_1, E_2, ..., E_N
  2. 1 个门控网络(Gating Network / Router):一个轻量级网络(如一个线性层),用于为每个输入的 token 决定应该由哪些专家来处理。

工作原理: 对于输入的一个 token 表征 xx,门控网络 gg 会计算出一个 N 维的概率分布,表示将该 token 发往每个专家的权重。

logits=xWg\text{logits} = x \cdot W_g gates=Softmax(logits)\text{gates} = \text{Softmax}(\text{logits})

其中 WgW_g 是门控网络的可学习权重矩阵。

在实践中,为了实现稀疏激活以节省计算,通常不使用所有专家的加权和,而是采用 Top-K 路由。即门控网络为每个 token 选择得分最高的 K 个专家。对于 InternViT-6B,通常采用 Top-2 路由(K=2K=2)。

设门控网络为 token xx 选择的 Top-K 专家索引集合为 T\mathcal{T},对应的门控值为 g(x)ig(x)_i。则 MoE 层的输出 yy 为这些被激活专家的输出的加权和:

y(x)=iTg(x)iEi(x)y(x) = \sum_{i \in \mathcal{T}} g(x)_i \cdot E_i(x)

InternViT-6B 的具体做法是将一个 ViT-Huge 模型作为基础骨架,并将其中的部分 FFN 层替换为 MoE 层。通过增加专家的数量 NN,模型总参数量可以扩展到 60 亿,但由于每次只激活 K 个专家,其推理计算量(FLOPs)仅与一个稍大于 ViT-Huge 的模型相当。

复杂度分析:

  • 参数量(空间复杂度)O(NParamsexpert+Paramsbackbone)O(N \cdot \text{Params}_{\text{expert}} + \text{Params}_{\text{backbone}})。参数量与专家数 NN 成正比。
  • 计算量(时间复杂度)O(KFLOPsexpert+FLOPsbackbone)O(K \cdot \text{FLOPs}_{\text{expert}} + \text{FLOPs}_{\text{backbone}})。计算量与激活专家数 KK 成正比,而与总专家数 NN 无关。这是 MoE 的核心优势。

2. 训练稳定性:负载均衡损失 (Load Balancing Loss)

MoE 训练的一个关键挑战是“专家坍塌”:门控网络可能倾向于总是选择少数几个“受欢迎”的专家,导致其他专家得不到充分训练。为解决此问题,需要引入一个辅助的负载均衡损失 LauxL_{aux}

该损失函数鼓励门控网络将 token 尽可能均匀地分配给所有专家。其定义如下:

Laux=αi=1NfiPiL_{aux} = \alpha \cdot \sum_{i=1}^{N} f_i \cdot P_i
  • NN 是专家总数。
  • fif_i 是一个 batch 中被分配给第 ii 个专家的 token 数量的比例。
  • PiP_i 是这个 batch 中,所有 token 的门控网络输出的概率值在第 ii 个专家上的均值。
  • α\alpha 是一个超参数,用于权衡主任务损失和该辅助损失。

总损失函数为:

Ltotal=Ltask+LauxL_{total} = L_{task} + L_{aux}

通过最小化 LauxL_{aux},可以惩罚那种将大量 token 集中在少数专家上的行为,从而保证所有专家都能得到有效的训练。

3. 训练数据策略:InternData

训练 60 亿参数的模型需要海量、高质量且多样化的数据。InternViT-6B 使用了自建的 InternData 数据集。

构建策略:

  1. 多源聚合:整合了多个公开数据集(如 ImageNet-21K, LAION, COCO 等)、网络爬取数据以及部分内部数据,构建了一个包含数十亿图像-文本对的原始数据池。
  2. 精细化清洗与过滤
    • 去重:使用图像哈希等技术去除重复或高度相似的图像。
    • 内容过滤:过滤掉低质量、不相关或不适宜(NSFW)的内容。
    • 美学与信息量筛选:可能使用预训练的美学评分模型或 CLIP 等模型对图像-文本对的相关性进行打分,保留高质量的样本。
  3. 数据均衡:对数据来源和类别进行分析,确保数据的多样性和均衡性,避免模型在特定领域上过拟合。

这种精细的数据策略是成功训练出高性能大模型的基石,其重要性不亚于模型结构本身。

代码实现

由于 InternViT-6B 模型巨大,无法在单张消费级 GPU 上完整运行。以下代码以 transformers 库的风格,概念性地展示了如何加载并使用一个 MoE 架构的 ViT 模型。这有助于理解其在工程中的调用方式。

python
1import torch
2from torch import nn
3from transformers import ViTConfig, ViTModel
4
5# 这是一个概念性的演示,实际的 InternViT-6B 可能需要特定的库(如 mmpretrain)来加载
6# 这里我们模拟一个带有 MoE 层的 ViT 模型的使用流程
7
8class MoEExpert(nn.Module):
9 """一个简单的专家网络,通常是一个 FFN"""
10 def __init__(self, config):
11 super().__init__()
12 self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size)
13 self.act_fn = nn.GELU()
14 self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size)
15
16 def forward(self, hidden_states):
17 return self.dense2(self.act_fn(self.dense1(hidden_states)))
18
19class MoELayer(nn.Module):
20 """一个简化的 MoE 层,用于演示"""
21 def __init__(self, config, num_experts=8, top_k=2):
22 super().__init__()
23 self.num_experts = num_experts
24 self.top_k = top_k
25 self.experts = nn.ModuleList([MoEExpert(config) for _ in range(num_experts)])
26
27 # 门控网络,决定每个 token 去哪个 expert
28 self.gate = nn.Linear(config.hidden_size, num_experts, bias=False)
29
30 def forward(self, hidden_states):
31 # hidden_states 的形状: (batch_size, seq_len, hidden_size)
32 batch_size, seq_len, hidden_dim = hidden_states.shape
33
34 # 将 token 展平以进行路由
35 hidden_states_flat = hidden_states.view(-1, hidden_dim) # (batch_size * seq_len, hidden_dim)
36
37 # 计算门控 logits
38 gate_logits = self.gate(hidden_states_flat) # (batch_size * seq_len, num_experts)
39
40 # 选择 Top-K 专家
41 # weights 是门控值,indices 是专家的索引
42 weights, indices = torch.topk(gate_logits, self.top_k, dim=-1)
43 weights = nn.functional.softmax(weights, dim=-1, dtype=torch.float).to(hidden_states.dtype)
44
45 # 实际工程中,这里会有复杂的 dispatch 和 combine 操作,以实现稀疏计算
46 # 为简化演示,我们这里用一个循环来模拟,这在真实场景中效率很低
47 final_hidden_states = torch.zeros_like(hidden_states_flat)
48 for i in range(self.num_experts):
49 # 找到被分配给当前专家 i 的 token
50 token_indices = (indices == i).any(dim=-1)
51 if token_indices.any():
52 # 提取对应的 token 和门控权重
53 tokens_for_expert = hidden_states_flat[token_indices]
54
55 # 找到对应的门控权重
56 gate_values = weights[token_indices]
57 expert_mask = (indices[token_indices] == i)
58 gate_values_for_expert = (gate_values * expert_mask).sum(dim=-1, keepdim=True)
59
60 # 将 token 输入专家网络,并用门控值加权
61 expert_output = self.experts[i](tokens_for_expert)
62 final_hidden_states[token_indices] += expert_output * gate_values_for_expert
63
64 return final_hidden_states.view(batch_size, seq_len, hidden_dim)
65
66
67# --- 模拟使用流程 ---
68# 1. 定义模型配置
69config = ViTConfig(
70 hidden_size=768,
71 num_hidden_layers=12,
72 num_attention_heads=12,
73 intermediate_size=3072,
74 image_size=224,
75 patch_size=16
76)
77
78# 2. 假设我们有一个预训练的 ViT 模型
79# 在实际应用中,这将是 InternViT-6B 的加载入口
80# model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
81# 假设我们用 MoE 层替换了其中一个 FFN
82# model.encoder.layer[11].output = MoELayer(config) # 概念性替换
83
84# 3. 准备输入数据
85# 模拟一张 224x224 的 RGB 图像
86dummy_image = torch.randn(1, 3, 224, 224)
87print(f"输入图像尺寸: {dummy_image.shape}")
88
89# 4. 模型推理
90# 在实际场景中,你需要一个完整的 ViT 模型实例
91# 这里我们只演示 MoE 层本身
92moe_layer = MoELayer(config)
93dummy_input_tokens = torch.randn(1, 197, 768) # 模拟 ViT patch embedding + [CLS] token
94output_tokens = moe_layer(dummy_input_tokens)
95
96print(f"MoE 层输入 token 尺寸: {dummy_input_tokens.shape}")
97print(f"MoE 层输出 token 尺寸: {output_tokens.shape}")
98# 输出尺寸应与输入尺寸相同,因为 MoE 替代的是 FFN 的功能
99assert dummy_input_tokens.shape == output_tokens.shape

工程实践

  • 使用场景:InternViT-6B 这类巨型模型主要用作视觉基础模型(Foundation Model)。它们在海量数据上预训练后,可以通过微调(Fine-tuning)或提示(Prompting)等方式,在广泛的下游任务上取得卓越性能,如图像分类、目标检测、语义分割、医学影像分析等,尤其擅长处理开放世界和零样本/少样本学习问题。
  • 超参数选择
    • K (Top-K):通常选择 K=2K=2K=1K=1(如 Switch Transformer)计算效率最高,但 K=2K=2 通常能带来更好的性能,因为它允许 token 信息被两个专家处理,增加了模型的表达能力和鲁棒性。
    • N (专家数):N 的选择是模型容量和硬件资源的权衡。越大的 N 意味着越高的参数量和内存占用,但可能带来更强的性能。
    • 负载均衡损失权重 α\alpha:这是一个关键的调优参数,通常从一个小值开始(如 0.01),根据训练过程中专家利用率的监控情况进行调整。
  • 性能 / 显存 / 吞吐 的权衡
    • 显存MoE 模型参数巨大,训练时需要借助 3D 并行策略:
      1. 数据并行 (Data Parallelism, DP):如 DeepSpeed ZeRO,将模型参数、梯度和优化器状态分片到不同 GPU 上,极大降低单卡显存占用。
      2. 张量并行 (Tensor Parallelism, TP):将单个大矩阵(如 Attention 或 FFN 的权重)切分到多个 GPU 上并行计算。
      3. 流水线并行 (Pipeline Parallelism, PP):将模型的不同层放置在不同 GPU 上,形成计算流水线。
    • 吞吐MoE 引入了 All-to-All 通信(用于在 GPU 间分发和收集 token),这会成为训练的瓶颈。需要高性能的计算集群(如 NVLink, InfiniBand)来降低通信延迟。
  • 常见坑和调试技巧
    • 监控专家利用率:训练 MoE 模型时,必须持续监控每个 expert 处理的 token 数量。可以使用 TensorBoard 或 W&B 绘制专家负载的直方图。如果负载极不均衡,说明 LauxL_{aux} 的权重 α\alpha 可能太小,或者学习率设置不当。
    • 数值稳定性:巨型模型容易出现梯度爆炸或消失。使用 BF16 混合精度训练通常比 FP16 更稳定。此外,一些研究(如 GShard)提出对门控网络的 logits 做随机噪声注入或使用 Router Z-loss 来增加稳定性。

常见误区与边界情况

  • 误区一:60亿参数意味着推理速度极慢
    • 澄清:这是对 MoE 最常见的误解。InternViT-6B 的推理计算量(FLOPs)并不与总参数量成正比,而是与激活参数量成正比。由于每次只激活 Top-K 个专家,其计算量远小于一个 60 亿参数的密集模型,大致相当于一个几亿参数的密集模型。但是,它的显存占用确实与总参数量成正比。
  • 误区二:MoE 就是模型集成(Ensemble)
    • 澄清:完全不同。模型集成是训练多个独立模型,推理时对同一输入并行计算,最后聚合结果,计算成本是模型数量的倍数。MoE单个模型内部的条件计算路径,一个输入的不同部分(tokens)被动态路由到不同的内部组件(专家),总计算量仅略微增加。
  • 边界情况:小批量(Small Batch Size)训练
    • MoE 的负载均衡机制依赖于在一个 batch 中有足够多的 token,以便在统计上可以均匀地分配给各个专家。如果 batch size 过小,token 分布会充满噪声,负载均衡损失可能失效或起反作用,导致训练不稳定。
  • 失败模式:专家坍塌(Expert Collapse)
    • 这是 MoE 最经典的失败模式。如果负载均衡机制失效,门控网络会收敛到只使用一个或少数几个专家。这使得模型退化为一个远小于预期的非 MoE 模型,浪费了大量参数,性能也会急剧下降。
  • 常见面试追问
    • :“既然 MoE 这么有效,为什么不把所有层都换成 MoE?”
    • :1) 参数与计算的权衡:FFN 层占据了 Transformer 中约 2/3 的计算和参数,是替换为 MoE 的最高性价比选择。注意力层虽然也重要,但其计算模式(token 间交互)与 FFN 的“知识存储”功能不同,将其专家化带来的收益尚不明确。2) 信息流:在层与层之间保留一些密集的 FFN 或注意力层,可能有助于整合和传播由不同专家处理过的信息,避免信息流过于稀疏。通常采用间隔替换的方式,如每隔一个 FFN 层替换为 MoE 层。
    • :“如何为特定任务选择合适的专家数量 N 和激活数量 K?”
    • :这是一个经验性的问题。K 通常固定为 1 或 2。N 的选择则是一个典型的模型缩放(Scaling Law)问题,取决于任务的复杂度和可用的计算/显存预算。可以从一个较小的 N 开始实验,逐步增加 N,观察模型在验证集上性能的提升曲线。当性能提升饱和或显存达到瓶颈时,就找到了一个合适的 N。
相关题目