欢迎光临
我们一直在努力

怎样通过 ONNX Simplifier 消除计算图冗余节点:提升移动端推理效率的第一步

在将深度学习模型部署到移动端或嵌入式设备时,模型的大小和推理速度是至关重要的指标。许多从PyTorch或TensorFlow导出的ONNX模型,在计算图中包含大量冗余节点、不必要的初始化器(Initializers)或可合并的常量操作(如Shape、Squeeze、Unsqueeze等),这些都会影响后续推理引擎(如NCNN、MNN、TFLite)的转换效率和最终运行速度。

ONNX Simplifier (onnx-simplifier) 是一个强大的工具,能够自动识别并优化这些冗余结构,生成更精简、更高效的ONNX模型。

准备工作:安装依赖

您需要安装ONNX库和ONNX Simplifier:

pip install onnx
pip install onnx-simplifier

第一步:导出原始ONNX模型

我们首先创建一个简单的PyTorch模型并将其导出为ONNX格式。为了演示简化效果,我们故意在模型中加入一些可能被Simplifier优化的操作。

import torch
import torch.nn as nn
import onnx

# 1. 定义一个简单的模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.bn = nn.BatchNorm2d(16) 
        self.relu = nn.ReLU()
        # 注册一个在推理时可能被消除的常量操作
        self.register_buffer('scale_factor', torch.tensor(1.0))

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = x * self.scale_factor # 如果 scale_factor 保持为 1.0,此操作理论上可优化
        return self.relu(x)

# 2. 导出模型
model = SimpleNet()
dummy_input = torch.randn(1, 3, 32, 32)
original_model_path = "model_original.onnx"

torch.onnx.export(
    model, 
    dummy_input, 
    original_model_path,
    opset_version=11, 
    input_names=['input'], 
    output_names=['output']
)

print(f"原始模型已导出到: {original_model_path}")

# 检查原始模型的节点数量
model_orig = onnx.load(original_model_path)
print(f"原始模型节点数: {len(model_orig.graph.node)}")

第二步:使用 ONNX Simplifier 简化模型

ONNX Simplifier 的核心功能是执行常量折叠(Constant Folding)和死代码消除(Dead Code Elimination),将多个操作合并为更少的节点,或直接将常量操作的结果烘焙到模型中。

我们可以使用 onnxsim.simplify 函数进行简化。它会自动处理输入/输出的形状推断。

import onnxsim
import onnx

original_model_path = "model_original.onnx"
simplified_model_path = "model_simplified.onnx"

# 执行简化操作
# input_shapes 参数可选,用于辅助确定动态输入或多输入时的具体形状
model_sim, check = onnxsim.simplify(
    original_model_path, 
    input_shapes={'input': [1, 3, 32, 32]} # 指定输入形状
)

if check:
    onnx.save(model_sim, simplified_model_path)
    print("\n--- 简化成功 ---")

    # 检查简化后模型的节点数量
    print(f"简化后的模型已保存到: {simplified_model_path}")
    print(f"简化后模型节点数: {len(model_sim.graph.node)}")

    # 比较原始节点数和简化后节点数,可以看到冗余操作(如恒等乘法、BN层参数合并)被消除。
else:
    print("Simplification failed.")

结果分析

如果原始模型包含BatchNorm层(PyTorch导出的ONNX中,BN层通常由多个节点构成)或冗余的恒等操作,Simplifier 会执行以下优化:

  1. 融合BN层参数:将BatchNorm操作的权重和偏置合并到其之前的Conv层权重中,消除多个 BN 相关的节点(Scale, Add, Mul, Div)。
  2. 常量折叠:消除像 x * 1.0 这样的恒等操作,或将所有常量计算的结果直接写入模型初始化器中。

简化后的模型节点更少,数据流更直接,这对于后续使用如 NCNN、MNN 或 TFLite 转换工具时,能够显著提高兼容性和转换速度,同时也有利于端侧推理框架更好地进行硬件加速优化。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 怎样通过 ONNX Simplifier 消除计算图冗余节点:提升移动端推理效率的第一步
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址