CLIP (Radford et al. 2021) jointly embeds images and text into a shared space by contrastive learning on ~400M (image, caption) pairs scraped from the web. Image encoder $f_v$ (ViT-B/32) and text encoder $f_t$ (small Transformer) each produce a unit-norm $d=512$ embedding. Inside a batch of $N$ pairs, image $i$ and text $i$ are positives; all $N-1$ other texts are negatives.
Logits matrix: $L_{ij} = \tau \cdot \langle f_v(I_i), f_t(T_j) \rangle$ with a trainable temperature $\tau = \exp(\log\tau)$ initialised to $\log(1/0.07)$ and clamped at $\log(100)$.
Symmetric InfoNCE loss:
$$\mathcal{L} = \tfrac{1}{2}\!\left[\mathrm{CE}(L, \mathrm{diag}) + \mathrm{CE}(L^\top, \mathrm{diag})\right] = -\tfrac{1}{2N}\sum_{i=1}^{N}\!\left[\log\frac{e^{L_{ii}}}{\sum_j e^{L_{ij}}} + \log\frac{e^{L_{ii}}}{\sum_j e^{L_{ji}}}\right].$$
The first term is image→text retrieval; the second is text→image. Both share the same logits matrix.
Data pipeline. Web-scraped image-text pairs (LAION-400M / LAION-2B). Each shard
is a WebDataset tar of (jpg, txt) pairs. Pipeline per pair:
- Decode image,
RandomResizedCrop(224, scale=(0.9, 1.0)),RandomHorizontalFlip, normalise with CLIP statistics. - Tokenise text with a 49,408-vocab byte-level BPE; truncate/pad to 77 tokens.
Architecture.
class CLIP(nn.Module):
def __init__(self, embed=512):
self.visual = ViT(patch=32, layers=12, width=768, heads=12)
self.transformer = TextTransformer(layers=12, width=512, heads=8, vocab=49408,
max_len=77)
self.v_proj = nn.Linear(768, embed, bias=False)
self.t_proj = nn.Linear(512, embed, bias=False)
self.logit_scale = nn.Parameter(torch.tensor(np.log(1/0.07))) # ~2.66
def encode_image(self, pixels):
h = self.visual(pixels) # CLS token
return F.normalize(self.v_proj(h), dim=-1)
def encode_text(self, ids):
h = self.transformer(ids) # take EOS-token state
return F.normalize(self.t_proj(h), dim=-1)
Training loop with large-batch contrastive across GPUs.
model = CLIP().cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
opt = torch.optim.AdamW(model.parameters(), lr=5e-4, betas=(0.9, 0.98), eps=1e-6,
weight_decay=0.2)
sched = warmup_cosine(opt, warmup=2000, total=epochs * steps_per_epoch)
scaler = torch.amp.GradScaler()
for epoch in range(32):
for imgs, texts in loader:
imgs, texts = imgs.cuda(non_blocking=True), texts.cuda(non_blocking=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
v = model.module.encode_image(imgs) # [B, 512]
t = model.module.encode_text(texts) # [B, 512]
# Gather across all GPUs to expand the negative set
v_all = all_gather(v); t_all = all_gather(t) # [B*world, 512]
tau = model.module.logit_scale.exp().clamp(max=100.0)
logits_v2t = tau * v_all @ t_all.T # [N, N]
logits_t2v = logits_v2t.T
labels = torch.arange(v_all.size(0), device="cuda")
loss = 0.5 * (F.cross_entropy(logits_v2t, labels)
+ F.cross_entropy(logits_t2v, labels))
scaler.scale(loss).backward()
scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(opt); scaler.update(); sched.step(); opt.zero_grad()
The all_gather trick is what makes CLIP scale: with 256 GPUs at micro-batch 128,
every example sees 32,768-1 negatives. Using gradient checkpointing on the
encoders + bf16 lets micro-batch fit in 24-32GB.
Hyperparameters.
- AdamW, $\beta_1=0.9$, $\beta_2=0.98$, $\epsilon=10^{-6}$, weight decay 0.2.
- Peak LR $5\times 10^{-4}$, linear warmup 2000 steps, cosine decay to 0.
- Effective batch 32,768. 32 epochs over 400M pairs.
logit_scalelearned, clamped at $\log(100) \approx 4.6$.- Mixed precision (fp16 on V100 / bf16 on A100).
Compute. ViT-B/32 forward+backward at 224 ≈ 8.7 GFLOPs; text encoder ~3 GFLOPs. 32 epochs × 400M pairs × ~12 GFLOPs ≈ $1.5\!\times\!10^{20}$ FLOPs. On 256×V100, ~12 days; the original CLIP paper trained the largest ResNet variant in 12 days, and the largest ViT-L/14 in 18 days, on 592 V100s.
Evaluation. Zero-shot ImageNet: encode the 1000 class names as "a photo of a {class}",
compute $\arg\max_c \langle f_v(I), f_t(T_c) \rangle$. Target: ViT-B/32 ≈ 63.2%
top-1 zero-shot.
Pitfalls. Letting $\tau$ go above 100 collapses the softmax to a delta and
grads vanish, always clamp. Forgetting all_gather keeps your effective negative
pool at the per-GPU batch (~128), quality drops by ~10 zero-shot points. Not
L2-normalising the embeddings before the dot product makes the temperature
redundant and unstable. Web data is noisy: filter on CLIP-score (bootstrap with a
smaller model) for 5-10% gains.
Related terms: CLIP, InfoNCE, Transformer, Attention Mechanism, Adam, Mixed Precision Training
Discussed in:
- Chapter 16: Ethics & Safety, Multimodal Models
- Chapter 11: CNNs, Self-Supervised Learning