
在TensorFlow 2.x中,Keras提供了高层的model.fit()接口,大多数场景下使用起来非常方便。但当我们需要更精细地控制训练过程时——比如实现梯度裁剪、多优化器交替更新、对抗训练(GAN)或者自定义学习率调度——就需要用到自定义训练循环(Custom Training Loop)。本文将从零开始讲解如何使用tf.GradientTape构建自定义训练循环,并介绍几个实战中常用的优化技巧。
为什么需要自定义训练循环
model.fit()封装了训练的所有细节,这在标准分类和回归任务中非常好用。但在以下场景中,我们需要更多的灵活性:
• GAN训练需要交替更新生成器和判别器
• 需要在每个step对梯度做裁剪或变换
• 需要同时使用多个优化器
• 需要在训练过程中动态计算和记录自定义指标
• 需要实现Meta-Learning等前沿算法
TensorFlow通过tf.GradientTape提供了自动微分能力,让我们可以精确控制前向传播和反向传播的每一步。
基础:使用GradientTape构建训练循环
下面是一个最基本的自定义训练循环示例,使用MNIST数据集训练一个简单的全连接网络:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import numpy as np
# 加载数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0
# 构建模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
# 创建优化器和损失函数
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# 构建Dataset管道
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(10000).batch(64)
# 训练循环
for epoch in range(5):
epoch_loss = 0.0
num_batches = 0
for x_batch, y_batch in train_dataset:
with tf.GradientTape() as tape:
logits = model(x_batch, training=True)
loss = loss_fn(y_batch, logits)
# 计算梯度并更新参数
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
epoch_loss += loss.numpy()
num_batches += 1
print(f'Epoch {epoch+1}, Loss: {epoch_loss / num_batches:.4f}')
技巧一:梯度裁剪防止梯度爆炸
在训练RNN或深层网络时,梯度爆炸是常见问题。自定义训练循环让我们可以在apply_gradients之前对梯度进行裁剪:
for x_batch, y_batch in train_dataset:
with tf.GradientTape() as tape:
logits = model(x_batch, training=True)
loss = loss_fn(y_batch, logits)
gradients = tape.gradient(loss, model.trainable_variables)
# 全局梯度裁剪:将梯度范数限制在最大值以内
gradients, global_norm = tf.clip_by_global_norm(gradients, clip_norm=1.0)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
tf.clip_by_global_norm会先计算所有梯度向量拼接后的全局范数,如果超过阈值则按比例缩放,这样既防止了梯度爆炸,又保持了梯度方向的一致性。

技巧二:封装为tf.function加速训练
默认情况下,Python原生的训练循环每次迭代都有Python解释器的开销。通过@tf.function装饰器将训练步骤编译为计算图,可以获得显著的速度提升:
@tf.function
def train_step(x_batch, y_batch):
with tf.GradientTape() as tape:
logits = model(x_batch, training=True)
loss = loss_fn(y_batch, logits)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
@tf.function
def eval_step(x_batch, y_batch):
logits = model(x_batch, training=False)
loss = loss_fn(y_batch, logits)
predictions = tf.argmax(logits, axis=1)
accuracy = tf.reduce_mean(
tf.cast(tf.equal(predictions, tf.cast(y_batch, tf.int64)), tf.float32)
)
return loss, accuracy
# 训练(速度明显更快)
for epoch in range(10):
for x_batch, y_batch in train_dataset:
loss = train_step(x_batch, y_batch)
# 验证
test_loss, test_acc = 0.0, 0.0
test_batches = 0
test_dataset = tf.data.Dataset.from_tensor_slices(
(x_test, y_test)
).batch(64)
for x_batch, y_batch in test_dataset:
l, a = eval_step(x_batch, y_batch)
test_loss += l.numpy()
test_acc += a.numpy()
test_batches += 1
print(f'Epoch {epoch+1}, Test Loss: {test_loss/test_batches:.4f}, '
f'Test Acc: {test_acc/test_batches:.4f}')
使用@tf.function后,首次调用会触发Trace编译,后续调用直接执行优化后的计算图。注意:编译后函数内部不能使用Python原生的print调试,需要用tf.print替代。
技巧三:自定义学习率调度
自定义训练循环让我们可以灵活地实现各种学习率调度策略,比如Warmup + 余弦退火:
class WarmupCosineSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, base_lr, warmup_steps, total_steps):
super().__init__()
self.base_lr = base_lr
self.warmup_steps = warmup_steps
self.total_steps = total_steps
def __call__(self, step):
step = tf.cast(step, tf.float32)
warmup = self.base_lr * (step / self.warmup_steps)
cosine_decay = 0.5 * (1.0 + tf.cos(
3.14159 * (step - self.warmup_steps) /
(self.total_steps - self.warmup_steps)
))
cosine = self.base_lr * cosine_decay
return tf.where(step < self.warmup_steps, warmup, cosine)
# 使用
schedule = WarmupCosineSchedule(base_lr=1e-3, warmup_steps=500, total_steps=10000)
optimizer = tf.keras.optimizers.Adam(learning_rate=schedule)

技巧四:整合TensorBoard可视化
自定义训练循环同样可以使用TensorBoard记录训练过程中的各项指标,方便对比和调优:
import datetime
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
summary_writer = tf.summary.create_file_writer(log_dir)
for epoch in range(10):
for step, (x_batch, y_batch) in enumerate(train_dataset):
loss = train_step(x_batch, y_batch)
# 每100步记录一次
if step % 100 == 0:
with summary_writer.as_default():
tf.summary.scalar('train_loss', loss, step=optimizer.iterations)
tf.summary.scalar('learning_rate',
optimizer.learning_rate(optimizer.iterations),
step=optimizer.iterations)
# 每个epoch记录验证指标
test_loss, test_acc = 0.0, 0.0
test_batches = 0
for x_batch, y_batch in test_dataset:
l, a = eval_step(x_batch, y_batch)
test_loss += l.numpy()
test_acc += a.numpy()
test_batches += 1
with summary_writer.as_default():
tf.summary.scalar('test_loss', test_loss/test_batches, step=epoch)
tf.summary.scalar('test_acc', test_acc/test_batches, step=epoch)
# 启动TensorBoard: tensorboard --logdir logs/fit
总结
TensorFlow自定义训练循环给了我们对训练过程的完全控制权。核心要点:
1. 使用tf.GradientTape记录前向传播,然后调用tape.gradient()计算梯度
2. 通过tf.clip_by_global_norm做梯度裁剪,防止梯度爆炸
3. 用@tf.function装饰训练步骤函数,获得2-5倍的速度提升
4. 继承tf.keras.optimizers.schedules.LearningRateSchedule实现自定义学习率策略
5. 结合TensorBoard的tf.summaryAPI记录训练指标
在实际项目中,建议先尝试model.fit(),当遇到需要精细控制的场景时再切换到自定义训练循环。两者并不矛盾,甚至可以在同一个项目中混合使用。
汤不热吧