15
A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU) RNN Minchen Li Department of Computer Science The University of British Columbia [email protected] Abstract In this tutorial, we provide a thorough explanation on how BPTT in GRU 1 is conducted. A MATLAB program which implements the entire BPTT for GRU and the psudo-codes describing the algorithms explicitly will be presented. We provide two algorithms for BPTT, a direct but quadratic time algorithm for easy understanding, and an optimized linear time algorithm. This tutorial starts with a specification of the problem followed by a mathematical derivation before the computational solutions. 1 Specification We want to use a dataset containing n s sentences each with n w words to train a GRU language model, and our vocabulary size is n v . Namely, we have input x R nv×nw×ns and label y R nv×nw×ns both representing n s sentences. For simplicity, lets look at one sentence at a time. In one sentence, the one-hot vector x t R nv×1 represents the t th word. For time step t, the GRU unit computes the output ˆ y t using the input x t and the previous internal state s t-1 as follows: z t = σ(U z x t + W z s t-1 + b z ) r t = σ(U r x t + W r s t-1 + b r ) h t = tanh(U h x t + W h (s t-1 r t )+ b h ) s t = (1 - z t ) h t + z t s t-1 ˆ y t = sof tmax(Vs t + b V ) (1) Here is the vector element-wise multiplication, σ() is the element-wise sigmoid function, and tanh() is the element-wise hyperbolictangent function. The dimensions of the parameters are as follows: U z ,U r ,U h R ni×nv W z ,W r ,W h R ni×ni b z ,b r ,b h R ni×1 V R nv×ni ,b V R nv×1 where n i is the internal memory size set by the user. 1 GRU is an improved version of traditional RNN (Recurrent Neural Network, see WildML.com for an in- troduction. This link also provides an introduction to GRU and some general discussion on BPTT and beyond.)

A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

  • Upload
    others

  • View
    38

  • Download
    0

Embed Size (px)

Citation preview

Page 1: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

A Tutorial OnBackward Propagation Through Time (BPTT)

In The Gated Recurrent Unit (GRU) RNN

Minchen LiDepartment of Computer Science

The University of British [email protected]

Abstract

In this tutorial, we provide a thorough explanation on how BPTT in GRU1 isconducted. A MATLAB program which implements the entire BPTT for GRUand the psudo-codes describing the algorithms explicitly will be presented. Weprovide two algorithms for BPTT, a direct but quadratic time algorithm for easyunderstanding, and an optimized linear time algorithm. This tutorial starts witha specification of the problem followed by a mathematical derivation before thecomputational solutions.

1 Specification

We want to use a dataset containing ns sentences each with nw words to train a GRU languagemodel, and our vocabulary size is nv . Namely, we have input x ∈ Rnv×nw×ns and label y ∈Rnv×nw×ns both representing ns sentences.

For simplicity, lets look at one sentence at a time. In one sentence, the one-hot vector xt ∈ Rnv×1

represents the tth word. For time step t, the GRU unit computes the output yt using the input xt andthe previous internal state st−1 as follows:

zt = σ(Uzxt +Wzst−1 + bz)

rt = σ(Urxt +Wrst−1 + br)

ht = tanh(Uhxt +Wh(st−1 � rt) + bh)

st = (1− zt)� ht + zt � st−1

yt = softmax(V st + bV )

(1)

Here � is the vector element-wise multiplication, σ() is the element-wise sigmoid function, andtanh() is the element-wise hyperbolictangent function. The dimensions of the parameters are asfollows:

Uz, Ur, Uh ∈ Rni×nv

Wz,Wr,Wh ∈ Rni×ni

bz, br, bh ∈ Rni×1

V ∈ Rnv×ni , bV ∈ Rnv×1

where ni is the internal memory size set by the user.

1GRU is an improved version of traditional RNN (Recurrent Neural Network, see WildML.com for an in-troduction. This link also provides an introduction to GRU and some general discussion on BPTT and beyond.)

Page 2: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

Then for step t, we can calculate the cross entropy loss Lt as:

Lt = sumOfAllElements(− yt � log(yt)

)(2)

Here log is also an element-wise function.

To train the GRU, we want to know the values of all parameters that minimize the total loss L =∑nw

t=1 Lt:argmin

ΘL

where Θ = {Uz, Ur, Uc,Wz,Wr,Wc, bz, br, bc, V, bV }. This is a non-convex problem with hugeinput data. So people usually use Stochastic Gradient Descent2 method to solve this problem, whichmeans we need to calculate ∂L/∂Uz , ∂L/∂Ur, ∂L/∂Uh, ∂L/∂Wz , ∂L/∂Wr, ∂L/∂Wh, ∂L/∂bz ,∂L/∂br, ∂L/∂bh, ∂L/∂V , ∂L/∂bV given a batch of sentences. (Note that in each step, theseparameters stays the same.) In this tutorial we consider using only one sentence at a time to make itconcise.

2 Derivation

The best way to calculate gradients using the Chain Rule from output to input is to first draw theexpression graph of the entire model in order to figure out the relations between the output, interme-diate results, and the input3. Here we draw part of the expression graph of GRU in Fig.1.

Figure 1: The upper part of expression graph describing the operations of GRU. Note that the sub-graph which st−1 depends on is just like the sub-graph of st. This is what the red dashed linesmean.

With this expression graph, the Chain Rule works if you go backwards along the edges (top-down).If a node X has multiple outgoing edges connecting the target node T , you need to sum over thepartial derivatives of each of those outgoing edges to derive the gradient ∂T/∂X . We will illustratethe rules in the following paragraphs.

Let’s take ∂L/∂Uz as the example here. Others are just similar. Since L =∑nw

t=1 Lt and theparameters stay the same in each step, we also have ∂L/∂Uz =

∑nw

t=1(∂Lt/∂Uz), so let’s calculateeach ∂Lt/∂Uz independently and sum them up.

2See the Wikipedia to get some knowledge about Stochastic Gradient Descent.3See colah’s blog and Stanford CS231n Course Note for some general introductions.

2

Page 3: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

With the Chain Rule, we have:∂Lt

∂Uz=∂Lt

∂st

∂st∂Uz

(3)

The first part is just trivial if you know how to differentiate the cross entropy loss function embeddedwith the softmax function:

∂Lt

∂st= V (yt − yt)

For ∂z/∂Uz , similarly, some people might just derive: (if they know how to differentiate sigmoidfunction)

∂st∂Uz

=(

(st−1 − ht)� zt � (1− zt))xTt (4)

Here there are two expressions 1 − z and z � st−1 influencing ∂st/∂z as shown in our expressiongraph. The solution is to derive partial derivatives through each edge and then add them up, whichis exactly how we deal with ∂st/∂st−1 as you will see in the following paragraphs. However, Eq.4only calculates one part of the gradient, so we put a bar on top of it, while you may find this veryuseful in our following calculations.

