89
Recurrent Neural Networks Alex Kalinin [email protected]

Recurrent Networks and LSTM deep dive

Embed Size (px)

Citation preview

Page 1: Recurrent Networks and LSTM deep dive

Recurrent Neural Networks

Alex Kalinin [email protected]

Page 2: Recurrent Networks and LSTM deep dive

Content

1. Example of Vanilla RNN2. RNN Forward pass3. RNN Backward pass4. LSTM design

RNN Training problem

Page 3: Recurrent Networks and LSTM deep dive

Feed-forward (“vanilla”) network

1

0

0

1

0

Page 4: Recurrent Networks and LSTM deep dive

X

y

RNN

h

𝑊 hh

𝑊 h𝑦

𝑊 h𝑥

Vanilla recurrent network

1¿h𝑡= tanh (𝑊 hh h𝑡−1+𝑊 h𝑥 𝑥+𝑏h )

2¿ 𝑦=𝑊 h𝑦h𝑡+𝑏 𝑦

Page 5: Recurrent Networks and LSTM deep dive

Example: character-level language processing

X

y

RNN

Training sequence: ”hello”

Vocabulary: [e, h, l, o]

0100

1000

0010

0001

“h”“e” “l” “0”

𝑊 hh

𝑊 h𝑦

𝑊 h𝑥

Page 6: Recurrent Networks and LSTM deep dive

hX Y

𝑊 h𝑥 =[3 .6 −4.8 0.35 −0.26 ]

𝑊 h𝑦=[ −12.−0.67−0.8514. ]

P

𝑏𝑦=[−0.2−2.96.1−3.4 ]

“hello” RNN

Page 7: Recurrent Networks and LSTM deep dive

hX Y P

0100

“h”

h0=0

“h”

Page 8: Recurrent Networks and LSTM deep dive

hX Y P

0100

“h”

h𝑡=tanh (𝑊 hh h𝑡− 1+𝑊 h𝑥 𝑥+𝑏h )

h0=0

“h”

Page 9: Recurrent Networks and LSTM deep dive

hX Y P

0100

“h”

h=−0.99

“h”

Page 10: Recurrent Networks and LSTM deep dive

hX Y P

0100

“h”

h=−0.99 𝑦=𝑊 h𝑦 h𝑡+𝑏 𝑦

“h”

Page 11: Recurrent Networks and LSTM deep dive

hX Y P

0100

“h”

h=−0.99 𝑦=[ 11.−2.26.9−17 ]

“h”

Page 12: Recurrent Networks and LSTM deep dive

hX Y P

0100

“h”

h=−0.99 𝑦=[ 11.−2.26.9−17 ] 𝑝=[0 .9900.010 ]

“h”

Page 13: Recurrent Networks and LSTM deep dive

hX Y P

0100

“h”

h=−0.99 𝑦=[ 11.−2.26.9−17 ] 𝑝=[0 .9900.010 ]

1000

“e”“h”

Page 14: Recurrent Networks and LSTM deep dive

hX Y P

1000

“e”

h=−0.99

“h” “e”

Page 15: Recurrent Networks and LSTM deep dive

hX Y P

1000

“e”

h=−0.99h𝑡=tanh (𝑊 hh h𝑡− 1+𝑊 h𝑥 𝑥+𝑏h )

“h” “e”

Page 16: Recurrent Networks and LSTM deep dive

hX Y P

1000

“e”

h=−0.09

“h” “e”

Page 17: Recurrent Networks and LSTM deep dive

hX Y P

1000

“e”

h=−0.09 𝑦=𝑊 h𝑦 h𝑡+𝑏 𝑦

“h” “e”

Page 18: Recurrent Networks and LSTM deep dive

hX Y P

1000

“e”

h=−0.09 𝑦=[ 0 .86−2.86.2−4.6 ]

“h” “e”

Page 19: Recurrent Networks and LSTM deep dive

hX Y P

1000

“e”

h=−0.09 𝑦=[ 0 .86−2.86.2−4.6 ] 𝑝=[ 000.990 ]

“h” “e”

Page 20: Recurrent Networks and LSTM deep dive

hX Y P

1000

“e”

h=−0.09 𝑦=[ 0 .86−2.86.2−4.6 ] 𝑝=[ 000.990 ]

0010

“l”“h” “e”

Page 21: Recurrent Networks and LSTM deep dive

hX Y P

0010

“l”

h=−0.09

“h” “e” “l”

