9.15 Computational graphs and reverse-mode autograd

In §9.6 and §9.7 we derived backpropagation for a multilayer perceptron, working through the chain rule one weight matrix at a time. That derivation is correct, but it gives a slightly misleading impression: it suggests that backprop is a peculiar trick for neural networks, something you have to re-derive whenever the architecture changes. In fact backprop is a special case of a far more general algorithm called reverse-mode automatic differentiation, often shortened to autodiff or autograd. The same algorithm that gives gradients for a tiny MLP also gives gradients for a 175-billion-parameter transformer, a physics simulator, a weather model, or any other program built from differentiable pieces.

Modern deep learning frameworks, PyTorch, JAX and TensorFlow chief among them, automate this entirely. You write the forward pass in ordinary Python, using whatever control flow and data structures you like, and the framework records what you did in a computational graph. When you call loss.backward() (PyTorch) or jax.grad(f)(x) (JAX), the framework walks that graph in reverse and hands you the gradient. You never write dL/dW by hand again. This is one of the great labour-saving inventions of the field, and it is the reason a research student in 2026 can prototype a new architecture in an afternoon when the same idea would have taken a fortnight in 2010.

That convenience comes with a hidden cost: the abstraction occasionally leaks. When training is slow, when memory blows up, when gradients silently disagree with what you expected, when you have to write a custom CUDA kernel, at those moments you need to know what is happening underneath. The aim of this section is to give you that mental model. We will define computational graphs precisely, contrast forward-mode and reverse-mode autodiff, work a small example numerically by hand, and explain why deep learning needs reverse mode in particular. The next two sections then connect the picture to working code, first in NumPy from scratch (§9.16) and then in PyTorch (§9.17).

Symbols Used Here
$f$a differentiable function
$\mathbf{x}$input
$\mathbf{y}$output
$\mathcal{L}$final scalar loss
$v_i$value at node $i$ in the graph
$\bar v_i = \partial \mathcal{L} / \partial v_i$adjoint (gradient with respect to $v_i$)
$J_f$Jacobian matrix of $f$
$\nabla \mathcal{L}$gradient vector

What a computational graph is

A computational graph is a directed acyclic graph (DAG) in which each node represents a value, usually a tensor, sometimes just a scalar, and each edge represents the application of an elementary differentiable function. The graph is directed because data flows in one direction, from inputs through intermediate computations to outputs. It is acyclic because there are no loops: the value of any node is determined by the values upstream of it, never by itself. During the forward pass, every node stores the value it was computed to. During the backward pass, every node will additionally store its adjoint, which is the partial derivative of the final output with respect to that node.

To make this concrete, take a deliberately tiny function:

$$ f(x_1, x_2) = (x_1 + x_2) \cdot x_1. $$

We can break this into elementary operations. Let $a = x_1 + x_2$ and then $f = a \cdot x_1$. Each elementary operation becomes a node, and edges connect each operation to the values it consumes. The resulting graph looks like this:

   x1 ----+-----+
          |     |
          |     v
          |    [+]----> a ----+
          |     ^             |
          |     |             v
   x2 ----+-----+            [*]----> f
                              ^
                              |
   x1 -------------------------

Five nodes in total: two input nodes ($x_1$, $x_2$), one node for the addition producing $a$, and one node for the multiplication producing $f$. Notice that $x_1$ feeds into the graph in two places: once as a summand and once as a factor. That kind of fan-out, where a single value is consumed by more than one downstream operation, is what makes naive symbolic differentiation explode and what makes the disciplined bookkeeping of autodiff worthwhile.

If we plug in $x_1 = 3$ and $x_2 = 4$, the forward pass fills in numerical values: $a = 3 + 4 = 7$ and $f = 7 \cdot 3 = 21$. So far we have nothing more than a careful trace of an arithmetic expression. The interesting question, the question autodiff answers, is: given those forward values, what are $\partial f / \partial x_1$ and $\partial f / \partial x_2$?