Note that st−1 also depends on Uz , so we can not treat it as a constant here. Moreover, this st−1

will also introduce the influence of si, where i = 1, ..., t − 2. So for clearness, we should expandEq.3 as:

∂Lt

∂Uz=∂Lt

∂st

∂st∂Uz

=∂Lt

∂st

t∑i=1

(∂st∂si

∂si∂Uz

)=∂Lt

∂st

t∑i=1

(( t−1∏j=i

∂sj+1

∂sj

) ∂si∂Uz

)(5)

where ∂si/∂Uz is the gradient of si with respect to Uz while taking si−1 as a constant, of which asimilar example has been shown in Eq.4 for step t.

The derivation of ∂st/∂st−1 is similar to the derivation of ∂st/∂z as has been discussed above.Since there are four outgoing edges from st−1 to st directly and indirectly through zt, rt, and ht inthe expression graph, we need to sum all the four partial derivatives together:

∂st∂st−1

=∂st∂ht

∂ht∂st−1

+∂st∂zt

∂zt∂st−1

+∂st∂st−1

=∂st∂ht

(∂ht∂rt

∂rt∂st−1

+∂ht∂st−1

)+∂st∂zt

∂zt∂st−1

+∂st∂st−1

(6)

where ∂st/∂st−1 is the gradient of st with respect to st−1 while taking ht and zt as constants.Similarly, ∂ht/∂st−1 is the gradient of ht with respect to st−1 while taking rt as a constant.

Plugging the intermediate results in the above formula, we get:∂st∂st−1

