22
Distributed Optimization of CNNs and RNNs GTC 2015 William Chan williamchan.ca [email protected] March 19, 2015

Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

  • Upload
    ngoanh

  • View
    227

  • Download
    0

Embed Size (px)

Citation preview

Page 1: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Distributed Optimization of CNNs and RNNsGTC 2015

William Chanwilliamchan.ca

[email protected]

March 19, 2015

Page 2: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Outline

1. Motivation

2. Distributed ASGD

3. CNNs

4. RNNs

5. Conclusion

Carnegie Mellon University 2

Page 3: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Motivation

I Why need distributed training?

Carnegie Mellon University 3

Page 4: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Motivation

I More data → better models

I More data → longer training times

Example: Baidu Deep Speech

I Synthetic training data generated from overlapping noise

I Synthetic training data → unlimited training data

Carnegie Mellon University 4

Page 5: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Motivation

I Complex models (e.g., CNNs and RNNs) better than simplemodels (DNNs)

I Complex models → longer training times

Example: GoogLeNet

I 22 layers deep CNN

Carnegie Mellon University 5

Page 6: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

GoogLeNet

input

Conv7x7+2(S)

MaxPool3x3+2(S)

LocalRespNorm

Conv1x1+1(V)

Conv3x3+1(S)

LocalRespNorm

