在深度学习和高性能计算领域,算子(Kernel)的性能往往是模型推理速度的瓶颈。虽然像 cuBLAS 和 cuDNN 这样的厂商原生库已经高度优化,但它们是通用性的。当面对特定维度、数据类型或计算模式时,通过像 Triton 这样的领域特定语言(DSL)手动编写和优化算子,往往能实现更高的性能。
本文将聚焦于如何使用 OpenAI 的 Triton 框架,编写一个比标准的 PyTorch 原生矩阵乘法(Matmul)在特定尺寸下更快的自定义 GPU 算子。
1. 为什么自定义算子可能更快?
厂商库(如 cuBLAS)为了通用性,必须处理各种边缘情况和尺寸。自定义算子则允许开发者:
- 精确控制 Tiling 和 Blocking: 根据目标硬件(如 A100/H100)的 L1/L2 缓存和共享内存布局,选择最佳的块大小。
- 避免不必要的同步或数据转换: 针对特定需求实现最小化的指令集。
- 使用最新的硬件特性: 如 Tensor Cores 的高级模式,可能尚未被通用库广泛采纳。
2. 环境准备
确保您安装了 PyTorch 和 Triton。建议使用支持 CUDA 的环境。
pip install torch triton
3. Triton 矩阵乘法(GEMM)实现
我们将实现一个简单的 $C = A \times B$ 矩阵乘法。Triton 的核心优势在于其 @triton.jit 装饰器,允许我们定义 GPU 程序的执行逻辑,并通过 triton.language (tl) 库来操作 CUDA 线程和内存。
假设我们计算 $C_{M imes N} = A_{M imes K} imes B_{K imes N}$。
import torch
import triton
import triton.language as tl
# 定义Triton JIT算子
@triton.jit
def matmul_kernel(A_ptr, B_ptr, C_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):
# 1. 确定当前线程块要处理的C的瓦片(tile)
pid_m = tl.program_id(axis=0) # M维度上的程序ID
pid_n = tl.program_id(axis=1) # N维度上的程序ID
# 2. 初始化累加器
# accumulator 是一个 BLOCK_SIZE_M x BLOCK_SIZE_N 的寄存器块
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# 3. 循环遍历 K 维度 (Reduction)
for k in range(0, K, BLOCK_SIZE_K):
# 3a. 计算A的内存指针并加载 BLOCK_SIZE_M x BLOCK_SIZE_K 的块
A_block_ptrs = A_ptr + \
(pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]) * stride_am + \
(k + tl.arange(0, BLOCK_SIZE_K)[None, :]) * stride_ak
# 3b. 计算B的内存指针并加载 BLOCK_SIZE_K x BLOCK_SIZE_N 的块
B_block_ptrs = B_ptr + \
(k + tl.arange(0, BLOCK_SIZE_K)[:, None]) * stride_bk + \
(pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn
# 边界检查 (处理K维度末尾的非对齐部分)
A_mask = (k + tl.arange(0, BLOCK_SIZE_K)[None, :]) < K
B_mask = (k + tl.arange(0, BLOCK_SIZE_K)[:, None]) < K
A_block = tl.load(A_block_ptrs, mask=A_mask)
B_block = tl.load(B_block_ptrs, mask=B_mask)
# 3c. 累加乘积
accumulator += tl.dot(A_block, B_block)
# 4. 将结果存储到 C 矩阵
C_ptrs = C_ptr + \
(pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]) * stride_cm + \
(pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_cn
# 边界检查 (处理M和N维度末尾的非对齐部分)
C_mask = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None] < M) & \
(pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :] < N)
tl.store(C_ptrs, accumulator, mask=C_mask)
# 5. Python 封装函数
def triton_matmul(A, B):
M, K = A.shape
K, N = B.shape
C = torch.empty((M, N), device=A.device, dtype=A.dtype)
# 定义最佳的块大小。这些值通常需要通过实验确定。
BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 32 # 影响内存带宽和L1/共享内存利用率
# 启动网格 (Grid) 大小
grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N))
matmul_kernel[grid](A, B, C, M, N, K,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K)
return C
# 6. 性能测试
# 注意:为了让自定义Triton算子有机会胜出,我们选择一个相对不常见的、非标准库优化的尺寸。
M, N, K = 1024, 768, 512
A = torch.randn(M, K, device='cuda', dtype=torch.float16)
B = torch.randn(K, N, device='cuda', dtype=torch.float16)
# 确保结果正确性 (可选)
# triton_result = triton_matmul(A, B)
# torch_result = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16)
# print(f"Max difference: {torch.max(torch.abs(triton_result - torch_result))}")
# 性能比较 (使用GPU计时)
quantiles = [0.5, 0.2, 0.8]
# Triton Benchmark
ms_triton = triton.testing.do_bench(lambda: triton_matmul(A, B), quantiles=quantiles)
print(f"Triton Matmul took: {ms_triton[0]:.3f}ms (50% percentile)")
# PyTorch Benchmark (使用厂商原生库)
ms_torch = triton.testing.do_bench(lambda: torch.matmul(A, B), quantiles=quantiles)
print(f"PyTorch Native Matmul took: {ms_torch[0]:.3f}ms (50% percentile)")
print(f"Triton Speedup: {ms_torch[0] / ms_triton[0]:.2f}x")
4. 结果分析与优化潜力
在上述代码中,如果 Triton 的块大小和 K 维度循环能完美匹配目标 GPU 的 L2 缓存和 Tensor Core 单元,并且 PyTorch/cuBLAS 对该特定维度组合的优化稍逊一筹(例如,为了支持批处理而引入额外开销),Triton 算子将展示出显著的性能优势。
关键优化点:
- 共享内存 (Shared Memory): 上述例子虽然使用了寄存器级别的 Tiling,但高性能 GEMM 必须显式地将 A 和 B 的块预加载到 CUDA 的 Shared Memory 中,以最大化内存重用,避免重复的全局内存访问。Triton 提供了 tl.load(…, cache=tl.load_cache_shared) 等指令来简化这一过程。
- 软件预取 (Prefetching): 在计算当前瓦片的同时,异步加载下一个瓦片的数据。
- 多配置搜索: 使用 Triton 提供的自动调优功能 (triton.autotune),搜索数百种不同的 BLOCK_SIZE_M/N/K 组合和调度策略,找到特定硬件上的全局最优解。这正是超越通用库的关键技术所在。
通过对内存访问模式的精确控制和对硬件架构的深刻理解,Triton 提供了底层黑魔法的工具,使得开发者能够实现比通用厂商库更快的专业化算子。
汤不热吧