Statistical learning and optimal control: A framework for biological learning and motor control...

Preview:

Citation preview

Statistical learning and optimal control:

A framework for biological learning and motor control

Lecture 1: Iterative learning and the Kalman filter

Reza Shadmehr

Johns Hopkins School of Medicine

Body +environment

State change

Sensory system

ProprioceptionVision

Audition

Measured sensory consequences

Forward model

Predicted sensory consequences

Integration

Belief about state of body

and world

Goalselector

Motor commandgenerator

Stochastic optimal control

Parameter estimation

Kalman filter

Results from classical conditioning

Effect of time on memory: spontaneous recovery

ITI=14 ITI=2

ITI=98

Performance during training

Test at 1 week

ITI=14

ITI=2

ITI=98

Testing at 1 day or 1 week (averaged together)

Effect of time on memory: inter-trial interval and retention

Integration of predicted state with sensory feedback

Choice of motor commands: optimality in saccades and reaching movements

eye velocity

deg

/sec

0 0.05 0.1 0.15 0.2 0.25

0

100

200

300

400

500

Time (sec)

5 10 15 30 40 50 Saccade size

Helpful reading:

1. Mathematical background

• Raul Rojas, The Kalman Filter. Freie Universitat Berlin.

• N.A. Thacker and A.J. Lacey, Tutorial: The Kalman Filter. University of Manchester.

2. Application to animal learning

• Peter Dayan and Angela J. Yu (2003) Uncertainty and learning. IETE Journal of Research 49:171-182.

3. Application to sensorimotor control

• D. Wolpert, Z. Ghahramani, MI Jordan (1995) An internal model for sensorimotor integration. Science

Linear regression, maximum likelihood, and parameter uncertainty

*( ) * ( )

( ) *( ) 2

(1) (1) (1,1) (2) (1,2) ( ) (1, )

1 (1)

(2) (1) (2,1) (2) (2,2) ( ) (2, )

1 (2)

1* 2

0,

, , , , , ,

, , , , , ,

,

i T i

i i

n n

T TML

n n

T TML

TML

y

y y N

D y y y

X X X

D y y y

X X X

N X X

w x

x x x

w y

x x x

w y

w w

A noisy process produces n data points and we form an ML estimate of w.

We run the noisy process again with the same sequence of x’s and re-estimate w:

The distribution of the resulting w will have a var-cov that depends only on the sequence of inputs, the bases that encode those inputs, and the noise sigma.

Bias of the parameter estimates for a given X

• How does the ML estimate behave in the presence of noise in y?

* *

*

X

X

y w

y y ε

y w ε

The “true” underlying process

What we measured

Our model of the process

1

1 1*

1*

T T

T T T T

T T

X X X

X X X X X X

X X X

w y

y ε

w ε

nx1 vector

2,N Iε 0

ML estimate:

Because is normally distributed:

1*

*

, var T TN X X X

E X

w w ε

w w

In other words:

Variance of the parameter estimates for a given X

For a given X, the ML (or least square) estimate of our parameter has this normal distribution:

1*, var T TN X X X

w w ε

1 1 1

1 1

1 12

12

var varT T T T T

T T T T

T T T

T

X X X X X X X X X

X X X X X X

X X X X X X

X X

ε ε

εε

var var TA A Ax x

Matrix of constants

vector of random variables

1* 2, TN X X

w w

2T IεεAssume:

mxm

The Gaussian distribution and its var-cov matrix2

22

1 ( )( ) exp

22

xp x

11 1( ) exp ( ) ( )

2(2 ) | |

T

np C

C

x x μ x μ

[( )( )]ij i i j jc E x x 21 12 1 2 1 1

212 1 2 2 2 2

21 1 2 2

n n

n n

n n n n n

C

A 1-D Gaussian distribution is defined as

In n dimensions, it generalizes to

When x is a vector, the variance is expressed in terms of a covariance matrix C, where ρij corresponds to the degree of correlation between variables xi and xj

2 2

( )( )

( ) ( )

i x i yxy xyi

x yxx yyi x i yi i

x yC C

C Cx y

-2 -1 0 1 2 3

-2

-1

0

1

2

3

-2 -1 0 1 2 3

-2

-1

0

1

2

3

21 12 1 2

212 1 2 2

,N C

C

x μ

1x

2x

0 1 0.9 2,

0 0.9 2 2N

x

-3 -2 -1 0 1 2 3 4

-3

-2

-1

0

1

2

3

4

0 1 0.1 2,

0 0.1 2 2N

x0 1 0.9 2

,0 0.9 2 2

N

x

x1 and x2 are positively correlated x1 and x2 are not correlated x1 and x2 are negatively correlated

Parameter uncertainty: Example 1

