Glossary

Train a CLIP-style Multimodal Model

CLIP (Radford et al. 2021) jointly embeds images and text into a shared space by contrastive learning on ~400M (image, caption) pairs scraped from the web. Image encoder $f_v$ (ViT-B/32) and text encoder $f_t$ (small Transformer) each produce a unit-norm $d=512$ embedding. Inside a batch of $N$ pairs, image $i$ and text $i$ are positives; all $N-1$ other texts are negatives.

Logits matrix: $L_{ij} = \tau \cdot \langle f_v(I_i), f_t(T_j) \rangle$ with a trainable temperature $\tau = \exp(\log\tau)$ initialised to $\log(1/0.07)$ and clamped at $\log(100)$.

Symmetric InfoNCE loss:

$$\mathcal{L} = \tfrac{1}{2}\!\left[\mathrm{CE}(L, \mathrm{diag}) + \mathrm{CE}(L^\top, \mathrm{diag})\right] = -\tfrac{1}{2N}\sum_{i=1}^{N}\!\left[\log\frac{e^{L_{ii}}}{\sum_j e^{L_{ij}}} + \log\frac{e^{L_{ii}}}{\sum_j e^{L_{ji}}}\right].$$

The first term is image→text retrieval; the second is text→image. Both share the same logits matrix.

Data pipeline. Web-scraped image-text pairs (LAION-400M / LAION-2B). Each shard is a WebDataset tar of (jpg, txt) pairs. Pipeline per pair:

  1. Decode image, RandomResizedCrop(224, scale=(0.9, 1.0)), RandomHorizontalFlip, normalise with CLIP statistics.
  2. Tokenise text with a 49,408-vocab byte-level BPE; truncate/pad to 77 tokens.

Architecture.

class CLIP(nn.Module):
    def __init__(self, embed=512):
        self.visual = ViT(patch=32, layers=12, width=768, heads=12)
        self.transformer = TextTransformer(layers=12, width=512, heads=8, vocab=49408,
                                             max_len=77)
        self.v_proj = nn.Linear(768, embed, bias=False)
        self.t_proj = nn.Linear(512, embed, bias=False)
        self.logit_scale = nn.Parameter(torch.tensor(np.log(1/0.07)))   # ~2.66
    def encode_image(self, pixels):
        h = self.visual(pixels)                  # CLS token
        return F.normalize(self.v_proj(h), dim=-1)
    def encode_text(self, ids):
        h = self.transformer(ids)                # take EOS-token state
        return F.normalize(self.t_proj(h), dim=-1)

Training loop with large-batch contrastive across GPUs.

model = CLIP().cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
opt = torch.optim.AdamW(model.parameters(), lr=5e-4, betas=(0.9, 0.98), eps=1e-6,
                          weight_decay=0.2)
sched = warmup_cosine(opt, warmup=2000, total=epochs * steps_per_epoch)
scaler = torch.amp.GradScaler()

for epoch in range(32):
    for imgs, texts in loader:
        imgs, texts = imgs.cuda(non_blocking=True), texts.cuda(non_blocking=True)
        with torch.amp.autocast("cuda", dtype=torch.float16):
            v = model.module.encode_image(imgs)             # [B, 512]
            t = model.module.encode_text(texts)             # [B, 512]

            # Gather across all GPUs to expand the negative set
            v_all = all_gather(v); t_all = all_gather(t)    # [B*world, 512]

            tau = model.module.logit_scale.exp().clamp(max=100.0)
            logits_v2t = tau * v_all @ t_all.T              # [N, N]
            logits_t2v = logits_v2t.T
            labels = torch.arange(v_all.size(0), device="cuda")
            loss = 0.5 * (F.cross_entropy(logits_v2t, labels)
                        + F.cross_entropy(logits_t2v, labels))

        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt); scaler.update(); sched.step(); opt.zero_grad()

The all_gather trick is what makes CLIP scale: with 256 GPUs at micro-batch 128, every example sees 32,768-1 negatives. Using gradient checkpointing on the encoders + bf16 lets micro-batch fit in 24-32GB.

Hyperparameters.

  • AdamW, $\beta_1=0.9$, $\beta_2=0.98$, $\epsilon=10^{-6}$, weight decay 0.2.
  • Peak LR $5\times 10^{-4}$, linear warmup 2000 steps, cosine decay to 0.
  • Effective batch 32,768. 32 epochs over 400M pairs.
  • logit_scale learned, clamped at $\log(100) \approx 4.6$.
  • Mixed precision (fp16 on V100 / bf16 on A100).

Compute. ViT-B/32 forward+backward at 224 ≈ 8.7 GFLOPs; text encoder ~3 GFLOPs. 32 epochs × 400M pairs × ~12 GFLOPs ≈ $1.5\!\times\!10^{20}$ FLOPs. On 256×V100, ~12 days; the original CLIP paper trained the largest ResNet variant in 12 days, and the largest ViT-L/14 in 18 days, on 592 V100s.

Evaluation. Zero-shot ImageNet: encode the 1000 class names as "a photo of a {class}", compute $\arg\max_c \langle f_v(I), f_t(T_c) \rangle$. Target: ViT-B/32 ≈ 63.2% top-1 zero-shot.

Pitfalls. Letting $\tau$ go above 100 collapses the softmax to a delta and grads vanish, always clamp. Forgetting all_gather keeps your effective negative pool at the per-GPU batch (~128), quality drops by ~10 zero-shot points. Not L2-normalising the embeddings before the dot product makes the temperature redundant and unstable. Web data is noisy: filter on CLIP-score (bootstrap with a smaller model) for 5-10% gains.

Related terms: CLIP, InfoNCE, Transformer, Attention Mechanism, Adam, Mixed Precision Training

Discussed in:

This site is currently in Beta. Contact: Chris Paton

Textbook of Usability · Textbook of Digital Health

Auckland Maths and Science Tutoring

AI tools used: Claude (research, coding, text), ChatGPT (diagrams, images), Grammarly (editing).