欢迎光临
我们一直在努力

PyTorch 通信原语详解:深度对比 AllReduce、AllGather 与 ReduceScatter

在现代深度学习中,模型和数据集的规模爆炸式增长,使得分布式训练成为常态。PyTorch 的 torch.distributed 包提供了一系列高效的通信原语(Collective Operations),这些原语是实现数据并行(DDP)和模型并行(FSDP)的关键。理解它们之间的差异和适用场景,对于优化训练速度至关重要。

本文将深入对比 PyTorch 中最常用的三个通信原语:AllReduceAllGatherReduceScatter,并通过实操代码展示其用法。

准备工作:初始化分布式环境

所有分布式操作都依赖于进程组(Process Group)的初始化。我们假设环境已配置好 rank 和 world_size。

import torch
import torch.distributed as dist
import os

def setup(rank, world_size):
    # 使用NCCL后端(GPU环境推荐)或Gloo后端
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

1. AllReduce:梯度同步的基石

AllReduce 是最常见的原语,它将所有进程上的输入张量进行归约(如求和、求平均),然后将最终结果分发给所有进程。

适用场景: 分布式数据并行(DDP)中的梯度同步。

工作原理

  • 输入: $N$ 个进程各自拥有 $T$。
  • 输出: $N$ 个进程各自获得 $ ext{Reduce}(T_1, T_2, …, T_N)$。

代码示例

假设我们有两个进程,rank 0 有张量 10,rank 1 有张量 20。我们使用求和操作。

# 假设已在setup函数中初始化环境
# rank = current_process_rank
# world_size = 2

def run_all_reduce(rank, world_size):
    setup(rank, world_size)

    input_data = torch.tensor([10.0 * (rank + 1)], device=rank)
    print(f"Rank {rank} - Initial data: {input_data}")

    # 执行 AllReduce (SUM)
    dist.all_reduce(input_data, op=dist.ReduceOp.SUM)

    # 归约结果:10.0 + 20.0 = 30.0
    print(f"Rank {rank} - After AllReduce: {input_data}")

    cleanup()

# 运行方式 (通常通过 torchrun 或 multiproc)
# Process 0 Output: Initial data: [10.0], After AllReduce: [30.0]
# Process 1 Output: Initial data: [20.0], After AllReduce: [30.0]

2. AllGather:收集全局信息

AllGather 将所有进程的输入张量收集起来,并将其拼接成一个更大的张量,然后把这个完整的结果分发给所有进程。

适用场景: 需要计算全局统计信息(如 LARS 优化器),或者在 BatchNorm 层中收集所有 GPU 上的统计量。

工作原理

  • 输入: $N$ 个进程各自拥有 $T$。
  • 输出: $N$ 个进程各自获得 $ ext{Concat}(T_1, T_2, …, T_N)$。

代码示例

每个进程拥有一个 2×2 的张量,world_size = 2

def run_all_gather(rank, world_size):
    setup(rank, world_size)

    # 每个进程的输入张量
    input_tensor = torch.ones(2, 2, device=rank) * (rank + 1)

    # 准备输出列表,大小等于 world_size,用于存放收集到的张量
    output_list = [torch.zeros_like(input_tensor) for _ in range(world_size)]

    dist.all_gather(output_list, input_tensor)

    # 结果是一个包含所有进程数据的列表
    # Rank 0 input: [[1., 1.], [1., 1.]]
    # Rank 1 input: [[2., 2.], [2., 2.]]
    # All processes receive: [tensor(1s), tensor(2s)]
    print(f"Rank {rank} - Collected data shape: {[o.shape for o in output_list]}")
    print(f"Rank {rank} - First collected item (from rank 0):\n{output_list[0]}")

    cleanup()

3. ReduceScatter:高效的模型并行基石

ReduceScatter 结合了 Reduce(归约)和 Scatter(分散)的功能。它首先对所有进程的输入数据列表进行归约,然后将归约后的结果分散回每个进程,每个进程只接收结果的一部分。

适用场景: 大规模模型(如 FSDP)中的高效梯度同步。它允许在归约过程中同时减少通信带宽,因为每个进程只接收其所需的部分。

工作原理

  • 输入: $N$ 个进程各自拥有一组数据 $L = [T_1, T_2, …, T_N]$ (通常是已经拼接好的大张量,但在 PyTorch API 中通常输入的是列表)。
  • 操作: 计算 $ ext{Reduce}(L_1, L_2, …, L_N)$。
  • 输出: 进程 $i$ 获得归约结果的第 $i$ 部分。

代码示例

我们将使用 reduce_scatter_tensor API,它更符合实际操作中对大张量的处理。

def run_reduce_scatter(rank, world_size):
    setup(rank, world_size)

    # 假设每个进程的输入是一个大张量,包含所有进程所需的数据部分
    # 输入大小必须是世界大小的倍数 (此处 4x2 = 8)
    # 目标是:对所有进程的输入求和,然后将结果分散给每个进程 (2x2)
    input_tensor_full = torch.ones(world_size * 2, 2, device=rank) * (rank + 1)

    # 准备输出张量,大小为输入大小 / world_size
    output_tensor = torch.zeros(2, 2, device=rank)

    # 使用 dist.reduce_scatter_tensor
    dist.reduce_scatter_tensor(output_tensor, input_tensor_full)

    # 归约求和: 假设 Rank 0 和 Rank 1 都有 [[1,1], [1,1], [1,1], [1,1]] 和 [[2,2], [2,2], [2,2], [2,2]]
    # 归约结果 (SUM): [[3,3], [3,3], [3,3], [3,3]]
    # Scatter: Rank 0 收到前一半 [[3,3], [3,3]]
    # Scatter: Rank 1 收到后一半 [[3,3], [3,3]]

    print(f"Rank {rank} - Output (received part of the reduced result):\n{output_tensor}")

    cleanup()

总结对比

原语 功能描述 数据流向 常见用途
AllReduce 归约所有数据,所有进程获得完整结果。 $N$ 到 1 (Reduce) -> 1 到 $N$ (Broadcast) DDP 梯度同步
AllGather 收集所有数据并拼接,所有进程获得完整拼接结果。 $N$ 到 $N$ (收集/广播) 收集 Batch Norm 统计量、全局特征
ReduceScatter 归约所有数据,并将结果分散到每个进程的不同部分。 $N$ 到 1 (Reduce) -> 1 到 $N$ (Scatter) FSDP 梯度高效同步

在选择通信原语时,核心原则是:如果只需要一个全局统计值(如平均梯度),使用 AllReduce;如果需要所有进程的数据,使用 AllGather;如果需要高效地归约大张量并分散结果的不同部分,使用 ReduceScatter。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » PyTorch 通信原语详解:深度对比 AllReduce、AllGather 与 ReduceScatter
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址