Real computational graphs are built out of the same ingredients, just much larger. A forward pass through a transformer layer might involve thousands of nodes: matrix multiplications, additions, layer normalisations, softmaxes, dropout masks. The graph is constructed automatically by the framework as you call into its tensor operations, and it is the framework's job to remember enough about each node to compute gradients later.

Forward-mode autodiff

Forward-mode autodiff carries derivatives along with the forward pass, rather than computing them afterwards. Alongside each value $v_i$ it carries a tangent $\dot v_i$, which is interpreted as the derivative of $v_i$ with respect to one chosen input. The tangents start as a one-hot vector at the inputs: if we are differentiating with respect to $x_1$, we set $\dot x_1 = 1$ and $\dot x_2 = 0$. Each elementary operation then propagates its tangent forward via the chain rule: for an operation $v_i = g(v_{k_1}, v_{k_2}, \dots)$ with multiple inputs,

$$ \dot v_i = \sum_k \frac{\partial v_i}{\partial v_k} \dot v_k. $$

The local partial derivatives $\partial v_i / \partial v_k$ are read off the operation: for addition they are both 1, for multiplication they are the other operand, for $\sin$ they are $\cos$, and so on. The framework knows these for every primitive it supports.

Let us run this on $f(x_1, x_2) = (x_1 + x_2) \cdot x_1$ at $x_1 = 3$, $x_2 = 4$, differentiating with respect to $x_1$.

Step 1. Initialise tangents at the inputs: $\dot x_1 = 1$, $\dot x_2 = 0$.

Step 2. Forward through the addition. Since $a = x_1 + x_2$, the local partials are $\partial a / \partial x_1 = 1$ and $\partial a / \partial x_2 = 1$, so

$$ \dot a = 1 \cdot \dot x_1 + 1 \cdot \dot x_2 = 1 \cdot 1 + 1 \cdot 0 = 1. $$

Step 3. Forward through the multiplication. Since $f = a \cdot x_1$, the local partials are $\partial f / \partial a = x_1 = 3$ and $\partial f / \partial x_1 = a = 7$, so

$$ \dot f = 3 \cdot \dot a + 7 \cdot \dot x_1 = 3 \cdot 1 + 7 \cdot 1 = 10. $$

The final tangent $\dot f = 10$ is, by construction, the directional derivative of $f$ in the direction of the seed vector $(\dot x_1, \dot x_2) = (1, 0)$. Because the seed picked out the $x_1$ axis, this is exactly $\partial f / \partial x_1 = 10$.

To get $\partial f / \partial x_2$, we have to start over with a different seed: $\dot x_1 = 0$, $\dot x_2 = 1$. Re-running the algorithm gives $\dot a = 0 + 1 = 1$ and $\dot f = 3 \cdot 1 + 7 \cdot 0 = 3$, so $\partial f / \partial x_2 = 3$. That is the price of forward mode in a nutshell: one pass per input you want to differentiate with respect to. For a function with $n$ inputs, computing the full gradient costs $n$ forward passes, each one as expensive as the original computation.

Forward mode is conceptually clean, derivatives travel in lock step with values, and it has the further advantage that you never need to store intermediate activations, since you do not need to revisit them. It is the natural choice when $n$ is small, for example when computing the sensitivity of a complicated simulator output to a single tunable parameter. Many scientific computing libraries default to forward mode for that reason.

Reverse-mode autodiff

Reverse mode flips the direction of the chain rule. It runs the forward pass once, storing the value of every intermediate node, and then walks the graph backwards from the output, accumulating adjoints $\bar v_i = \partial \mathcal{L} / \partial v_i$. The seed at the top is $\bar f = 1$, on the principle that $\partial f / \partial f = 1$. Each node then pushes its adjoint to its inputs via the local Jacobian: if $v_i$ is consumed by $v_j$, then

$$ \bar v_i \mathrel{+}= \frac{\partial v_j}{\partial v_i} \bar v_j. $$

The accumulation (+=, not =) matters whenever a value is used in more than one place downstream, as $x_1$ is in our example. The contributions from each downstream user must be summed, which is exactly the chain rule for a function with multiple paths.

Let us run reverse mode on the same example: $f(x_1, x_2) = (x_1 + x_2) \cdot x_1$ at $x_1 = 3$, $x_2 = 4$.