Page 22: Recurrent Networks and LSTM deep dive

hX Y P

0010

“l”

38

“h” “e” “l”

Page 23: Recurrent Networks and LSTM deep dive

hX Y P

0010

“l”

38 𝑦=[−4.7−3.25.81.9 ]

“h” “e” “l”

Page 24: Recurrent Networks and LSTM deep dive

hX Y P

0010

“l”

38 𝑦=[−4.7−3.25.81.9 ] 𝑝=[ 000.980.02]

“h” “e” “l”

Page 25: Recurrent Networks and LSTM deep dive

hX Y P

0010

“l”

38 𝑦=[−4.7−3.25.81.9 ] 𝑝=[ 000.980.02]

0010

“l”“h” “e” “l”

Page 26: Recurrent Networks and LSTM deep dive

hX Y P

0010

“l”

38

“h” “e” “l” “l”

Page 27: Recurrent Networks and LSTM deep dive

hX Y P

0010

“l”

98

“h” “e” “l” “l”

Page 28: Recurrent Networks and LSTM deep dive

hX Y P

0010

“l”

98

“h” “e” “l” “l”

𝑦=[−12.−3.65.310. ]

Page 29: Recurrent Networks and LSTM deep dive

hX Y P

0010

“l”

98

“h” “e” “l” “l”

𝑦=[−12.−3.65.310. ] 𝑝=[ 000.010.99 ]

Page 30: Recurrent Networks and LSTM deep dive

hX Y P

0010

“l”

98

“h” “e” “l” “l”

𝑦=[−12.−3.65.310. ] 𝑝=[ 000.010.99 ]

0001

“o”

Page 31: Recurrent Networks and LSTM deep dive

hX Y P

98

“h” “e” “l” “l” “o”

Page 32: Recurrent Networks and LSTM deep dive

hX Y P

“h” h0=0 “e”⨁

“e” -0.99 “l”⨁

“l” -0.09 “l”⨁

“l” 0.38 “o”⨁

Page 33: Recurrent Networks and LSTM deep dive

hX Y P

“hello” “hello”

“hello ben” “hello ben”

“hello world” “hello world”

Page 34: Recurrent Networks and LSTM deep dive

hX Y P

“it was” “it was”

“it was the” “it was the”

“it was the best” “it was the best”

“It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness… “, A Tale of Two Cities, Charles Dickens

50,000

300,000 (loss = 1.6066)

1,000,000 (loss = 1.8197)

“it was the best of” “it wes the best of” 2,000,000 (loss = 4.0844)

Page 35: Recurrent Networks and LSTM deep dive

hX Y P

…epoch 500000, loss: 6.447782290456328 …epoch 1000000, loss: 5.290576956983398 …epoch 1800000, loss: 4.267105168323299 epoch 1900000, loss: 4.175163586546514 epoch 2000000, loss: 4.0844739848413285

Page 36: Recurrent Networks and LSTM deep dive

X

y

RNN

h

𝑊 hh

𝑊 h𝑦

𝑊 h𝑥

Vanilla recurrent network

1¿h𝑡= tanh (𝑊 hh h𝑡−1+𝑊 h𝑥 𝑥+𝑏h )

2¿ 𝑦=𝑊 h𝑦h𝑡+𝑏 𝑦

Page 37: Recurrent Networks and LSTM deep dive

Input:

Target:

i t “ “ w a s “ “

t “ “ w a s “ “ t h

t

Page 38: Recurrent Networks and LSTM deep dive

RNNs for Different Problems

Vanilla Neural Network

Page 39: Recurrent Networks and LSTM deep dive

RNNs for Different Problems

Image Captioningimage -> sequence of words

Page 40: Recurrent Networks and LSTM deep dive

RNNs for Different Problems

Sentiment Analysissequence of words -> class

Page 41: Recurrent Networks and LSTM deep dive

RNNs for Different Problems

Translationsequence of words -> sequence of words

Page 42: Recurrent Networks and LSTM deep dive

h1h0

1 1 2

3

h2

𝑥0 𝑥1 𝑥2

𝐿= 𝑓 (𝑊 h𝑥 ,𝑊 hh ,𝑊 h𝑦)

51

𝑊 hh=0.024

𝑤 h𝑥 ≔𝑤 h𝑥 −0.01 ∙𝜕𝐿𝜕𝑤 h𝑥

𝑤hh≔𝑤hh−0.01 ∙𝜕𝐿𝜕𝑤hh

