张量并行(Tensor Parallelism, TP)是大型语言模型(LLMs)训练和推理中必不可少的优化技术,它通过在不同设备(如GPU)上切分模型的权重张量来扩展计算能力。在Transformer架构中,Attention层和MLP层都可以进行并行化,但MLP层因其结构特性,最适合进行高效的横向切分(即先列切分再行切分)。
1. 理解Transformer中的MLP层
Transformer的Feed-Forward Network (FFN),或称为MLP层,负责对输入序列的每一个Token进行独立的非线性变换。它通常由两个线性层和一个激活函数(如GeLU)组成:
- 扩展层 (Expansion): $Y_{mid} = \text{GeLU}(X W_{in} + B_{in})$
- 收缩层 (Contraction): $Y_{out} = Y_{mid} W_{out} + B_{out}$
其中,$W_{in}$的维度通常为 $(H, 4H)$,$W_{out}$的维度为 $(4H, H)$。($H$是隐藏维度)。
2. 为什么MLP适合横向切分?
横向切分(Hiding the communication)的效率在于,前一个操作的输出可以直接喂给后一个操作,而无需在中间进行全局同步通信。
2.1 步骤一:扩展层 $W_{in}$ 的列并行(Column Parallelism, TP-C)
我们将 $W_{in}$ 沿着输出维度(列)切分成 $P$ 份:$W_{in} = [W_{in}^1, W_{in}^2, \dots, W_{in}^P]$。每个GPU $i$ 只存储 $W_{in}^i$。
输入张量 $X$ 广播到所有GPU上,每个GPU独立计算局部结果:
$$Y_{mid}^i = X W_{in}^i$$
由于 $W_{in}$ 是列切分的,计算出的中间激活 $Y_{mid}^i$ 也被切分了。这一步不需要任何通信。
2.2 步骤二:收缩层 $W_{out}$ 的行并行(Row Parallelism, TP-R)
为了保持计算效率,切分后的中间激活 $Y_{mid}$ 必须被用于计算 $W_{out}$。我们将 $W_{out}$ 沿着输入维度(行)切分成 $P$ 份:$W_{out} = [W_{out}^1; W_{out}^2; \dots; W_{out}^P]$.
每个GPU $i$ 使用其局部激活 $Y_{mid}^i$ 和局部权重 $W_{out}^i$ 进行计算:
$$Y_{out}^i = Y_{mid}^i W_{out}^i$$
2.3 最终通信:All-Reduce
由于 $Y_{out}$ 是由所有局部计算的和构成的($Y_{out} = Y_{out}^1 + Y_{out}^2 + \dots + Y_{out}^P$),因此在最后,我们需要一个高效的All-Reduce操作来同步并求和所有GPU的局部结果,得到最终的输出 $Y_{out}$。
关键优势: 整个MLP块的两个矩阵乘法操作(TP-C 和 TP-R)之间无需通信。通信被推迟到整个块的末尾,极大地提高了并行效率。
3. PyTorch中的TP实战模拟
以下代码展示了如何在 PyTorch 分布式环境中模拟 MLP 层的 TP 流程,使用 torch.distributed 库进行权重切分和结果聚合。
假设我们有 $P=2$ 个进程,隐藏维度 $H=1024$,扩展维度 $4H=4096$。
import torch
import torch.distributed as dist
# 假设已初始化DDP环境,并设置rank和world_size
# dist.init_process_group(backend='nccl')
# rank = dist.get_rank()
# world_size = dist.get_world_size()
# 为了演示,我们模拟rank 0 和 world_size 2
rank = 0 # 假设只关注一个进程的局部操作
world_size = 2
# 1. 初始化原始参数 (H=10, 4H=40)
H, FH = 10, 40
# 原始输入 (Batch=4, H=10)
X_original = torch.randn(4, H)
# 原始权重
W_in_full = torch.randn(H, FH)
W_out_full = torch.randn(FH, H)
# === 步骤 1: 列并行 (TP-C) on W_in ===
# 切分W_in:沿着列维度切分 (FH / world_size)
W_in_split = torch.chunk(W_in_full, world_size, dim=1)[rank]
# 局部计算 Y_mid (假设X已被广播或使用All-Gather/Reduce)
# 在TP中,X保持完整
Y_mid_local = torch.relu(X_original @ W_in_split)
# Y_mid_local 的形状现在是 (4, FH / world_size) -> (4, 20)
print(f"Rank {rank}: Y_mid_local shape: {Y_mid_local.shape}")
# === 步骤 2: 行并行 (TP-R) on W_out ===
# 切分W_out:沿着行维度切分 (FH / world_size)
W_out_split = torch.chunk(W_out_full, world_size, dim=0)[rank]
# 局部计算 Y_out
Y_out_local = Y_mid_local @ W_out_split
# Y_out_local 的形状现在是 (4, H) -> (4, 10)
print(f"Rank {rank}: Y_out_local shape: {Y_out_local.shape}")
# === 步骤 3: 最终通信 (All-Reduce) ===
# 假设 dist.all_reduce 已经完成,这里用求和模拟
# 实际运行中,所有rank都执行all_reduce(Y_out_local, op=dist.ReduceOp.SUM)
# 模拟另一个进程的局部输出
# (通常需要真正的分布式环境)
# 假设所有局部输出都已知 (仅为演示)
# Y_out_sum = Y_out_local_rank0 + Y_out_local_rank1
# Y_out_final = Y_out_sum
print("TP-C + TP-R 成功将通信延迟到MLP层的末尾。")
4. 总结:TP的高效性
通过将 $W_{in}$ 进行列切分(TP-C)和 $W_{out}$ 进行行切分(TP-R)的组合,我们实现了“切分输入,切分输出”的模式。这种模式确保了 $W_{in}$ 产生的局部激活正好是 $W_{out}$ 所需的局部输入,从而避免了中间通信(如 All-Gather)。只需要在最终的输出阶段执行一次 All-Reduce,这在带宽和延迟上都比在每层操作中频繁通信要高效得多,是实现大规模Transformer模型张量并行的核心技术之一。
汤不热吧