欢迎光临
我们一直在努力

微软内部使用PyTorch还是TensorFlow?

许多关注AI部署的技术人员都会好奇,微软在内部和其AI服务(如Azure ML、Windows ML)中更侧重于哪个深度学习框架?事实是,尽管微软是TensorFlow的早期支持者,但近年来它在PyTorch生态中的投入巨大,特别是在开源贡献和加速库方面。然而,无论是使用PyTorch还是TensorFlow,模型部署的通用标准和性能瓶颈的解决方案都指向同一个核心技术:ONNX (Open Neural Network Exchange) 和 ONNX Runtime (ORT)

ONNX Runtime 是微软内部孵化并维护的高性能推理引擎。它支持跨硬件、跨操作系统平台(如Windows、Linux、Android)的高效推理,并且能利用各种硬件加速器(如CUDA, DirectML, OpenVINO等)。对于 PyTorch 模型而言,通过将其转换为 ONNX 格式,可以显著提升部署效率和推理速度。

本文将深入讲解如何将一个基础的 PyTorch 模型成功导出为 ONNX 格式,并使用 ONNX Runtime 进行高性能推理。

步骤一:环境准备与依赖安装

你需要安装 PyTorch、ONNX 导出工具以及 ONNX Runtime 库。

pip install torch onnx onnxruntime numpy

步骤二:定义并导出 PyTorch 模型到 ONNX

导出过程的关键在于使用 torch.onnx.export 函数。我们必须为模型提供一个示例输入(dummy input),以便追踪图的计算路径并确定输入输出的动态尺寸(如果需要)。

以下是一个简单的 PyTorch 线性分类器示例:

import torch
import torch.nn as nn
import numpy as np

# 1. 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = SimpleModel()
model.eval() # 切换到评估模式

# 2. 准备示例输入 (批次大小为1,特征数为10)
dummy_input = torch.randn(1, 10, requires_grad=True)

# 3. 定义 ONNX 导出参数
output_onnx_path = "simple_model.onnx"
input_names = ["input_data"]
output_names = ["output_prediction"]

# 动态轴配置:如果需要支持不同批次大小的输入
dynamic_axes = {
    'input_data': {0: 'batch_size'},
    'output_prediction': {0: 'batch_size'}
}

# 4. 执行导出
print(f"开始导出 PyTorch 模型到 {output_onnx_path}...")
torch.onnx.export(
    model,                           # 待导出的模型
    dummy_input,                     # 示例输入
    output_onnx_path,                # 输出文件名
    export_params=True,              # 导出模型的权重
    opset_version=14,                # ONNX 算子集版本
    do_constant_folding=True,        # 优化常量折叠
    input_names=input_names,         # 输入节点名称
    output_names=output_names,       # 输出节点名称
    dynamic_axes=dynamic_axes        # 动态轴配置
)

print("模型导出成功!")

步骤三:使用 ONNX Runtime 进行高性能推理

模型导出为 ONNX 格式后,我们不再需要 PyTorch 库进行推理。现在可以直接使用 ONNX Runtime 加载模型并执行推理。

import onnxruntime as ort
import numpy as np

# 1. 加载 ONNX 模型
output_onnx_path = "simple_model.onnx"
# 注意:默认情况下 ORT 会自动选择最佳执行提供程序(如CPU, CUDA, DirectML)
sess = ort.InferenceSession(
    output_onnx_path, 
    providers=['CPUExecutionProvider'] # 也可以指定 ['CUDAExecutionProvider'] 等
)

# 2. 准备新的推理数据 (确保数据格式与模型期望的一致,通常为 numpy 数组)
new_input = np.random.randn(5, 10).astype(np.float32) # 使用批次大小为 5

# 3. 准备输入字典
# 键必须与导出时定义的 input_names 匹配
input_name = sess.get_inputs()[0].name
input_dict = {input_name: new_input}

# 4. 执行推理
print("开始使用 ONNX Runtime 进行推理...")
output = sess.run(None, input_dict)

# 5. 处理输出
predictions = output[0]

print(f"输入形状: {new_input.shape}")
print(f"输出形状: {predictions.shape}")
print("推理结果 (前两行):")
print(predictions[:2])

总结:ONNX 在 AI Infra 中的核心地位

虽然微软在研究和开发中大量使用 PyTorch,但它确保了通过 ONNX 这一中间表示层,所有的模型都能高效地运行在 Azure 及其边缘设备上。这种“训练在 PyTorch,部署在 ONNX Runtime”的模式是现代 AI 基础设施的最佳实践。它不仅解决了框架兼容性问题,还通过 ORT 的优化内核和硬件加速器集成,实现了极致的推理性能,极大地提升了模型部署的灵操性和效率。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 微软内部使用PyTorch还是TensorFlow?
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址