Navigating the Path to Model Performance: From Gradient Descent to Adam
Authored by: Loveleen Narang
Date: December 2, 2024
At the heart of training most machine learning models lies an optimization problem. The goal is to find the set of model parameters (weights and biases, collectively denoted by \( \theta \)) that minimizes a predefined loss function (or cost function), \( J(\theta) \) (Formula 1). This loss function quantifies the discrepancy between the model's predictions and the actual target values in the training data. The process of minimizing this loss is essentially how a model "learns" from data.
This is often framed as Empirical Risk Minimization (ERM), where we aim to minimize the average loss over the training dataset \( \{(x_i, y_i)\}_{i=1}^N \): Formula (2):
Here, \( L \) is the loss for a single data point, and \( f(x_i; \theta) \) is the model's prediction for input \( x_i \) using parameters \( \theta \). Common loss functions include Mean Squared Error (MSE) for regression (Formula 3: \( L = (y - \hat{y})^2 \)) and Cross-Entropy Loss for classification (Formula 4: \( L = -\sum y_k \log(\hat{y}_k) \)). The choice of optimization technique significantly impacts training speed, convergence, and the final performance of the model.
Fig 1: Analogy of minimizing loss as finding the lowest point on an error surface, showing potential pitfalls like local minima and saddle points.
Most optimization in modern ML relies on variants of Gradient Descent.
The most basic form. It calculates the gradient of the loss function \( J(\theta) \) with respect to the parameters \( \theta \) using the entire training dataset in each iteration. The parameters are then updated in the direction opposite to the gradient.
Pros: Guaranteed convergence to a local minimum (for convex functions) or stationary point, stable updates.
Cons: Computationally very expensive for large datasets as it requires processing all data for a single update, can get stuck in poor local minima or saddle points.
SGD addresses the computational cost of GD by updating parameters using the gradient calculated from only one randomly chosen training sample \( (x_i, y_i) \) at each iteration.
Pros: Much faster updates (computationally cheap), introduces noise that can help escape shallow local minima and saddle points.
Cons: High variance in updates (noisy convergence path), requires careful tuning of the learning rate (often needs decay schedule), doesn't utilize hardware vectorization effectively.
A compromise between GD and SGD. It updates parameters using the gradient calculated from a small, randomly selected subset (mini-batch) of the training data (\( B \) samples) at each iteration.
Pros: Reduces variance compared to SGD leading to more stable convergence, allows efficient computation using vectorized operations on modern hardware (GPUs/TPUs), still benefits from some noise to escape local minima.
Cons: Introduces a new hyperparameter (batch size), still requires careful learning rate tuning.
Mini-Batch GD is the most common approach used in deep learning.
Optimizer | Data per Update | Update Speed | Variance | Computation | Common Use |
---|---|---|---|---|---|
Batch GD | Entire Dataset | Slow | Low | Very High (per update) | Small datasets, theoretical analysis |
SGD | Single Sample | Fast | High | Very Low (per update) | Online learning, can escape local minima |
Mini-Batch GD | Small Batch (B) | Moderate | Moderate | Efficiently Vectorized | Standard for Deep Learning |
Fig 2: Conceptual paths taken by different Gradient Descent variants towards the minimum.
Momentum methods aim to accelerate GD convergence, especially in directions of persistent descent, and dampen oscillations.
Introduces a "velocity" vector \( v \) that accumulates an exponentially decaying moving average of past gradients. The parameter update incorporates this velocity.
Think of it as a ball rolling down a hill – it gathers momentum in consistent directions and smooths out bumps (oscillations).
A modification of momentum that provides a "lookahead" capability. It calculates the gradient not at the current position \( \theta_t \), but at an approximated future position \( \theta_t - \gamma v_{t-1} \) based on the current velocity.
NAG often converges faster than standard momentum as it can correct its course more effectively before overshooting.
These methods automatically adjust the learning rate \( \eta \) during training, often on a per-parameter basis, eliminating the need for extensive manual tuning of a global learning rate schedule.
Adapts the learning rate for each parameter \( \theta_j \) inversely proportional to the square root of the sum of all historical squared gradients for that parameter. Large gradients lead to smaller learning rates, and small gradients lead to larger learning rates.
Pros: Good for sparse data (e.g., NLP) as infrequent features get larger updates.
Cons: The accumulated squared gradients \( G_t \) grow monotonically, causing the learning rate to eventually become infinitesimally small, potentially stopping learning prematurely.
Addresses AdaGrad's diminishing learning rate by using an exponentially decaying moving average of squared gradients instead of accumulating all past squared gradients.
RMSprop keeps the denominator from growing indefinitely, allowing learning to continue.
Perhaps the most popular adaptive method. It combines the ideas of Momentum (using an exponentially decaying average of past gradients - first moment) and RMSprop (using an exponentially decaying average of past squared gradients - second moment).
Adam often works well with default hyperparameter settings across a wide range of problems.
A modification of Adam that decouples the weight decay (L2 regularization) from the adaptive learning rate mechanism, often leading to better generalization.
Optimizer | Key Idea | Pros | Cons |
---|---|---|---|
AdaGrad | Per-parameter rate based on accumulated squared gradients | Good for sparse data | Learning rate shrinks too fast |
RMSprop | Per-parameter rate based on decaying average of squared gradients | Solves AdaGrad's shrinking rate problem | Requires tuning decay rate \( \beta \) |
Adam | Combines Momentum (1st moment) and RMSprop (2nd moment) with bias correction | Often fast convergence, widely used, good default parameters | Can sometimes converge to suboptimal points; AdamW variant often preferred with regularization |
AdamW | Adam with decoupled weight decay | Often better generalization than Adam when using L2 regularization | Slightly different implementation |
While gradient descent relies on first-order derivatives (gradients), second-order methods utilize the second derivative (Hessian matrix \( H \)) which captures curvature information.
Pros: Can converge much faster than first-order methods near a minimum, especially in ill-conditioned problems.
Cons: Calculating/storing/inverting the Hessian (or its approximation) is computationally prohibitive for large models (\( O(d^2) \) or \( O(d^3) \) cost for \( d \) parameters), making them impractical for most deep learning applications, although L-BFGS sees some use.
Regularization techniques, like L1 and L2 weight decay, are often incorporated directly into the optimization objective to prevent overfitting. They add a penalty term \( \lambda R(\theta) \) to the loss function.
The optimizer then minimizes this combined objective, balancing empirical risk with model complexity.
Optimization is the engine that drives machine learning model training. While basic Gradient Descent provides the foundation, its limitations in large-scale, complex settings have spurred the development of more sophisticated techniques. Stochastic and Mini-Batch variants handle large datasets, Momentum methods accelerate convergence, and Adaptive Learning Rate methods like Adam automatically tune step sizes, becoming the default choice for many deep learning tasks. While challenges like navigating complex loss landscapes and hyperparameter tuning persist, the continuous evolution of optimization algorithms is crucial for pushing the boundaries of machine learning performance and enabling the training of increasingly complex models. Understanding these techniques is essential for any practitioner aiming to effectively train and deploy machine learning solutions.
(Formula count check: Includes Min J, ERM, MSE, CrossEntropy, GD Update, Eta, Grad J, SGD Update, MiniBatch Update, B, Momentum v, Momentum Update, Gamma, NAG v, NAG Update, AdaGrad G, AdaGrad Update, Epsilon, RMSprop E[g^2], RMSprop Update, Beta, Adam m, Adam v, Beta1, Beta2, Adam m_hat, Adam v_hat, Adam Update, Newton Update, Hessian H, L2 Term, L1 Term, Regularized J. Total > 33).