从单卡突围到万卡集群:大模型训练的分布式策略完全指南(DP/PP/TP与ZeRO)
引言:大模型时代的“内存墙”与“算力墙”
自 ChatGPT 横空出世以来,大语言模型(LLM)已经成为人工智能领域的绝对核心。从早期的 GPT-3 的 1750 亿参数,到如今 Llama-3、Mixtral 等模型的各种千亿、万亿级参数规模,模型的能力在飞速进化。然而,对于从事大模型训练的工程师来说,一个残酷的物理现实始终横亘在面前:显存永远不够用,单卡算力永远不够强。
以一个 1750 亿参数的模型为例,仅仅使用 FP32(32位浮点数)来存储模型权重,就需要约 700 GB 的显存。如果加上训练时的优化器状态、梯度和激活值,总显存消耗轻松突破 TB 级别。而目前地表最强的商业 GPU(如 NVIDIA H100),其单卡显存也仅为 80 GB。
显存捉襟见肘,训练时间更是难以接受。如果用单张卡去训练 GPT-3,可能需要数十年的时间。
为了打破“内存墙”和“算力墙”,我们必须依赖分布式训练技术,将原本无法装进单张卡的大模型,拆解并分配到成百上千张 GPU 上协同计算。在当下的工业界,主流的分布式训练策略主要分为三大类:数据并行、模型并行(主要指张量并行)以及流水线并行。
本文将从底层原理出发,结合 PyTorch 代码实战,为你深度剖析这些分布式策略是如何运转的,以及它们如何共同编织出大模型训练的“万卡集群”网络。
一、 数据并行:最朴素的思想与极限优化
数据并行是目前应用最广泛、最基础的分布式策略。它的核心思想非常简单:既然模型太大放不进多张卡,那我们把模型复制多份,每个卡放一个完整的模型副本,然后把数据拆分给这些模型同时计算。
1.1 朴素的数据并行 (DP)
在标准的 PyTorch DataParallel (DP) 中,训练流程如下:
- Scatter: 将一个 Batch 的数据切分成 N 份,发送给 N 张 GPU。
- Replicate: 将完整的模型参数复制到 N 张 GPU 上。
- Forward: 各个 GPU 独立进行前向传播,计算 Loss。
- Gather: 将所有 GPU 的 Loss 收集到主 GPU(通常是 GPU 0)。
- Backward: 主 GPU 计算梯度并分发(Broadcast)给所有 GPU,各卡进行反向传播。
- Reduce: 收集所有卡的梯度,在主 GPU 上求平均,然后更新模型参数。
- Broadcast: 将更新后的模型参数再次广播给所有 GPU。
致命缺陷: DP 存在严重的负载不均衡问题。主 GPU 需要承担梯度汇总、参数更新和广播的任务,其显存占用和计算时间远高于其他 GPU,导致整体 GPU 利用率(MFU)极低。目前工业界已基本淘汰 DP 模式。
1.2 分布式数据并行 (DDP)
为了解决 DP 的单点瓶颈,PyTorch 推出了 DistributedDataParallel (DDP)。DDP 采用的是单进程多卡或多进程的架构,每个 GPU 都由一个独立的进程控制。
DDP 的核心改进在于Ring-AllReduce算法。在前向传播时,各 GPU 独立计算;在反向传播时,一旦某一层的梯度计算完毕,DDP 就会启动 AllReduce 操作,让所有 GPU 通过环形网络互相交换梯度数据。
最终结果就是:所有 GPU 在任何时刻计算出的梯度都是完全一致的。每个 GPU 都在本地独立完成参数更新,省去了主节点的瓶颈。
1.3 显存刺客与破局者:ZeRO 优化
虽然 DDP 解决了算力瓶颈,但它引入了巨大的显存浪费:N 张卡上存储了 N 份一模一样的模型权重、优化器状态和梯度。如果我们能把这些冗余数据切分开来呢?
微软 DeepSpeed 团队提出的 ZeRO (Zero Redundancy Optimizer) 技术是对 DDP 的终极进化。ZeRO 分为三个阶段:
- ZeRO-1 (优化器状态切分):每张卡只保留 的优化器状态(如 Adam 的动量和方差)。显存大幅下降,通信开销基本不增加。
- ZeRO-2 (梯度切分):在 ZeRO-1 的基础上,每张卡只保留 的梯度。计算完梯度后直接通过 Reduce-Scatter 丢弃不需要的部分。
- ZeRO-3 (参数切分):又称为 FSDP (Fully Sharded Data Parallel)。模型权重也被切分到 N 张卡上。在进行前向或反向计算某层时,通过 All-Gather 动态从其他卡“借用”这部分权重;计算完毕后立刻丢弃。
代码示例:PyTorch FSDP (ZeRO-3 的官方实现)
1 | import torch |
通过 ZeRO/FSDP,原本需要单卡装下整个模型的限制被彻底打破,只要总卡数足够,哪怕是一个万亿参数的模型,也能通过数据并行的方式训练起来。
二、 模型并行:拆解“巨人”的解剖刀
尽管 ZeRO-3 解决了显存冗余问题,但它本质上依然是数据并行。当模型进行前向和反向传播时,它依然需要通过 All-Gather 在单张卡上临时拼凑出完整的模型层。
如果某一层的参数量极大(例如 GPT-3 中维度高达 12288 的隐藏层),即使只临时拼凑一层,单卡显存也会被撑爆。此外,如果模型单层非常宽,计算过程中的激活值也会把显存耗尽。
这时候,我们需要的是真正的模型并行。在当今的语境下,我们通常称其为张量并行。
2.1 张量并行的数学原理
张量并行的核心思想是:将模型单个层内部的矩阵乘法切分到不同的 GPU 上计算。
以全连接层(Linear Layer)为例,假设我们有一个矩阵乘法 。
其中 是输入矩阵, 是权重矩阵。我们将权重矩阵 沿着列切分为两部分:。
那么这个乘法可以改写为:
- GPU 1 负责计算
- GPU 2 负责计算
这个操作被称为 列并行。此时,两张卡各自拥有一般输出。如果我们在后面接一个激活函数(如 ReLU/GELU),两张卡可以独立对各自的结果进行非线性激活,不需要任何通信。
接下来,如果还要接下一层 ,我们可以将矩阵 按行切分为 。
那么:
- GPU 1 负责计算
- GPU 2 负责计算
- 最后,通过一次 All-Reduce 通信,得到最终结果 。
这个操作被称为 行并行。
2.2 Megatron-LM 的伟大贡献
英伟达的 Megatron-LM 框架将上述数学原理巧妙地应用在了 Transformer 架构上。
一个 Transformer Block 包含一个自注意力层和一个 MLP(多层感知机)层。
- MLP 层:先用列并行将第一个线性层拆开,经过激活函数后,再用行并行将第二个线性层合并。首尾只需要一次 All-Reduce。
- 自注意力层:将注意力头切分给不同的 GPU。例如 32 个头分给 8 张卡,每张卡计算 4 个头。这也是一种列并行,计算完后再通过行并行将结果拼接。首尾同样只需要一次 All-Reduce。
特点与局限:
张量并行极大地降低了单层的显存需求和激活值显存占用,且计算效率极高。但是,由于每一层的计算都极度依赖其他卡的协作(前向和反向传播中每一层都需要一次 All-Reduce),它的通信开销极大。
因此,张量并行只能在机器内部(节点内)使用,依靠 GPU 之间极高的 NVLink 带宽(通常 600 GB/s 以上)来掩盖通信延迟。一台 8 卡机器,最高只能做 TP=8 的张量并行。
三、 流水线并行:机器之间的“接力学”
如果我们已经用张量并行(TP)把一台机器内的 8 张卡榨干了,模型还是太大,我们需要跨机器(节点间)进行并行,该怎么办?
节点之间的网络带宽(如 InfiniBand 或 RoCE)通常只有几十 GB/s,如果在这里用张量并行,通信延迟会让 GPU 绝大多数时间都在等待,毫无算力可言。
这时候,流水线并行 就派上用场了。
3.1 基本思想:按层切分
流水线并行的思想非常直观:把模型按层切开。
假设一个模型有 24 层 Transformer,我们有 4 台机器(每台算一个 Stage):
- Stage 0 负责 Layer 1-6
- Stage 1 负责 Layer 7-12
- Stage 2 负责 Layer 13-18
- Stage 3 负责 Layer 19-24
数据从 Stage 0 流向 Stage 3,就像工厂流水线一样。这被称为模型并行(按层)。这种切分方式在不同 Stage 之间只需要传输中间的隐藏状态(Activations,即张量),通信量非常小,非常适合跨节点网络。
3.2 致命的流水线气泡
朴素流水线存在致命问题:气泡。
如果 Stage 0 正在计算 Batch 1,那么 Stage 1、2、3 都处于闲置状态,等待 Stage 0 传数据。这导致 GPU 利用率极低。
为了减少气泡,GPipe 和 PipeDream 等算法被提了出来。核心思想是微批次技术:
我们将一个大 Batch 拆分成多个 Micro-batch(比如拆成 4 个)。Stage 0 处理完 Micro-batch 1 后,立刻发给 Stage 1,同时自己开始处理 Micro-batch 2。
1F1B 调度策略 (One Forward, One Backward):
为了防止显存因为积压过多的激活值而爆炸,现代流水线并行通常采用 1F1B 策略:
- 起初,各个 Stage 逐渐预热,连续做多个 Micro-batch 的前向传播(1F)。
- 预热完成后,每个 Stage 严格遵循:做一次前向传播(1F),紧接着做一次反向传播(1B),并立刻把产生的梯度反向传递给前一个 Stage,丢弃不再需要的激活值释放显存。
- 最后进行收尾冷却,完成剩下的反向传播。
代码示例:PyTorch 中应用 Pipeline 并行 (基于 torch.distributed.pipeline)
PyTorch 提供了非常易用的 pipeline_sync API(基于 GPipe 逻辑,为了简洁展示核心概念):
1 | import os |
流水线并行的优点是通信量小,适合跨节点;缺点是依然无法完全消除流水线气泡(通常会有 10%-20% 的算力浪费),并且需要开发者精心调整每个 Stage 包含的层数,以平衡各节点的计算时间(负载均衡)。
四、 终极奥义:多维混合并行与 3D 并行
在训练千亿、万亿级别的大模型(如 GPT-4、GLM-130B)时,单纯依靠上述某一种策略是远远不够的。工业界通常采用3D 并行策略,即巧妙地将 DP、TP 和 PP 融合在一起。
想象一个由几千张 GPU 组成的超级计算集群,我们将其划分为多个“组”:
- 最内层:张量并行 (TP)。在一台物理机的 8 张 GPU 之间进行张量并行,利用机内极高带宽的 NVLink 解决单层过大、激活值过大的问题。
- 中间层:流水线并行 (PP)。将几十层 Transformer 按层切分给不同的物理机。利用节点间相对较低但足够支撑隐藏状态传输的网络带宽,进一步压缩单卡上的模型体积。
- 最外层:数据并行 (DP / ZeRO)。当上述 TP 和 PP 把单卡上的模型参数压缩到足够小(比如能装下一个 Stage 里的一部分参数)时,我们在外层套上大规模的数据并行(通常结合 ZeRO-3),进一步切分优化器状态、梯度和剩余参数,从而利用几千张卡的超强算力加速训练。
通信拓扑视角下的 3D 并行:
- TP Network:NVLink (Node-Internal)
- PP Network:InfiniBand/RoCE (Inter-Node)
- DP Network:InfiniBand/RoCE (Global)
通过 3D 并行配合,我们不仅能把最大的模型塞进显存,还能保持极高的硬件利用率(MFU)。例如,Megatron-LM 就是在这套 3D 并行基础上,实现了在 3072 张 A100 GPU 上高效训练万亿参数模型的壮举。
总结与展望
从单卡训练到万卡集群,分布式训练是大模型工程化落地的绝对基石。回顾本文的核心内容:
- 数据并行 (DP / ZeRO / FSDP):通过复制模型、切分数据提升算力。结合 ZeRO 技术,可以有效消除数据并行中的显存冗余。
- 张量并行 (TP):将单个矩阵乘法拆解到多张 GPU,解决单层参数量和激活值过大导致单卡显存溢出的问题,极度依赖机器内的高速互联(NVLink)。
- 流水线并行 (PP):按层切分模型,降低通信量,适合跨节点的模型拆分,通过微批次和 1F1B 策略努力减少“气泡”。
- 3D 并行:结合 TP + PP + DP,是目前工业界训练大模型的最优解。
随着 AI 硬件的飞速迭代,比如 NVIDIA Blackwell 架构对张量内存的进一步优化,以及 NVLink 4.0 和 InfiniBand 带宽的成倍提升,分布式训练的边界正在不断被拓宽。不仅如此,诸如 MoE(混合专家模型)、异构计算(CPU Offload)、序列并行(Sequence Parallelism,针对超长上下文)等更细分的并行技术也正在蓬勃发展。
理解这些分布式策略,不仅能帮助 AI 工程师在遇到 OOM(Out Of Memory)时找到破局之法,更是深入理解现代大模型底层运行逻辑的必经之路。希望这篇文章能成为你在探索大模型训练之路上的坚实垫脚石。