• Input history:

1 1 2 2

1 1 2

2 1 2

1* 2

2

ˆ

var cov ,[ ],

cov , var

,

0.5 0.25 0,

0.5 0 1

T

ML

T

y w x w x

w w wN E

w w w

N X X

N

x w

w w

w

1 0 0.5

1 0 0.5

1 0 0.5

1 0 0.5

0 1 0.5

1x 2x*y

x1 was “on” most of the time. I’m pretty certain about w1. However, x2 was “on” only once, so I’m uncertain about w2.

1w

2w

-0.5 0 0.5 1 1.5 2

-0.5

0

0.5

1

1.5

2

1 0

1 0

1 0

1 0

0 1

X

Parameter uncertainty: Example 2

• Input history:

1 1

1 1

1 1

1 1

1 0

X

1 1 2

2 1 2

1* 2

2

var cov ,[ ],

cov , var

,

0.5 1 1,

0.5 1 1.25

ML

T

w w wN E

w w w

N X X

N

w w

w

1 1 1

1 1 1

1 1 1

1 1 1

1 0 0.5

1x 2x*y

x1 and x2 were “on” mostly together. The weight var-cov matrix shows that what I learned is that:

I do not know individual values of w1 and w2 with much certainty.

x1 appeared slightly more often than x2, so I’m a little more certain about the value of w1.

-0.5 0 0.5 1 1.5 2

-0.5

0

0.5

1

1.5

2

1w

2w1 2 1w w

Parameter uncertainty: Example 3

• Input history:

1 1 2

2 1 2

1* 2

2

var cov ,[ ],

cov , var

,

0.5 1.25 0.25,

0.5 0.25 0.25

ML

T

w w wN E

w w w

N X X

N

w w

w

0 1 0.5

0 1 0.5

0 1 0.5

0 1 0.5

1 1 1

1x 2x*y

x2 was mostly “on”. I’m pretty certain about w2, but I am very uncertain about w1. Occasionally x1 and x2 were on together, so I have some reason to believe that:

1w

2w1 2 1w w

-0.5 0 0.5 1 1.5 2

-0.5

0

0.5

1

1.5

2

Effect of uncertainty on learning rate

• When you observe an error in trial n, the amount that you should change w should depend on how certain you are about w. The more certain you are, the less you should be influenced by the error. The less certain you are, the more you should “pay attention” to the error.

( 1) ( ) ( ) ( ) ( ) ( )n n n n n T ny w w k x w

mx1 mx1

Kalman gain

error

Rudolph E. Kalman (1960) A new approach to linear filtering and prediction problems. Transactions of the ASME–Journal of Basic Engineering, 82 (Series D): 35-45.

Research Institute for Advanced Study7212 Bellona Ave, Baltimore, MD

Example of the Kalman gain: running estimate of average

( )

*( ) * ( ) *( ) 2

1( ) ( )

1

1( 1) ( )

1

1( ) ( ) ( ) ( 1) ( ) ( 1) ( )

1

( ) ( 1) ( ) (

1

; 0,

1 1 1

1

1

1

1 1 1 11 1

1

i

i i i

nT

nn T T i

i

nn i

i

nn i n n n n n

i

n n n n

x

y w y y N

X

w X X X yn

w yn

w y y n w y w yn n n n

w w y wn

y

1)

Kalman gain: learning rate decreases as the number of samples increase

As n increases, we trust our past estimate w(n-1) a lot more than the new observation y(n)

Past estimate New measure

w(n) is the online estimate of the mean of y

Example of the Kalman gain: running estimate of variance

*( ) * ( ) * ( ) 2

2 22 ( ) ( ) ( )( )

1 1

1 2 2( ) ( ) ( ) ( )

1

22 ( ) ( )( 1)

22 ( ) ( 1) ( ) ( 1)( 1)

; 0,

1 1ˆ

1

1ˆ1

1 1ˆ1

i i i

n ni i n

ni i

ni n n n

i

n nn

n n n nn

y w y w N

y E y y wn n

y w y wn

n y wn

n y w y wn n

22 ( ) ( 1)( 1)

2 22 ( ) ( 1)( 1)

22 2 ( ) ( 1)( ) ( 1) 2

1 1 1ˆ1 1 1

1 1ˆ1 1

1 1ˆ ˆ

n nn

n nn

n nn n

n y wn n n

n y wn n

n ny w

n n

sigma_hat is the online estimate of the var of y

( ) ( ) * ( ) 2

1

1( ) ( )

1 1( ) ( ) ( )

0,n n T n

n n

n nn n T

n n n n n nn n n T

y N

y

y

x w

w

x w

w w k x w

Objective: adjust learning gain in order to minimize model uncertainty

1 1*

*

1 1 1

