在 PyTorch 中,当内置操作无法满足性能或功能需求时,我们需要自定义高性能的 C++/CUDA 算子。要让 PyTorch 的自动求导机制(Autograd)识别并正确计算这些自定义算子的梯度,我们必须使用 torch.autograd.Function 来封装我们的前向和后向 CUDA 逻辑。
本文将以一个简单的元素级平方操作 ($Y=X^2$) 为例,演示如何编写 CUDA Kernel、C++ 接口以及 Python 端的 autograd.Function。
1. 理论基础:自定义算子的梯度
对于 $Y=X^2$ 这个操作,我们需要实现其前向和后向逻辑。
- 前向 (Forward): $Y_i = X_i^2$
- 后向 (Backward): 我们需要计算输入梯度 $\nabla X$。根据链式法则:$$\nabla X = \nabla Y \cdot \frac{\partial Y}{\partial X}$$其中,$\frac{\partial Y}{\partial X} = 2X$。因此,后向操作为:$$\nabla X_i = \nabla Y_i \cdot 2X_i$$
2. CUDA Kernel 实现 (square_kernel.cu)
我们将 CUDA 代码分为前向和后向两个核函数。
#include <cuda.h>
#include <cuda_runtime.h>
// 辅助宏定义,用于计算线程索引
#define THREADS_PER_BLOCK 512
// -------------------------------------
// 前向核函数: Y = X^2
// -------------------------------------
__global__ void square_forward_kernel(const float* input, float* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = input[idx] * input[idx];
}
}
// -------------------------------------
// 后向核函数: grad_X = grad_Y * 2X
// -------------------------------------
__global__ void square_backward_kernel(const float* grad_output, const float* input, float* grad_input, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
// 应用链式法则: dL/dX = dL/dY * dY/dX
// dY/dX = 2 * X
grad_input[idx] = grad_output[idx] * 2.0f * input[idx];
}
}
// 计算 CUDA 启动参数
static dim3 get_grid(int n) {
int blocks = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
return dim3(blocks);
}
// C++ 接口函数,供 ATen 调用
void square_forward_cuda(const at::Tensor& input, at::Tensor& output) {
int n = input.numel();
dim3 grid = get_grid(n);
dim3 block(THREADS_PER_BLOCK);
square_forward_kernel<<<grid, block>>>(input.data_ptr<float>(), output.data_ptr<float>(), n);
}
void square_backward_cuda(const at::Tensor& grad_output, const at::Tensor& input, at::Tensor& grad_input) {
int n = input.numel();
dim3 grid = get_grid(n);
dim3 block(THREADS_PER_BLOCK);
square_backward_kernel<<<grid, block>>>(grad_output.data_ptr<float>(), input.data_ptr<float>(), grad_input.data_ptr<float>(), n);
}
3. C++ PyTorch 接口注册 (square_op.cpp)
我们需要一个 C++ 文件来定义 PyTorch ATen 接口,并使用 Pybind11 (或 Torch C++ Extension) 将其绑定到 Python。
#include <torch/extension.h>
#include "square_kernel.cu" // 假设我们把 CUDA kernel 放在头文件中或编译时链接
// C++ Wrapper for Forward Pass
// 检查输入并分配输出张量
at::Tensor square_forward(const at::Tensor& input) {
// 确保输入是 CUDA float 张量
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
at::Tensor output = torch::empty_like(input);
square_forward_cuda(input, output);
return output;
}
// C++ Wrapper for Backward Pass
at::Tensor square_backward(const at::Tensor& grad_output, const at::Tensor& input) {
TORCH_CHECK(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
at::Tensor grad_input = torch::empty_like(input);
square_backward_cuda(grad_output, input, grad_input);
return grad_input;
}
// 注册模块到 Python
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("square_forward", &square_forward, "Custom Square Forward (CUDA)");
m.def("square_backward", &square_backward, "Custom Square Backward (CUDA)");
}
4. 编译设置 (setup.py)
使用 torch.utils.cpp_extension 编译 CUDA 扩展。
# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='square_op',
ext_modules=[
CUDAExtension(
'square_extension',
[
'square_op.cpp',
'square_kernel.cu' # 包含 CUDA kernels
],
extra_compile_args={'nvcc': ['-O3']}
),
],
cmdclass={
'build_ext': BuildExtension
}
)
执行编译命令:
python setup.py install
5. Python Autograd 封装和测试
编译成功后,即可在 Python 中导入 square_extension 并使用 torch.autograd.Function 封装它。
import torch
# 导入编译好的扩展库 (假设编译为 square_extension)
import square_extension
class CustomSquareFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# 必须确保输入张量是 CUDA 张量
if not input.is_cuda:
input = input.cuda()
# 1. 保存必要的数据供 backward 使用
# 对于 Y=X^2,backward 需要用到 X
ctx.save_for_backward(input)
# 2. 调用 C++/CUDA 前向接口
output = square_extension.square_forward(input)
return output
@staticmethod
def backward(ctx, grad_output):
# 1. 取出保存的数据
input, = ctx.saved_tensors
# 2. 调用 C++/CUDA 后向接口
# grad_output 是上游传来的梯度 (dL/dY)
grad_input = square_extension.square_backward(grad_output, input)
return grad_input
# 实例化并测试
custom_square = CustomSquareFunction.apply
# 验证梯度正确性
X = torch.tensor([1.0, 2.0, 3.0], requires_grad=True, dtype=torch.float32).cuda()
# 使用自定义算子
Y = custom_square(X)
# 假设上游梯度为 [10.0, 10.0, 10.0]
# 理论梯度: grad_X = grad_Y * 2X
# [10*2*1, 10*2*2, 10*2*3] = [20.0, 40.0, 60.0]
grad_Y_upstream = torch.tensor([10.0, 10.0, 10.0], dtype=torch.float32).cuda()
Y.backward(grad_Y_upstream)
print("输入 X:\n", X)
print("输出 Y:\n", Y)
print("计算得到的梯度 dX:\n", X.grad)
# 验证是否等于理论值 [20.0, 40.0, 60.0]
assert torch.allclose(X.grad, torch.tensor([20.0, 40.0, 60.0]).cuda())
print("\n梯度验证成功!")
汤不热吧