𝑤h𝑦≔𝑤h𝑦−0.01∙𝜕𝐿𝜕𝑤h𝑦

Training is hard with vanilla RNNs

𝛻 𝐿=[𝜕𝐿𝜕𝑤 h𝑥

, 𝜕𝐿𝜕𝑤hh, 𝜕𝐿𝜕𝑤h 𝑦

]

𝑊 h𝑥

𝑊 hh

𝑊 h𝑦

<— Forward pass

<— Backward pass

Page 43: Recurrent Networks and LSTM deep dive

h1h0

1 1 2

3

h2

𝑥0 𝑥1 𝑥2

𝜕𝐿𝜕𝑤hh

=?

𝐿=?

y

Page 44: Recurrent Networks and LSTM deep dive

𝜕𝐿𝜕𝑤=

𝜕 𝑓𝜕𝑔 ∙

𝜕𝑔𝜕h ∙

𝜕h𝜕𝑘 ∙

𝜕𝑘𝜕 𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤𝐿= 𝑓 (𝑔 (h(𝑘 (𝑙 (𝑚 (𝑛 (𝑤)))))))

𝜕𝐿𝜕𝑤hh

=?

𝐿=(( 𝑊 hh tanh (𝑊 hh tanh (𝑊 hh tanh (𝑊 h𝑥 𝑥0)+𝑊 h𝑥 𝑥1)+𝑊 h𝑥 𝑥2))−3)2

Compute gradient

Recursive application of chain rule:

𝜕𝐿𝜕𝑤=?

𝑓 = 𝑓 (𝑔)𝑔=𝑔(h)h=h (𝑘)

Page 45: Recurrent Networks and LSTM deep dive

Gradient by hand

Page 46: Recurrent Networks and LSTM deep dive

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

𝑊 hh=0.024

1

Forward Pass

0.078

1.

𝑊 h𝑥

𝑥0

Page 47: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

0.078

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 48: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

0.078

tanh0.0778

h0

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 49: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

h0

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 50: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

h0

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 51: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

0.078

1.

𝑊 h𝑥

𝑥1

h0

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 52: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 53: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 54: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970tanh

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 55: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

024

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 56: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

024

*0.0019

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 57: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

0.078

2.

𝑊 h𝑥

𝑥2

024

*0.0019

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 58: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 59: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 60: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 61: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 62: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 63: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+-2.99

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 64: Recurrent Networks and LSTM deep dive

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 65: Recurrent Networks and LSTM deep dive

𝜕𝐿𝜕𝑤=

𝜕 𝑓𝜕𝑔 ∙

𝜕𝑔𝜕h ∙

𝜕h𝜕𝑘 ∙

𝜕𝑘𝜕 𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤

𝐿= 𝑓 (𝑔 (h(𝑘 (𝑙 (𝑚 (𝑛 (𝑤)))))))

𝜕𝐿𝜕𝑤hh

=?

Compute gradient

Recursive application of chain rule:

Page 66: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝜕 𝑓𝜕𝑔 ∙

𝜕𝑔𝜕h ∙

𝜕h𝜕𝑘 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 67: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝜕 𝑓𝜕𝑔 ∙

𝜕𝑔𝜕h ∙

𝜕h𝜕𝑘 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 68: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝜕𝑔𝜕h ∙

𝜕 h𝜕𝑘 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1

𝜕 𝑓𝜕𝑔=?

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 69: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝜕𝑔𝜕h ∙

𝜕 h𝜕𝑘 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1

𝜕 𝑓𝜕𝑔=

𝜕𝑔2𝜕𝑔 =2𝑔=2 (−2.99 )=−5.98

-5.98

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 70: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝜕 h𝜕𝑘 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1-5.98

𝜕𝑔𝜕h=1

-5.98

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 71: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1-5.98

-5.98

𝜕 h𝜕𝑘=𝑊 h𝑦

0.051tanh

tanh

𝜕h𝜕𝑊 h𝑦

=𝑘

0.1566

-0.304

0.936

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 72: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1-5.98

-5.98

𝜕 h𝜕𝑘=𝑊 h𝑦

tanh

tanh

𝜕h𝜕𝑊 h𝑦

=𝑘

-0.304

0.936

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 73: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1-5.98

-5.98

𝜕𝑘𝜕𝑙 =1−𝑘

2=1− .15662=.975

-0.304-0.297tanh

tanh

0.936

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 74: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝝏 𝒍𝝏𝒎 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071

0.936

-0.304

-0.297

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 75: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝜕𝑛𝜕𝑤 h𝑥

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071