Forward pass (recorded for use during the backward pass):

  • $x_1 = 3$
  • $x_2 = 4$
  • $a = x_1 + x_2 = 7$
  • $f = a \cdot x_1 = 21$

Backward pass.

Step 1. Initialise the output adjoint: $\bar f = 1$. Initialise all other adjoints to 0: $\bar a = 0$, $\bar x_1 = 0$, $\bar x_2 = 0$.

Step 2. Reverse through the multiplication $f = a \cdot x_1$. The local Jacobian gives $\partial f / \partial a = x_1 = 3$ and $\partial f / \partial x_1 = a = 7$, so

$$ \bar a \mathrel{+}= 3 \cdot \bar f = 3 \cdot 1 = 3, $$ $$ \bar x_1 \mathrel{+}= 7 \cdot \bar f = 7 \cdot 1 = 7. $$

After this step $\bar a = 3$ and $\bar x_1 = 7$.

Step 3. Reverse through the addition $a = x_1 + x_2$. The local Jacobian is $(1, 1)$, so

$$ \bar x_1 \mathrel{+}= 1 \cdot \bar a = 1 \cdot 3 = 3, $$ $$ \bar x_2 \mathrel{+}= 1 \cdot \bar a = 1 \cdot 3 = 3. $$

After this step $\bar x_1 = 7 + 3 = 10$ and $\bar x_2 = 3$.

Result. $\partial f / \partial x_1 = \bar x_1 = 10$ and $\partial f / \partial x_2 = \bar x_2 = 3$. The forward-mode answer for $\partial f / \partial x_1$ was also 10, as it must be, the two algorithms compute the same derivatives by different routes, and the gradient with respect to $x_2$ now drops out for free.

Three things deserve emphasis. First, one forward pass and one backward pass deliver the gradient of $f$ with respect to every input. There is no need to re-run anything for additional inputs. Second, the bookkeeping at the multiplication node uses the forward values $a = 7$ and $x_1 = 3$, which is why we had to remember them; this is the source of the memory cost we discuss below. Third, the += at $\bar x_1$ correctly summed the two contributions from the two paths through which $x_1$ influenced $f$ (the additive path through $a$ and the direct multiplicative path), giving 7 + 3 = 10. If we had used = instead of +=, we would have overwritten one contribution with the other and produced a silently wrong gradient. This is the single most common bug in hand-rolled autograd, and it is why every framework's tensor adjoint defaults to accumulation rather than assignment.

In a deep network the same algorithm runs over a graph with millions or billions of nodes, but the local rule is unchanged. Backpropagation through an MLP, derived in §9.6, is simply this algorithm specialised to the case where the graph is a linear chain of matrix multiplications and elementwise non-linearities.

Why reverse-mode is right for deep learning

The key fact is the asymmetry in cost. Forward mode runs in time

$$ \text{cost}_\text{fwd-AD} = O(\text{cost}_\text{forward}) \cdot O(\text{number of inputs}), $$

because each input demands its own seeded forward pass. Reverse mode runs in time

$$ \text{cost}_\text{rev-AD} = O(\text{cost}_\text{forward} + \text{cost}_\text{backward}), $$

independent of the number of inputs. The backward pass is itself only a small constant times more expensive than the forward, typically between 1.5x and 3x in practice.

For deep learning this asymmetry is decisive. A neural network has one scalar loss as its output and very many parameters as its inputs: a small image classifier might have $10^7$ parameters, a large language model in 2026 has $10^{11}$ to $10^{12}$. With forward mode, computing the gradient would require one forward pass per parameter, $10^{12}$ forward passes for a frontier model, each one taking a multi-million-dollar GPU cluster a measurable fraction of a second. The total time would exceed the age of the universe by a comfortable margin. With reverse mode, one forward pass plus one backward pass, a few hundred milliseconds, gives the gradient for all $10^{12}$ parameters at once. Without this efficiency, modern deep learning would simply be impossible.

