Reinforcement Learning from Human Feedback (Christiano et al. 2017; Ouyang et al. 2022) aligns a base LLM in three sequential stages.
Stage 1, Supervised fine-tuning (SFT). Take a pretrained base model $\pi_{\text{base}}$ and fine-tune it on ~10-100k high-quality demonstrations $(x, y^*)$, minimising
$$\mathcal{L}_{\text{SFT}} = -\mathbb{E}_{(x, y^*)} \sum_{t} \log \pi_\theta(y^*_t \mid x, y^*_{\lt t}).$$
Standard causal-LM training (see [train-gpt2-recipe]). Lower LR ($1\times 10^{-5}$, AdamW), 1-3 epochs, mask the prompt in the loss. Output: $\pi_{\text{SFT}}$.
Stage 2, Reward model (RM). Collect pairwise preferences: for each prompt $x$, sample two completions $y_w, y_l$ (winner, loser) and have humans label which is better. Train a reward model $r_\phi(x, y) \in \mathbb{R}$, typically the SFT model with a scalar head, under the Bradley-Terry likelihood:
$$\mathcal{L}_{\text{RM}} = -\mathbb{E}_{(x, y_w, y_l)}\big[\log \sigma(r_\phi(x, y_w) - r_\phi(x, y_l))\big].$$
class RewardModel(nn.Module):
def __init__(self, base):
self.base = base # SFT backbone
self.v = nn.Linear(base.config.hidden_size, 1)
def forward(self, ids, mask):
h = self.base(ids, attention_mask=mask).last_hidden_state
last = (mask.sum(1) - 1).long()
return self.v(h[torch.arange(h.size(0)), last]).squeeze(-1)
rm = RewardModel(load_sft()); opt = AdamW(rm.parameters(), lr=1e-5)
for batch in pref_loader:
r_w = rm(batch.win_ids, batch.win_mask)
r_l = rm(batch.lose_ids, batch.lose_mask)
loss = -F.logsigmoid(r_w - r_l).mean()
opt.zero_grad(); loss.backward(); opt.step()
Hyperparameters: 1 epoch, lr $1\times 10^{-5}$, batch 64, ~50k pairs. Validate on held-out preferences; expect ~70% pairwise accuracy.
Stage 3, PPO with KL penalty. Initialise the policy $\pi_\theta \leftarrow \pi_{\text{SFT}}$ and the reference $\pi_{\text{ref}} \leftarrow \pi_{\text{SFT}}$ (frozen). For each prompt $x$, sample $y \sim \pi_\theta(\cdot \mid x)$ and compute a per-token reward
$$R_t = r_\phi(x, y) \cdot \mathbb{1}[t = T] - \beta \log \frac{\pi_\theta(y_t \mid x, y_{\lt t})}{\pi_{\text{ref}}(y_t \mid x, y_{\lt t})},$$
so the reward model fires only at the EOS, while the per-token KL keeps the policy near $\pi_{\text{SFT}}$. Compute advantages $\hat A_t$ via GAE, then optimise PPO's clipped surrogate
$$\mathcal{L}_{\text{PPO}} = -\mathbb{E}\Big[\min\big(\rho_t \hat A_t,\;\mathrm{clip}(\rho_t, 1\!-\!\varepsilon, 1\!+\!\varepsilon) \hat A_t\big)\Big], \quad \rho_t = \frac{\pi_\theta(y_t)}{\pi_{\text{old}}(y_t)}.$$
policy = load_sft(); policy_old = copy(policy)
ref = load_sft(); ref.requires_grad_(False)
rm = load_rm(); rm.requires_grad_(False)
value = ValueHead(policy)
opt = AdamW(list(policy.parameters()) + list(value.parameters()), lr=1e-6)
BETA, CLIP, GAMMA, LAMB = 0.05, 0.2, 1.0, 0.95
for step in range(num_iters):
# 1. Rollout: sample completions with current policy
prompts = sample_prompts(B)
with torch.no_grad():
ys, logp_old = policy.generate_with_logp(prompts, max_new=256)
logp_ref = ref.token_logp(prompts, ys)
r_term = rm(prompts, ys) # scalar per traj
# 2. Per-token reward = -β KL, plus terminal RM reward
kl = logp_old - logp_ref # [B, T]
rewards = -BETA * kl
rewards[:, -1] += r_term
# 3. Values + GAE advantages
with torch.no_grad():
V = value(prompts, ys) # [B, T]
A = gae(rewards, V, gamma=GAMMA, lam=LAMB)
returns = A + V
# 4. PPO update: 4 epochs over the rollout buffer
for _ in range(4):
logp = policy.token_logp(prompts, ys)
ratio = torch.exp(logp - logp_old)
unclipped = ratio * A
clipped = torch.clamp(ratio, 1-CLIP, 1+CLIP) * A
policy_loss = -torch.min(unclipped, clipped).mean()
value_loss = 0.5 * (value(prompts, ys) - returns).pow(2).mean()
loss = policy_loss + 0.1 * value_loss
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(opt.param_groups[0]["params"], 1.0)
opt.step()
Hyperparameters. $\beta \in [0.01, 0.1]$ (start at 0.05; raise if the policy drifts, lower if it stays glued to SFT). PPO clip $\varepsilon = 0.2$. Policy lr $1\times 10^{-6}$, value lr $1\times 10^{-5}$. Rollout batch 256-512 prompts × 256 tokens; 4 PPO epochs per rollout.
Compute. Rollouts dominate, sampling 1B tokens over ~5k iterations ≈ 8×A100 for 1-3 days for a 7B policy. Reward-model forward passes are smaller but non-trivial.
Pitfalls. Reward hacking, the policy finds RM exploits (verbose, sycophantic, bullet-pointed). Fix by raising $\beta$, retraining the RM on adversarial examples, or switching to DPO (closed-form, no PPO). KL going negative means $\pi_{\text{ref}}$ drifted, re-freeze it. Forgetting to clip the value function loss creates instabilities. PPO is famously unstable: log per-token KL, RM mean, response length every step.
Related terms: RLHF, Transformer, Gradient Descent, Adam
Discussed in:
- Chapter 13: Attention & Transformers, Alignment and RLHF
- Chapter 12: Sequence Models, Reinforcement Learning