谷歌是否会放弃TensorFlow(TF)是一个复杂的生态问题,但对于AI基础设施工程师而言,更实际的挑战是:如何在新模型普遍倾向于使用PyTorch训练的情况下,继续高效利用已经搭建好的TensorFlow Serving(TFS)集群和TFX管道?
答案是:通过中间表示层ONNX(Open Neural Network Exchange)。ONNX提供了一个标准的图表示,允许我们在不同的框架之间平滑过渡。本文将详细指导如何将一个PyTorch模型转换为ONNX格式,进而转换为TFS可识别的SavedModel格式,并最终进行部署。
1. 环境准备与依赖安装
我们需要安装PyTorch用于模型训练和导出,以及onnx和onnx-tf库用于转换操作。
pip install torch torchvision onnx numpy tensorflow==2.10 onnx-tf
2. PyTorch模型的定义与ONNX导出
我们以一个简单的线性回归模型为例进行操作。
import torch
import torch.nn as nn
import numpy as np
# 1. 定义一个简单的PyTorch模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
# 实例化模型并加载权重(此处使用随机权重)
model = SimpleModel()
model.eval()
# 定义输入数据的张量形状(Batch Size = 1)
dummy_input = torch.randn(1, 10, requires_grad=True)
# 2. 导出到ONNX格式
# 注意:input_names和output_names是部署时的关键,必须定义清晰
onnx_file_path = "simple_model.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_file_path,
export_params=True,
opset_version=13, # 推荐使用较新的版本
do_constant_folding=True,
input_names=['input_tensor'],
output_names=['output_tensor'],
dynamic_axes={'input_tensor': {0: 'batch_size'}}, # 允许动态批处理大小
)
print(f"PyTorch模型已成功导出到: {onnx_file_path}")
3. 将ONNX模型转换为TensorFlow SavedModel
TensorFlow Serving(TFS)的标准输入格式是SavedModel。我们使用onnx-tf工具链进行转换。注意:转换后的SavedModel需要符合TFS的目录结构要求:model_name/version_number/****。
import tensorflow as tf
from onnx_tf.backend import prepare
import onnx
import os
# 定义SavedModel的输出路径
EXPORT_VERSION = 1
SAVED_MODEL_DIR = f"serving_model/{EXPORT_VERSION}"
# 3.1 加载ONNX模型
onnx_model = onnx.load(onnx_file_path)
# 3.2 准备TensorFlow后端
# prepare函数会将ONNX图转换为TensorFlow可执行的图结构
tf_rep = prepare(onnx_model)
# 3.3 导出为SavedModel
tf_rep.export_graph(SAVED_MODEL_DIR)
print(f"SavedModel已成功生成于: {SAVED_MODEL_DIR}")
print("SavedModel签名定义 (用于客户端调用):")
# 打印默认签名,通常为'default_signature'
print(tf_rep.signatures)
4. 使用TensorFlow Serving部署模型
我们将使用Docker容器来运行TensorFlow Serving,将上一步生成的serving_model目录映射到容器内部。
# 确保你位于包含 serving_model 目录的父级目录下
MODEL_BASE_PATH=$(pwd)
docker run -d --rm \
-p 8501:8501 \
-v "$MODEL_BASE_PATH/serving_model:/models/simple_model" \
-e MODEL_NAME=simple_model \
tensorflow/serving:2.10.0
# 验证服务是否启动成功 (等待几秒)
# curl http://localhost:8501/v1/models/simple_model
5. 客户端调用测试
最后,我们使用Python向TFS发送一个RESTful请求,验证模型是否正确运行。
import requests
import json
# 确保输入数据维度与PyTorch导出时的 dummy_input 维度一致 (1, 10)
input_data = np.random.rand(1, 10).astype(np.float32)
# 构造请求体
# 注意:这里的 'instances' 字段必须是列表形式,且数据类型必须匹配 SavedModel 签名
# 在 ONNX 转换的情况下,输入名称通常会保留为你在导出时定义的 'input_tensor'
request_data = {
"instances": input_data.tolist()
}
# 发送预测请求
response = requests.post(
'http://localhost:8501/v1/models/simple_model:predict',
data=json.dumps(request_data)
)
response.raise_for_status() # 检查HTTP错误
result = response.json()
print("------------------------------------------")
print(f"TFS 预测结果: {result['predictions'][0][0]:.4f}")
print("------------------------------------------")
通过这种方法,我们成功地将一个纯粹的PyTorch模型,部署到了以TensorFlow为核心的服务基础设施中,实现了模型训练框架和部署框架的解耦。
汤不热吧