抢占式实例(Spot Instance)是云服务商提供的一种基于竞价模式的计算资源,通常价格比按需实例(On-Demand)低50%到90%。对于需要数周甚至数月训练周期的大型语言模型(LLM)而言,Spot Instance是降低训练成本的关键策略。然而,Spot实例最大的挑战是其随时可能被回收(通常有2分钟的通知期)。
本文将深入探讨如何结合持久化存储和智能的容错机制,构建一个能够利用Spot实例,并自动恢复训练的弹性AI基础设施。
Contents
1. 核心挑战:状态持久化与容错
要成功使用Spot实例进行长时间训练,必须解决两个关键问题:
- 持久化存储: 训练数据、模型权重和优化器状态必须存储在外部的、非易失的存储系统上(例如AWS S3、Azure Blob或EFS/FSx)。本地NVMe存储在实例终止后会丢失。
- 优雅退出机制: 必须在实例被回收前的2分钟通知期内,执行最后的检查点保存操作。
2. 环境准备与配置(AWS为例)
假设我们使用AWS S3作为我们的检查点(Checkpoint)存储桶。
2.1 依赖安装
我们需要
1 | boto3 |
(用于S3操作)和标准的深度学习框架(如PyTorch或Hugging Face Accelerate)。
1 pip install torch transformers accelerate boto3
2.2 S3 Checkpoint 路径配置
设置训练脚本能够自动将检查点上传到S3。
3. 实现容错训练的Python代码
在LLM训练中,检查点不仅包括模型权重,还必须包括优化器状态、学习率调度器状态和当前的训练步数(Epoch/Step Count)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49 import torch
import os
import boto3
# 假设配置
BUCKET_NAME = 'llm-training-checkpoints'
CHECKPOINT_KEY = 'latest_checkpoint.pt'
def load_checkpoint(model, optimizer, scheduler):
s3 = boto3.client('s3')
try:
# 尝试从S3下载最新的检查点
s3.download_file(BUCKET_NAME, CHECKPOINT_KEY, '/tmp/checkpoint.pt')
checkpoint = torch.load('/tmp/checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
print(f"--- Checkpoint loaded. Resuming from epoch {start_epoch} ---")
return start_epoch
except Exception as e:
print(f"No checkpoint found or load failed: {e}. Starting fresh.")
return 0
def save_checkpoint(model, optimizer, scheduler, epoch):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
}
# 1. 保存到本地临时文件
temp_path = '/tmp/temp_checkpoint.pt'
torch.save(checkpoint, temp_path)
# 2. 上传到S3
s3 = boto3.client('s3')
s3.upload_file(temp_path, BUCKET_NAME, CHECKPOINT_KEY)
print(f"Checkpoint saved and uploaded to s3://{BUCKET_NAME}/{CHECKPOINT_KEY}")
# 示例训练循环 (简化)
# ... model, optimizer, scheduler 初始化
# start_epoch = load_checkpoint(model, optimizer, scheduler)
# for epoch in range(start_epoch, total_epochs):
# # Training steps...
# if should_save_periodically():
# save_checkpoint(model, optimizer, scheduler, epoch)
4. 关键:Spot 实例终止通知处理(优雅退出)
在AWS EC2上,可以通过访问特定的元数据服务URL来检查实例是否收到终止通知。这是实现优雅退出的核心。
终止通知URL:
1 | http://169.254.169.254/latest/meta-data/spot/termination-time |
如果该URL返回一个时间戳(而不是404 Not Found),则表示实例将在该时间点被回收。
我们使用一个独立的脚本(或集成到训练启动脚本中)来持续监控此URL。
4.1 监控与保存脚本 (
1
monitor_and_train.sh
)
1 | monitor_and_train.sh |
此脚本首先启动训练进程,然后在一个后台循环中监控终止通知。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46 #!/bin/bash
TRAINING_SCRIPT="/path/to/your/train_llm.py"
# 启动训练进程(必须在后台运行,以便脚本继续监控)
python $TRAINING_SCRIPT &
TRAINING_PID=$!
echo "Training process started with PID: $TRAINING_PID"
# 循环检查终止通知
while true; do
# 尝试获取终止时间
TERMINATION_TIME=$(curl -s -w '%{http_code}' -o /dev/null http://169.254.169.254/latest/meta-data/spot/termination-time)
# AWS文档:如果HTTP状态码是200,表示收到终止通知
if [ "$TERMINATION_TIME" = "200" ]; then
echo "\n--- WARNING: Spot Instance Termination Notice Received! ---"
# 收到通知后,立即执行最后的检查点保存逻辑(例如,通过发送信号给训练脚本)
# 理想情况下,训练脚本应该接收SIGTERM信号,并调用 save_checkpoint()
kill -SIGTERM $TRAINING_PID
# 等待训练进程优雅退出(最多等待90秒,以留出S3上传时间)
wait $TRAINING_PID
EXIT_CODE=$?
if [ $EXIT_CODE -eq 0 ]; then
echo "Training saved checkpoint successfully. Shutting down."
else
echo "Training process failed during forced shutdown. EXIT CODE: $EXIT_CODE"
fi
# 脚本退出,等待实例回收
exit 0
fi
# 检查训练进程是否意外终止(非Spot回收导致)
if ! kill -0 $TRAINING_PID 2> /dev/null; then
echo "Training process ended normally or crashed."
exit 1
fi
# 每隔5秒检查一次
sleep 5
done
5. 部署与弹性恢复
为了确保弹性恢复,您应该使用AWS Auto Scaling Group (ASG) 或一个工作流编排工具(如AWS Step Functions或Kubernetes Job with Spot Nodes)。
- ASG配置: 将您的ASG配置为使用Spot实例,并设置
1Desired Capacity
为1。
- 启动脚本: 将上述
1monitor_and_train.sh
设置为EC2实例的User Data脚本。
当Spot实例被回收时,ASG会自动启动一个新的Spot实例。新实例启动后,它会运行启动脚本。由于训练脚本在启动时调用了
1 | load_checkpoint() |
函数,它将自动从S3下载最新的检查点,恢复训练,从而确保训练任务的连续性,即便经历了多次中断。
汤不热吧