欢迎光临
我们一直在努力

深度解析:PyTorch 的 backward() 是如何通过 Autograd 动态构建计算图的?

PyTorch 的 Autograd 机制是其核心竞争力之一。与 TensorFlow 1.x 等框架使用的静态图不同,PyTorch 采用动态计算图(Define-by-Run),这意味着计算图是在前向传播过程中即时构建的。而 backward() 方法则是触发反向传播,计算所有叶子节点(需要梯度)梯度更新的关键。

本文将通过一个简单的示例,解析 PyTorch 是如何利用 Autograd 和 grad_fn 来动态构建计算图,并执行反向传播的。

核心概念:grad_fn 与动态图

当你在 PyTorch 中对一个设置了 requires_grad=True 的张量执行任何操作时,结果张量会自动获得一个 grad_fn 属性。这个 grad_fn 指向一个函数,它知道如何计算该操作的局部梯度,以及需要引用哪些输入张量来完成反向传播的链式法则。

动态构建意味着: 每次前向传播时,都会生成一个新的计算图。这使得 PyTorch 在处理循环神经网络(RNN)和条件控制流(如 if/else 语句)时,表现得极其灵活。

实操:追踪计算图的构建过程

我们来定义一个简单的线性计算:$z = 3x^2 + 2x$。为了简化,我们只追踪 $y=x*3$ 这一步。

步骤一:创建叶子张量

首先,我们创建一个需要追踪梯度的叶子张量 x。叶子张量是用户定义的,而不是由其他操作生成的。

import torch

# 创建叶子张量,必须设定 requires_grad=True
x = torch.tensor([2.0, 3.0], requires_grad=True)
print(f"X (叶子节点): {x}")
print(f"X 的 grad_fn: {x.grad_fn}\n")

输出解析: 叶子节点 xgrad_fnNone,因为它不是任何操作的结果。

步骤二:前向传播与 grad_fn 的生成

我们执行一个简单的乘法操作 $y = x * 3$:

# 运算操作:y = x * 3
y = x * 3

# 检查 y
print(f"Y: {y}")
print(f"Y 的 grad_fn: {y.grad_fn}")

输出解析:

  • y 是一个新的张量,其值是 $[6.0, 9.0]$。
  • y.grad_fn 变成了 。这个对象记录了如何计算 $d(y)/d(x)$ 的梯度,并存储了必要的上下文信息(比如引用的输入张量 x)。

我们继续操作,将 y 聚合为一个标量 z

# 运算操作:z = y.sum() (聚合为标量)
z = y.sum()

# 检查 z
print(f"Z: {z}")
print(f"Z 的 grad_fn: {z.grad_fn}\n")

输出解析:

  • z 同样拥有自己的 grad_fn,即
  • 此时,计算图已经构建完成:x -> MulBackward0 -> SumBackward0 -> z

步骤三:调用 z.backward()

调用 backward() 方法,PyTorch Autograd 就会从终点 z 开始,沿着 grad_fn 指示的路径,逆向遍历计算图,应用链式法则,并将计算出的梯度累加到叶子节点的 .grad 属性中。

注意: 默认情况下,backward() 只能在标量(或只有一个元素的张量)上调用。如果 z 是一个非标量张量,我们需要向 backward() 传递一个 gradient 参数(通常是与 z 形状相同的全 1 张量)。

# 触发反向传播
z.backward()

# 检查 x 的梯度 d(z)/d(x)
print(f"X 的梯度 (x.grad): {x.grad}")

梯度计算验证

我们的函数是 $z = ext{sum}(y) = ext{sum}(3x)$。

如果 $x = [x_1, x_2]$,那么 $z = 3x_1 + 3x_2$。

我们要求的梯度是 $\frac{\partial z}{\partial x} = [\frac{\partial z}{\partial x_1}, \frac{\partial z}{\partial x_2}]$

$\frac{\partial z}{\partial x_1} = 3$

$\frac{\partial z}{\partial x_2} = 3$

因此,计算结果 x.grad 应该是 tensor([3., 3.]),这与我们的代码输出一致。

总结

backward() 方法本身并不进行梯度计算,它只是一个调度器,触发 Autograd 引擎。Autograd 通过张量上的 grad_fn 属性,在运行时动态地构建和维护计算图。当调用 backward() 时,它按照从后向前的顺序,利用存储在每个 grad_fn 中的信息,高效地应用链式法则,完成整个反向传播过程。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 深度解析:PyTorch 的 backward() 是如何通过 Autograd 动态构建计算图的?
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址