背景
在深度学习模型部署过程中,我们经常会遇到由于推理框架(如 MNN、NCNN、TNN)更新较慢,导致某些新出的激活函数(如 Swish、HardSwish)或者自定义算子不被支持的情况。这时,开发者通常面临两个选择:一是修改模型结构,用基础算子拼接;二是手动实现并注册自定义算子(Custom Op)。
本文将以 MNN 推理框架为例,详细演示如何通过 C++ 接口实现一个简单的 Swish 算子($f(x) = x \cdot \text{sigmoid}(x)$)并将其注册到推理引擎中。
1. 算法逻辑实现
首先,我们需要继承推理框架的算子执行基类。在 MNN 中,这个基类是 Execution。我们需要重写 onExecute 方法来编写具体的计算逻辑。
#include <MNN/Interpreter.hpp>
#include <MNN/MNNDefine.h>
#include <cmath>
#include <vector>
using namespace MNN;
class SwishExecution : public Execution {
public:
SwishExecution(Backend *backend) : Execution(backend) {}
// 核心计算逻辑
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override {
auto input = inputs[0];
auto output = outputs[0];
const float *inputData = input->host<float>();
float *outputData = output->host<float>();
int elementSize = input->elementSize();
// Swish(x) = x * (1 / (1 + exp(-x)))
for (int i = 0; i < elementSize; i++) {
float x = inputData[i];
outputData[i] = x / (1.0f + std::exp(-x));
}
return NO_ERROR;
}
};
2. 定义算子工厂 (Creator)
推理引擎在解析模型文件时,会根据算子名称或类型寻找对应的工厂类来创建实例。我们需要实现一个 Creator 类。
class SwishCreator : public CPUBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op, Backend *backend) const override {
return new SwishExecution(backend);
}
};
3. 注册算子
在 C++ 程序启动时或动态库加载时,需要将算子注册到对应的后端(如 CPU 或 GPU)。MNN 提供了宏来简化这一过程。注意:这里的 OpType_Extra 通常用于处理转换工具无法直接识别的自定义插件算子。
void registerSwishOp() {
// 这里的 \"Swish\" 字符串必须与模型转换时指定的 Plugin Name 一致
static std::once_flag s_flag;
std::call_once(s_flag, []() {
MNNInsertExtraBackendCreator(MNN_FORWARD_CPU, \"Swish\", new SwishCreator);
});
}
4. 模型转换时的配合
仅仅在推理端写好代码是不够的。在使用转换工具(如 MNNConvert)将 ONNX 或 PyTorch 模型转为 MNN 格式时,如果遇到不支持的算子,需要确保模型中的该节点被标记为 \”Plugin\” 或 \”Extra\” 类型,并保留算子名称为 \”Swish\”。
在 Python 端(PyTorch 导出)通常可以这样操作:
import torch
import torch.nn as nn
class SwishPlugin(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
return i * torch.sigmoid(i)
@staticmethod
def symbolical(g, n):
# 导出为 ONNX 时指定算子名为 Swish
return g.op(\"Swish\", n)
总结
通过自定义算子,我们可以快速补全推理库缺失的功能,而无需等待官方更新。实现的三个关键步骤是:逻辑实现(Execution)、实例创建(Creator)和全局注册(Register)。在实际生产环境下,建议对 onExecute 循环进行 SIMD(如 NEON 或 AVX2)指令集优化,以获得最佳性能。
汤不热吧