3.6 Computational graphs

Up to this point in the chapter, the chain rule has been a piece of algebra: a way of writing down how a small change at the input of a composition propagates through to its output. That algebra is correct, but for any computation more elaborate than a textbook exercise it becomes unwieldy. A neural network may chain together hundreds of operations, share intermediate values between many downstream consumers, and pass tensors of millions of entries between one layer and the next. Writing the chain rule by hand for such a computation, term by term, is impossible in practice and pointless even when possible. We need a representation that does the bookkeeping for us.

That representation is the computational graph. The idea is simple and very general. Any numerical calculation we ever wish to differentiate, adding two numbers, multiplying two matrices, applying a sigmoid, computing the cross-entropy loss of a transformer over a batch of tokens, can be drawn as a directed acyclic graph, or DAG. The leaf nodes are the inputs and the parameters of our model. The internal nodes are elementary operations: add, multiply, exponentiate, take a logarithm, multiply two matrices, apply a non-linearity. Edges record functional dependence: an edge from node $u$ to node $v$ means "the value at $v$ is computed from the value at $u$, among others". The root node, at the top of the graph, is the final output. In machine learning that final output is almost always a single scalar, the loss $\mathcal{L}$.

Drawing the calculation in this way changes nothing about its mathematical content, but it changes everything about how we think. The chain rule becomes a procedure on a data structure: walk the graph backward from the root, and at each node multiply by a known local derivative. The result, computed automatically, is the gradient of the loss with respect to every parameter. This procedure is the basis of automatic differentiation, and §3.7 turns the graph into the backpropagation algorithm.

Symbols Used Here
$v_i$value at node $i$
$\mathbf{f}_i$local function at node $i$
$\bar v_i = \partial \mathcal{L}/\partial v_i$adjoint of node $i$

Definition

A computational graph for a function $f$ is a directed acyclic graph in which:

  • Leaf nodes are inputs and parameters. They have no incoming edges; their values are supplied from outside the computation. In a neural network the leaves are the input vector, the target label, and every weight matrix and bias term.
  • Internal nodes apply elementary operations to the values at their predecessors. Each operation is something the framework knows how to compute and to differentiate: $+$, $-$, $\times$, $/$, $\exp$, $\log$, $\sin$, matrix multiplication, convolution, ReLU, softmax. The set of "elementary" operations is a design choice, the bigger the set, the more efficient the framework, but the more derivative formulae it must implement.
  • The root node is the output of the computation. For training, the root is the scalar loss.

The graph is directed because each operation has well-defined inputs and outputs; it is acyclic because a value cannot depend on itself. Acyclicity is what allows us to evaluate the graph in a meaningful order: we can compute leaves first, then any node whose predecessors have all been evaluated, and so on until the root.

Take the small worked example $$ f(x_1, x_2) = (x_1 + x_2) \cdot \sin(x_1). $$ We split it into a sequence of elementary operations and label each result with a node identifier:

  • $v_1 = x_1$ (leaf input)
  • $v_2 = x_2$ (leaf input)
  • $v_3 = v_1 + v_2$ (addition)
  • $v_4 = \sin(v_1)$ (sine, applied to $v_1$ on its own)
  • $v_5 = v_3 \cdot v_4$ (multiplication, the output)

There are five nodes and four operations. Two leaves feed into the graph; one internal node, $v_3$, has two parents; the leaf $v_1$ is consumed twice, once via $v_3$ and once via $v_4$. That last detail will turn out to matter when we differentiate, because the gradient with respect to $v_1$ will receive contributions from both downstream paths and the two contributions will need to be added.

Forward pass

To evaluate the graph we visit its nodes in topological order: each node is processed only after all its predecessors. Leaves come first, the root comes last. With $x_1 = \pi/2$ and $x_2 = 1$ the calculation runs as follows:

  • $v_1 = \pi/2 \approx 1.5708$
  • $v_2 = 1$
  • $v_3 = v_1 + v_2 \approx 2.5708$
  • $v_4 = \sin(v_1) = \sin(\pi/2) = 1$
  • $v_5 = v_3 \cdot v_4 \approx 2.5708 \cdot 1 = 2.5708$

So $f(\pi/2, 1) \approx 2.5708$.

A few remarks. First, every node now has a numerical value attached to it, not just the root. These intermediate values are not just scratch work; the backward pass we shall meet in §3.7 needs them, because the local derivative at a node is usually a function of the values at that node's parents. This is why automatic differentiation libraries store, in addition to the graph itself, a tape of intermediate results from the most recent forward pass.

