4 min read
On this page

Optimization for Machine Learning

Gradient Descent Variants

Stochastic Gradient Descent (SGD)

Update parameters using gradient estimated from a mini-batch:

theta_{t+1} = theta_t - lr * g_t
where g_t = (1/|B|) * sum_{i in B} grad L(theta_t; x_i, y_i)

Properties:

  • Noisy gradients help escape local minima and saddle points
  • Mini-batch size B trades variance for computation: B=32..512 typical
  • Larger B: more stable but worse generalization (sharp minima)

SGD with Momentum

Accumulate a velocity vector to dampen oscillations:

v_{t+1} = beta * v_t + g_t               # beta typically 0.9
theta_{t+1} = theta_t - lr * v_{t+1}

Nesterov momentum (look-ahead gradient):

v_{t+1} = beta * v_t + grad L(theta_t - lr * beta * v_t)
theta_{t+1} = theta_t - lr * v_{t+1}

Nesterov converges faster: O(1/t^2) vs O(1/t) for convex functions.

Adaptive Learning Rate Methods

AdaGrad

Adapt learning rate per parameter based on accumulated squared gradients:

G_{t+1} = G_t + g_t^2              # element-wise
theta_{t+1} = theta_t - lr / sqrt(G_{t+1} + epsilon) * g_t

Problem: learning rate monotonically decreases and eventually becomes too small.

RMSProp

Fix AdaGrad's decay using exponential moving average:

v_t = beta * v_{t-1} + (1 - beta) * g_t^2
theta_{t+1} = theta_t - lr / sqrt(v_t + epsilon) * g_t

beta = 0.999 typical.

Adam (Adaptive Moment Estimation)

Combines momentum (first moment) and RMSProp (second moment):

m_t = beta_1 * m_{t-1} + (1 - beta_1) * g_t          # first moment
v_t = beta_2 * v_{t-1} + (1 - beta_2) * g_t^2        # second moment

# Bias correction (critical for early steps)
m_hat_t = m_t / (1 - beta_1^t)
v_hat_t = v_t / (1 - beta_2^t)

theta_{t+1} = theta_t - lr * m_hat_t / (sqrt(v_hat_t) + epsilon)

Default hyperparameters: beta_1=0.9, beta_2=0.999, epsilon=1e-8, lr=3e-4.

def adam(params, grads, state, lr=3e-4, beta1=0.9, beta2=0.999, eps=1e-8):
    if 't' not in state:
        state['t'] = 0
        state['m'] = [np.zeros_like(p) for p in params]
        state['v'] = [np.zeros_like(p) for p in params]

    state['t'] += 1
    t = state['t']

    for i, (p, g) in enumerate(zip(params, grads)):
        state['m'][i] = beta1 * state['m'][i] + (1 - beta1) * g
        state['v'][i] = beta2 * state['v'][i] + (1 - beta2) * g**2

        m_hat = state['m'][i] / (1 - beta1**t)
        v_hat = state['v'][i] / (1 - beta2**t)

        p -= lr * m_hat / (np.sqrt(v_hat) + eps)

AdamW (Decoupled Weight Decay)

Adam's L2 regularization interacts poorly with adaptive learning rates. AdamW decouples weight decay:

theta_{t+1} = theta_t - lr * (m_hat_t / (sqrt(v_hat_t) + epsilon) + lambda * theta_t)

This is NOT equivalent to Adam + L2. AdamW applies weight decay directly to parameters, not through the gradient. Preferred for training transformers.

Learning Rate Warmup

Start with a very small learning rate and linearly increase to the target:

lr_t = lr_target * min(1, t / warmup_steps)

Why it helps:

  • Early gradients are large and noisy (random weights)
  • Adam's second moment estimates are inaccurate initially
  • Prevents early divergence, especially in large models

Common: warmup for 1-5% of total training steps.

Gradient Clipping

Prevent exploding gradients by capping gradient norms:

Gradient norm clipping (most common):

if ||g|| > max_norm:
    g = g * max_norm / ||g||

Gradient value clipping:

g = clip(g, -max_value, max_value)
def clip_grad_norm(parameters, max_norm):
    total_norm = sqrt(sum(p.grad.norm()**2 for p in parameters))
    if total_norm > max_norm:
        scale = max_norm / total_norm
        for p in parameters:
            p.grad *= scale
    return total_norm

Typical max_norm: 1.0 for transformers, 5.0 for RNNs.

Second-Order Methods

Newton's Method

theta_{t+1} = theta_t - H^{-1} g_t

where H = Hessian of the loss. Quadratic convergence near optimum but O(d^3) per step and d^2 memory for H.

L-BFGS

