2.10 Tensors, broadcasting, and einsum
If you spend any time at all writing deep-learning code in PyTorch or JAX, you will quickly notice that almost every variable on the screen is a tensor. A weight matrix is a tensor. A batch of images is a tensor. The activations halfway through a Transformer's attention layer are a tensor. Even a single scalar loss value, the number you are trying to minimise, is held inside a tensor object. Once you accept that everything in modern numerical Python is a tensor, the daily work of a practitioner becomes clear: you spend most of your time arranging tensors of the right shape, combining them in the right order, and checking that the axes line up. Bugs in deep-learning code are almost always shape bugs. The mathematics is correct, the optimiser is fine, but the wrong axis is being summed and the loss never falls.
This section introduces the three ideas that, taken together, let you write almost any tensor computation cleanly: broadcasting, which lets arrays of different shapes combine without writing loops; reshaping and transposing, which let you rearrange axes without copying data; and einsum, a single compact operator that expresses matrix multiplication, batched matrix multiplication, transposes, traces, dot products, outer products, and full multi-head attention scoring in one consistent syntax. We finish with a short note on contraction order and on how memory layout affects speed.
Section 2.3 introduced matrices as rank-2 tensors. This section generalises those ideas to higher rank, which is what you actually need from §2.11 onward and throughout chapters 9 to 15, where convolutional networks, recurrent networks, and Transformers all live in four-, five-, and six-dimensional tensor spaces.
What a tensor is
A tensor, in the deep-learning sense of the word, is just a multi-dimensional array of numbers. The number of axes is called the rank (some authors prefer "order"), and the shape is the tuple that lists the size of each axis. A tensor's shape is the single most important piece of information about it. If you can recite the shape, you can usually reason about whether a line of code is correct.
It is worth walking through the ranks one at a time, because each corresponds to a familiar object from earlier in this chapter.
- Rank 0 is a scalar. The shape is the empty tuple, written
(). A loss value at the end of a training step is rank 0. - Rank 1 is a vector. The shape is
(D,)for a vector of length $D$. A single word embedding in a language model is rank 1, typically with shape(768,)or(4096,). - Rank 2 is a matrix. The shape is
(M, N). A weight matrix in a fully connected layer is rank 2, with shape(d_out, d_in). - Rank 3 is, for example, a colour image stored as height by width by channel, shape
(H, W, C), or a small batch of vectors, shape(B, D)extended with a sequence axis. - Rank 4 is the canonical shape of a batch of images in PyTorch:
(B, C, H, W)for batch size, channels, height, width. ImageNet training uses this shape on every step. - Rank 5 turns up in multi-head attention when you want to track per-head, per-position scores; for example a tensor of shape
(B, H, T, T, T)for some triple-attention scheme or a video model with shape(B, T, C, H, W)(batch, time, channels, height, width).
To make the scale concrete, consider a single mini-batch in a vision Transformer. Thirty-two colour photographs at resolution $224 \times 224$ form a tensor of shape (32, 3, 224, 224). Multiplying out gives $32 \cdot 3 \cdot 224 \cdot 224 = 4{,}816{,}896$ individual floating-point numbers. In single precision (fp32, four bytes each) that is roughly nineteen megabytes for one batch, before any computation has happened. In half precision (fp16 or bf16) it halves to about ten megabytes. Multiply by the dozens of intermediate activations that a deep network keeps around for the backward pass and you can see why GPU memory budgets dominate model design.
Whenever you suspect a bug, the first habit to develop is calling print(x.shape) on every tensor in the offending line. Most deep-learning errors are not subtle, they are mismatched axes that an explicit shape print would reveal in seconds.
Broadcasting
When you operate on two tensors of different shapes, say you add a vector to a matrix, or subtract a per-feature mean from a batch of inputs, NumPy and PyTorch do not raise an error. They quietly broadcast the smaller one up to the shape of the larger one and perform an element-wise operation. Broadcasting is what makes vectorised numerical code possible without writing nested loops by hand, and it is what makes that same code crash with a confusing message six lines later when the shapes happen not to align.
The rule is simple to state. Align the two shapes from the right. For each pair of corresponding axes, they are compatible if they are equal, or if one of them is 1. A size-1 axis is virtually replicated to match the other. If the shapes have different rank, the shorter one is padded on the left with implicit 1s.
Worked through, that rule explains why a tensor of shape (3, 1, 5) and a tensor of shape (4, 5) combine cleanly. Pad the shorter shape on the left so it becomes (1, 4, 5). Compare from the right: the 5s match, the 1 against 4 is broadcastable (the 1 expands to 4), the 3 against 1 is broadcastable (the 1 expands to 3). The result has shape (3, 4, 5). Neither input was actually copied, the framework only pretends to replicate; under the bonnet it loops cleverly. You get the convenience of replication with none of the memory cost.
The most common everyday use is subtracting a per-feature mean from a batch. Imagine a batch of inputs X with shape (B, D), $B$ rows, $D$ features per row. The mean across the batch is X.mean(axis=0) with shape (D,). To centre the data you write X - X.mean(axis=0). Broadcasting pads the mean's shape from (D,) to (1, D), then virtually replicates it $B$ times along axis 0 to match X, then subtracts. There is no loop, no explicit replication, and the resulting code is one short line.
import numpy as np
X = np.random.randn(100, 5)
X_centred = X - X.mean(axis=0) # mean is shape (5,), broadcasts over rows
The standard gotcha is that a (N,) array and an (N, 1) array behave differently when broadcast against an (N, N) matrix. The first is treated as a row that gets replicated down the rows of the matrix; the second is treated as a column that gets replicated across the columns. The mathematics that you intended one of those for is exactly the wrong one half the time. Whenever you write broadcasting code where the intent is not visually obvious, make the intent explicit by reshaping with [:, None] (turn a row vector into a column) or [None, :] (turn a column into a row), or by calling .reshape((shape)) directly. The few extra characters cost nothing and remove a whole category of bug.
Broadcasting is one of those features that quietly makes you ten times more productive once you internalise it. The price you pay is the discipline of always knowing what shape you have, and what shape the framework is going to inflate it to.
Reshaping, transposing, and the contiguity question
Tensors are stored in memory as one long flat strip of numbers, and the shape is just metadata that tells the framework how to interpret that strip. This decoupling of layout from logical shape is what lets reshape and transpose run in essentially zero time: they change the metadata, not the data.
reshape((new_shape)) returns a view with the requested shape and the same underlying buffer, provided the new shape has the same number of elements as the old. transpose swaps two axes, again as a metadata-only change. PyTorch's permute((axis_order)) reorders all axes at once: a tensor of shape (B, C, H, W) permuted with (0, 2, 3, 1) becomes shape (B, H, W, C).
The subtlety is contiguity. After a transpose or a permute, the tensor's logical shape no longer matches the order of the bytes in memory: physically the bytes are still laid out in the original way, the metadata just labels them differently. Many fast operations in PyTorch, notably view, which is a strict reshape that refuses to copy, require the input to be contiguous in row-major order. If you transpose, then call view, you will get an error. The fix is to call .contiguous(), which copies the tensor into a fresh buffer in row-major order; the data values are unchanged, but the bytes are now in a layout the next operation expects.
A worked example. A (B, C, H, W) image tensor is transposed to (B, H, W, C) for some pipeline that prefers the channels-last convention. The transpose itself is free. The next line tries to call view(B, -1) to flatten everything except the batch dimension. PyTorch raises RuntimeError: view size is not compatible with input tensor's size and stride. The fix is one line: insert .contiguous() between the transpose and the view, accepting the cost of a single full copy of the tensor in exchange for a contiguous memory layout. This pattern, transpose, contiguous, view, is so common it appears in tens of thousands of lines of production deep-learning code. If you understand why you need each step, you will never be surprised by the error.
Einsum
Of all the tools in numerical Python, einsum is the one that most reliably makes new practitioners feel that they have levelled up. It is a single operator that expresses matrix multiplication, batched matrix multiplication, transposes, traces, dot products, outer products, arbitrary axis permutations, and full multi-head attention contractions, all in a consistent syntax. It is named after Einstein's summation convention, and the convention is the whole rule.
Indices that appear on both sides of the -> are kept; indices that appear only on the left are summed over. That is the entire idea. Once you grasp it, you can read and write almost any tensor expression without reaching for a textbook.
A few worked examples make the convention concrete:
- Matrix multiplication. $C_{ij} = \sum_k A_{ik} B_{kj}$ becomes
einsum('ik,kj->ij', A, B). The index $k$ appears on the left and not on the right, so it is summed over. The indices $i$ and $j$ appear on both sides, so they are kept. - Batched matrix multiplication. Add a leading batch axis $b$ that is kept on both sides:
einsum('bik,bkj->bij', A, B)performs one matrix multiply per batch element, no Python loop required. - Dot product. Two vectors with a single shared index, summed:
einsum('i,i->', a, b)produces a scalar (note the empty right-hand side). - Trace. A single tensor with a repeated index that is summed:
einsum('ii->', A)returns the trace of a square matrix. - Outer product. Two vectors with two distinct indices, neither summed:
einsum('i,j->ij', a, b)produces a matrix. - Transpose. No summation at all, just a relabelling of axes:
einsum('ij->ji', A). - Multi-head attention scores. Here einsum is at its best. With queries $Q$ and keys $K$ each of shape (batch, time, heads, per-head-dimension), the attention scores per head per batch are
einsum('bthd,bThd->bhtT', Q, K). The shared index $d$, the per-head feature dimension, is summed over; the batch $b$ is kept; the head index $h$ is kept; the query position $t$ and the key position $T$ are kept and become the two axes of the resulting score matrix. Try writing this contraction withmatmuland a sequence of permutes, and then compare it to the einsum. The einsum reads almost exactly like the underlying mathematics.
import numpy as np
A = np.random.randn(8, 5, 3) # 8 matrices of shape (5, 3)
B = np.random.randn(8, 3, 4) # 8 matrices of shape (3, 4)
C = np.einsum('bij,bjk->bik', A, B)
assert C.shape == (8, 5, 4)
batch, head, seq, d = 2, 4, 16, 32
Q = np.random.randn(batch, head, seq, d)
K = np.random.randn(batch, head, seq, d)
scores = np.einsum('bhsd,bhtd->bhst', Q, K) / np.sqrt(d)
assert scores.shape == (batch, head, seq, seq)
The same syntax works in NumPy (np.einsum), PyTorch (torch.einsum), and JAX (jax.numpy.einsum). Beyond the readability win, einsum gives the framework's optimiser maximum freedom to choose a contraction order, which, as we shall see, can change the cost of a computation by orders of magnitude. When you write a chain of matmul calls, the order is fixed by the way you wrote it; when you write a single einsum, the optimiser can pick.
If you remember nothing else, remember the convention: shared on both sides means kept, only on the left means summed. You can write almost anything from there.
Tensor contractions and the cost of order
When you contract more than two tensors, the order in which you do the pairwise contractions changes the cost dramatically, not by a constant factor, but by orders of magnitude. This is the same phenomenon as the matrix-chain multiplication problem in classical algorithms.
Take three matrices: $\mathbf{A} \in \mathbb{R}^{m \times k}$, $\mathbf{B} \in \mathbb{R}^{k \times n}$, $\mathbf{C} \in \mathbb{R}^{n \times p}$. The product $\mathbf{A}\mathbf{B}\mathbf{C}$ is mathematically the same regardless of how you parenthesise it, but the count of scalar multiplications is not. Computing $(\mathbf{A}\mathbf{B})\mathbf{C}$ first does an $m \times k$ by $k \times n$ multiply (cost $mkn$), then an $m \times n$ by $n \times p$ multiply (cost $mnp$), giving a total of $mkn + mnp$. The other parenthesisation, $\mathbf{A}(\mathbf{B}\mathbf{C})$, costs $knp + mkp$.
The two are equally expensive only by accident. Take a common case: $\mathbf{A}$ is a row vector ($m = 1$) and $\mathbf{B}, \mathbf{C}$ are square ($k = n = p = 1000$). Left-to-right costs $mkn + mnp = 10^6 + 10^6 = 2 \times 10^6$ multiplications. Right-to-left costs $knp + mkp = 10^9 + 10^6 \approx 10^9$. The difference is a factor of about 500. On a hot inner loop run thousands of times per training step, that is the difference between a model you can train overnight and a model you cannot train at all.
Modern frameworks provide an automatic answer. The library opt_einsum (which both PyTorch and NumPy can be configured to call under the bonnet) treats the choice of contraction order as a dynamic-programming search and picks an order that is provably optimal under a simple cost model. For chains of three or four tensors the difference is usually small; for the kinds of seven- or eight-tensor expressions that show up in tensor-network methods or in some attention variants, the difference between the naive order and the optimal order can be several orders of magnitude.
This is the deeper reason to prefer einsum over a chain of matmuls when an expression involves many tensors: einsum exposes the whole expression to the optimiser at once, and the optimiser can see the full graph of contractions and pick a good order. A chain of matmuls hides the structure inside Python's evaluation order and forecloses that choice.
Memory layout and efficiency
Two tensors with identical logical shape can occupy memory in different orders, and the order can matter for speed. PyTorch's default for image tensors is channels-first, (B, C, H, W), which puts the channels axis next to the batch axis. The alternative is channels-last, (B, H, W, C), which keeps each pixel's three channels adjacent in memory and tends to be friendlier to the cache and to vectorised hardware on convolutions, especially at high resolution. The reason is that a convolution kernel sweeping across an image touches a small spatial neighbourhood of every channel at once; if those channels sit next to each other in memory, the prefetcher can deliver them in a single cache line.
Calling tensor.contiguous(memory_format=torch.channels_last) switches the layout without changing the logical shape; subsequent convolutions can then take a faster code path on supported hardware such as Ampere and Hopper GPUs. The improvement matters most when the tensors are large, high resolution, large batch, and is essentially nothing on small inputs. As with most performance concerns, measure first; do not assume. Profile a representative training step, identify the genuine bottleneck, then change layout. A change that helps a ResNet-50 at $224 \times 224$ may do nothing for a small Transformer on tabular data.
What you should take away
- Everything is a tensor. The shape is the most important thing about it. Calling
.shapeon every variable in a suspect line of code is the fastest debugging habit in deep learning. - Broadcasting expands smaller tensors to match larger ones. Align shapes from the right, then each axis must be equal or 1. Subtle bugs come from mistaking
(N,)for(N, 1); reshape explicitly when intent is not obvious. - Reshape and transpose are metadata-only and free, but they can leave a tensor non-contiguous. Insert
.contiguous()beforeviewwhenever a transpose has happened. - Einsum is the cleanest single operator for tensor manipulation. Indices on both sides are kept, indices only on the left are summed; matrix multiplication is
'ik,kj->ij', attention scoring is one line. - Contraction order can change the cost of a computation by hundreds of times. Let einsum and
opt_einsumchoose; do not write long chains of matmuls when a single einsum will reveal the structure to the optimiser.