欢迎光临
我们一直在努力

谷歌会放弃TensorFlow吗?

谷歌是否会放弃TensorFlow(TF)是一个复杂的生态问题,但对于AI基础设施工程师而言,更实际的挑战是:如何在新模型普遍倾向于使用PyTorch训练的情况下,继续高效利用已经搭建好的TensorFlow Serving(TFS)集群和TFX管道?

答案是:通过中间表示层ONNX(Open Neural Network Exchange)。ONNX提供了一个标准的图表示,允许我们在不同的框架之间平滑过渡。本文将详细指导如何将一个PyTorch模型转换为ONNX格式,进而转换为TFS可识别的SavedModel格式,并最终进行部署。

1. 环境准备与依赖安装

我们需要安装PyTorch用于模型训练和导出,以及onnxonnx-tf库用于转换操作。

pip install torch torchvision onnx numpy tensorflow==2.10 onnx-tf

2. PyTorch模型的定义与ONNX导出

我们以一个简单的线性回归模型为例进行操作。

import torch
import torch.nn as nn
import numpy as np

# 1. 定义一个简单的PyTorch模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

# 实例化模型并加载权重(此处使用随机权重)
model = SimpleModel()
model.eval()

# 定义输入数据的张量形状(Batch Size = 1)
dummy_input = torch.randn(1, 10, requires_grad=True)

# 2. 导出到ONNX格式
# 注意:input_names和output_names是部署时的关键,必须定义清晰
onnx_file_path = "simple_model.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_file_path,
    export_params=True,
    opset_version=13, # 推荐使用较新的版本
    do_constant_folding=True,
    input_names=['input_tensor'],
    output_names=['output_tensor'],
    dynamic_axes={'input_tensor': {0: 'batch_size'}}, # 允许动态批处理大小
)
print(f"PyTorch模型已成功导出到: {onnx_file_path}")

3. 将ONNX模型转换为TensorFlow SavedModel

TensorFlow Serving(TFS)的标准输入格式是SavedModel。我们使用onnx-tf工具链进行转换。注意:转换后的SavedModel需要符合TFS的目录结构要求:model_name/version_number/****。

import tensorflow as tf
from onnx_tf.backend import prepare
import onnx
import os

# 定义SavedModel的输出路径
EXPORT_VERSION = 1
SAVED_MODEL_DIR = f"serving_model/{EXPORT_VERSION}"

# 3.1 加载ONNX模型
onnx_model = onnx.load(onnx_file_path)

# 3.2 准备TensorFlow后端
# prepare函数会将ONNX图转换为TensorFlow可执行的图结构
tf_rep = prepare(onnx_model)

# 3.3 导出为SavedModel
tf_rep.export_graph(SAVED_MODEL_DIR)

print(f"SavedModel已成功生成于: {SAVED_MODEL_DIR}")
print("SavedModel签名定义 (用于客户端调用):")
# 打印默认签名,通常为'default_signature'
print(tf_rep.signatures)

4. 使用TensorFlow Serving部署模型

我们将使用Docker容器来运行TensorFlow Serving,将上一步生成的serving_model目录映射到容器内部。

# 确保你位于包含 serving_model 目录的父级目录下
MODEL_BASE_PATH=$(pwd)

docker run -d --rm \
  -p 8501:8501 \
  -v "$MODEL_BASE_PATH/serving_model:/models/simple_model" \
  -e MODEL_NAME=simple_model \
  tensorflow/serving:2.10.0

# 验证服务是否启动成功 (等待几秒)
# curl http://localhost:8501/v1/models/simple_model

5. 客户端调用测试

最后,我们使用Python向TFS发送一个RESTful请求,验证模型是否正确运行。

import requests
import json

# 确保输入数据维度与PyTorch导出时的 dummy_input 维度一致 (1, 10)
input_data = np.random.rand(1, 10).astype(np.float32)

# 构造请求体
# 注意:这里的 'instances' 字段必须是列表形式,且数据类型必须匹配 SavedModel 签名
# 在 ONNX 转换的情况下,输入名称通常会保留为你在导出时定义的 'input_tensor'
request_data = {
    "instances": input_data.tolist()
}

# 发送预测请求
response = requests.post(
    'http://localhost:8501/v1/models/simple_model:predict',
    data=json.dumps(request_data)
)

response.raise_for_status() # 检查HTTP错误
result = response.json()

print("------------------------------------------")
print(f"TFS 预测结果: {result['predictions'][0][0]:.4f}")
print("------------------------------------------")

通过这种方法,我们成功地将一个纯粹的PyTorch模型,部署到了以TensorFlow为核心的服务基础设施中,实现了模型训练框架和部署框架的解耦。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 谷歌会放弃TensorFlow吗?
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址