In this tutorial, we’ll talk about ADAM, an optimization algorithm we frequently use to train machine-learning models.
2. Optimization in Learning
When training models such as neural networks or support vector machines, we search for the model’s parameters that minimize the cost function quantifying the model’s predictions’ deviation from the correct labels.
Over the years, many optimization algorithms have been proposed. Stochastic gradient descent (SGD) with mini-batches updates the parameters by evaluating the gradient only on a sample from training data.
Momentum algorithms keep track of the gradient history and use the sequence of gradients at previous iterations to update the parameter vector at each new iteration.
Some of these algorithms are adaptive. That means they change the learning rate in each iteration according to the current gradient or the history of gradients recorded so far.
ADAM, whose name is an abbreviation for adaptive moments, combines all those ideas and is currently one of the most widely used training algorithms.
The distinctive features of ADAM are:
- mini-batch gradient updates
- adaptive momentum
- adaptive learning rate
- bias correction
The first one on the list means we use a sample of training data to update the parameters in each iteration.
We’ll denote with the sample we use in the -th iteration.
3.1. Adaptive Momentum
be the parameter vector in the -th iteration. Its gradient, with respect to the cost function, is:
In the mini-batch SGD, we update by going in the direction opposite to :
where is the learning rate at the iteration . Usually, we let decay exponentially with .
ADAM follows a different strategy. Instead of only , it uses all the gradients computed before: . The idea is to treat as a particle in the parameter space. Its previous gradients show its velocity and direction. Even if it gets stuck in a subspace with small gradients, the momentum of the previous steps is expected to push it in the right direction. In contrast, the ordinary SGD will stop if the gradient reaches zero.
So, this is how ADAM defines the momentum vector:
It’s a linear combination of all the gradients computed up to the iteration , but with the coefficients of the gradients computed in the distant past decaying exponentially.
3.2. Bias Correction
The vector estimates the expected value of the gradient. First, we initialize it to the zero vector, so its computation history is as follows:
Taking the expectations of the last equation’s both sides, we get:
Each expectation is off from . We can account for all the (weighted) deviations with a constant :
which we can keep small if results in small coefficients for less recent gradients.
Since , the equation comes down to:
So, to correct for the bias, we divide it by . The update step is then:
3.3. Adaptive Learning Rate
The final ingredient is the adaptive learning rate . In ADAM, each parameter has its learning rate, so is a vector of rates that multiply with one element at a time.
What do we want to achieve with adaptive rates? The idea is that if a dimension is rarely updated and has a lot of zero gradients in its history, we should allow it a higher learning rate. In ADAM, the learning rate is inversely proportional to the square of the sum of the gradient elements:
where is element-wise multiplication.
More specifically, the rate vector is inversely proportional to the variance of the gradient elements.
Just as we did with , we need to correct :
The derivation is the same as for .
Finally, the update step is:
where is the step size, and is a small constant we use to avoid division by zero.
Here’s the pseudocode of ADAM:
There are various termination criteria to use. For instance, we can set the maximal number of iterations or stop the algorithm if the difference between two consecutive vectors is less than a predefined constant.
In this article, we explained how ADAM works.
ADAM is an adaptive optimization algorithm we use for training machine-learning models. It uses the history of mini-batch gradients to find the direction of its update steps and tune its learning rates.