MaxPool3x3+2(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

MaxPool3x3+1(S)

DepthConcat

Conv3x3+1(S)

Conv5x5+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

MaxPool3x3+1(S)

DepthConcat

Conv3x3+1(S)

Conv5x5+1(S)

Conv1x1+1(S)

MaxPool3x3+2(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

MaxPool3x3+1(S)

DepthConcat

Conv3x3+1(S)

Conv5x5+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

MaxPool3x3+1(S)

AveragePool5x5+3(V)

DepthConcat

Conv3x3+1(S)

Conv5x5+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

MaxPool3x3+1(S)

DepthConcat

Conv3x3+1(S)

Conv5x5+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

MaxPool3x3+1(S)

DepthConcat

Conv3x3+1(S)

Conv5x5+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

MaxPool3x3+1(S)

AveragePool5x5+3(V)

DepthConcat

Conv3x3+1(S)

Conv5x5+1(S)

Conv1x1+1(S)

MaxPool3x3+2(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

MaxPool3x3+1(S)

DepthConcat

Conv3x3+1(S)

Conv5x5+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

Conv1x1+1(S)

MaxPool3x3+1(S)

DepthConcat

Conv3x3+1(S)

Conv5x5+1(S)

Conv1x1+1(S)

AveragePool7x7+1(V)

FC

Conv1x1+1(S)

FC

FC

SoftmaxActivation

softmax0

Conv1x1+1(S)

FC

FC

SoftmaxActivation

softmax1

SoftmaxActivation

softmax2

Figure 3: GoogLeNet network with all the bells and whistles

7

Carnegie Mellon University 6

Page 7: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Distributed Asynchronous Stochastic Gradient Descent

I Google Cats, DistBelief, 32 000 CPU cores and more...

Figure 1: Google showed we can apply ASGD with Deep Learning.

Carnegie Mellon University 7

Page 8: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Distributed Asynchronous Stochastic Gradient Descent

I CPUs are expensive

I PhD students are poor : (

I Let us use GPUs!

Carnegie Mellon University 8

Page 9: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Distributed Asynchronous Stochastic Gradient Descent

Stochastic Gradient Descent:

θ = θ − η∇θ (1)

Distributed Asynchronous Stochastic Gradient Descent:

θ = θ − η∇θi (2)

Carnegie Mellon University 9

Page 10: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Distributed Asynchronous Stochastic Gradient Descent

CMU SPEECH3:

I x1 GPU Master Parameter Server

I xN GPU ASGD Shards

Carnegie Mellon University 10

Page 11: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Distributed Asynchronous Stochastic Gradient Descent

Parameter Server

SGD Shard

SGD Shard

PC

IE D

MA

Independent SGD GPU shards, synchronization with the parameter server is done via PCIE DMA bypassing the CPU

Figure 2: CMU SPEECH3 GPU ASGD.

Carnegie Mellon University 11

Page 12: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Distributed Asynchronous Stochastic Gradient Descent

SPEECH3 ASGD Shard ↔ Parameter Server Sync:

I Compute a minibatch (e.g., 128).

I If Parameter Server is free, sync.

I Else compute another minibatch.

I Easy to implement, < 300 lines of code.

I Works surprisingly well.

Carnegie Mellon University 12

Page 13: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Distributed Asynchronous Stochastic Gradient Descent

Minor tricks:

I Momentum / Gradient Projection on Parameter Server

I Gradient Decay on Parameter Server

I Tunable max distance limit between Parameter Server andShard.

Carnegie Mellon University 13

Page 14: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

CNNs

Convolutional Neural Networks (CNNs)

I Computer Vision

I Automatic Speech Recognition

I CNNs are typically ≈ 5% relative Word Error Rate (WER)better than DNNs

Carnegie Mellon University 14

Page 15: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

CNNs

Time

Spe

ctru

m

2D Convolution Max Pooling Fully Connected Posteriors2D Convolution Fully Connected

Figure 3: CNN for Acoustic Modelling.

Carnegie Mellon University 15

Page 16: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

CNNs

0 5 10 15 20Time (Hours)

30

35

40

45

50Test

Fra

me A

ccura

cy

SGD Baseline 2x 3x 4x 5x

Test Frame Accuracy vs. Time

Figure 4: SGD vs ASGD.Carnegie Mellon University 16

Page 17: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

CNNs

Workers 40% FA 43% FA 44% FA

1 5:50 (100%) 14:36 (100%) 19:29 (100%)

2 3:36 (81.0%) 8:59 (81.3%) 11:58 (81.4%)3 2:48 (69.4%) 5:59 (81.3%) 7:58 (81.5%)4 2:05 (70.0%) 4:28 (81.7%) 6:32 (74.6%)5 1:40 (70.0%) 3:49 (76.5%) 5:43 (68.2%)

Table 1: Time (hh:mm) and scaling efficiency (in brackets) comparisonfor convergence to 40%, 43% and 44% Frame Accuracy (FA).

Carnegie Mellon University 17

Page 18: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

RNNs

Recurrent Neural Networks (RNNs)

I Machine Translation

I Automatic Speech Recognition

I RNNs are typically ≈ 5-10% relative WER better than DNNs

Minor Tricks:

I Long Short Term Memory (LSTM)

I Cell activation clipping

Carnegie Mellon University 18

Page 19: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

RNNs

DNN RNN

Figure 5: DNN vs. RNN.

Carnegie Mellon University 19

Page 20: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

RNNs

Workers 46.5% FA 47.5% FA 48.5% FA

1 1:51 (100%) 3:42 (100%) 7:41 (100%)

2 1:00 (92.5%) 2:00 (92.5%) 3:01 (128%)5 - - 1:15 (122%)

Table 2: Time (hh:mm) and scaling efficiency (in brackets) comparisonfor convergence to 46.5%, 47.5% and 48.5% Frame Accuracy (FA).

I RNNs seem to really like distributed training!

Carnegie Mellon University 20

Page 21: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

RNNs

Workers WER Time

1 3.95 18:37

2 4.11 8:045 4.06 5:24

Table 3: WERs.

I No (major) difference in WER!

Carnegie Mellon University 21

Page 22: Distributed Optimization of CNNs and RNNs - GTC …on-demand.gputechconf.com/gtc/2015/presentation/S5632-William-Cha… · Distributed Optimization of CNNs and RNNs GTC 2015 William

Conclusion

I Distributed ASGD on GPU, easy to implement!

I Speed up your training!

I Minor difference in loss against SGD baseline!

Carnegie Mellon University 22