欢迎光临
我们一直在努力

如何用tensorflow serving部署pytorch模型

对于个人站长和开发者来说,在VPS或云虚拟机上部署机器学习模型服务是一个常见的需求。虽然PyTorch在训练上灵活强大,但TensorFlow Serving(TFS)在生产环境中的稳定性和批处理能力往往更胜一筹。本文将指导您如何通过ONNX格式作为桥梁,实现PyTorch模型到TFS的无缝部署。

1. 环境准备

首先,您需要确保您的环境中安装了所有必要的库。我们将在一个Python环境中完成模型的转换工作。

pip install torch torchvision onnx tensorflow==2.x onnx-tf
# 安装用于测试Serving的库
pip install requests

2. 编写 PyTorch 模型并导出为 ONNX

我们创建一个简单的示例模型(如一个基础的线性回归)并将其导出为ONNX格式。

注意: 导出时必须指定模型的输入名称和输出名称,这对于后续TensorFlow Serving的签名定义至关重要。

import torch
import torch.nn as nn

# 2.1 定义一个简单的PyTorch模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 3)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()

# 2.2 准备输入数据和导出参数
dummy_input = torch.randn(1, 10)
output_path_onnx = 'model.onnx'

# 2.3 导出为ONNX格式
torch.onnx.export(
    model, 
    dummy_input, 
    output_path_onnx, 
    export_params=True,
    opset_version=12, 
    do_constant_folding=True, 
    input_names=['input_data'],  # 必须定义输入名称
    output_names=['output_result'], # 必须定义输出名称
    dynamic_axes={'input_data': {0: 'batch_size'}} # 可选:支持动态批量大小
)
print(f"PyTorch模型已成功导出到 {output_path_onnx}")

3. ONNX 模型转换为 TensorFlow SavedModel

TensorFlow Serving只接受SavedModel格式。我们使用 onnx-tf 库进行转换。TFS要求模型文件必须存放在 /model_name/version_number/ 的结构中。

import tensorflow as tf
from onnx_tf.backend import prepare
import onnx
import os

# 3.1 加载ONNX模型
onnx_model = onnx.load("model.onnx")

# 3.2 转换为TF后端格式
tf_rep = prepare(onnx_model)

# 3.3 定义SavedModel的路径结构
MODEL_DIR = './tf_serving_model'
VERSION = 1
export_path = os.path.join(MODEL_DIR, str(VERSION))

# 3.4 导出为SavedModel
tf_rep.export_graph(export_path)

print(f"TensorFlow SavedModel已导出到 {export_path}")

此时,您应该有一个名为 tf_serving_model/1/ 的目录,其中包含 saved_model.pbvariables/ 文件夹。

4. 运行 TensorFlow Serving 服务器

最简单的方法是通过Docker来启动TFS服务。如果您在VPS上操作,确保Docker已安装。

# 假设您在与tf_serving_model同级的目录下运行此命令
docker pull tensorflow/serving
docker run -t --rm -p 8501:8501 \ 
    -v "$(pwd)/tf_serving_model:/models/my_pytorch_model" \ 
    -e MODEL_NAME=my_pytorch_model \ 
    tensorflow/serving

这条命令做了以下几件事:
1. 将本地的 tf_serving_model 目录挂载到容器内的 /models/my_pytorch_model
2. 通过 -e MODEL_NAME 指定模型名称为 my_pytorch_model
3. 将容器的8501端口映射到主机的8501端口(用于RESTful API调用)。

5. 测试部署的模型

现在服务已在 http://your_vps_ip:8501 上运行。我们使用一个简单的Python脚本发送推理请求。

重要: TFS REST API的输入字段必须与您在PyTorch导出时定义的 input_names (‘input_data’) 一致。

import requests
import numpy as np
import json

# 5.1 构造测试数据 (必须是列表结构,与模型输入维度匹配)
data = np.random.randn(1, 10).tolist() # 单个样本,10个特征

# 5.2 构造TFS请求体
payload = {
    "signature_name": "serving_default", 
    "instances": [
        {"input_data": data[0]} # 注意这里使用了PyTorch导出时的输入名称 'input_data'
    ]
}

# 5.3 发送请求
url = 'http://localhost:8501/v1/models/my_pytorch_model:predict'
headers = {"Content-Type": "application/json"}

response = requests.post(url, data=json.dumps(payload), headers=headers)

# 5.4 打印结果
if response.status_code == 200:
    result = response.json()
    print("--- 部署成功,推理结果 ---")
    print(json.dumps(result, indent=4))
else:
    print(f"请求失败,状态码: {response.status_code}")
    print(response.text)

通过上述步骤,您成功地将PyTorch模型转换并部署到了高可用的TensorFlow Serving架构上,实现了跨框架的模型服务化。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 如何用tensorflow serving部署pytorch模型
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址