=(1− zt)(WT

r ((WTh (1− h� h))� st−1 � r � (1− r)) + ((WT

h (1− h� h))� rt)

+

WTz

((st−1 − ht)� zt � (1− zt)

)+ z

Till now, we have covered all the components needed to calculate ∂Lt/∂Uz . The gradient of Lt withrespect to other parameters are just similar. In the next chapter, we will provide a more machineryview of the calculation - the psudo-code describing the algorithm to calculate the gradients. In thelast chapter of this tutorial, we will provide the pure machine representation - a MATLAB programwhich implements the calculation and verification of BPTT. If you just want to understand the ideabehind BPTT and decide to use fully supported auto-differentiation packages (like Theano4) to buildyour own GRU, you can stop here. If you need to implement the exact chain rule like us or justcurious about what will happen next, get ready to proceed!

4 Theano is a Python library that allows you to define, optimize, and evaluate mathematical expressionsinvolving multi-dimensional arrays efficiently.

3

Page 4: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

3 Algorithm

Here we also only take ∂L/∂Uz as the example. We will provide the calculation of all the gradientsin the next chapter.

We present two algorithms, one direct algorithm as derived previously calculating ∂Lt/∂Uz andsum them up while taking O(n2

w) time, and the other O(nw) time algorithm which we will see later.

Algorithm 1 A direct but O(n2w) time algorithm to calculate ∂L/∂Uz (and beyond)

Input: The training dataX,Y ∈ Rnv×nw composed of the one-hot column vectors xt, yt ∈ Rnv×1,t = 1, 2, ..., nw representing the words in the sentence.

Input: A vector s0 ∈ Rni×1 representing the initial internal state of the model (usually set to 0).Input: The parameters Θ = {Uz, Ur, Uc,Wz,Wr,Wc, bz, br, bc, V, bV } of the model.Output: The total loss gradient ∂L/∂Uz .

1: %forward propagate to calculate the internal states S ∈ Rni×nw , the predictions Y ∈ Rnv×nw ,the losses Lmtr ∈ Rnw×1, and the intermediate results Z,R,C ∈ Rni×nw of each step:

2: [S, Y , Lmtr, Z,R,C] = forward(X,Y,Θ, s0) % forward() can be implemented easily accord-ing to Eq.1 and Eq.2

3: dUz = zeros(ni, nv) % initialize a variable dUz

4: ∂Lmtr/∂S = V T (Y − Y ) % calculate ∂Lt/∂st for t = 1, 2, ..., nw with one matrix operation5: for t← 1 to nw % calculate each ∂Lt/∂Uz and accumulate6: for j ← t to 1 % calculate each (∂Lt/∂sj)(∂sj/∂Uz) and accumulate7: ∂Lt/∂zj = ∂Lt/∂sj � (sj−1 − hj) % ∂sj/∂zj is (sj−1 − hj), ∂Lt/∂sj is calculated

in the last inner loop iteration or in Line 48: ∂Lt/∂(Uzxj +Wzsj−1 + bz) = ∂Lt/∂zj � zj � (1− zj) % ∂σ(x)/∂x = σ(x)� (1−σ(x))

9: dUz+ =(∂Lt/∂(Uzxj +Wzsj−1 + bz)

)xTj % accumulate

10: calculate ∂Lt/∂sj−1 using ∂Lt/∂sj and Eq.6 % for the next inner loop iteration11: end12: end13: return dUz % ∂L/∂Uz

The above direct algorithm actually follows Eq.5 to calculate ∂Lt/∂Uz and then add them up toform ∂L/∂Uz:

∂L

∂Uz=

nw∑t=1

∂Lt

∂Uz

=

nw∑t=1

(∂Lt

∂st

t∑i=1

(∂st∂si

∂si∂Uz

))=

nw∑t=1

(∂Lt

∂st

t∑i=1

(( t−1∏j=i

∂sj+1

∂sj

) ∂si∂Uz

))If we just expand ∂Lt/∂Uz to the second line of the above equation and do some reordering, we canget:

∂L

∂Uz=

nw∑t=1

(∂Lt

∂st

t∑i=1

(∂st∂si

∂si∂Uz

))=

nw∑t=1

( t∑i=1

(∂Lt

∂st

∂st∂si

∂si∂Uz

))=

nw∑t=1

( t∑i=1

(∂Lt

∂si

∂si∂Uz

))

4

Page 5: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

Right now the inner summation keeps the subscript of ∂Lt and iterate over ∂si. If we further expandthe inner summation and then sort them to iterate over ∂Li, we get:

∂L

∂Uz=

nw∑t=1

(( nw∑i=t

∂Li

∂st

) ∂st∂Uz

)(7)

For the inner summation of Eq.7, we have:nw∑i=t

(∂Li

∂st

)=( nw∑

i=t+1

( ∂Li

∂st+1

∂st+1

∂st

))+∂Lt

∂st

=( nw∑

i=t+1

∂Li

∂st+1

)∂st+1

∂st+∂Lt

∂st

(8)

This just gives us an updating formula to calculate this inner summation for each step t incrementallyrather than executing another for loop, thus making it possible for us to implement an O(nw) timealgorithm!

Algorithm 2 An optimized O(nw) time algorithm to calculate ∂L/∂Uz (and beyond)Input: The training dataX,Y ∈ Rnv×nw composed of the one-hot column vectors xt, yt ∈ Rnv×1,

t = 1, 2, ..., nw representing the words in the sentence.Input: A vector s0 ∈ Rni×1 representing the initial internal state of the model (usually set to 0).Input: The parameters Θ = {Uz, Ur, Uc,Wz,Wr,Wc, bz, br, bc, V, bV } of the model.Output: The total loss gradient ∂L/∂Uz .

1: %forward propagate to calculate the internal states S ∈ Rni×nw , the predictions Y ∈ Rnv×nw ,the losses Lmtr ∈ Rnw×1, and the intermediate results Z,R,C ∈ Rni×nw of each step:

2: [S, Y , Lmtr, Z,R,C] = forward(X,Y,Θ, s0) % forward() can be implemented easily accord-ing to Eq.1 and Eq.2

3: dUz = zeros(ni, nv) % initialize a variable dUz

4: ∂Lmtr/∂S = V T (Y − Y ) % calculate ∂Lt/∂st for t = 1, 2, ..., nw with one matrix operation5: for t← nw to 1 % calculate each

(∑nw

i=t

(∂Li

∂st

))∂st∂Uz

and accumulate

6:∑nw

i=t(∂Li/∂zt) =(∑nw

i=t(∂Li/∂st))� (st−1 − ht) % ∂st/∂zt is (st−1 − ht),∑nw

i=t(∂Li/∂st) is calculated in the last iteration or in Line 4. (when t = nw,∑nw

i=t(∂Li/∂st) = ∂Lt/∂st)

7:∑nw

i=t(∂Li/∂(Uzxt +Wzst−1 +bz)) =(∑nw

i=t(∂Li/∂zt))�zt� (1−zt) % ∂σ(x)/∂x =

σ(x)� (1− σ(x))

8: dUz+ =(∑nw

i=t(∂Li/∂(Uzxt +Wzsj−t + bz)))xTt % accumulate

9: calculate∑nw

i=t−1(∂Li/∂st−1) using Eq.6 and Eq.8 % for the next iteration10: end11: return dUz % ∂L/∂Uz

5

Page 6: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

4 Implementation

Here we provide the MATLAB program which calculates the gradients with respect to all the pa-rameters of GRU using our two proposed algorithms. It also checks the gradients with the numericalresults. We will divide our code into two parts, the first part presented below contains the corefunctions implementing the BPTT of GRU we just derived, the second part is composed of somefunctions that are less important to the topic of this tutorial.

Core Functions

1 % This program t e s t s t h e BPTT p r o c e s s we manua l ly d e v e l o p e d f o r GRU.% We c a l c u l a t e t h e g r a d i e n t s o f GRU p a r a m e t e r s w i th c h a i n r u l e , and t h e n

3 % compare them t o t h e n u m e r i c a l g r a d i e n t s t o check whe the r our c h a i n r u l e% d e r i v a t i o n i s c o r r e c t .

5

% Here , we p r o v i d e d 2 v e r s i o n s o f BPTT , b a c k w a r d d i r e c t ( ) and backward ( ) .7 % The f o r me r one i s t h e d i r e c t i d e a t o c a l c u l a t e g r a d i e n t w i t h i n each

s t e p% and add them up (O( s e n t e n c e s i z e ˆ 2 ) t ime ) . The l a t t e r one i s o p t i m i z e d

t o9 % c a l c u l a t e t h e c o n t r i b u t i o n o f each s t e p t o t h e o v e r a l l g r a d i e n t , which

i s% on ly O( s e n t e n c e s i z e ) t ime .

11

% This i s ve ry h e l p f u l f o r p e o p l e who wants t o implement GRU i n C a f f es i n c e

13 % C a f f e didn ’ t s u p p o r t au to−d i f f e r e n t i a t i o n . Th i s i s a l s o ve ry h e l p f u lf o r

% t h e p e o p l e who wants t o know t h e d e t a i l s a b o u t B a c k p r o p a g a t i o n Through15 % Time a l g o r i t h m i n t h e R e c c u r e n t N eu r a l Networks ( such as GRU and LSTM)

% and a l s o g e t a s e n s e on how auto−d i f f e r e n t i a t i o n i s p o s s i b l e .17

% NOTE: We didn ’ t i n v o l v e SGD t r a i n i n g h e r e . With SGD t r a i n i n g , t h i s19 % program would become a c o m p l e t e i m p l e m e n t a t i o n o f GRU which can be

% t r a i n e d wi th s e q u e n c e d a t a . However , s i n c e t h i s i s on ly a CPU s e r i a l21 % Matlab v e r s i o n o f GRU, a p p l y i n g i t on l a r g e d a t a s e t s w i l l be

d r a m a t i c a l l y% slow .

23

% by Minchen Li , a t The U n i v e r s i t y o f B r i t i s h Columbia . 2016−04−2125

f u n c t i o n testBPTT GRU27 % s e t GRU and d a t a s c a l e

v o c a b u l a r y s i z e = 6 4 ;29 iMem size = 4 ;

s e n t e n c e s i z e = 2 0 ; % number o f words i n a s e n t e n c e31 %( i n c l u d i n g s t a r t and end symbol )

% s i n c e we w i l l on ly use one s e n t e n c e f o rt r a i n i n g ,

33 % t h i s i s a l s o t h e t o t a l s t e p s d u r i n g t r a i n i n g .

35 [ x y ] = g e t T r a i n i n g D a t a ( v o c a b u l a r y s i z e , s e n t e n c e s i z e ) ;

37 % i n i t i a l i z e p a r a m e t e r s :% m u l t i p l i e r f o r i n p u t x t o f i n t e r m e d i a t e v a r i a b l e s

39 U z = rand ( iMem size , v o c a b u l a r y s i z e ) ;U r = rand ( iMem size , v o c a b u l a r y s i z e ) ;

41 U c = rand ( iMem size , v o c a b u l a r y s i z e ) ;% m u l t i p l i e r f o r p e r v i o u s s o f i n t e r m e d i a t e v a r i a b l e s

43 W z = rand ( iMem size , iMem size ) ;W r = rand ( iMem size , iMem size ) ;

45 W c = rand ( iMem size , iMem size ) ;% b i a s t e r m s of i n t e r m e d i a t e v a r i a b l e s

47 b z = rand ( iMem size , 1 ) ;

6

Page 7: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

b r = rand ( iMem size , 1 ) ;49 b c = rand ( iMem size , 1 ) ;

% d e c o d e r f o r g e n e r a t i n g o u t p u t51 V = rand ( v o c a b u l a r y s i z e , iMem size ) ;

b V = rand ( v o c a b u l a r y s i z e , 1 ) ; % b i a s o f d e c o d e r53 % p r e v i o u s s o f s t e p 1

s 0 = rand ( iMem size , 1 ) ;55

% c a l c u l a t e and check g r a d i e n t57 t i c

[ dV , db V , dU z , dU r , dU c , dW z , dW r , dW c , db z , db r , db c , d s 0] = . . .

59 b a c k w a r d d i r e c t ( x , y , U z , U r , U c , W z , W r , W c , b z , b r , b c, V, b V , s 0 ) ;t o c

61 t i ccheckGradient GRU ( x , y , U z , U r , U c , W z , W r , W c , b z , b r , b c ,

V, b V , s 0 , . . .63 dV , db V , dU z , dU r , dU c , dW z , dW r , dW c , db z , db r , db c ,

d s 0 ) ;t o c

65

t i c67 [ dV , db V , dU z , dU r , dU c , dW z , dW r , dW c , db z , db r , db c , d s 0

] = . . .backward ( x , y , U z , U r , U c , W z , W r , W c , b z , b r , b c , V,

b V , s 0 ) ;69 t o c

t i c71 checkGradient GRU ( x , y , U z , U r , U c , W z , W r , W c , b z , b r , b c ,

V, b V , s 0 , . . .dV , db V , dU z , dU r , dU c , dW z , dW r , dW c , db z , db r , db c ,

d s 0 ) ;73 t o c

end75

% Forward p r o p a g a t e c a l c u l a t e s , y h a t , l o s s and i n t e r m e d i a t e v a r i a b l e sf o r each s t e p

77 f u n c t i o n [ s , y h a t , L , z , r , c ] = f o r w a r d ( x , y , . . .U z , U r , U c , W z , W r , W c , b z , b r , b c , V, b V , s 0 )

79 % c o u n t s i z e s[ v o c a b u l a r y s i z e , s e n t e n c e s i z e ] = s i z e ( x ) ;

81 iMem size = s i z e (V, 2 ) ;

83 % i n i t i a l i z e r e s u l t ss = z e r o s ( iMem size , s e n t e n c e s i z e ) ;

85 y h a t = z e r o s ( v o c a b u l a r y s i z e , s e n t e n c e s i z e ) ;L = z e r o s ( s e n t e n c e s i z e , 1 ) ;

87 z = z e r o s ( iMem size , s e n t e n c e s i z e ) ;r = z e r o s ( iMem size , s e n t e n c e s i z e ) ;

89 c = z e r o s ( iMem size , s e n t e n c e s i z e ) ;

91 % c a l c u l a t e r e s u l t f o r s t e p 1 s i n c e s 0 i s n o t i n sz ( : , 1 ) = s igmoid ( U z∗x ( : , 1 ) + W z∗ s 0 + b z ) ;

93 r ( : , 1 ) = s igmoid ( U r∗x ( : , 1 ) + W r∗ s 0 + b r ) ;c ( : , 1 ) = t a n h ( U c∗x ( : , 1 ) + W c∗ ( s 0 . ∗ r ( : , 1 ) ) + b c ) ;

95 s ( : , 1 ) = (1−z ( : , 1 ) ) . ∗c ( : , 1 ) + z ( : , 1 ) . ∗ s 0 ;y h a t ( : , 1 ) = so f tmax (V∗ s ( : , 1 ) + b V ) ;

97 L ( 1 ) = sum(−y ( : , 1 ) . ∗ l o g ( y h a t ( : , 1 ) ) ) ;% c a l c u l a t e r e s u l t s f o r s t e p 2 − s e n t e n c e s i z e s i m i l a r l y

99 f o r wordI = 2 : s e n t e n c e s i z ez ( : , wordI ) = s igmoid ( U z∗x ( : , wordI ) + W z∗ s ( : , wordI −1) + b z ) ;

101 r ( : , wordI ) = s igmoid ( U r∗x ( : , wordI ) + W r∗ s ( : , wordI −1) + b r ) ;c ( : , wordI ) = t a n h ( U c∗x ( : , wordI ) + W c∗ ( s ( : , wordI −1) . ∗ r ( : , wordI ) )

+ b c ) ;

7

Page 8: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

103 s ( : , wordI ) = (1−z ( : , wordI ) ) . ∗c ( : , wordI ) + z ( : , wordI ) . ∗ s ( : , wordI−1) ;

y h a t ( : , wordI ) = so f tmax (V∗ s ( : , wordI ) + b V ) ;105 L ( wordI ) = sum(−y ( : , wordI ) . ∗ l o g ( y h a t ( : , wordI ) ) ) ;

end107 end

109 % Backward p r o p a g a t e t o c a l c u l a t e g r a d i e n t u s i n g c h a i n r u l e% (O( s e n t e n c e s i z e ) t ime )

111 f u n c t i o n [ dV , db V , dU z , dU r , dU c , dW z , dW r , dW c , db z , db r , db c ,d s 0 ] = . . .

backward ( x , y , U z , U r , U c , W z , W r , W c , b z , b r , b c , V, b V ,s 0 )

113 % f o r w a r d p r o p a g a t e t o g e t t h e i n t e r m e d i a t e and o u t p u t r e s u l t s[ s , y h a t , L , z , r , c ] = f o r w a r d ( x , y , U z , U r , U c , W z , W r , W c ,. . .

115 b z , b r , b c , V, b V , s 0 ) ;% c o u n t s e n t e n c e s i z e

117 [ ˜ , s e n t e n c e s i z e ] = s i z e ( x ) ;

119 % c a l c u l a t e g r a d i e n t u s i n g c h a i n r u l ed e l t a y = y h a t − y ;

121 db V = sum ( d e l t a y , 2 ) ;

123 dV = z e r o s ( s i z e (V) ) ;f o r wordI = 1 : s e n t e n c e s i z e

125 dV = dV + d e l t a y ( : , wordI ) ∗ s ( : , wordI ) ’ ;end

127

d s 0 = z e r o s ( s i z e ( s 0 ) ) ;129 dU c = z e r o s ( s i z e ( U c ) ) ;

dU r = z e r o s ( s i z e ( U r ) ) ;131 dU z = z e r o s ( s i z e ( U z ) ) ;

dW c = z e r o s ( s i z e ( W c ) ) ;133 dW r = z e r o s ( s i z e ( W r ) ) ;

dW z = z e r o s ( s i z e ( W z ) ) ;135 db z = z e r o s ( s i z e ( b z ) ) ;

d b r = z e r o s ( s i z e ( b r ) ) ;137 db c = z e r o s ( s i z e ( b c ) ) ;

d s s i n g l e = V’ ∗ d e l t a y ;139 % c a l c u l a t e t h e d e r i v a t i v e c o n t r i b u t i o n o f each s t e p and add them up

d s c u r = z e r o s ( s i z e ( d s s i n g l e , 1 ) , 1 ) ;141 f o r wordJ = s e n t e n c e s i z e :−1:2

d s c u r = d s c u r + d s s i n g l e ( : , wordJ ) ;143 d s c u r b k = d s c u r ;

145 d t a n h I n p u t = ( d s c u r . ∗(1−z ( : , wordJ ) ) . ∗(1−c ( : , wordJ ) . ∗c ( : , wordJ ) ) );

db c = db c + d t a n h I n p u t ;147 dU c = dU c + d t a n h I n p u t ∗x ( : , wordJ ) ’ ; %c o u l d be a c c e l e r a t e d by

a v o i d i n g add 0dW c = dW c + d t a n h I n p u t ∗ ( s ( : , wordJ −1) . ∗ r ( : , wordJ ) ) ’ ;

149 d s r = W c ’ ∗ d t a n h I n p u t ;d s c u r = d s r . ∗ r ( : , wordJ ) ;

151 d s i g I n p u t r = d s r . ∗ s ( : , wordJ −1) . ∗ r ( : , wordJ ) . ∗(1− r ( : , wordJ ) ) ;d b r = d b r + d s i g I n p u t r ;

153 dU r = dU r + d s i g I n p u t r ∗x ( : , wordJ ) ’ ; %c o u l d be a c c e l e r a t e d bya v o i d i n g add 0

dW r = dW r + d s i g I n p u t r ∗ s ( : , wordJ −1) ’ ;155 d s c u r = d s c u r + W r ’ ∗ d s i g I n p u t r ;

157 d s c u r = d s c u r + d s c u r b k . ∗z ( : , wordJ ) ;dz = d s c u r b k . ∗ ( s ( : , wordJ −1)−c ( : , wordJ ) ) ;

159 d s i g I n p u t z = dz . ∗z ( : , wordJ ) . ∗(1−z ( : , wordJ ) ) ;db z = db z + d s i g I n p u t z ;

8

Page 9: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

161 dU z = dU z + d s i g I n p u t z ∗x ( : , wordJ ) ’ ; %c o u l d be a c c e l e r a t e d bya v o i d i n g add 0

dW z = dW z + d s i g I n p u t z ∗ s ( : , wordJ −1) ’ ;163 d s c u r = d s c u r + W z ’ ∗ d s i g I n p u t z ;

end165

% s 1167 d s c u r = d s c u r + d s s i n g l e ( : , 1 ) ;

169 d t a n h I n p u t = ( d s c u r . ∗(1−z ( : , 1 ) ) . ∗(1−c ( : , 1 ) . ∗c ( : , 1 ) ) ) ;db c = db c + d t a n h I n p u t ;

171 dU c = dU c + d t a n h I n p u t ∗x ( : , 1 ) ’ ; %c o u l d be a c c e l e r a t e d by a v o i d i n gadd 0dW c = dW c + d t a n h I n p u t ∗ ( s 0 . ∗ r ( : , 1 ) ) ’ ;

173 d s r = W c ’ ∗ d t a n h I n p u t ;d s 0 = d s 0 + d s r . ∗ r ( : , 1 ) ;

175 d s i g I n p u t r = d s r . ∗ s 0 . ∗ r ( : , 1 ) . ∗(1− r ( : , 1 ) ) ;d b r = d b r + d s i g I n p u t r ;

177 dU r = dU r + d s i g I n p u t r ∗x ( : , 1 ) ’ ; %c o u l d be a c c e l e r a t e d by a v o i d i n gadd 0dW r = dW r + d s i g I n p u t r ∗ s 0 ’ ;

179 d s 0 = d s 0 + W r ’ ∗ d s i g I n p u t r ;

181 d s 0 = d s 0 + d s c u r . ∗z ( : , 1 ) ;dz = d s c u r . ∗ ( s 0−c ( : , 1 ) ) ;

183 d s i g I n p u t z = dz . ∗z ( : , 1 ) . ∗(1−z ( : , 1 ) ) ;db z = db z + d s i g I n p u t z ;

185 dU z = dU z + d s i g I n p u t z ∗x ( : , 1 ) ’ ; %c o u l d be a c c e l e r a t e d by a v o i d i n gadd 0dW z = dW z + d s i g I n p u t z ∗ s 0 ’ ;

187 d s 0 = d s 0 + W z ’ ∗ d s i g I n p u t z ;end

189

% A more d i r e c t view of backward p r o p a g a t e t o c a l c u l a t e g r a d i e n t u s i n g191 % c h a i n r u l e . (O( s e n t e n c e s i z e ˆ 2 ) t ime )

% I n s t e a d of c a l c u l a t i n g how much c o n t r i b u t i o n o f d e r i v a t i v e each s t e phas ,

193 % h e r e we c a l c u l a t e t h e g r a d i e n t w i t h i n e v e r y s t e p .f u n c t i o n [ dV , db V , dU z , dU r , dU c , dW z , dW r , dW c , db z , db r , db c ,

d s 0 ] = . . .195 b a c k w a r d d i r e c t ( x , y , U z , U r , U c , W z , W r , W c , b z , b r , b c , V,

b V , s 0 )% f o r w a r d p r o p a g a t e t o g e t t h e i n t e r m e d i a t e and o u t p u t r e s u l t s

197 [ s , y h a t , L , z , r , c ] = f o r w a r d ( x , y , U z , U r , U c , W z , W r , W c ,. . .

b z , b r , b c , V, b V , s 0 ) ;199 % c o u n t s e n t e n c e s i z e

[ ˜ , s e n t e n c e s i z e ] = s i z e ( x ) ;201

% c a l c u l a t e g r a d i e n t u s i n g c h a i n r u l e203 d e l t a y = y h a t − y ;

db V = sum ( d e l t a y , 2 ) ;205

dV = z e r o s ( s i z e (V) ) ;207 f o r wordI = 1 : s e n t e n c e s i z e

dV = dV + d e l t a y ( : , wordI ) ∗ s ( : , wordI ) ’ ;209 end

211 d s 0 = z e r o s ( s i z e ( s 0 ) ) ;dU c = z e r o s ( s i z e ( U c ) ) ;

213 dU r = z e r o s ( s i z e ( U r ) ) ;dU z = z e r o s ( s i z e ( U z ) ) ;

215 dW c = z e r o s ( s i z e ( W c ) ) ;dW r = z e r o s ( s i z e ( W r ) ) ;

217 dW z = z e r o s ( s i z e ( W z ) ) ;

9

Page 10: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

db z = z e r o s ( s i z e ( b z ) ) ;219 d b r = z e r o s ( s i z e ( b r ) ) ;

db c = z e r o s ( s i z e ( b c ) ) ;221 d s s i n g l e = V’ ∗ d e l t a y ;

% c a l c u l a t e t h e d e r i v a t i v e s i n each s t e p and add them up223 f o r wordI = 1 : s e n t e n c e s i z e

d s c u r = d s s i n g l e ( : , wordI ) ;225 % s i n c e i n each s t e p t , t h e d e r i v a t i v e s depends on s 0 − s t ,

% we need t o t r a c e back from t o t 0 each t ime227 f o r wordJ = wordI :−1:2

d s c u r b k = d s c u r ;229

d t a n h I n p u t = ( d s c u r . ∗(1−z ( : , wordJ ) ) . ∗(1−c ( : , wordJ ) . ∗c ( : ,wordJ ) ) ) ;

231 db c = db c + d t a n h I n p u t ;dU c = dU c + d t a n h I n p u t ∗x ( : , wordJ ) ’ ; %c o u l d be a c c e l e r a t e d

by a v o i d i n g add 0233 dW c = dW c + d t a n h I n p u t ∗ ( s ( : , wordJ −1) . ∗ r ( : , wordJ ) ) ’ ;

d s r = W c ’ ∗ d t a n h I n p u t ;235 d s c u r = d s r . ∗ r ( : , wordJ ) ;

d s i g I n p u t r = d s r . ∗ s ( : , wordJ −1) . ∗ r ( : , wordJ ) . ∗(1− r ( : , wordJ ) ) ;237 d b r = d b r + d s i g I n p u t r ;

dU r = dU r + d s i g I n p u t r ∗x ( : , wordJ ) ’ ; %c o u l d be a c c e l e r a t e dby a v o i d i n g add 0

239 dW r = dW r + d s i g I n p u t r ∗ s ( : , wordJ −1) ’ ;d s c u r = d s c u r + W r ’ ∗ d s i g I n p u t r ;

241

d s c u r = d s c u r + d s c u r b k . ∗z ( : , wordJ ) ;243 dz = d s c u r b k . ∗ ( s ( : , wordJ −1)−c ( : , wordJ ) ) ;

d s i g I n p u t z = dz . ∗z ( : , wordJ ) . ∗(1−z ( : , wordJ ) ) ;245 db z = db z + d s i g I n p u t z ;

dU z = dU z + d s i g I n p u t z ∗x ( : , wordJ ) ’ ; %c o u l d be a c c e l e r a t e dby a v o i d i n g add 0

247 dW z = dW z + d s i g I n p u t z ∗ s ( : , wordJ −1) ’ ;d s c u r = d s c u r + W z ’ ∗ d s i g I n p u t z ;

249 end

251 % s 1d t a n h I n p u t = ( d s c u r . ∗(1−z ( : , 1 ) ) . ∗(1−c ( : , 1 ) . ∗c ( : , 1 ) ) ) ;

253 db c = db c + d t a n h I n p u t ;dU c = dU c + d t a n h I n p u t ∗x ( : , 1 ) ’ ; %c o u l d be a c c e l e r a t e d by

a v o i d i n g add 0255 dW c = dW c + d t a n h I n p u t ∗ ( s 0 . ∗ r ( : , 1 ) ) ’ ;

d s r = W c ’ ∗ d t a n h I n p u t ;257 d s 0 = d s 0 + d s r . ∗ r ( : , 1 ) ;

d s i g I n p u t r = d s r . ∗ s 0 . ∗ r ( : , 1 ) . ∗(1− r ( : , 1 ) ) ;259 d b r = d b r + d s i g I n p u t r ;

dU r = dU r + d s i g I n p u t r ∗x ( : , 1 ) ’ ; %c o u l d be a c c e l e r a t e d bya v o i d i n g add 0

261 dW r = dW r + d s i g I n p u t r ∗ s 0 ’ ;d s 0 = d s 0 + W r ’ ∗ d s i g I n p u t r ;

263

d s 0 = d s 0 + d s c u r . ∗z ( : , 1 ) ;265 dz = d s c u r . ∗ ( s 0−c ( : , 1 ) ) ;

d s i g I n p u t z = dz . ∗z ( : , 1 ) . ∗(1−z ( : , 1 ) ) ;267 db z = db z + d s i g I n p u t z ;

dU z = dU z + d s i g I n p u t z ∗x ( : , 1 ) ’ ; %c o u l d be a c c e l e r a t e d bya v o i d i n g add 0

269 dW z = dW z + d s i g I n p u t z ∗ s 0 ’ ;d s 0 = d s 0 + W z ’ ∗ d s i g I n p u t z ;

271 endend

273

% Sigmoid f u n c t i o n f o r n e u r a l ne twork275 f u n c t i o n v a l = s igmoid ( x )

10

Page 11: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

v a l = s igmf ( x , [ 1 0 ] ) ;277 end

testBPTT GRU.m

Less Important Functions

1 % Fake a t r a i n i n g d a t a s e t : g e n e r a t e on ly one s e n t e n c e f o r t r a i n i n g .%! ! ! Only f o r t e s t i n g . Needs t o be changed t o r e a d i n t r a i n i n g d a t a from

f i l e s .3 f u n c t i o n [ x t , y t ] = g e t T r a i n i n g D a t a ( v o c a b u l a r y s i z e , s e n t e n c e s i z e )

a s s e r t ( v o c a b u l a r y s i z e > 2) ; % f o r s t a r t and end of s e n t e n c e symbol5 a s s e r t ( s e n t e n c e s i z e > 0) ;

7 % d e f i n e s t a r t and end of s e n t e n c e i n t h e v o c a b u l a r ySENTENCE START = z e r o s ( v o c a b u l a r y s i z e , 1 ) ;

9 SENTENCE START ( 1 ) = 1 ;SENTENCE END = z e r o s ( v o c a b u l a r y s i z e , 1 ) ;

11 SENTENCE END ( 2 ) = 1 ;

13 % g e n e r a t e s e n t e n c e :x t = z e r o s ( v o c a b u l a r y s i z e , s e n t e n c e s i z e −1) ; % l e a v e one s l o t f o r

SENTENCE START15 f o r wordI = 1 : s e n t e n c e s i z e −1

% g e n e r a t e a random word e x c l u d e s s t a r t and end symbol17 x t ( r a n d i ( v o c a b u l a r y s i z e −2 ,1 ,1) +2 , wordI ) = 1 ;

end19 y t = [ x t , SENTENCE END ] ; % t r a i n i n g o u t p u t

x t = [SENTENCE START , x t ] ; % t r a i n i n g i n p u t21 end

23 % Use n u m e r i c a l d i f f e r e n t i a t i o n t o a p p r o x i m a t e t h e g r a d i e n t o f each% p a r a m e t e r and c a l c u l a t e t h e d i f f e r e n c e between t h e s e n u m e r i c a l r e s u l t s

25 % and our r e s u l t s c a l c u l a t e d by a p p l y i n g c h a i n r u l e .f u n c t i o n checkGradient GRU ( x , y , U z , U r , U c , W z , W r , W c , b z , b r ,

b c , V, b V , s 0 , . . .27 dV , db V , dU z , dU r , dU c , dW z , dW r , dW c , db z , db r , db c , d s 0 )

% Here we use t h e c e n t r e d i f f e r e n c e f o r m u l a :29 % df ( x ) / dx = ( f ( x+h )−f ( x−h ) ) / (2 h )

% I t i s a second o r d e r a c c u r a t e method wi th e r r o r bounded by O( h ˆ 2 )31

h = 1e−5;33 % NOTE: h couldn ’ t be t o o l a r g e o r t o o s m a l l s i n c e l a r g e h w i l l

% i n t r o d u c e b i g g e r t r u n c a t i o n e r r o r and s m a l l h w i l l i n t r o d u c e b i g g e r35 % r o u n d o f f e r r o r .

37 dV numer i ca l = z e r o s ( s i z e ( dV ) ) ;% C a l c u l a t e p a r t i a l d e r i v a t i v e e l e m e n t by e l e m e n t

39 f o r rowI = 1 : s i z e ( dV numer ica l , 1 )f o r c o l I = 1 : s i z e ( dV numer ica l , 2 )

41 V plus = V;V plus ( rowI , c o l I ) = V plus ( rowI , c o l I ) + h ;

43 V minus = V;V minus ( rowI , c o l I ) = V minus ( rowI , c o l I ) − h ;

45 [ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .U z , U r , U c , W z , W r , W c , b z , b r , b c , V plus , b V ,

s 0 ) ;47 [ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .

U z , U r , U c , W z , W r , W c , b z , b r , b c , V minus , b V, s 0 ) ;

49 dV numer i ca l ( rowI , c o l I ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2 /h ;

end51 end

11

Page 12: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

d i s p l a y ( sum ( sum ( abs ( dV numer ica l−dV ) . / ( abs ( dV numer i ca l ) +h ) ) ) , . . .53 ’dV r e l a t i v e e r r o r ’ ) ; % p r e v e n t d i v i d i n g by 0 by ad d in g h

55 d U c n u m e r i c a l = z e r o s ( s i z e ( dU c ) ) ;f o r rowI = 1 : s i z e ( dU c n u m e r i c a l , 1 )

57 f o r c o l I = 1 : s i z e ( d U c n u m e r i c a l , 2 )U c p l u s = U c ;

59 U c p l u s ( rowI , c o l I ) = U c p l u s ( rowI , c o l I ) + h ;U c minus = U c ;

61 U c minus ( rowI , c o l I ) = U c minus ( rowI , c o l I ) − h ;[ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .

63 U z , U r , U c p lu s , W z , W r , W c , b z , b r , b c , V, b V ,s 0 ) ;

[ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .65 U z , U r , U c minus , W z , W r , W c , b z , b r , b c , V, b V

, s 0 ) ;d U c n u m e r i c a l ( rowI , c o l I ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2

/ h ;67 end

end69 d i s p l a y ( sum ( sum ( abs ( d U c n u m e r i c a l−dU c ) . / ( abs ( d U c n u m e r i c a l ) +h ) ) ) ,

. . .’ dU c r e l a t i v e e r r o r ’ ) ;

71

dW c numer ica l = z e r o s ( s i z e ( dW c ) ) ;73 f o r rowI = 1 : s i z e ( dW c numer ica l , 1 )

f o r c o l I = 1 : s i z e ( dW c numer ica l , 2 )75 W c plus = W c ;

W c plus ( rowI , c o l I ) = W c plus ( rowI , c o l I ) + h ;77 W c minus = W c ;

W c minus ( rowI , c o l I ) = W c minus ( rowI , c o l I ) − h ;79 [ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .

U z , U r , U c , W z , W r , W c plus , b z , b r , b c , V, b V ,s 0 ) ;

81 [ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .U z , U r , U c , W z , W r , W c minus , b z , b r , b c , V, b V

, s 0 ) ;83 dW c numer ica l ( rowI , c o l I ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2

/ h ;end

85 endd i s p l a y ( sum ( sum ( abs ( dW c numer ica l−dW c ) . / ( abs ( dW c numer ica l ) +h ) ) ) ,. . .

87 ’dW c r e l a t i v e e r r o r ’ ) ;

89 d U r n u m e r i c a l = z e r o s ( s i z e ( dU r ) ) ;f o r rowI = 1 : s i z e ( d U r n u m e r i c a l , 1 )

91 f o r c o l I = 1 : s i z e ( d U r n u m e r i c a l , 2 )U r p l u s = U r ;

93 U r p l u s ( rowI , c o l I ) = U r p l u s ( rowI , c o l I ) + h ;U r minus = U r ;

95 U r minus ( rowI , c o l I ) = U r minus ( rowI , c o l I ) − h ;[ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .

97 U z , U r p l u s , U c , W z , W r , W c , b z , b r , b c , V, b V ,s 0 ) ;

[ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .99 U z , U r minus , U c , W z , W r , W c , b z , b r , b c , V, b V

, s 0 ) ;d U r n u m e r i c a l ( rowI , c o l I ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2

/ h ;101 end

end103 d i s p l a y ( sum ( sum ( abs ( d U r n u m e r i c a l−dU r ) . / ( abs ( d U r n u m e r i c a l ) +h ) ) ) ,

. . .’ dU r r e l a t i v e e r r o r ’ ) ;

12

Page 13: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

105

d W r n u m e r i c a l = z e r o s ( s i z e ( dW r ) ) ;107 f o r rowI = 1 : s i z e ( dW r numer ica l , 1 )

f o r c o l I = 1 : s i z e ( dW r numer ica l , 2 )109 W r p l u s = W r ;

W r p l u s ( rowI , c o l I ) = W r p l u s ( rowI , c o l I ) + h ;111 W r minus = W r ;

W r minus ( rowI , c o l I ) = W r minus ( rowI , c o l I ) − h ;113 [ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .

U z , U r , U c , W z , W r plus , W c , b z , b r , b c , V, b V ,s 0 ) ;

115 [ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .U z , U r , U c , W z , W r minus , W c , b z , b r , b c , V, b V

, s 0 ) ;117 d W r n u m e r i c a l ( rowI , c o l I ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2

/ h ;end

119 endd i s p l a y ( sum ( sum ( abs ( dW r numer ica l−dW r ) . / ( abs ( d W r n u m e r i c a l ) +h ) ) ) ,. . .

121 ’ dW r r e l a t i v e e r r o r ’ ) ;

123 d U z n u m e r i c a l = z e r o s ( s i z e ( dU z ) ) ;f o r rowI = 1 : s i z e ( dU z n u m e r i c a l , 1 )

125 f o r c o l I = 1 : s i z e ( d U z n u m e r i c a l , 2 )U z p l u s = U z ;

127 U z p l u s ( rowI , c o l I ) = U z p l u s ( rowI , c o l I ) + h ;U z minus = U z ;

129 U z minus ( rowI , c o l I ) = U z minus ( rowI , c o l I ) − h ;[ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .

131 U z p lus , U r , U c , W z , W r , W c , b z , b r , b c , V, b V ,s 0 ) ;

[ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .133 U z minus , U r , U c , W z , W r , W c , b z , b r , b c , V, b V

, s 0 ) ;d U z n u m e r i c a l ( rowI , c o l I ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2

/ h ;135 end

end137 d i s p l a y ( sum ( sum ( abs ( d U z n u m e r i c a l−dU z ) . / ( abs ( d U z n u m e r i c a l ) +h ) ) ) ,

. . .’ dU z r e l a t i v e e r r o r ’ ) ;

139

dW z numer ica l = z e r o s ( s i z e ( dW z ) ) ;141 f o r rowI = 1 : s i z e ( dW z numer ica l , 1 )

f o r c o l I = 1 : s i z e ( dW z numer ica l , 2 )143 W z plus = W z ;

W z plus ( rowI , c o l I ) = W z plus ( rowI , c o l I ) + h ;145 W z minus = W z ;

W z minus ( rowI , c o l I ) = W z minus ( rowI , c o l I ) − h ;147 [ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .

U z , U r , U c , W z plus , W r , W c , b z , b r , b c , V, b V ,s 0 ) ;

149 [ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .U z , U r , U c , W z minus , W r , W c , b z , b r , b c , V, b V

, s 0 ) ;151 dW z numer ica l ( rowI , c o l I ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2

/ h ;end

153 endd i s p l a y ( sum ( sum ( abs ( dW z numer ica l−dW z ) . / ( abs ( dW z numer ica l ) +h ) ) ) ,. . .

155 ’dW z r e l a t i v e e r r o r ’ ) ;

157 d b z n u m e r i c a l = z e r o s ( s i z e ( db z ) ) ;

13

Page 14: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

f o r i = 1 : l e n g t h ( d b z n u m e r i c a l )159 b z p l u s = b z ;

b z p l u s ( i ) = b z p l u s ( i ) + h ;161 b z m i n u s = b z ;

b z m i n u s ( i ) = b z m i n u s ( i ) − h ;163 [ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .

U z , U r , U c , W z , W r , W c , b z p l u s , b r , b c , V, b V , s 0) ;

165 [ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .U z , U r , U c , W z , W r , W c , b z minus , b r , b c , V, b V ,

s 0 ) ;167 d b z n u m e r i c a l ( i ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2 / h ;

end169 d i s p l a y ( sum ( abs ( d b z n u m e r i c a l−db z ) . / ( abs ( d b z n u m e r i c a l ) +h ) ) , . . .

’ db z r e l a t i v e e r r o r ’ ) ;171

d b r n u m e r i c a l = z e r o s ( s i z e ( d b r ) ) ;173 f o r i = 1 : l e n g t h ( d b r n u m e r i c a l )

b r p l u s = b r ;175 b r p l u s ( i ) = b r p l u s ( i ) + h ;

b r m i n u s = b r ;177 b r m i n u s ( i ) = b r m i n u s ( i ) − h ;

[ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .179 U z , U r , U c , W z , W r , W c , b z , b r p l u s , b c , V, b V , s 0

) ;[ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .

181 U z , U r , U c , W z , W r , W c , b z , b r m i n u s , b c , V, b V ,s 0 ) ;

d b r n u m e r i c a l ( i ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2 / h ;183 end

d i s p l a y ( sum ( abs ( d b r n u m e r i c a l −d b r ) . / ( abs ( d b r n u m e r i c a l ) +h ) ) , . . .185 ’ d b r r e l a t i v e e r r o r ’ ) ;

187 d b c n u m e r i c a l = z e r o s ( s i z e ( db c ) ) ;f o r i = 1 : l e n g t h ( d b c n u m e r i c a l )

189 b c p l u s = b c ;b c p l u s ( i ) = b c p l u s ( i ) + h ;

191 b c m i n u s = b c ;b c m i n u s ( i ) = b c m i n u s ( i ) − h ;

193 [ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .U z , U r , U c , W z , W r , W c , b z , b r , b c p l u s , V, b V , s 0

) ;195 [ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .

U z , U r , U c , W z , W r , W c , b z , b r , b c minus , V, b V ,s 0 ) ;

197 d b c n u m e r i c a l ( i ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2 / h ;end

199 d i s p l a y ( sum ( abs ( d b c n u m e r i c a l−db c ) . / ( abs ( d b c n u m e r i c a l ) +h ) ) , . . .’ db c r e l a t i v e e r r o r ’ ) ;

201

d b V n u m e r i c a l = z e r o s ( s i z e ( db V ) ) ;203 f o r i = 1 : l e n g t h ( d b V n u m e r i c a l )

b V p l u s = b V ;205 b V p l u s ( i ) = b V p l u s ( i ) + h ;

b V minus = b V ;207 b V minus ( i ) = b V minus ( i ) − h ;

[ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .209 U z , U r , U c , W z , W r , W c , b z , b r , b c , V, b V plus , s 0

) ;[ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .

211 U z , U r , U c , W z , W r , W c , b z , b r , b c , V, b V minus ,s 0 ) ;

d b V n u m e r i c a l ( i ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2 / h ;213 end

d i s p l a y ( sum ( abs ( db V numer i ca l−db V ) . / ( abs ( d b V n u m e r i c a l ) +h ) ) , . . .

14

Page 15: A Tutorial On Backward Propagation Through Time (BPTT) In ...minchenl/doc/BPTTTutorial.pdf · A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU)

215 ’ db V r e l a t i v e e r r o r ’ ) ;

217 d s 0 n u m e r i c a l = z e r o s ( s i z e ( d s 0 ) ) ;f o r i = 1 : l e n g t h ( d s 0 n u m e r i c a l )

219 s 0 p l u s = s 0 ;s 0 p l u s ( i ) = s 0 p l u s ( i ) + h ;

221 s 0 m i n u s = s 0 ;s 0 m i n u s ( i ) = s 0 m i n u s ( i ) − h ;

223 [ ˜ , ˜ , L p l u s ] = f o r w a r d ( x , y , . . .U z , U r , U c , W z , W r , W c , b z , b r , b c , V, b V , s 0 p l u s

) ;225 [ ˜ , ˜ , L minus ] = f o r w a r d ( x , y , . . .

U z , U r , U c , W z , W r , W c , b z , b r , b c , V, b V ,s 0 m i n u s ) ;

227 d s 0 n u m e r i c a l ( i ) = ( sum ( L p l u s ) − sum ( L minus ) ) / 2 / h ;end

229 d i s p l a y ( sum ( abs ( d s 0 n u m e r i c a l−d s 0 ) . / ( abs ( d s 0 n u m e r i c a l ) +h ) ) , . . .’ d s 0 r e l a t i v e e r r o r ’ ) ;

231 end

testBPTT GRU.m

15