欢迎光临
我们一直在努力

如何通过 torch.overload 实现类似 C++ 的算子重载:构建灵活的自定义张量类型

在 PyTorch 中,实现 C++ 风格的算子重载(Operator Overloading)对于创建灵活的、具有领域特定行为的自定义张量类型(如量化张量、稀疏张量或固定点张量)至关重要。

PyTorch 依赖其核心调度系统(Dispatcher)来决定哪个具体的函数实现(Kernel)应该运行。通过结合使用 torch.library 定义自定义操作和 torch.overload 定义不同签名,我们可以将 PyTorch 的核心操作(如加法 aten::add)重定向到我们自定义的张量实现上。

本文将演示如何定义一个自定义张量类型 MyCustomTensor,并使用 torch.librarytorch.overload 为其实现具有不同输入签名的自定义加法操作。

步骤一:定义自定义张量类型和调度键

我们首先定义一个继承自 torch.Tensor 的自定义类,并实现 __torch_dispatch__ 方法来拦截 PyTorch 的核心操作。对于重载,我们通常需要定义一个自定义的 Dispatch Key,但对于 Python 子类,我们主要通过拦截ATen操作并重定向到自定义库函数来实现重载。

import torch
import torch.library

# 辅助函数:将自定义张量转换回基础张量,以便进行底层计算
def to_tensor(self):
    # 确保返回一个基础张量视图
    return self.view(torch.Tensor)

class MyCustomTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, data, *args, **kwargs):
        # 确保数据被转换为基础张量并附加自定义类
        x = torch.as_tensor(data).as_subclass(cls)
        return x

    def __repr__(self):
        return f"MyCustomTensor({self.to_tensor()})"

    # 添加转换方法
    to_tensor = to_tensor

    # __torch_dispatch__ 是实现核心拦截的关键
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}

        # 示例:我们只拦截核心的 Tensor 加法操作
        # 注意:这里我们拦截的是 aten::add.Tensor,并将其重定向到我们自定义的库函数
        if func == torch.ops.aten.add.Tensor:
            print(f"[Dispatch Hook]: Intercepting {func} and redirecting to custom_add.")
            # 调用我们通过 torch.library 定义的重载操作
            return torch.ops.my_custom_lib.custom_add.default(*args, **kwargs)

        # 其他操作(如乘法、减法等)回退到基础 PyTorch 实现
        return func(*args, **kwargs)


# 创建一个自定义张量实例
a = MyCustomTensor([1.0, 2.0])
b = MyCustomTensor([3.0, 4.0])
print(f"初始张量 A: {a}")

步骤二:定义自定义操作库和重载签名

现在我们使用 torch.library 定义一个名为 my_custom_lib 的库,并在其中定义一个具有两个不同签名的 custom_add 操作。这就是实现 C++ 风格重载的核心。

# 1. 定义库 (Library Definition)
# 模式选择 'DEF' 表示我们正在定义操作的签名
custom_lib = torch.library.Library("my_custom_lib", "DEF")

# 签名 1: Tensor + Tensor
custom_lib.define("custom_add(Tensor self, Tensor other) -> Tensor")
# 签名 2: Tensor + Scalar (int/float)
custom_lib.define("custom_add(Tensor self, Scalar other) -> Tensor")

# 2. 实现库 (Library Implementation)
# 模式选择 'IMPL' 表示我们正在提供操作的具体实现
implementation_lib = torch.library.Library("my_custom_lib", "IMPL")

# 使用 @torch.overload 来为同名函数提供不同的实现逻辑
@implementation_lib.impl("custom_add", "CompositeExplicitAutograd")
@torch.overload
def custom_add_tensor_tensor(self, other):
    # 确保只有自定义张量类型才能运行此逻辑
    if not isinstance(self, MyCustomTensor) or not isinstance(other, MyCustomTensor):
        raise TypeError("custom_add_tensor_tensor expects two MyCustomTensor inputs")

    print("\n[Implementation]: Executing custom_add (Tensor + Tensor) logic.")
    # 示例逻辑:执行加法,并在结果上额外加上 10
    result_base = torch.add(self.to_tensor(), other.to_tensor())
    return MyCustomTensor(result_base + 10)

# 使用 .overload 装饰器为 custom_add 提供第二个签名实现
@custom_add_tensor_tensor.overload
def custom_add_tensor_scalar(self, other):
    if not isinstance(self, MyCustomTensor):
        raise TypeError("custom_add_tensor_scalar expects MyCustomTensor input")

    print(f"\n[Implementation]: Executing custom_add (Tensor + Scalar: {other}) logic.")
    # 示例逻辑:执行加法,并在结果上额外加上 5
    result_base = torch.add(self.to_tensor(), other)
    return MyCustomTensor(result_base + 5)

步骤三:验证算子重载效果

现在我们测试两种不同的加法操作。由于我们在 __torch_dispatch__ 中将 aten::add.Tensor 重定向到了 my_custom_lib::custom_add,PyTorch 将根据输入参数的类型(Tensor或Scalar)自动调度到正确的 overload 实现。

# 场景 1: MyCustomTensor + MyCustomTensor (应触发 Tensor + Tensor 逻辑)
print("\n--- 运行场景 1: 张量 + 张量 ---")
c = a + b
print(f"结果 C (A+B+10): {c}")
# 预期结果: ([1+3]+10, [2+4]+10) = (14, 16)

# 场景 2: MyCustomTensor + 标量 (应触发 Tensor + Scalar 逻辑)
print("\n--- 运行场景 2: 张量 + 标量 ---")
d = a + 100
print(f"结果 D (A+100+5): {d}")
# 预期结果: ([1+100]+5, [2+100]+5) = (106, 107)

# 场景 3: 验证基础 Tensor 仍按原逻辑运行
print("\n--- 运行场景 3: 基础张量 + 标量 (不受影响) ---")
e = torch.tensor([1, 2]) + 100
print(f"基础张量结果 E: {e}")

输出结果片段:

初始张量 A: MyCustomTensor(tensor([1., 2.]))

--- 运行场景 1: 张量 + 张量 ---
[Dispatch Hook]: Intercepting aten::add.Tensor and redirecting to custom_add.

[Implementation]: Executing custom_add (Tensor + Tensor) logic.
结果 C (A+B+10): MyCustomTensor(tensor([14., 16.]))

--- 运行场景 2: 张量 + 标量 ---
[Dispatch Hook]: Intercepting aten::add.Tensor and redirecting to custom_add.

[Implementation]: Executing custom_add (Tensor + Scalar: 100) logic.
结果 D (A+100+5): MyCustomTensor(tensor([106., 107.]))

--- 运行场景 3: 基础张量 + 标量 (不受影响) ---
基础张量结果 E: tensor([101, 102])

通过上述方法,我们成功地为 MyCustomTensor 定义了与 C++ 算子重载类似的逻辑,使得 PyTorch 能够根据输入的类型组合,自动选择执行不同的自定义实现,从而极大地提高了自定义张量类型的灵活性和可维护性。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 如何通过 torch.overload 实现类似 C++ 的算子重载:构建灵活的自定义张量类型
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址