欢迎光临
我们一直在努力

如何通过自定义 Op实现 解决推理库不支持的 Swish 等激活函数:从 C++ 接口到算子注册

背景

在深度学习模型部署过程中,我们经常会遇到由于推理框架(如 MNN、NCNN、TNN)更新较慢,导致某些新出的激活函数(如 Swish、HardSwish)或者自定义算子不被支持的情况。这时,开发者通常面临两个选择:一是修改模型结构,用基础算子拼接;二是手动实现并注册自定义算子(Custom Op)。

本文将以 MNN 推理框架为例,详细演示如何通过 C++ 接口实现一个简单的 Swish 算子($f(x) = x \cdot \text{sigmoid}(x)$)并将其注册到推理引擎中。

1. 算法逻辑实现

首先,我们需要继承推理框架的算子执行基类。在 MNN 中,这个基类是 Execution。我们需要重写 onExecute 方法来编写具体的计算逻辑。

#include <MNN/Interpreter.hpp>
#include <MNN/MNNDefine.h>
#include <cmath>
#include <vector>

using namespace MNN;

class SwishExecution : public Execution {
public:
    SwishExecution(Backend *backend) : Execution(backend) {}

    // 核心计算逻辑
    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override {
        auto input = inputs[0];
        auto output = outputs[0];

        const float *inputData = input->host<float>();
        float *outputData = output->host<float>();
        int elementSize = input->elementSize();

        // Swish(x) = x * (1 / (1 + exp(-x)))
        for (int i = 0; i < elementSize; i++) {
            float x = inputData[i];
            outputData[i] = x / (1.0f + std::exp(-x));
        }
        return NO_ERROR;
    }
};

2. 定义算子工厂 (Creator)

推理引擎在解析模型文件时,会根据算子名称或类型寻找对应的工厂类来创建实例。我们需要实现一个 Creator 类。

class SwishCreator : public CPUBackend::Creator {
public:
    virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, 
                               const MNN::Op *op, Backend *backend) const override {
        return new SwishExecution(backend);
    }
};

3. 注册算子

在 C++ 程序启动时或动态库加载时,需要将算子注册到对应的后端(如 CPU 或 GPU)。MNN 提供了宏来简化这一过程。注意:这里的 OpType_Extra 通常用于处理转换工具无法直接识别的自定义插件算子。

void registerSwishOp() {
    // 这里的 \"Swish\" 字符串必须与模型转换时指定的 Plugin Name 一致
    static std::once_flag s_flag;
    std::call_once(s_flag, []() {
        MNNInsertExtraBackendCreator(MNN_FORWARD_CPU, \"Swish\", new SwishCreator);
    });
}

4. 模型转换时的配合

仅仅在推理端写好代码是不够的。在使用转换工具(如 MNNConvert)将 ONNX 或 PyTorch 模型转为 MNN 格式时,如果遇到不支持的算子,需要确保模型中的该节点被标记为 \”Plugin\” 或 \”Extra\” 类型,并保留算子名称为 \”Swish\”。

在 Python 端(PyTorch 导出)通常可以这样操作:

import torch
import torch.nn as nn

class SwishPlugin(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        return i * torch.sigmoid(i)

    @staticmethod
    def symbolical(g, n):
        # 导出为 ONNX 时指定算子名为 Swish
        return g.op(\"Swish\", n)

总结

通过自定义算子,我们可以快速补全推理库缺失的功能,而无需等待官方更新。实现的三个关键步骤是:逻辑实现(Execution)、实例创建(Creator)和全局注册(Register)。在实际生产环境下,建议对 onExecute 循环进行 SIMD(如 NEON 或 AVX2)指令集优化,以获得最佳性能。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 如何通过自定义 Op实现 解决推理库不支持的 Swish 等激活函数:从 C++ 接口到算子注册
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址