The general principle behind the asymmetry is clean. A function $f: \mathbb{R}^n \to \mathbb{R}^m$ has a Jacobian of shape $m \times n$. Forward mode computes Jacobian-vector products $J v$, taking time roughly proportional to one forward pass per column of $J$ that you need (equivalently, per input direction). Reverse mode computes vector-Jacobian products $v^\top J$, taking time proportional to one row of $J$ (per output direction). For a scalar-valued function, $m = 1$, so reverse mode finishes in a single sweep regardless of $n$.

When does forward mode win? When $m \gg n$. The classic case is a scientific simulator with many outputs but only one or two tunable parameters: forward mode computes all output sensitivities in one pass per parameter, whereas reverse mode would need one backward pass per output. Hessian-vector products $H v$ live in a hybrid sweet spot, computed by composing forward mode through the gradient, this is how second-order optimisation methods like K-FAC and Shampoo achieve their efficiency.

The takeaway for deep learning is simple and unconditional: reverse mode is the right default, every major framework defaults to it, and you only ever escape from it for the rare narrow purposes mentioned above.

Storing intermediate values: the memory cost

Reverse mode pays for its time efficiency with a memory cost. The backward pass needs the forward values at every node where a non-trivial local Jacobian depends on them, the multiplication node above needed $a = 7$ and $x_1 = 3$, the layer-norm node in a transformer needs the means and variances it computed, the matmul node needs the input tensor it consumed. So every framework, by default, saves all of these activations during the forward pass and frees them only after the backward pass has used them.

The memory bill scales with the size of the network and the size of the batch. Take a 100-layer transformer with batch size 32, sequence length 2048, and hidden size 4096. Each layer holds an activation tensor of shape $32 \times 2048 \times 4096$, which is roughly $2.7 \times 10^8$ floats. At 4 bytes per float (single precision), that is around 1 GB per layer, so 100 GB across the network. A single H100 GPU has 80 GB of memory; an A100 has 80 GB; even an H200 has only 141 GB. The activations alone do not fit, never mind the parameters and gradients.

There are three standard remedies, used singly or in combination:

  • Gradient checkpointing. Save activations only at every $k$-th layer, and during the backward pass re-run the forward computation between checkpoints to regenerate the missing activations on demand. This trades compute for memory: the forward pass effectively runs twice (once in the original pass, partially again during backward), but the peak memory drops from $O(N)$ activations to roughly $O(\sqrt{N})$ when $k$ is chosen optimally. PyTorch ships torch.utils.checkpoint.checkpoint for this.
  • Mixed precision. Store activations and parameters in 16-bit formats, fp16 or, more often these days, bf16, instead of fp32. This roughly halves the memory needed for activations. A small number of numerically sensitive computations (the loss accumulator, the optimiser state) are kept in fp32 to avoid overflow and underflow. Mixed precision also runs faster on tensor-core hardware, so it is a near-universal default in 2026.
  • Activation offloading. Push activations to CPU memory or even NVMe between forward and backward, fetching them back when needed. This makes very large models trainable on smaller hardware at the cost of PCIe bandwidth.

In practice, training a frontier-scale model uses all three at once, plus model and data parallelism (FSDP, ZeRO, tensor parallelism, pipeline parallelism). The memory cost of reverse-mode autodiff is the single most important reason these techniques exist: they are the engineering response to the awkward fact that the gradient algorithm wants to remember everything.

How PyTorch and JAX implement this

PyTorch builds the computational graph dynamically, on the fly, as your forward pass runs. Every torch.Tensor that requires gradients carries a .grad_fn attribute pointing to the operation that produced it, and every such operation carries pointers to its inputs. When you call loss.backward(), PyTorch walks the graph backwards from the loss, calling each operation's saved backward function in reverse topological order, and accumulating gradients into the leaf tensors' .grad fields. Because the graph is rebuilt every iteration, you can write arbitrary Python control flow, if, for, recursion, even reading from disk, and it just works. This is what people mean when they describe PyTorch as define-by-run or as having an eager mode.

