欢迎光临
我们一直在努力

如何通过 torch.distributed.rpc 构建跨机器的参数服务器架构

在构建大规模深度学习模型训练系统时,参数服务器(Parameter Server, PS)架构是一种常见的解决方案。它将模型参数的存储和更新集中在专用的PS节点上,而Worker节点(训练器)只负责计算梯度并发送给PS。PyTorch 提供的 torch.distributed.rpc 框架,凭借其简洁的远程过程调用(RPC)和远程引用(RRef)机制,非常适合实现这种跨机器的PS架构。

本文将聚焦如何利用 RPC 框架,定义 PS 节点和 Worker 节点,实现参数的远程管理和梯度更新。

1. 核心概念:RPC 与 RRef

  • RPC (Remote Procedure Call): 允许一个进程调用另一个远程进程上的函数,就像调用本地函数一样。
  • RRef (Remote Reference): 远程引用,允许一个进程持有一个指向另一个远程进程上对象的句柄。PS 节点创建的模型参数对象,可以通过 RRef 被 Worker 节点安全引用。

2. 环境设置与运行脚本

由于 RPC 需要多进程环境,我们使用 multiprocessing 在一台机器上模拟两个角色:psworker。在实际跨机器部署时,只需确保所有机器的网络互通,并使用 python your_script.py 配合正确的环境变量即可。

首先,确保 PyTorch 版本 >= 1.6。

# ps_rpc_example.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed.rpc as rpc
from torch.distributed.rpc import RRef, rpc_sync
import time
import multiprocessing

WORLD_SIZE = 2

# 定义参数服务器上的模型和优化器
class ParameterServer(object):
    def __init__(self):
        # 示例模型:简单的线性层
        self.model = nn.Linear(10, 1)
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
        print("PS Model initialized.")

    # Worker 调用此函数获取最新的模型参数
    @staticmethod
    @rpc.functions.async_execution
    def get_parameter_rrefs(ps_rref: RRef):
        param_rrefs = []
        ps = ps_rref.local_value()

        # 遍历模型的参数,并为每个参数创建一个 RRef
        for param in ps.model.parameters():
            param_rrefs.append(RRef(param))
        return param_rrefs

    # Worker 调用此函数将梯度发送回来并执行优化步
    def apply_gradients(self, gradients_dict):
        self.optimizer.zero_grad()

        # 假设 Worker 发送的 dict 包含了命名参数的梯度
        for name, param in self.model.named_parameters():
            if name in gradients_dict:
                # 确保梯度张量和参数张量在同一个设备上 (尽管在这个例子中都是CPU)
                param.grad = gradients_dict[name].to(param.device)

        self.optimizer.step()
        print(f"[PS] Parameters updated at time {time.time():.2f}")
        return True

# PS 角色初始化函数
def run_ps(rank, world_size):
    # 初始化 RPC 进程
    rpc.init_rpc(
        name="ps",
        rank=rank,
        world_size=world_size,
        rpc_backend=rpc.Backend.TENSORPIPE,

    )
    # PS 实例化自身,并持有 RRef
    ps_instance = ParameterServer()
    global_ps_rref = RRef(ps_instance)

    # PS 进程持续运行,等待 Worker 节点的调用
    print(f"PS running, rank {rank}.")
    rpc.shutdown()

# Worker 角色执行函数
def run_worker(rank, world_size):
    # 初始化 RPC 进程
    rpc.init_rpc(
        name=f"worker{rank}",
        rank=rank,
        world_size=world_size,
        rpc_backend=rpc.Backend.TENSORPIPE,
    )

    # 获取 PS 的 RRef
    ps_rref = rpc_sync("ps", lambda: global_ps_rref).to_here()

    # 1. 假设 Worker 在本地拥有与 PS 结构相同的模型
    local_model = nn.Linear(10, 1)

    # 2. 模拟训练步骤:前向传播、计算损失和反向传播
    data = torch.randn(5, 10)
    target = torch.randn(5, 1)

    # 实际场景中,Worker 会先从 PS 获取最新权重同步到 local_model

    # 模拟计算梯度
    output = local_model(data)
    loss = (output - target).pow(2).mean()
    loss.backward()

    # 3. 提取梯度,准备发送给 PS
    gradients_to_send = {}
    for name, param in local_model.named_parameters():
        if param.grad is not None:
            gradients_to_send[name] = param.grad.cpu() # 梯度通常需要拷贝到 CPU 或特定设备以便传输

    # 4. 通过 RPC 调用 PS 上的 apply_gradients 方法
    print(f"Worker {rank} sending gradients...")
    future = rpc.rpc_async("ps", ps_rref.local_value().apply_gradients, args=(gradients_to_send,))

    # 等待 PS 确认更新完成
    result = future.wait()
    print(f"Worker {rank} update confirmed: {result}")

    rpc.shutdown()

# 主入口,用于启动多进程模拟
def main():
    # 必须设置网络环境
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'

    processes = []
    roles = {'ps': run_ps, 'worker1': run_worker}

    for rank, (role_name, fn) in enumerate(roles.items()):
        p = multiprocessing.Process(target=fn, args=(rank, WORLD_SIZE))
        p.start()
        processes.append(p)
        # 确保 PS 先启动,避免 Worker 找不到 PS 节点
        if role_name == 'ps':
            time.sleep(1) 

    for p in processes:
        p.join()

if __name__ == "__main__":
    # 警告:此模拟代码在 Python 3.8+ 版本可能需要调整多进程启动方式(使用 forkserver 或 spawn)
    # 在实际分布式环境中,通常使用 torch.distributed.launch 或 slurm 进行管理。
    main()

3. 代码解析与实操要点

  1. PS 节点的职责 (ParameterServer** 类):** PS 类负责实例化和维护模型 (self.model) 和优化器 (self.optimizer)。它提供 apply_gradients 方法,这是 Worker 远程调用的核心接口。
  2. RRef 的使用: 在 PS 启动时,我们创建了 global_ps_rref = RRef(ps_instance)。Worker 通过 rpc_sync(“ps”, lambda: global_ps_rref) 获取到这个远程引用,从而能调用 PS 实例上的方法。
  3. 异步通信 (rpc_async): Worker 使用 rpc.rpc_async 调用 apply_gradients,这意味着 Worker 不需要阻塞等待梯度应用完成,理论上可以继续下一轮计算,实现异步训练(当然,本示例中为了简化演示,仍使用了 future.wait() 来获取结果)。
  4. 数据传输: 梯度张量在通过 RPC 网络传输时,PyTorch 会自动处理序列化和反序列化。确保在发送前,数据格式(如 CPU/GPU 状态)符合预期。

通过这种方式,我们成功地将模型参数管理与梯度计算分离,实现了跨机器的高效参数服务器架构。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 如何通过 torch.distributed.rpc 构建跨机器的参数服务器架构
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址