在AI模型进入生产环境时,模型的部署和管理是至关重要的一环。直接在Web框架中加载TensorFlow模型会带来性能瓶颈、版本控制困难和缺乏监控等问题。TensorFlow Serving (TFS) 是Google专门为部署机器学习模型设计的灵活、高性能服务系统,它支持模型版本管理、并发处理和标准化的API接口(RESTful/gRPC)。
本文将详细介绍如何将一个Keras模型转换为TFS要求的格式,并通过Docker容器快速启动TensorFlow Serving,最终通过HTTP接口进行推理。
步骤一:准备 SavedModel 格式的模型
TensorFlow Serving 要求模型必须以 SavedModel 格式存储,并且必须按照特定的目录结构组织:/{model_name}/{version_number}/。
我们首先创建一个简单的Keras模型并将其保存为版本1。
import tensorflow as tf
import numpy as np
import os
# 1. 创建一个简单的Keras模型 (输入维度5,输出维度1)
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, input_shape=(5,), activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy')
# 2. 定义保存路径:TFS要求路径结构为 /model_name/version_number/
MODEL_DIR = './tf_serving_model'
VERSION = 1
export_path = os.path.join(MODEL_DIR, str(VERSION))
# 3. 导出 SavedModel 格式
tf.saved_model.save(model, export_path)
print(f"Model successfully saved to: {export_path}")
# 此时目录结构为:./tf_serving_model/1/...
步骤二:使用 Docker 部署 TensorFlow Serving
使用官方提供的TensorFlow Serving Docker镜像可以极大地简化部署过程。我们只需要将本地保存的模型目录映射到容器内的指定路径,并通过环境变量告知TFS模型名称。
前置条件: 确保Docker已安装并运行。
# 假设 SavedModel 路径在本地的 ./tf_serving_model
MODEL_NAME="my_classifier"
# 运行 TFS 容器
# -p 8501:8501: 暴露 REST API 端口
# -v ...: 将本地模型目录映射到容器的 /models/my_classifier 路径下
# -e MODEL_NAME: 告诉 TFS 要加载的模型名称
docker run -t --rm -p 8501:8501 -v "$(pwd)/tf_serving_model:/models/${MODEL_NAME}" \
-e MODEL_NAME="${MODEL_NAME}" tensorflow/serving &
echo "TensorFlow Serving已在后台启动,监听8501端口..."
# 等待几秒钟,确保服务完全启动。
当容器启动后,TFS会自动检测 /models/my_classifier 路径下的最高版本模型(即版本1),并开始提供服务。
步骤三:发送 RESTful 推理请求
TensorFlow Serving 默认在 8501 端口提供 RESTful API。我们可以构造一个JSON payload,通过POST请求发送给服务。
请求的URL格式为:http://localhost:8501/v1/models/{model_name}:predict
import requests
import json
import numpy as np
# 准备数据 (需要匹配模型的输入维度)
data = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]])
# 构造请求 payload
payload = {
"instances": data.tolist()
}
# 发送请求到 TFS REST API
SERVER_URL = 'http://localhost:8501/v1/models/my_classifier:predict'
try:
response = requests.post(SERVER_URL, data=json.dumps(payload))
response.raise_for_status() # 确保请求成功
prediction = response.json()
print("--- 原始请求数据 ---")
print(data)
print("\n--- TFS 返回结果 ---")
print(json.dumps(prediction, indent=2))
except requests.exceptions.RequestException as e:
print(f"请求失败:{e}")
高级应用:模型版本控制
TensorFlow Serving最大的优势在于模型版本管理。如果模型迭代到了版本2,你只需在 ./tf_serving_model/ 目录下创建 2/ 文件夹并保存新模型。TFS会在不中断服务的情况下自动加载版本2,同时保留版本1,从而实现无缝滚动更新和回滚能力。如果你需要指定版本进行推理,请求URL可以更改为:http://localhost:8501/v1/models/{model_name}/versions/{version_number}:predict。
汤不热吧