Second, notice how the topological order makes the choice of evaluation sequence almost unambiguous. We were free to compute $v_3$ before $v_4$ or after, because they do not depend on each other; but we could not have computed $v_5$ before either of $v_3$ or $v_4$, and we could not have computed $v_3$ before $v_1$ and $v_2$. Acyclicity is what guarantees that a valid topological order always exists.

Third, the forward pass is purely an evaluation; no derivatives have appeared yet. That separation between forward (evaluate) and backward (differentiate) is what makes automatic differentiation modular and efficient.

Why this representation?

Three reasons make the graph view indispensable for machine learning.

Composability. Any calculation we will ever wish to perform, a single neuron, a residual block, a thousand-layer transformer, a diffusion model, a reinforcement-learning policy network, decomposes into elementary operations. The library only needs to know how to handle the elementary operations; everything else is composition. Programmers gain enormous freedom: they write models in ordinary Python, using ordinary arithmetic, and the framework records the graph in the background. The abstraction is so transparent that beginners often fail to realise it is there at all.

Differentiability. Each elementary operation has a known local derivative, the partial of its output with respect to each of its inputs, computed at the values seen during the forward pass. The chain rule then says that the derivative of the whole is the appropriate product of local derivatives along the paths through the graph. Because the graph encodes all the paths, applying the chain rule becomes a mechanical traversal: we never need to write the global formula by hand. This is the conceptual reason backpropagation is so well-behaved: it is just a graph algorithm.

Automation. Modern frameworks, PyTorch, JAX, TensorFlow, build the graph automatically as the user's code runs, then traverse it in reverse to compute gradients. The user writes a forward pass; the framework supplies the backward pass. This separation is what allows researchers to experiment with new architectures by writing a few lines of Python instead of deriving and coding a custom gradient. More than any other piece of infrastructure, automatic differentiation, packaged as a graph traversal, is what made the deep-learning era possible.

Static vs dynamic graphs

There are two strategies for building the graph, and they correspond to two design philosophies that have shaped the deep-learning frameworks.

Static graphs, sometimes called "define-and-run", separate graph construction from graph execution. The user first writes code that describes the graph as a data structure, and only afterwards feeds in numbers and executes it. TensorFlow 1.x and the original Theano worked this way. The advantage is that, once the graph is fully known, the framework can apply aggressive whole-graph optimisations: fusing operations, eliminating dead nodes, allocating memory once for the entire computation, compiling the graph to fast device code. The disadvantage is rigidity. Anything that depends on runtime values, an if that branches on a tensor's content, a for whose length changes from batch to batch, cannot be expressed in plain Python; it has to be encoded with special graph primitives such as tf.cond and tf.while_loop. The result is verbose code that often feels two languages thick.

Dynamic graphs, also called "define-by-run", build the graph incrementally as the forward pass executes. Each operation is performed immediately, and a record is added to the graph at the moment the operation is invoked. PyTorch, TensorFlow 2.x in eager mode, and JAX with jax.grad all use this style. The advantage is enormous flexibility. Ordinary Python control flow works exactly as expected: you can write if, for, while, recursion, and the graph traced at runtime simply reflects the actual sequence of operations. Debugging is also far easier, because the user can drop a print or set a breakpoint anywhere and see real numbers. The disadvantage is that the framework cannot optimise across the whole graph as aggressively, since the graph is only known one operation at a time.

The trade-off has narrowed considerably in recent years. PyTorch added torch.compile, which traces dynamic code into a static graph for optimisation while preserving the define-by-run programming model. JAX takes the opposite route: it is define-by-run by default but offers jax.jit, which compiles a Python function to a static, XLA-optimised graph the first time it runs. The pragmatic result is that PyTorch dominates research and the bulk of industry, with JAX and TensorFlow holding strong niches in performance-critical training pipelines. From the point of view of the underlying mathematics, the choice is irrelevant: both styles produce the same DAG, they just build it at different times.

Worked example: a tiny network's computational graph