JAX takes a different approach. You write a pure Python function, and JAX traces it by running it once with abstract tracer values that record every operation. The result is a static representation of the function (an XLA HLO computation), which JAX can then transform: jax.grad(f) produces a new function that returns the gradient, jax.vmap(f) produces a batched version, jax.jit(f) compiles the traced graph through XLA for aggressive optimisation. JAX is therefore trace-and-compile: less flexible than PyTorch with respect to runtime control flow (data-dependent shapes do not trace well), but capable of fusing operations, eliminating intermediates, and producing extremely fast compiled kernels. The same model expressed in PyTorch and in JAX will produce the same gradients to within numerical noise; the difference is in convenience and performance characteristics.

In practice, PyTorch dominates research and industry by a wide margin in 2026. Its dynamic graph is more forgiving for prototyping, its ecosystem (Hugging Face, torch-vision, torchaudio, distributed training libraries) is enormous, and torch.compile, introduced in PyTorch 2.0, closes much of the performance gap with JAX by tracing and ahead-of-time compiling pure-functional subgraphs. JAX retains a foothold in scientific computing, in some Google internal work, and in research that needs fine-grained transformation composition (per-sample gradients, Hessian-vector products, complex parallelism). TensorFlow, the historical heavyweight, has largely receded into a deployment-and-mobile niche.

For everyday work the lesson is: pick PyTorch unless you have a specific reason not to. The autograd machinery underneath is the same algorithm you traced by hand above, just industrialised.

Custom backward functions

Sometimes the default autograd path is wasteful or numerically unstable. An iterative algorithm, say, a Newton solver inside a forward pass, would, by default, have autograd record every iteration and differentiate through the whole iterative loop. That is correct but expensive: often the analytic gradient at the converged point is cheaper and more stable than the chain through hundreds of iterations. A function with a known closed-form derivative that simplifies algebraically can also benefit. And custom CUDA kernels, written in C++ for speed, need their backward written by hand because there is no Python operation graph for autograd to walk.

PyTorch exposes this as torch.autograd.Function. You subclass it and provide two static methods: forward(ctx, *inputs), which computes the output and stashes anything the backward will need into ctx, and backward(ctx, *grad_outputs), which returns the gradients with respect to each input. The skeleton looks like this:

class MyOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return f(x)

    @staticmethod
    def backward(ctx, grad_out):
        x, = ctx.saved_tensors
        return grad_out * df_dx(x)

The most prominent recent example is FlashAttention, which fuses the softmax-attention computation into a single CUDA kernel that never materialises the full $N \times N$ attention matrix in memory. Because attention is fused, autograd cannot differentiate through it node by node; FlashAttention ships its own backward kernel, also fused, which recomputes the necessary quantities on the fly. Mainstream libraries, xFormers, the FlashAttention package, Triton-based kernels, custom layers in research papers, all rely on this mechanism. JAX exposes the analogous capability via jax.custom_vjp (custom vector-Jacobian product) and jax.custom_jvp (custom Jacobian-vector product).

For everyday network code you will not need to write one of these. But when you read a paper that promises a 3x speedup over standard attention, or when you need to differentiate through a physical simulator that wraps a Fortran solver, the custom backward is the door through which you escape the framework's defaults without losing the rest of autograd.

What you should take away

  1. A computational graph is a directed acyclic graph in which nodes are tensor values and edges are differentiable operations; reverse-mode autodiff walks it backwards from a scalar output to compute the gradient with respect to every input.
  2. Forward mode costs one forward pass per input direction; reverse mode costs one forward plus one backward pass total, independent of the number of inputs. For neural networks, with one scalar loss and billions of parameters, reverse mode is the only feasible choice.
  3. Reverse mode requires storing the forward activations needed by each node's backward function; for large models this memory cost is the binding constraint, addressed in practice by gradient checkpointing, mixed precision, and offloading.
  4. Modern frameworks automate the bookkeeping: PyTorch builds the graph dynamically as your code runs, JAX traces and compiles via XLA, both produce identical gradients to numerical precision; PyTorch is the default for most research and production work in 2026.
  5. When the default path is wasteful or numerically awkward, drop down to a custom backward via torch.autograd.Function (PyTorch) or jax.custom_vjp (JAX); this is how kernels like FlashAttention coexist with the rest of autograd.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).