n n n n

n n n n

n n n n n n T

n n n n n n T

P E

P E

w w w

w w w

w w

w w

parameter error before I saw the data (a prior error)

parameter error after I saw the data point (a posterior error)

a prior var-cov of parameter error

a posterior var-cov of parameter error

my estimate of w* before I see y in trial n, given that I have seen y up to n-1

error in trial n

my estimate after I see y in trial n

Hypothesis about data observation in trial n

Some observations about model uncertainty

* *

* *

var

Tn n n n n n

Tn n n n

Tn n n n n n n n

n n n n

P E

E

E E E

P

w w w w

w w w w

w w w w

w

We note that P(n) is simply the var-cov matrix of our model weights. It represents the uncertainty in our model.

We want to update the weights so to minimize a measure of this uncertainty.

Trace of parameter var-cov matrix is the sum of squared parameter errors

1 1 1 2

2 2 1 2

2( ) ( )21 11 1

1 1

( )2 ( )21 2 1 2

1

var cov ,0, ,

cov , var0

1 1var

1var var

T

n ni i

i i

ni i

i

P E

w w w wN P N

w w w w

w w E w wn n

trace P w w w wn

ww

w 0

Our objective is to find learning rate k (Kalman gain) such that we minimize the sum of the squared error in our parameter estimates. This sum is the trace of the P matrix. Therefore, given observation y(n), we want to find k such that we minimize the variance of our estimate w.

1 1( ) ( ) ( )n n n n n nn n n Ty

w w k x w

1 1( ) ( ) ( )

1 1( ) ( ) * ( ) ( )

1( ) ( ) ( ) ( ) ( ) ( ) *

1 1

1( ) ( ) ( ) ( ) ( ) ( )

var

var

var

n n n n n nn n n T

n n n n n nn n T n n T

n n n nn n T n n n n T

n n n n

n n n n

Tn nn n T n n T n n

y

I

P

P

I P I

w w k x w

w w k x w x w

w k x w k k x w

w

w

k x k x k

( )

1( ) ( ) ( ) ( ) ( ) 2 ( )

n T

Tn n n nn n T n n T n n TP I P I

k

k x k x k k

Find K to minimize trace of uncertainty

1 1 1( ) ( ) ( ) ( )

1( ) ( ) ( ) 2 ( )

1 1 1( ) ( ) ( ) ( ) ( ) 2 ( )2

n n n n n n n nn n T n n T

n nn n T n n T

n n n n n nn n T n n T n n T

tr P tr P tr P tr P

tr P

tr P tr P tr P

x k k x

k x x k

k x k x x k

1( ) ( ) ( ) ( ) ( ) 2 ( )

1 1 1 1( ) ( ) ( ) ( ) ( ) ( ) ( ) ( ) ( ) 2 ( )

Tn n n nn n T n n T n n T

n n n n n n n nn n T n n T n n T n n T n n T

P I P I

P P P P

k x k x k k

x k k x k x x k k k

tr aB atr B

T

T

tr A tr A

P P

1 1( ) ( ) ( ) 2 ( ) ( ) ( ) 2 ( ) ( )

1( ) ( ) 2 ( ) ( )

1( ) ( ) 2 ( ) ( )

n n n nn n T n n T n T n n n T

n nn T n n n T

n nn T n n T n

tr P tr P

P tr

P

k x x k x x k k

x x k k

x x k k

scalar

Find K to minimize trace of uncertainty

The Kalman gain

Tdtr AB B

dA

If I have a lot of uncertainty about my model, P is large compared to sigma. I will learn a lot from the current error.

If I am pretty certain about my model, P is small compared to sigma. I will tend to ignore the current error.

1 1 1( ) ( ) ( ) ( ) 2 ( ) ( )

1 1( ) ( ) ( ) 2 ( )( )

1 ( )( )

1( ) ( ) 2

2

2 2 0

n n n n n n n nn n T n T n n T n

n n n n n nn n T n nn

n n nn

n nn T n

tr P tr P tr P P

dtr P P P

d

P

P

k x x x k k

x x x kk

xk

x x

Update of model uncertainty

1 1 1 1( ) ( ) ( ) ( ) ( ) ( ) ( ) 2 ( )

11 1( ) ( ) ( ) ( ) 2

1 1 1 1( ) ( ) ( ) 2 ( )

11 1( ) ( ) ( ) 2

n n n n n n n n n nn n T n n T n n T n n T

n n n nn n n T n

Tn n n n n n n n n nn n T n n T

n n n nn n T n

P P P P P

P P

P P P P P

P P

x k k x k x x k

k x x x

x x x x

x x x

1( )

11 1 1( ) ( ) ( ) 2 ( ) ( ) 2

