MultiWorkerMirroredStrategy (MWMS) 是 TensorFlow 2.x 中用于多机多卡同步训练的首选策略。它通过在每个 Worker 的 GPU 上复制模型权重,并在梯度计算后使用 All-reduce 操作同步更新,实现了高效的分布式训练。然而,在实际操作中,尤其是在初始化阶段或自定义训练循环中,开发者经常会遇到集群通信死锁(Deadlock)的问题。
集群死锁通常发生在以下情况:部分 Worker 进程卡在等待通信的屏障(Barrier)上,而另一部分 Worker 进程由于某种原因(如初始化不一致、I/O阻塞或未进入分布式上下文)没有执行相应的通信操作。
解决 MWMS 死锁的核心在于确保所有 Worker 进程在正确的分布式作用域 (strategy.scope()) 内同步执行操作。
关键步骤:避免 MWMS 死锁
1. 正确配置 TF_CONFIG
MWMS 依赖于 TF_CONFIG 环境变量来识别集群中的角色(Worker, Chief, Parameter Server等)以及各自的地址。配置不当是死锁的首要原因。
假设我们有两台机器,每台机器有两个GPU:
Worker 0 上的 TF_CONFIG (Chief):
export TF_CONFIG='{
"cluster": {
"worker": ["host1:12345", "host2:12345"]
},
"task": {
"type": "worker",
"index": 0
}
}'
Worker 1 上的 TF_CONFIG:
export TF_CONFIG='{
"cluster": {
"worker": ["host1:12345", "host2:12345"]
},
"task": {
"type": "worker",
"index": 1
}
}'
2. 在分布式作用域内定义所有变量和模型
MWMS 在初始化时需要同步所有 Worker 上的变量状态。如果模型、优化器或任何需要跨设备同步的 tf.Variable 在 strategy.scope() 外部定义,将导致变量无法被策略跟踪,从而引发死锁或错误同步。
3. 使用 strategy.run() 执行同步操作
对于训练的每一步,必须使用 strategy.run() 来确保所有副本(包括不同机器上的副本)同时执行相同的计算图。这是 MWMS 实现 All-reduce 同步的关键入口。
实例代码:避免 MWMS 死锁的最小化示例
以下代码展示了一个使用 MWMS 避免死锁的最小化训练骨架。
import tensorflow as tf
import os
# 假设 TF_CONFIG 已经在环境中设置
# 1. 初始化 MultiWorkerMirroredStrategy
strategy = tf.distribute.MultiWorkerMirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
# 定义全局参数
GLOBAL_BATCH_SIZE = 64
# 2. 在 strategy.scope() 内定义模型、优化器和分布式训练步骤
with strategy.scope():
# 模型定义 (确保在 scope 内)
def create_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
return model
model = create_model()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
# 定义损失函数和指标
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False, reduction=tf.keras.losses.Reduction.NONE)
# 包装损失函数以处理分布式批次大小
def compute_loss(labels, predictions):
# 计算每个样本的损失
per_example_loss = loss_object(labels, predictions)
# 必须对每个Worker/Replica的损失求和,然后除以全局批次大小
return tf.nn.compute_average_loss(
per_example_loss,
global_batch_size=GLOBAL_BATCH_SIZE)
# 3. 定义分布式训练步骤 (使用 tf.function 优化)
@tf.function
def distributed_train_step(dist_inputs):
# 使用 strategy.run 来执行所有副本上的计算
per_replica_losses = strategy.run(train_step, args=(dist_inputs,))
# 聚合所有副本的损失并返回平均值
return strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
def train_step(inputs):
images, labels = inputs
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = compute_loss(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
# 4. 数据集准备
# 确保数据集在 Workers 之间正确分片
def create_datasets():
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
y_train = y_train.astype('int64')
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat()
# 使用 Strategy 的 **distribute_datasets_from_function** 进行分片
def dataset_fn(input_context):
ds = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
return ds.batch(GLOBAL_BATCH_SIZE // input_context.num_inputpipelines)
dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
return dist_dataset
# 5. 启动训练循环
dist_dataset = create_datasets()
iterator = iter(dist_dataset)
if tf.distribute.get_replica_context() is None or tf.distribute.get_replica_context().is_chief:
print("Starting distributed training...")
EPOCHS = 3
STEPS_PER_EPOCH = 100
for epoch in range(EPOCHS):
for step in range(STEPS_PER_EPOCH):
total_loss = distributed_train_step(next(iterator))
if step % 10 == 0:
# 只有 Chief Worker 打印日志,避免日志冲突导致的I/O死锁或混乱
if tf.distribute.get_replica_context() is None or tf.distribute.get_replica_context().is_chief:
print(f"Epoch {epoch}, Step {step}, Loss: {total_loss.numpy():.4f}")
# 注意:在多机环境下运行此代码,必须保证所有 Worker 同时启动,且 TF_CONFIG 配置正确。
总结
MWMS 死锁通常不是由 TensorFlow 本身错误引起的,而是由于集群设置或代码流不一致导致的。要避免死锁,请始终确保:
1. TF_CONFIG 在所有 Worker 上正确且一致地设置。
2. 模型、优化器和关键变量定义在 strategy.scope() 内。
3. 使用 strategy.run() 和 tf.function 组合来封装同步的训练步骤。
4. 使用 strategy.distribute_datasets_from_function 确保数据分片正确,避免 I/O 成为瓶颈或同步障碍。
5. 在日志打印或文件操作时,仅允许 Chief Worker 执行,使用 tf.distribute.get_replica_context().is_chief 进行判断,避免I/O竞争导致的潜在阻塞。
汤不热吧