欢迎光临
我们一直在努力

tensorflow导出模型的过程

对于个人站长或技术爱好者来说,在VPS或虚拟机上部署机器学习模型是实现高阶应用的关键一步。TensorFlow的SavedModel格式是官方推荐的模型导出和部署标准,适用于TensorFlow Serving、TFLite甚至直接加载到Python环境。

本文将聚焦如何使用Keras API将训练好的模型导出为SavedModel格式,并验证导出的完整性。

步骤一:准备环境与训练模型

我们首先需要一个简单的Keras模型作为示例。如果您已经在VPS上安装了TensorFlow环境,可以直接运行以下代码。

import tensorflow as tf
from tensorflow import keras
import numpy as np
import os

# 1. 创建一个简单的线性回归Keras模型 (y = x + 1)
model = keras.Sequential([
    keras.layers.Dense(units=1, input_shape=[1])
])
model.compile(optimizer='sgd', loss='mean_squared_error')

# 2. 训练模型 (模拟)
xs = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=float)
ys = np.array([2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=float)
print("开始训练模型...")
model.fit(xs, ys, epochs=50, verbose=0)
print("模型训练完成。")

步骤二:导出模型为SavedModel格式

SavedModel格式是跨语言、跨平台部署的基础。在导出时,通常推荐在路径末尾添加版本号(例如/1/2),这对于后续使用TensorFlow Serving进行无缝版本更新至关重要。

我们使用tf.saved_model.save()函数进行导出操作。

# 3. 定义导出路径 (使用版本号 '1')
EXPORT_PATH = './tf_saved_model_example/1'

# 如果目录存在则先删除,确保干净导出
if os.path.exists(EXPORT_PATH):
    import shutil
    shutil.rmtree(EXPORT_PATH)

# 4. 导出模型 (SavedModel格式)
tf.saved_model.save(model, EXPORT_PATH)

print(f"模型已成功导出到: {EXPORT_PATH}")

# 检查导出的目录结构
print("\n--- 导出的目录结构 ---")
!ls -R {EXPORT_PATH}

导出后的目录结构说明:

导出的文件夹(如./tf_saved_model_example/1)包含两个核心文件和子目录:

  1. saved_model.pb: 实际的模型图和元数据。
  2. variables/: 存储模型的权重参数。
  3. assets/: 存储额外的文件(如词汇表),本例中为空。

步骤三:验证SavedModel的完整性

模型导出后,我们必须确保它能被正确加载并进行预测。我们使用tf.saved_model.load()来重新加载模型。

请注意,当模型作为SavedModel加载时,它的行为类似于一个低级Function,而不是Keras Model对象。我们需要通过其签名(Signature)来调用它,默认签名为serving_default

# 5. 验证模型加载
reloaded_model = tf.saved_model.load(EXPORT_PATH)

# 6. 使用原始模型预测 (输入10.0)
original_prediction = model.predict([10.0])

# 7. 使用加载的模型签名进行预测
# 'serving_default' 是Keras模型自动生成的默认签名
reloaded_function = reloaded_model.signatures["serving_default"]

# 输入必须是Tensor类型
reloaded_prediction_tensor = reloaded_function(tf.constant([10.0], dtype=tf.float32))

print(f"\n--- 预测结果验证 ---")
print(f"原始Keras模型预测 (10.0): {original_prediction.flatten()[0]:.4f}")
# 结果是一个字典,键名通常是 'output_0'
print(f"SavedModel加载后预测 (10.0): {reloaded_prediction_tensor['output_0'].numpy().flatten()[0]:.4f}")

if abs(original_prediction.flatten()[0] - reloaded_prediction_tensor['output_0'].numpy().flatten()[0]) < 0.01:
    print("模型导出和加载成功,预测结果一致!")
else:
    print("模型导出或加载可能存在问题。")

通过以上步骤,您已经成功将训练好的TensorFlow Keras模型安全地导出为SavedModel格式,为下一步在生产环境(如云主机或Docker容器)中部署TensorFlow Serving打下了坚实的基础。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » tensorflow导出模型的过程
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址