欢迎光临
我们一直在努力

如何通过 tf.config.experimental.set_memory_growth 解决显存被 TF 强行吃光的尴尬

在使用 TensorFlow (TF) 进行深度学习开发时,尤其是涉及到 GPU 资源管理时,许多开发者会遇到一个令人头疼的问题:TensorFlow 默认会在初始化时,预先分配几乎所有可用的 GPU 显存,即使模型非常小。这导致了显存资源的浪费,并且使得同一块 GPU 上无法运行其他模型或进程。

本文将聚焦于一个简单且高效的解决方案:使用 tf.config.experimental.set_memory_growth 来实现按需分配(Memory Growth)。

为什么 TF 默认要吃光所有显存?

TensorFlow 采取这种激进的策略是为了提高性能和避免内存碎片化。当它预先分配了整个池子后,后续的操作就不需要频繁地向操作系统请求内存,从而减少了开销。然而,在端侧推理、多任务服务器或资源受限的环境中,这种默认行为是不可接受的。

解决方案:启用显存增长 (Memory Growth)

tf.config.experimental.set_memory_growth(device, enable) 这个 API 可以告诉 TensorFlow,不要预先分配所有显存,而是在需要时逐渐增加分配量。它类似于操作系统中的动态内存管理。

注意: 该 API 通常需要在模型创建或任何 GPU 操作发生之前调用。

实操步骤与代码示例

以下是如何在 TensorFlow 2.x 环境中配置显存增长的完整步骤。

步骤 1: 检查并获取 GPU 列表

首先,我们需要知道系统中有哪些可用的物理 GPU 设备。

步骤 2: 循环设置 Memory Growth

我们遍历所有找到的 GPU 设备,并对每个设备启用 set_memory_growth(True)

import tensorflow as tf
import os

# 确保使用 TensorFlow 2.x
print(f"TensorFlow 版本: {tf.__version__}")

# -----------------------------------------------------
# 核心配置部分:必须在任何模型实例化之前执行
# -----------------------------------------------------

gpus = tf.config.experimental.list_physical_devices('GPU')

if gpus:
    try:
        # 遍历所有物理 GPU 设备
        for gpu in gpus:
            # 启用显存增长模式
            tf.config.experimental.set_memory_growth(gpu, True)
            print(f"成功为设备 {gpu.name} 启用显存增长")

        # 验证设置是否成功
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(f"物理GPU数量: {len(gpus)}, 逻辑GPU数量: {len(logical_gpus)}")

    except RuntimeError as e:
        # 如果设备已经在运行时调用此函数,会报错
        print(f"设置内存增长失败 (运行时错误): {e}")
else:
    print("未检测到GPU设备。此设置仅对GPU有效。")

# -----------------------------------------------------
# 步骤 3: 运行模型验证效果
# -----------------------------------------------------

if gpus:
    # 只有当 GPU 存在时,我们才运行模型来验证效果
    print("\n--- 运行小型模型进行验证 ---")

    # 创建一个极小型的 Keras 模型
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    # 运行一次推理
    dummy_input = tf.random.normal((1, 784))
    _ = model(dummy_input)

    print("模型已初始化并运行一次。")
    print("请使用 'nvidia-smi' 命令查看当前的显存占用情况。")
    print("如果配置成功,您将看到只占用了数百兆字节 (MB),而不是全部显存。")

    # 模拟长时间运行,方便用户检查 nvidia-smi
    # import time
    # time.sleep(10)

总结与建议

启用 tf.config.experimental.set_memory_growth(True) 是解决 TensorFlow 默认显存占用问题的黄金法则。在现代 TensorFlow 部署中(无论是训练还是推理),特别是在共享资源的 GPU 服务器上,强烈建议总是将此配置放在程序的最开始部分。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 如何通过 tf.config.experimental.set_memory_growth 解决显存被 TF 强行吃光的尴尬
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址