Let us draw the graph of something closer to a real model. Take a two-input, one-output linear regression with squared loss: $$ y_{\text{pred}} = w_1 x_1 + w_2 x_2 + b, \qquad \mathcal{L} = \tfrac{1}{2}(y_{\text{pred}} - y)^2. $$ The leaves are the parameters $w_1, w_2, b$ and the data $x_1, x_2, y$. The internal nodes break the loss into elementary pieces:

  • $a_1 = w_1 \cdot x_1$
  • $a_2 = w_2 \cdot x_2$
  • $a_3 = a_1 + a_2$
  • $y_{\text{pred}} = a_3 + b$
  • $r = y_{\text{pred}} - y$ (the residual)
  • $r^2 = r \cdot r$
  • $\mathcal{L} = r^2 / 2$ (the root)

Already the structure of a neural-network forward pass is visible. The two products $a_1$ and $a_2$ are the multiplications of weights and inputs that, when generalised to a vector of inputs, become a matrix–vector product. The sum $a_3$ is the dot product. Adding $b$ supplies the bias term. The residual, its square and the halving give the loss. A full neural network simply repeats and elaborates this pattern: many such products, summed and biased, then passed through a non-linearity, the result fed into another such layer, and so on until a final loss.

A concrete forward pass: take $w_1 = 1$, $w_2 = 2$, $b = 0$, $x_1 = 1$, $x_2 = 1$, $y = 4$. Then $a_1 = 1$, $a_2 = 2$, $a_3 = 3$, $y_{\text{pred}} = 3$, $r = 3 - 4 = -1$, $r^2 = 1$, $\mathcal{L} = 0.5$.

Notice that we now have a numerical value at every node. When we run the backward pass in §3.7 we will start from $\mathcal{L}$, work back through the graph, and use these stored values to compute the gradient of $\mathcal{L}$ with respect to each of $w_1$, $w_2$ and $b$. The gradient with respect to $x_1$, $x_2$ and $y$ is rarely of interest in supervised learning (those are data, not parameters), but the framework will happily compute them too if we ask. In a generative-adversarial setting, or when computing adversarial examples, the gradient with respect to the input is what we want.

The graph also makes it obvious that the parameters and the data play structurally identical roles. Both are leaves; both flow forward through the same operations; both can in principle be differentiated against. What distinguishes them is only that the optimiser updates parameters and leaves data alone, a convention of training, not of mathematics.

Operator overloading and graph construction

How does the framework actually build the graph? In PyTorch, an expression as ordinary as c = a * b does two things at once. First, it computes the numerical product $a \cdot b$ and stores it in c. Second, because a and b are tensors with requires_grad=True, it sets c.grad_fn to a MulBackward node that holds references back to a and b, along with whatever local information the backward pass will need (in the case of multiplication, the values of the two operands). This is operator overloading: the multiplication operator has been redefined for tensors so that, in addition to its arithmetic effect, it transparently extends the computational graph by one node. Every supported operation in PyTorch behaves the same way, and the linked chain of grad_fn references is the graph.

JAX takes a different route. Rather than overload operators on the spot, it uses tracing: a JAX function is first invoked on abstract symbolic inputs, the operations performed during that invocation are recorded as a graph, and the recorded graph is then either differentiated, JIT-compiled or both. The user-facing programming model is similar, write the forward pass in normal Python, but the implementation is more amenable to whole-graph compilation. TensorFlow 2.x straddles both styles, defaulting to operator overloading in eager mode and offering tf.function to trace a Python callable into a static graph. From the user's perspective, all three frameworks share the same essential property: writing the forward pass is enough, and the gradients arrive automatically.

What you should take away

  1. A computational graph is a directed acyclic graph whose leaves are inputs and parameters, whose internal nodes are elementary operations, and whose root is the output. It re-expresses the chain rule as a graph algorithm rather than a manipulation on paper.
  2. The forward pass evaluates the graph in topological order, leaves first and root last, storing the value at every node. Those stored values feed the backward pass.
  3. The same graph view supports any model we shall meet, from a one-line linear regression to a hundred-billion-parameter transformer, because every model decomposes into elementary operations whose local derivatives are known.
  4. Static graphs (TensorFlow 1.x) optimise the whole computation at the cost of inflexibility; dynamic graphs (PyTorch, JAX, TF 2.x) build the graph as code runs and trade some optimisation for very natural Python control flow. Modern frameworks blur the distinction with just-in-time compilation.
  5. Frameworks build the graph automatically, through operator overloading in PyTorch, through tracing in JAX, so the user writes only the forward pass and the backward pass is derived for free. The next section, §3.7, makes that derivation explicit.

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).