在将深度学习模型部署到移动端或嵌入式设备时,模型的大小和推理速度是至关重要的指标。许多从PyTorch或TensorFlow导出的ONNX模型,在计算图中包含大量冗余节点、不必要的初始化器(Initializers)或可合并的常量操作(如Shape、Squeeze、Unsqueeze等),这些都会影响后续推理引擎(如NCNN、MNN、TFLite)的转换效率和最终运行速度。
ONNX Simplifier (onnx-simplifier) 是一个强大的工具,能够自动识别并优化这些冗余结构,生成更精简、更高效的ONNX模型。
准备工作:安装依赖
您需要安装ONNX库和ONNX Simplifier:
pip install onnx
pip install onnx-simplifier
第一步:导出原始ONNX模型
我们首先创建一个简单的PyTorch模型并将其导出为ONNX格式。为了演示简化效果,我们故意在模型中加入一些可能被Simplifier优化的操作。
import torch
import torch.nn as nn
import onnx
# 1. 定义一个简单的模型
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
# 注册一个在推理时可能被消除的常量操作
self.register_buffer('scale_factor', torch.tensor(1.0))
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = x * self.scale_factor # 如果 scale_factor 保持为 1.0,此操作理论上可优化
return self.relu(x)
# 2. 导出模型
model = SimpleNet()
dummy_input = torch.randn(1, 3, 32, 32)
original_model_path = "model_original.onnx"
torch.onnx.export(
model,
dummy_input,
original_model_path,
opset_version=11,
input_names=['input'],
output_names=['output']
)
print(f"原始模型已导出到: {original_model_path}")
# 检查原始模型的节点数量
model_orig = onnx.load(original_model_path)
print(f"原始模型节点数: {len(model_orig.graph.node)}")
第二步:使用 ONNX Simplifier 简化模型
ONNX Simplifier 的核心功能是执行常量折叠(Constant Folding)和死代码消除(Dead Code Elimination),将多个操作合并为更少的节点,或直接将常量操作的结果烘焙到模型中。
我们可以使用 onnxsim.simplify 函数进行简化。它会自动处理输入/输出的形状推断。
import onnxsim
import onnx
original_model_path = "model_original.onnx"
simplified_model_path = "model_simplified.onnx"
# 执行简化操作
# input_shapes 参数可选,用于辅助确定动态输入或多输入时的具体形状
model_sim, check = onnxsim.simplify(
original_model_path,
input_shapes={'input': [1, 3, 32, 32]} # 指定输入形状
)
if check:
onnx.save(model_sim, simplified_model_path)
print("\n--- 简化成功 ---")
# 检查简化后模型的节点数量
print(f"简化后的模型已保存到: {simplified_model_path}")
print(f"简化后模型节点数: {len(model_sim.graph.node)}")
# 比较原始节点数和简化后节点数,可以看到冗余操作(如恒等乘法、BN层参数合并)被消除。
else:
print("Simplification failed.")
结果分析
如果原始模型包含BatchNorm层(PyTorch导出的ONNX中,BN层通常由多个节点构成)或冗余的恒等操作,Simplifier 会执行以下优化:
- 融合BN层参数:将BatchNorm操作的权重和偏置合并到其之前的Conv层权重中,消除多个 BN 相关的节点(Scale, Add, Mul, Div)。
- 常量折叠:消除像 x * 1.0 这样的恒等操作,或将所有常量计算的结果直接写入模型初始化器中。
简化后的模型节点更少,数据流更直接,这对于后续使用如 NCNN、MNN 或 TFLite 转换工具时,能够显著提高兼容性和转换速度,同时也有利于端侧推理框架更好地进行硬件加速优化。
汤不热吧