Approximate the inverse Hessian using the last m gradient differences:

  • Stores m vectors of size d (memory O(md) vs O(d^2))
  • Very effective for convex problems with moderate d
  • Not suitable for stochastic settings (needs full gradients)

Natural Gradient

Use the Fisher information matrix F instead of Hessian:

theta_{t+1} = theta_t - lr * F^{-1} g_t

Accounts for the geometry of the parameter space. K-FAC approximates F^{-1} efficiently using Kronecker structure.

Sharpness-Aware Minimization (SAM)

Seek parameters in flat loss regions for better generalization:

epsilon* = rho * g / ||g||
theta_{t+1} = theta_t - lr * grad L(theta_t + epsilon*)

Compute gradient at a perturbed point, then update from the original point.

Hyperparameter Optimization

Evaluate all combinations from a predefined grid. Exhaustive but exponential in the number of hyperparameters.

param_grid = {
    'lr': [1e-4, 1e-3, 1e-2],
    'hidden_size': [64, 128, 256],
    'dropout': [0.1, 0.3, 0.5]
}
# 3 * 3 * 3 = 27 experiments

Sample hyperparameters from distributions. More efficient than grid search because:

  • Important hyperparameters get better coverage
  • Each trial explores a unique value for each parameter
  • Bergstra & Bengio (2012): random search finds good configurations with fewer trials
param_distributions = {
    'lr': loguniform(1e-5, 1e-1),
    'hidden_size': randint(32, 512),
    'dropout': uniform(0.0, 0.5),
    'weight_decay': loguniform(1e-6, 1e-2)
}

Bayesian Optimization

Model the objective function with a surrogate (typically Gaussian Process), then use an acquisition function to decide where to evaluate next.

Surrogate: f(x) ~ GP(mu(x), k(x, x'))
Acquisition: alpha(x) = E[max(0, f(x) - f_best)]    # Expected Improvement

Loop:

  1. Fit GP to observed (hyperparams, score) pairs
  2. Maximize acquisition function to find next candidate
  3. Evaluate candidate and add to observations
  4. Repeat

Tools: Optuna, Ax, Hyperopt. More sample-efficient than random search for expensive evaluations.

Successive Halving and Hyperband

Efficient early stopping of unpromising configurations.

Successive Halving:

  1. Start n configurations with a small budget (e.g., 1 epoch)
  2. Evaluate all, keep top 1/eta fraction
  3. Double the budget, repeat until one remains

Hyperband: run Successive Halving with different initial budgets to handle the exploration-exploitation tradeoff. Uses brackets with varying numbers of initial configurations and budgets.

Given: max budget R, reduction factor eta (typically 3)
s_max = floor(log_eta(R))

For s in {s_max, s_max-1, ..., 0}:
    n = ceil(s_max+1 / (s+1)) * eta^s     # initial configs
    r = R * eta^{-s}                        # initial budget
    Run Successive Halving with n configs starting at budget r

ASHA (Asynchronous Successive Halving)

Asynchronous variant for distributed settings. Promotes configurations as soon as they finish a rung, without waiting for all configurations to complete.

Optimization Landscape

Loss Surface Properties

  • Saddle points: more common than local minima in high dimensions. Gradient is zero but Hessian has positive and negative eigenvalues.
  • Flat minima: SGD with small batch sizes tends to find flat minima, which generalize better than sharp minima.
  • Mode connectivity: different optima are often connected by low-loss paths.

Batch Size Effects

| Small Batch (32-256) | Large Batch (1K-64K) | |-------------------------------|-------------------------------| | More noise, better exploration | Less noise, sharper minima | | Better generalization | Faster wall-clock time | | Lower throughput | Needs lr scaling (linear) | | Standard lr works | May need warmup + LARS/LAMB |

Linear scaling rule: when multiplying batch size by k, multiply learning rate by k.

Practical Optimizer Selection

Default choice: AdamW (lr=3e-4, wd=0.01)
                with cosine schedule + warmup

For CNNs:       SGD + momentum (0.9) + step/cosine lr often competitive
For transformers: AdamW almost always
For fine-tuning:  AdamW with lower lr (1e-5 to 5e-5)
For convex:       L-BFGS if d is moderate and full gradient is cheap

Debugging Optimization

  1. Overfit one batch: if loss doesn't go to zero, the model/code has a bug
  2. Learning rate finder: sweep lr from 1e-7 to 10, plot loss vs lr. Use lr where loss is still decreasing steeply
  3. Gradient histograms: check for vanishing (all near 0) or exploding (very large) gradients
  4. Loss curves: train loss should decrease smoothly. Val loss diverging = overfitting. Both flat = underfitting or lr too low.