1−𝑘2=1− .07972=.993

-0.0071

0.936

-0.304

-0.297

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 76: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

0.936

-0.304

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝜕𝑛𝜕𝑤 h𝑥

-0.0005

-0.297

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 77: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

1−𝑘2=1− .07782=.993

0.936

-0.304

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝜕𝑛𝜕𝑤 h𝑥

-0.00017

-0.0005

-0.297

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 78: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

0.936

-0.304

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝝏𝒏

𝝏𝒘 𝒙𝒉

-0.00017

-0.00017

-0.0005

-0.297

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 79: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

0.936

-0.304

-0.00017

-0.00017

-0.0005

-0.297𝑤𝑎≔𝑤𝑎−0.01 ∙

𝜕𝐿𝜕𝑤𝑎

𝑤 h𝑥 ≔0.078−0.01∙ (− .00017 )=0.0780017

𝑤hh≔0.024−0.01 ∙ (− .0005 )=0.024005

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 80: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh

*

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+

h1

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*

+0.1579

0.051𝑊 h𝑦

*

+ **

1-5.98

tanh

tanh-0.297-0.0071

-0.0071

-0.00017

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Page 81: Recurrent Networks and LSTM deep dive

𝜕𝐿𝜕 𝑥=𝑤hh…𝑤hh…𝑤hh…𝑤hh=𝑤hh

𝑛 ∙𝐶 (𝑤)

𝑤hh𝑤hh𝑤hh𝑤hh𝑤hh

1. 0.024 2. 0.000576 3. 1.382e-05 4. 3.318e-07 5. 7.963e-09 6. 1.911e-10 7. 4.586e-12 8. 1.101e-13 9. 2.642e-1510. 6.340e-17

𝑊 hh=0.024tanh tanhtanhtanhtanhtanh

Page 83: Recurrent Networks and LSTM deep dive

W

x

2n

4n

(𝑖𝑓𝑜𝑔)=(

𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚

h𝑡𝑎𝑛 )𝑊 ( 𝑥h𝑡−1)

𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔

h𝑡=𝑜 ∙ tanh (𝑐𝑡)

i

f

o

g

x

h

Long Short-Term Memory (LSTM)

n

n

n

n

𝜎

𝜎

𝜎

𝜏

𝑡−1 𝑡

h𝑡=( tanh )𝑊 ( 𝑥h𝑡− 1) - RNN

Page 84: Recurrent Networks and LSTM deep dive

𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔

h𝑡=tanh (𝑊 hh h𝑡− 1+𝑊 h𝑥 𝑥 )RNN:

LSTM:

(𝑖𝑓𝑜𝑔)=(

𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚

h𝑡𝑎𝑛 )𝑊 ( 𝑥h𝑡−1)

𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔

h𝑡=𝑜 ∙ tanh (𝑐𝑡)

forgetgate,0/1

inputgate, 0/1

Page 85: Recurrent Networks and LSTM deep dive

f

incomingX

i og

+

X

tanh

X

Long Short-Term Memory (LSTM)

(𝑖𝑓𝑜𝑔)=(

𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚

h𝑡𝑎𝑛 )𝑊 ( 𝑥h𝑡−1)

𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔

h𝑡=𝑜 ∙ tanh (𝑐𝑡)

𝑐𝑡− 1

h𝑡

Page 86: Recurrent Networks and LSTM deep dive

𝜕𝐿𝜕 𝑥=𝑤hh…𝑤hh…𝑤hh…𝑤hh=𝑤hh

𝑛 ∙𝐶 (𝑤)

𝑤hh𝑤hh𝑤hh

f f f

f f f

+ + +

RNN

LSTM

Flow of gradient

𝑡−1 𝑡 𝑡+1

𝑡−1 𝑡 𝑡+1

Page 89: Recurrent Networks and LSTM deep dive

Reference

1. Long Term-Short Memory (Hochreiter, 1997), http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf

2. Learning Long Term Dependencies With Gradient Descent is Difficult (Yoshua Bengio, 1994), http://www.dsi.unifi.it/~paolo/ps/tnn-94-gradient.pdf

3. http://neuralnetworksanddeeplearning.com/chap5.html

4. Deep Learning, Ian Goodfellow et al., The MIT Press

5. Recurrent Neural Networks, LSTM, Andrej Karpathy, Stanford Lectures, https://www.youtube.com/watch?v=iX5V1WpxxkY

Alex Kalinin [email protected]