在构建大规模深度学习模型训练系统时,参数服务器(Parameter Server, PS)架构是一种常见的解决方案。它将模型参数的存储和更新集中在专用的PS节点上,而Worker节点(训练器)只负责计算梯度并发送给PS。PyTorch 提供的 torch.distributed.rpc 框架,凭借其简洁的远程过程调用(RPC)和远程引用(RRef)机制,非常适合实现这种跨机器的PS架构。
本文将聚焦如何利用 RPC 框架,定义 PS 节点和 Worker 节点,实现参数的远程管理和梯度更新。
1. 核心概念:RPC 与 RRef
- RPC (Remote Procedure Call): 允许一个进程调用另一个远程进程上的函数,就像调用本地函数一样。
- RRef (Remote Reference): 远程引用,允许一个进程持有一个指向另一个远程进程上对象的句柄。PS 节点创建的模型参数对象,可以通过 RRef 被 Worker 节点安全引用。
2. 环境设置与运行脚本
由于 RPC 需要多进程环境,我们使用 multiprocessing 在一台机器上模拟两个角色:ps 和 worker。在实际跨机器部署时,只需确保所有机器的网络互通,并使用 python your_script.py 配合正确的环境变量即可。
首先,确保 PyTorch 版本 >= 1.6。
# ps_rpc_example.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed.rpc as rpc
from torch.distributed.rpc import RRef, rpc_sync
import time
import multiprocessing
WORLD_SIZE = 2
# 定义参数服务器上的模型和优化器
class ParameterServer(object):
def __init__(self):
# 示例模型:简单的线性层
self.model = nn.Linear(10, 1)
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
print("PS Model initialized.")
# Worker 调用此函数获取最新的模型参数
@staticmethod
@rpc.functions.async_execution
def get_parameter_rrefs(ps_rref: RRef):
param_rrefs = []
ps = ps_rref.local_value()
# 遍历模型的参数,并为每个参数创建一个 RRef
for param in ps.model.parameters():
param_rrefs.append(RRef(param))
return param_rrefs
# Worker 调用此函数将梯度发送回来并执行优化步
def apply_gradients(self, gradients_dict):
self.optimizer.zero_grad()
# 假设 Worker 发送的 dict 包含了命名参数的梯度
for name, param in self.model.named_parameters():
if name in gradients_dict:
# 确保梯度张量和参数张量在同一个设备上 (尽管在这个例子中都是CPU)
param.grad = gradients_dict[name].to(param.device)
self.optimizer.step()
print(f"[PS] Parameters updated at time {time.time():.2f}")
return True
# PS 角色初始化函数
def run_ps(rank, world_size):
# 初始化 RPC 进程
rpc.init_rpc(
name="ps",
rank=rank,
world_size=world_size,
rpc_backend=rpc.Backend.TENSORPIPE,
)
# PS 实例化自身,并持有 RRef
ps_instance = ParameterServer()
global_ps_rref = RRef(ps_instance)
# PS 进程持续运行,等待 Worker 节点的调用
print(f"PS running, rank {rank}.")
rpc.shutdown()
# Worker 角色执行函数
def run_worker(rank, world_size):
# 初始化 RPC 进程
rpc.init_rpc(
name=f"worker{rank}",
rank=rank,
world_size=world_size,
rpc_backend=rpc.Backend.TENSORPIPE,
)
# 获取 PS 的 RRef
ps_rref = rpc_sync("ps", lambda: global_ps_rref).to_here()
# 1. 假设 Worker 在本地拥有与 PS 结构相同的模型
local_model = nn.Linear(10, 1)
# 2. 模拟训练步骤:前向传播、计算损失和反向传播
data = torch.randn(5, 10)
target = torch.randn(5, 1)
# 实际场景中,Worker 会先从 PS 获取最新权重同步到 local_model
# 模拟计算梯度
output = local_model(data)
loss = (output - target).pow(2).mean()
loss.backward()
# 3. 提取梯度,准备发送给 PS
gradients_to_send = {}
for name, param in local_model.named_parameters():
if param.grad is not None:
gradients_to_send[name] = param.grad.cpu() # 梯度通常需要拷贝到 CPU 或特定设备以便传输
# 4. 通过 RPC 调用 PS 上的 apply_gradients 方法
print(f"Worker {rank} sending gradients...")
future = rpc.rpc_async("ps", ps_rref.local_value().apply_gradients, args=(gradients_to_send,))
# 等待 PS 确认更新完成
result = future.wait()
print(f"Worker {rank} update confirmed: {result}")
rpc.shutdown()
# 主入口,用于启动多进程模拟
def main():
# 必须设置网络环境
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
processes = []
roles = {'ps': run_ps, 'worker1': run_worker}
for rank, (role_name, fn) in enumerate(roles.items()):
p = multiprocessing.Process(target=fn, args=(rank, WORLD_SIZE))
p.start()
processes.append(p)
# 确保 PS 先启动,避免 Worker 找不到 PS 节点
if role_name == 'ps':
time.sleep(1)
for p in processes:
p.join()
if __name__ == "__main__":
# 警告:此模拟代码在 Python 3.8+ 版本可能需要调整多进程启动方式(使用 forkserver 或 spawn)
# 在实际分布式环境中,通常使用 torch.distributed.launch 或 slurm 进行管理。
main()
3. 代码解析与实操要点
- PS 节点的职责 (ParameterServer** 类):** PS 类负责实例化和维护模型 (self.model) 和优化器 (self.optimizer)。它提供 apply_gradients 方法,这是 Worker 远程调用的核心接口。
- RRef 的使用: 在 PS 启动时,我们创建了 global_ps_rref = RRef(ps_instance)。Worker 通过 rpc_sync(“ps”, lambda: global_ps_rref) 获取到这个远程引用,从而能调用 PS 实例上的方法。
- 异步通信 (rpc_async): Worker 使用 rpc.rpc_async 调用 apply_gradients,这意味着 Worker 不需要阻塞等待梯度应用完成,理论上可以继续下一轮计算,实现异步训练(当然,本示例中为了简化演示,仍使用了 future.wait() 来获取结果)。
- 数据传输: 梯度张量在通过 RPC 网络传输时,PyTorch 会自动处理序列化和反序列化。确保在发送前,数据格式(如 CPU/GPU 状态)符合预期。
通过这种方式,我们成功地将模型参数管理与梯度计算分离,实现了跨机器的高效参数服务器架构。
汤不热吧