1 1( ) ( ) 2 ( )

11 1 1 1( ) ( ) ( ) 2 ( )

1( ) ( )

n nn T

n n n n n nn n T n n T n

Tn n n nn T n n T

n n n n n n n nn n T n n T

n n n nn n T

P

P P P

P P

P P P P

P I P

x

x x x x x

x x x

x x x x

k xModel uncertainty decreases with every data point that you observe.

* *( 1) ( )

( ) ( ) * ( ) 2

10 10

1 1( ) ( ) ( )

1 ( )( )

1( ) ( ) 2

1( ) ( )

1

1

0,

,

n n

n n T n

n n n n n nn n n T

n n nn

n nn T n

n n n nn n T

n n n n

n n n n

y N

P

y

P

P

P I P

P P

w w

x w

w

w w k x w

xk

x x

k x

w w

*w *w *w

y y y

In this model, we hypothesize that the hidden variables, i.e., the “true” weights, do not change from trial to trial.

Observedvariables

Hidden variable

x x x

A priori estimate of mean and variance of the hidden variable before I observe the first data point

Update of the estimate of the hidden variable after I observed the data point

Forward projection of the estimate to the next trial

* * ( )( 1) ( )

( ) ( ) * ( ) 2

10 10

1 ( )( )

1( ) ( ) 2

1 1( ) ( ) ( )

1( ) ( )

1

1

0,

0,

,

nn n w w

n n T ny y

n n nn

n nn T n

n n n n n nn n n T

n n n nn n T

n n n n

n n n n T

A N Q

y N

P

P

P

y

P I P

A

P AP A Q

w w ε ε

x w

w

xk

x x

w w k x w

k x

w w

*w *w *w

y y y

In this model, we hypothesize that the hidden variables change from trial to trial.

x x x

A priori estimate of mean and variance of the hidden variable before I observe the first data point

Update of the estimate of the hidden variable after I observed the data point

Forward projection of the estimate to the next trial

1n n n n TP AP A Q

• Learning rate is proportional to the ratio between two uncertainties: my model vs. my measurement.

• After we observe an input x, the uncertainty associated with the weight of that input decreases.

• Because of state update noise Q, uncertainty increases as we form the prior for the next trial.

Uncertainty about my model parameters

1 ( )( )

1( ) ( ) 2

n n nn

n nn T n

P

P

xk

x x

Uncertainty about my measurement

1( ) ( )n n n nn n TP I P k x

* *( 1) ( )

( ) ( ) * ( ) 2 0,

n n

n n T ny N

w w

x w

Comparison of Kalman gain to LMS

See derivation of this in homework

( )( )

2

1 1( ) ( ) ( )

1 1( ) ( ) ( )2

n n nn

n n nn n n T

n nn nn n T n

P

y

Py

xk

w w k x w

w x w x

In the Kalman gain approach, the P matrix depends on the history of all previous and current inputs. In LMS, the learning rate is simply a constant that does not depend on past history.

1 1( ) ( ) ( )n n nn n T ny w w x w x

With the Kalman gain, our estimate converges on a single pass over the data set. In LMS, we don’t estimate the var-cov matrix P on each trial, but we will need multiple passes before our estimate converges.

* * ( ) 2( 1) ( )

( ) ( ) * ( ) 2

0.99 0,

=1 0,

nn n w w

n n T ny y

w aw a N q

y x w x N

( )nk

1n nP

2 4 6 8 10

0.65

0.7

0.75

0.8

2 4 6 8 10

2

2.5

3

3.5

4

4.5

5

High noise in the state update model produces increased uncertainty in model parameters. This produces high learning rates.

2 4 6 8 100.5

0.55

0.6

0.65

0.7

0.75

0.8

( )nk

2 4 6 8 10

2

2.5

3

3.5

4

4.5

5

2 22, 1q

2 21, 1q 2 21, 2q

1n nP

2 22, 1q

2 21, 1q

High noise in the measurement also increases parameter uncertainty. But this increase is small relative to measurement uncertainty. Higher measurement noise leads to lower learning rates.

Effect of state and measurement noise on the Kalman gain

* * ( ) 2 2( 1) ( )

( ) ( ) * ( ) 2 2

0, 1

=1 0, 1

nn n w w

n n T ny y

w aw N q q

y x w x N

( )nk

1n nP

Learning rate is higher in a state model that has high auto-correlations (larger a). That is, if the learner assumes that the world is changing slowly (a is close to 1), then the learner will have a large learning rate.

0.99

0.50

0.10

a

a

a

2 4 6 8 10

0.5

0.55

0.6

0.65

0.7

0.75

0.8

2 4 6 8 101

2

3

4

5

Effect of state transition auto-correlation on the Kalman gain

Recommended