Effects of Learning Rates on the Convergence of Gradient ...lxiong/cs378/share/project/... ·...

Preview:

Citation preview

Effects of Learning Rates on the Convergence of Gradient Descent

A Comparative Study on the WNGrad1 Algorithm

Yijun Dong

CS 378 Final Project

1. X. Wu, R. Ward, and L. Bottou. WNGrad: Learn the Learning Rate in Gradient Descent. ArXiv preprint arVix: 1803.02865v1

[stat.ML], 2018.

Standard set-up for gradient descent

• For the loss function 𝑓: ℝ𝑑 → ℝ

• We want to find 𝒙∗ ∈ ℝ𝑑 such that

𝑓 𝒙∗ = min𝒙

𝑓(𝒙)

• By iteratively moving 𝒙 in the direction of negative gradient

𝒙𝒌+𝟏 ← 𝒙𝒌 − 𝜂𝑗 ∙ 𝛻𝑓(𝒙𝒌)

• The algorithm only converges for appropriately tuned learning rates 𝜂𝑗

Loss function for linear regression

• Given database {𝒂𝒌, 𝑦𝑘}𝑘=1:𝑁 with N items and 𝒂𝒌 ∈ ℝ𝑚

• Each item 𝒂𝒌 contains m attributes with label 𝑦𝑘

• We want to find a linear model 𝑀: ℝ𝑚 → ℝ

𝑀𝒙 𝒂 = 𝒙𝑇 ∙ 𝒂 + 𝒃

• To minimize the loss function

𝑓 𝒙 =1

2𝑁

𝑖=1

𝑁

(𝑀𝒙 𝒂𝑖 − 𝑦𝑖)2

• With gradient

𝛻𝑓 𝒙 =1

𝑁

𝑖=1

𝑁

(𝑀𝒙 𝒂𝑖 − 𝑦𝑖) ∙ 𝒂𝒊

L-Lipschitz continuous and convex function

• We say that a differentiable function 𝑓: ℝ𝑑 → ℝ has L-Lipschitz continuous gradient, denoting 𝑓 ∈ 𝐶𝐿

1, if

||𝛻𝑓 𝒙 − 𝛻𝑓 𝒚 || ≤ 𝐿 ∙ ||𝒙 − 𝒚||, ∀ 𝒙, 𝒚 ∈ ℝ𝑑

• L is known as the Lipschitz constant

• L measures the least upper bound of the ‘oscillation’ in 𝜵𝒇

• We also assume that 𝑓 is convex (e.g., linear regression problem)

• The global minimum of the loss function exists

• The convergence of gradient descent is independent of the initial point 𝒙𝟎

Three approaches: 𝒙𝒌+𝟏 ← 𝒙𝒌 − 𝜂𝑗 ∙ 𝛻𝑓(𝒙𝒌)

• Batch gradient descent:

• Exact gradient 𝛻𝑓 from the entire training set

• Stochastic gradient descent:

• Stochastic gradient 𝑔𝑘 from single items

• Scan the database in a randomly shuffled order

𝑔𝑘 =1

𝑁(𝑀𝒙 𝒂𝑖 − 𝑦𝑖) ∙ 𝒂𝒊

• Mini-Batch gradient descent: (with batch size n)

• Stochastic gradient 𝑔𝑘 from a random subset of the training set

• Random shuffle + Partitioning into mini-batches

𝑔𝑘 =1

𝑛

𝑖=1

𝑛

(𝑀𝒙 𝒂𝑖 − 𝑦𝑖) ∙ 𝒂𝒊

Testing databases

• EX3

• 47 items, 2 attributes, labeled

• Linear regression

• Online news popularity data set (UCI)

• 39644 items, 58 attributes, labeled

• Linear regression

• MNIST database of handwritten digits

• 28*28 images of handwritten digits 0-9

• 60,000 images in training set

• 10,000 images in testing set

• Convolutional neural network

Constant learning rate in BGD (EX3): 𝒙𝒌+𝟏 ← 𝒙𝒌 − 𝜂 ∙ 𝛻𝑓(𝒙𝒌)

Constant learning rate in BGD (EX3): 𝒙𝒌+𝟏 ← 𝒙𝒌 − 𝜂 ∙ 𝛻𝑓(𝒙𝒌)

BGD with constant learning rate 𝜼 is very

sensitive to the Lipschitz constant:

diverge when 𝜼 ≥ 𝟐/𝑳

Tuning learning rate according to ‘local Lipschitz constant’

• Goal: keep the learning rate bounded

• When η ≫ 1/L, the algorithm may diverge

• When η ≪ 1/L, the convergence is slow

• WNGrad algorithm

• 𝒃𝒋 ~ local Lipschitz constant

• After finite (k) iterations,

• Either 𝑓𝑘 → 𝑓∗ converges

• Or 𝑏𝑘 ≥ 𝐿• 𝑏𝑗 is stabilized:

𝐿 ≤ 𝑏𝑗 ≤ 𝐶𝐿 for all 𝑗 ≥ 𝑘

WNGrad is more robust to the choice of 𝜼𝟎 in BGD

Batch gradient descent with

constant learning rate and

WNGrad, tested on a training

subset of 500 items from UCI

online news popularity data set

WNGrad is more robust to the choice of 𝜼𝟎 in BGD

Batch gradient descent with

WNGrad diverges much slower

than the constant learning rate as

𝜂0 increases 0.05 → 1

(UCI online news popularity data set)

WNGrad is robust to the choice of 𝜼𝟎 in SGD

Stochastic gradient descent with

WNGrad considerably oscillates

without obviously diverging as 𝜂0increases 0.05 → 1

(UCI online news popularity data set)

WNGrad is robust to the choice of 𝜼𝟎 in MBGD

Mini-batch gradient descent with

WNGrad and batch size of 10/500

oscillates and gradually increases

without obviously diverging as 𝜂0increases 0.05 → 1

(UCI online news popularity data set)

Summary and next steps

• By testing a convex and L-Lipschitz continuous loss function from the linear regression problem, we find that

• WNGrad algorithm is robust to the choice of initial learning rate 𝜂0, and therefore the Lipschitz constant of the loss function

• The robustness is retained for all three gradient descent algorithms, GD, SGD, and MBGD

• Next steps

• Apply WNGrad on the backpropagation in training CNN

• Test on the MNIST database

Break Point

Recommended