在分布式训练,尤其是深度学习模型的分布式训练中,高效的节点间通信是性能的关键。All-Reduce、All-Gather 和 Reduce-Scatter 是最核心的三种集体通信原语(Collective Communication Primitives)。理解它们的执行逻辑,有助于我们更好地优化和调试分布式系统。
什么是集体通信原语?
集体通信(Collective Communication)指的是集群中所有或部分进程(节点)都参与的数据交换操作。与点对点通信(如Send/Recv)不同,集体通信通常由优化的算法实现,以最小化通信延迟和带宽占用。
以下我们将详细解析这三个关键原语的执行逻辑。
1. All-Reduce:全局数据的归约与同步
目的: 每个进程都拥有所有进程的输入数据经过某种操作(如求和、求平均、求最大值)后的最终结果。
在分布式深度学习中,All-Reduce 最常用于同步梯度(求平均梯度)。
执行逻辑:Ring All-Reduce
在大规模集群中,为了避免带宽瓶颈,现代框架(如NCCL、Gloo)通常采用基于环形拓扑的 Ring All-Reduce 算法。
假设有 $P$ 个进程(P0, P1, P2, P3),输入数据 $D$ 被分割成 $P$ 个块($D_0, D_1, D_2, D_3$)。
Ring All-Reduce 分为两个主要阶段:Reduce-Scatter 和 All-Gather。
阶段 A: Reduce-Scatter (局部归约)
- 初始化: 每个进程将自己的数据块 $D$ 环形发送给邻居,同时接收上游邻居的数据块。进程 $i$ 将 $D$ 发送给 $(i+1) \pmod P$,从 $(i-1) \pmod P$ 接收。
- 迭代: 经过 $P-1$ 步,每个进程都接收到了所有其他进程针对特定数据块的贡献,并完成局部归约。
示例 (P0):P0 最终只保留了 $D_0$ 部分的归约结果 $R_0 = D_0^{P0} + D_0^{P1} + D_0^{P2} + D_0^{P3}$。
阶段 B: All-Gather (结果同步)
- 同步: 进程 $i$ 将它在阶段 A 中计算出的归约结果 $R_i$ 环形发送给邻居。
- 迭代: 经过 $P-1$ 步,每个进程都收集了所有的归约结果 $R_0, R_1, R_2, R_3$。
最终,所有进程都拥有完整的全局归约结果。
2. All-Gather:数据扩展与聚合
目的: 每个进程将其局部数据发送给所有其他进程,使得每个进程最终拥有所有进程的完整数据集。
在分布式训练中,All-Gather 常用于同步批归一化(Batch Normalization)的统计信息,或是在Transformer模型中收集局部特征向量。
执行逻辑:基于环形或树形
与 All-Reduce 类似,All-Gather 也可以通过 Ring 算法高效实现,但它只涉及数据传输,不涉及计算。
假设进程 $i$ 有数据 $L_i$。
- 发送: 进程 $i$ 将 $L_i$ 发送给 $(i+1) \pmod P$。
- 接收与聚合: 进程 $i$ 接收 $(i-1) \pmod P$ 发来的数据块 $A$,并将 $A$ 和自己的数据 $L_i$ 一起,发送给 $(i+1) \pmod P$。
- 迭代: 经过 $P-1$ 步后,所有进程都收集了所有 $L_0, L_1, L_2, ext{…}, L_{P-1}$。
3. Reduce-Scatter:归约与分散的结合
目的: 这是 All-Reduce 算法的第一步。它将输入数据进行归约操作后,然后将归约后的结果分散(Scatter)到各个进程,每个进程只保留最终归约结果的一个子集。
输入: 进程 $i$ 有数据 $D_i$。
输出: 进程 $i$ 接收到所有输入数据的归约结果的子集 $R_i$(即上文 All-Reduce 阶段 A 的结果)。
执行逻辑
Reduce-Scatter 的执行逻辑与 Ring All-Reduce 的第一阶段完全相同:
- 分块: 数据 $D$ 被分成 $P$ 块 $D_0, ext{…}, D_{P-1}$。
- 迭代归约: 进程 $i$ 负责计算最终归约结果的 $i$ 块(在 $P-1$ 步内完成环形传输和局部求和)。
注意: Reduce-Scatter 并非一个独立的常用操作,它通常作为高效实现 All-Reduce 的关键中间步骤。
实践示例:PyTorch Distributed 使用
在 PyTorch 中,这些操作封装在 torch.distributed 模块中。以下示例展示了如何在两个进程间执行 All-Reduce 和 All-Gather。
假设我们使用 nccl 或 gloo 作为后端,需要启动多进程环境(例如使用 torchrun)。
import torch
import torch.distributed as dist
import os
def run(rank, world_size):
# 1. 初始化进程组
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# All-Reduce 示例
# 进程 0 的初始张量是 1.0, 进程 1 的初始张量是 2.0
tensor = torch.tensor([1.0 + rank * 1.0])
print(f"Rank {rank}: Initial tensor value: {tensor}")
# 执行 All-Reduce (求和操作)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
# 预期结果: 1.0 + 2.0 = 3.0
print(f"Rank {rank}: All-Reduce result (SUM): {tensor}\n")
# All-Gather 示例
# 准备要收集的数据
local_data = torch.tensor([10.0 + rank])
output_list = [torch.zeros_like(local_data) for _ in range(world_size)]
# 执行 All-Gather
dist.all_gather(output_list, local_data)
# 预期结果: 进程 0 和 进程 1 都会得到 [10.0, 11.0]
print(f"Rank {rank}: All-Gather result: {output_list}")
if __name__ == '__main__':
# 实际运行通常通过命令行工具如 torchrun 启动,这里模拟 rank 和 world_size
# 实际部署时请勿在单进程内运行此代码块,需要 setup 多进程环境。
print("--- Simulation requires multi-process environment ---")
# 假设 rank = 0, world_size = 2
# run(0, 2)
通过理解这些原语的内部执行(尤其是高效的 Ring 算法),我们可以更好地设计分布式系统,并在 GPU 资源紧张时,通过调整模型分块和通信策略来优化训练速度。
汤不热吧