Gradient descent & advanced gradient methods
In this article we will dive into the specifics of how gradient descent works and will also look at its limitations. At the end we will cover advanced gradient descent methods like Adam optimizer or Nesterov momentum that improve on these limitations.
Gradient descent is a method that appears in many different areas in machine learning and data science. It is an important method in optimization problems like minimizing a loss function.
The basics of gradient descent
We start off with the problem to find the minimum of a given function $f(x)$. For the gradient descent method we also need the derivate of the function $\nabla f(x)$. We then pick a starting point $x_0$ on the function and choose learning rate $\alpha$ to initialize the algorithm. From this starting point $x_0$ we can use the gradient to calculate the next step with
\[ x_{n+1} = x_n - \alpha \nabla f(x_n) \]For every step we check the gradient and move in the opposite direction times the learning rate. This way the next step will be closer to the minimum as we move down the function against the gradient. Usually one chooses a learning rate round $\alpha < 0.1$ because a too large learning rates slows down the convergence process.
Let us now go through an explicit example to show how this works in detail. We will first write an implementation of the gradient descent algorithm in python.
import numpy as np
def gradient_descent(learning_rate=0.1, iterations=100, x0=0):
x = x0
path = [x] # To store x values over iterations
for i in range(iterations):
grad = gradient_f(x)
x = x - learning_rate * grad # Calculate new x for step
path.append(x)
return np.array(path)
We control the number of iterations and calculate the next step with the formula from above.
Applying gradient descent to find minimum
We now choose an example function like $f(x)=x^2-2x$. For simplicity we assume that we know the gradient and write $f'(x)=2x - 2$. In the code we define the functions
# Function to minimize: f(x) = x^2 - 2x
def f(x):
return x**2 - 2x
# Gradient of f(x): df/dx = 2x - 2
def gradient_f(x):
return 2 * x - 2
We then choose our starting point as $x_0=-2$ and our learning rate as $\alpha=0.1$. Here we can see our result. The algorithm converges nicely and is able to detect the minimum of our function.
Let us now try a different learning rate to see what impact this has. If we choose $\alpha=0.9$ we can observe that the learning rate is so large that on every step the algorithm overshoots and takes a longer time to converge to the minimum. If we choose a learning rate like $\alpha=0.01$ the convergence is very slow and with our 100 iterations does not even reach the minimum. In praxis one can figure out the optimal or almost optimal learning late by setting a stopping condition like $|x_n - x_{n+1}|< 10^{-6}$, which means the updates are not changing much anymore. Checking how many steps the algorithm takes until convergence, will give a good idea of the optimal learning rate.
Limitations of gradient descent
In the example above we saw that this works great to find the minimum of a function. However this is not always the case. Let us now look at another example function $g(x)=x^4 - 3x^3 + 2$, where the derivative is $g'(x)=4x^3-9x^2$. If we apply our standard gradient descent method with starting point $x_0=-0.75$ and learning rate $\alpha=0.03$ we get the following outcome.
The problem here is that the algorithm cannot "see" further than the saddle point an mistakenly converges there becuase the gradient is zero. We will need more advanced ideas to overcome this issue.
Advanced gradient descent methods
While basic gradient descent is a powerful optimization technique, it has its limitations, especially with large datasets and complex functions or function landscapes in higher dimensions. To address the issues we saw above, more advanced methods have been developed.
RMSprop
RMSprop (Root Mean Square Propagation) is an adaptive learning rate method that divides the learning rate by an exponentially decaying average of squared gradients. This helps to stabilize the training process and can lead to faster convergence.
Algorithm Steps: Initialize parameters $x_0$, learning rate $\alpha$, decay factor $\beta$, and a small constant $\epsilon$. Set the initial moving average of squared gradients: $v = 0$. Then for each iteration:
- First we update the moving average of squared gradients. \[ v_n = \beta v_{n-1} + (1 - \beta) \cdot (\nabla f(x_n))^2 \]
- Then we calculate the next step using this adjusted learning rate: \[ x_{n+1} = x_n - \dfrac{\alpha}{\sqrt{v_n} + \epsilon} \nabla f(x_n) \]
def rmsprop(lr=0.o1, iterations=100, x0=0, beta=0.9, epsilon=1e-8):
x = x0
path = [x]
v = 0
for _ in range(iterations):
grad = gradient_f(x)
v = beta * v + (1 - beta) * grad**2
x -= lr * grad / (np.sqrt(v) + epsilon)
path.append(x)
return np.array(path)
Adam
Adam (Adaptive Moment Estimation) combines the benefits of both SGD and RMSprop by maintaining a running average of both the gradients and their squared values. This provides an adaptive learning rate for each parameter, which can lead to faster and more stable convergence.
Algorithm Steps: Initialize parameters $x$, learning rate $\eta$, decay rates $\beta_1, \beta_2$, and small constant $\epsilon$. Then set moment estimates: $m_0 = 0$ and $v_0 = 0$. Then for each iteration:
- Update the biased first moment estimate: \[ m_n = \beta_1 m_{n-1} + (1 - \beta_1) \nabla f(x_n) \]
- Update the biased second moment estimate: \[ v_n = \beta_2 v_{n-1} + (1 - \beta_2) (\nabla f(x_n))^2 \]
- Compute bias-corrected estimates: \[ \hat{m}_n = \frac{m_n}{1 - \beta_1^n}, \quad \hat{v}_n = \frac{v_n}{1 - \beta_2^n} \]
- Then we calculate the next step: \[ x_{n+1} = x_n - \frac{\eta}{\sqrt{\hat{v}_n} + \epsilon} \hat{m}_n \]
def adam(learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
iterations=50,
initial_x=0):
x = x0
m = 0
v = 0
for i in range(iterations):
grad = gradient_f(x, point)
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * grad**2
m_hat = m / (1 - beta1**(i+1))
v_hat = v / (1 - beta2**(i+1))
x = x - (learning_rate / (v_hat**0.5 + epsilon)) * m_hat
return x
Nesterov Accelerated Gradient (NAG)
Nesterov Accelerated Gradient (NAG) uses a momentum term to accelerate convergence and anticipates the gradient, leading to faster convergence.
Algorithm Steps: Initialize parameters $x$, learning rate $\eta$, momentum $\mu$, and velocity $v = 0$. Then for each iteration:
- Compute the lookahead point: \[ x_{\text{lookahead}} = x_n - \mu v_{n-1} \]
- Compute the gradient at the lookahead point: \[ g_n = \nabla f(x_{\text{lookahead}}) \]
- Update velocity: \[ v_n = \mu v_{n-1} + \eta g_n \]
- Update parameters: \[ x_{n+1} = x_n - v_n \]
def nesterov(learning_rate=0.01, momentum=0.9, iterations=50, x0=0):
x = 0x
v = 0
for i in range(iterations):
grad = gradient_f(x - momentum * v)
v = momentum * v + learning_rate * grad
x = x - v
return x
Comparison of gradient based optimizers
Finally we look at our test function $g(x)=x^4 - 3x^3 + 2$ again, where the derivative is still $g'(x)=4x^3-9x^2$. We test the four methods described in this article with starting point $x_0=-0.75$ and learning rate $\alpha=0.01$.
We can see that the basic gradient descent method gets stuck at the saddle point, RMSprop is also stuck at the saddle point with this parameter choice, which is probably not optimal. If we choose a learning rate of $\alpha=0.1$ RMSprop is able to overcome the saddle point and converge to the minimum.
Adam is able to overcome the plateau due to its momentum term. However with this parameter choice Adam needs a lot of steps on the flat part of the function.
Nesterov Accelerated Gradient is able to overcome the saddle point and converge to the minimum much quicker than the other methods. Due to its lookahead momentum term it is able to anticipate the gradient and move in the right direction with larger steps. This shows that the choice of optimizer and its parameters can have a big impact on the optimization process.
Of course our results are not universal and depend on the function we are trying to optimize. In general we can say that the more complex the function is, the more advanced methods are needed to overcome saddle points and local minima. In the end it is always a good idea to try out different optimizers and their parameters to find the best solution for your problem.