PyTorch之所以强大且灵活,很大程度上归功于其动态计算图(Dynamic Computational Graph, DCG)和自动微分系统(Autograd)。与TensorFlow 1.x的静态图不同,PyTorch的计算图是根据代码执行即时构建的,图中的每一个操作(Operation)都是一个节点,而这些节点在底层都由一个核心机制支撑:torch.autograd.Function。
理解Function对象,就是理解PyTorch如何实现反向传播的链式调用本质。
什么是 torch.autograd.Function?
当我们在PyTorch中执行一个张量操作(例如 torch.matmul 或 tensor.relu())时,PyTorch会在后台自动创建一个关联的Function对象。这个对象负责两件事:
- 前向计算 (forward): 执行实际的计算逻辑。
- 反向计算 (backward): 根据链式法则计算输入张量的梯度,并将其传递给上一个节点。
对于自定义的操作或需要手动优化内存和速度的场景,我们就需要自己继承并实现 Function 类。
核心概念:ctx 对象与保存中间变量
链式法则的核心是:要计算 $L$ 对 $x$ 的导数 $\frac{\partial L}{\partial x}$,我们需要知道本地操作的导数 $\frac{\partial y}{\partial x}$ 和上游传来的梯度 $\frac{\partial L}{\partial y}$。
在 forward 方法中,我们通常需要保存一些中间计算结果(即 $y$ 或 $x$ 本身)供 backward 方法计算 $\frac{\partial y}{\partial x}$。这个保存动作是通过 ctx (Context) 对象完成的。
- ctx.save_for_backward(*tensors): 在 forward 中保存张量。
- ctx.saved_tensors: 在 backward 中获取保存的张量。
实践示例:自定义 Tanh 函数
我们以一个自定义的 Tanh (双曲正切) 函数为例。其导数为 $\frac{d}{dx} \tanh(x) = 1 – \tanh^2(x)$。显然,为了计算梯度,我们需要在前向计算时保存 $\tanh(x)$ 的输出值。
import torch
from torch.autograd import Function
# 1. 定义自定义 Function
class CustomTanh(Function):
@staticmethod
def forward(ctx, input_tensor):
# 必须是静态方法
# 1. 执行前向计算
output = torch.tanh(input_tensor)
# 2. 保存中间结果供反向传播使用
# 为了计算 1 - tanh(x)^2, 我们需要保存 output
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
# 必须是静态方法
# grad_output 是上游节点传来的梯度 dL/dy
# 1. 恢复保存的张量
output, = ctx.saved_tensors
# 2. 计算本地梯度 (dy/dx = 1 - y^2)
local_grad = 1 - output ** 2
# 3. 计算最终输入梯度 (dL/dx) = (dL/dy) * (dy/dx)
grad_input = grad_output * local_grad
# 必须返回对应 forward 输入数量的梯度张量
return grad_input
# 2. 封装成易于调用的接口(Function.apply)
custom_tanh = CustomTanh.apply
# 3. 运行演示
x = torch.tensor([2.0, -0.5], requires_grad=True, dtype=torch.float32)
# 构建一个简单的计算链:Custom Tanh -> 乘法 -> 求和
y = custom_tanh(x)
# y = tanh(x)
z = y * 5.0
# z = 5 * tanh(x)
loss = z.sum()
# loss = sum(5 * tanh(x))
# 4. 反向传播
loss.backward()
# 验证结果
print(f"输入 X: {x}")
print(f"输出 Y: {y.data}")
print(f"Z 对 X 的导数 (解析解: 5 * (1 - tanh(x)^2)):\n")
# 手动计算解析解进行对比
analytic_grad = 5 * (1 - y.data ** 2)
print(f"自动微分计算的梯度 (x.grad): {x.grad}")
print(f"解析计算的梯度 (Analytic): {analytic_grad}")
print(f"梯度是否接近? {torch.allclose(x.grad, analytic_grad)}")
运行结果清晰地展示了 backward 方法如何利用 forward 阶段保存的 output (即 $y$),结合上游传来的梯度 grad_output (在本例中是 5.0,来自 y * 5.0 的导数),完成了链式法则的计算,最终得到了正确的输入梯度。
总结
torch.autograd.Function 是 PyTorch 自动微分系统的基本构建块。每个 Function 实例都是动态计算图中的一个节点,它在前向阶段负责计算和保存必要的上下文(ctx.save_for_backward),在反向阶段则严格遵循数学上的链式法则,利用保存的上下文和传入的梯度(grad_output)来计算并传递新的梯度(grad_input)。理解并能够实现自定义 Function,是深入理解 PyTorch DCG 工作原理的关键一步。
汤不热吧