Transfer Learning via Bayesian Latent Factor Analysis ... › statml › research ›...

Preview:

Citation preview

Transfer Learning viaBayesian Latent Factor Analysis

Preliminary Exam Presentation

Liz Lorenzi1

Joint work with Katherine Heller1 and Ricardo Henao2

Duke University Department of Statistical Science1,Duke University Department of Electrical and Computer Engineering2

November 30, 2016

1 / 40

The Problem:

I Average surgical complication rate is 15%

I 50% of these complications are avoidable

I Average cost of complication is $11, 626 [Dimick et al., 2004]

Our mission:

I Predict post-operative complications using surgery patientelectronic health records (EHRs)

I Enhance decision making of clinicians by suggestingappropriate interventions

2 / 40

Leveraging information across databasesI National Surgery Quality Improvement Program (NSQIP)

I 3.7 million patients, > 700 hospitals contribute

I Duke Medical CenterI 13,711 patients

Programs collect same information but have different populationsI Duke:

I teaching hospitalI higher variability in outcomes and complicationsI more experimental surgeries

I NSQIP:I wide variety in hospital typesI different patient care and patient cohorts

Our goal:

1. Predict complications for patients at Duke

2. Leverage information in NSQIP

3. Discern important factors in predicting complications3 / 40

Transfer Learning

In machine learning, we define a problem with an additional sourceof information (NSQIP) apart from the standard training data(Duke) to be transfer learning [Pan and Yang, 2010].

I Goal: improve learning in target task by leveraging knowledgefrom related tasks

I Our approach: Hierarchical latent factor models

I Learn one set of latent factors that accounts for thedistributional differences across populations

I Appropriately model separate covariance structurefor each population

4 / 40

Latent Factor Model (LFM)

LFM explains underlying variability among observed, correlatedcovariates in lower-dimensional unobserved “factors.”

I Relate observed data, Xi , to a k-vector of random variables, fi

Xi = Λfi + εi , εi ∼ N(0,Σ)

Σ = diag(σ21, ...σ

2p)

I Λ is P × K factor loadings matrix

Xi1 Xi2 Xi3 Xi4 Xi5

fi1 fi2

5 / 40

Properties of LFM

We assume the P variables of X are distributed as a multivariatezero-mean normal distribution

X ∼ Normal(0,Ω)

where Ω = ΛΛ′ + Σ

I Dependence between observed variables is induced bymarginalizing over the distribution of the factors

I Allows for direct modeling of covariance matrix

6 / 40

Transfer Learning via Latent Factor Model (TL-LFM)

I X ti : i = 1, ..., nt represent predictors of target data

I X si : i = 1, ..., ns represent predictors of source data

X ji = Λj fi + εi

I where j ∈ s, t represents the different populations

We facilitate sharing between groups via the prior setup:

mp ∼ N(0,1

φIk)

Λsp ∼ N(mp,

1

φsIK ), Λt

p ∼ N(mp,1

φtIK )

7 / 40

Properties of TL-LFM

Marginalizing over the factors, X has the following form:

Xi ∼ Np(0,Ωj)

Ωj = V (Xi |Λj ,Σ) = ΛjΛj ′ + Σ

I Results in separate modeling of populations’ covariances

8 / 40

TL-LFM Regression

Let Z = Y ,X represent the full data.

I Joint model implies that E (yi |xi ) = x ′i θj where

θj = ΩjXX−1Ωj

YX

The posterior predictive distribution is easily found by solving,

f (yn+1|y1, ..., yn, xn+1) =

∫f (yn+1|xn+1,Ω)π(Ω|y1, ..., yn, x1, ..., xn)dΩ

9 / 40

Simulation Experiments

Goal: mimic transfer learning across two populations

I different sample ratios (target:source)

I 35 binary predictors, 35 continuous predictors

I repeated ten times

Simulate Zi , for i = 1, ..., 5000 from a 70-dimensional normaldistribution, with zero mean and covariance equal to Ωj .

For each population:

I Sample each row of Λj from a Normal(0, Ik) with K = 20

I Randomly select two locations of first row of Λ and set to -1and 1, with the rest 0

I Draw the diagonal of Σ from an InvGamma(1, 0.5) with priormean equal to 2.

10 / 40

Visualizing TL-LFM

Plot K -dimensional latent factors using t-sne (van der Maaten,Hinton 2008) comparing hierarchical and non-hierarchical models

−20 −10 0 10 20

−20

−10

010

20

Y1

Y2

(a) TL - 700:2800

−15 −10 −5 0 5 10 15

−20

−10

010

20

Y1

Y2

(b) TL - 500:2500

−10 0 10

−15

−5

05

1015

Y1

Y2

(c) TL - 100:2000

−20 −10 0 10 20

−20

−10

010

20

Y1

Y2

(d) NoTL - 700:2800

−20 −10 0 10 20

−20

−10

010

20

Y1

Y2

(e) NoTL - 500:2500

−20 −10 0 10 20

−20

−10

010

Y1

Y2

(f) NoTL - 100:2000

Figure: Method - Target:Source11 / 40

Evaluating TL-LFM Prediction Results

We report the area under the ROC curve with standard errors

Target:Source TL-LFM LFM Lasso700:2800 0.809 (.007) 0.587 (.005) 0.723 (.006)500:2500 0.790 (.005) 0.594 (.008) 0.732 (.005)200:2000 0.744 (.005) 0.547 (.005) 0.585 (.004)

Table: Tested on target only held out test set

12 / 40

Surgery Data Results

NSQIP/Duke data contains information for a single patientundergoing surgery, with covariates describing

I demographic information

I preoperative and intraoperative variables

I outcomes of surgery (e.g. cardiac arrest, pneumonia,infection)

Results show Lasso outperforms our TL-LFM.

TL-LFM LFM Lasso

0.73 0.60 0.76

Table: Prediction on Duke-only patients for any-morbidity

13 / 40

Results of TL-LFM: How to improve performance?

We focus on 3 areas:

1. Modeling modalstructure

2. Allowing more flexibletransferring ofinformation

3. Inducing strongersparsity

14 / 40

1-2: Hierarchical Dirichlet Process

G 0|H ∼ DP(α0,H)

Base measure of child DP is also DP.

G j |G 0 ∼ DP(αj ,G0), ∀j ∈ 1, .., J

G 0 =∞∑k=1

π0kδλ0

pk

λ0p ∼ Norm(0,Σλ)

and for each group j ∈ J,

G j =∞∑k=1

πjkδλ0pk

[Teh et al., 2012]15 / 40

Graphical model for HDP

[Teh et al., 2012]

16 / 40

Finite HDP Conversion

HDP is infinite limit of finite mixture models formulation.

π0|α0 ∼ Dir(α0/K , ..., α0/K )

πj |αj , π0 ∼ Dir(αjπ

0)

λ0 ∼ Normal(0,Σλ)

G j =K∑

k=1

πjkδλ0k

17 / 40

Adapting HDP for factor modelInstead of drawing from discrete mixture:

G j =K∑

k=1

πjkδλ0k

Consider the P × K loadings matrix for λ0 weighted by thestick-breaking weights, Λj = [πj1λ

01, .., π

jKλ

0K ].

18 / 40

HDP as scale mixture

We use the stick-breaking proportions from the HDP as aweighting scheme to the rows of the loadings matrix.

√πjλ0

p

where λ0p ∼ N(0, 1

φ · IK ).This results in √

πjkλ0k ∼ N(0, πjk

1

φ)

Can we formulate this as a sparse prior to address our third goal?

19 / 40

3: Sparse modelingFrom Bayesian-learning perspective, there are 2 mainsparse-estimation options

I Discrete mixtures - e.g. spike and slab([Mitchell and Beauchamp, 1988];[George and McCulloch, 1993])

βj ∼ w · g(βj) + (1− w) · δ0

I Shrinkage priors - e.g. horseshoe, L1/Laplace prior([Carvalho et al., 2009]; [Tibshirani, 1996];[Mohamed et al., 2011])

βj |τ2, λj ∼ Norm(0, τ2λ2j )

λ2j ∼ π(λ2

j )

τ2 ∼ π(τ2)

20 / 40

Examples of Scale Mixture Priors

Marginal distributions for β:I Student-t with λj ∼ InvGam(v/2, v/2)I Double Exponential/Laplace with λj ∼ Exp(2)I Horseshoe with half Cauchy, λj ∼ C+(0, 1)

[Carvalho et al., 2009]21 / 40

How to choose a sparse prior?

[Polson and Scott, 2010] presents criteria for evaluating differentsparsity priors. They focus on two guidelines:

I π(λ2j ) should have heavy tails

I π(τ2) should have substantial mass at zero

“Strong global shrinkage handles the noise; the local λj ’s act todetect the signals.”

22 / 40

Motivating new model: TL-SLFM

Back to the original model. Let j represent separate populations(expanding from just S or T):

Xji = Λj fji + εji

εji ∼ N(0,Σj), Σj = diag(σ2j1, ..., σ

2jP)

fji ∼ N(0, IK )

How can we change the prior on Λj to result in covariancestructure that adjusts to our goals?

23 / 40

Constructing Λ

For each k in 1, ..,K , we weigh a global λ0k with

√πjk , such that

λjk |πjk :=

√πjkλ

0k .

I The global parameter, πj , controls shrinkage of λp.

λjp|πj ∼ Normal(0,1

φpπj Ik)

πj |π0 ∼ Dir(αjπ0)

π0 ∼ Dir(α0/K )

I The local parameter, φp, will have heavy tails. For φpk ∈ φp,

φpk ∼ Gamma(τ/2, τ/2)

24 / 40

Properties of resulting factor model

Model learns a marginal covariance of Ωj = λj ′λj + Σj , where λj isthe resulting sparse loadings matrix.

Results in partitioned covariance that adjusts to each population:

Ωj = (λ0Πjλ0′) + Σj

where Πj = diag(πj1, ..., πjK )

25 / 40

Choosing number of factors

Choosing correct number of factors is difficult computationally andconceptually.

I Early work chooses number of factors by maximizing marginallikelihood, AIC, or BIC

I [Lopes and West, 2004] suggest a reversible-jump MCMCmethod to learn K

I [Lucas et al., 2006]; [Carvalho et al., 2012] choose number offactors by using model selection priors to zero out parts of theloadings matrix

I [Bhattacharya and Dunson, 2011] propose a multiplicativegamma shrinkage prior to allow the number of factors toapproach infinity while the columns of the loadings matrixincreasingly shrink towards zero

26 / 40

Model is robust to choosing number of factorsComparing data simulated under a 10-factor model (red) to 20largest weights learned from models with K = 20 : 100 (black).

20 Factors 30 Factors 40 Factors

50 Factors 60 Factors 70 Factors

80 Factors 90 Factors 100 Factors

0.0

0.1

0.2

0.3

0.4

0.0

0.1

0.2

0.3

0.4

0.0

0.1

0.2

0.3

0.4

5 10 15 20 5 10 15 20 5 10 15 20Index

Wei

ghts

First 20 weights from Models learned with different K

Models appropriately shrink weights for models with K > 10.27 / 40

Deriving Inference for stick-breaking scale mixtureWe use the following identity to decompose the weights, πj :

w jk ∼ Gam(αjπ

0, 1), πj = (w j

1∑w jk

, ...,w jK∑w jk

) ∼ Dir(αjπ01, .., αjπ

0K )

We rewrite the generative model using the unnormalized w j .

Xji = λj fi + εi

εi ∼ N(0,Σj), Σj = diag(σ2jp)p∈1,...,P

σ2jp ∼ InvGam(ν/2, νs2/2)

fi ∼ N(0, IK )

λjp|w j ∼ Normal(0,W j1/φp), W j = diag(w j1, ..,w

jK )

w jk |αj , π

0 ∼ Gamma(αjπ0k , 1)

π0 ∼ Dirichlet(α0/K )

φpk ∼ Gamma(τ/2, τ/2)

where i = 1, .., n, p = 1, ..,P, k = 1, ..,K , j = 1, .., J.28 / 40

Resulting Full Conditionals for stick-breaking scale mixture

Results in the following tractable full conditionals:

(λjp|−) ∼ N(m = (σ−2jp F ′Xjp)V ,V = (φpW

j−1 + σ−2jp F ′F )−1)

(w jk |−) ∼ GIG(p = αjπ

0k − P/2, a = 2, b = Φk(λjTk λ

jk))

(φpk |−) ∼ Gamma(τ/2 + J/2, τ/2 +J∑

j=1

λj2pk

2w jk

)

Note: Capital parameters represent diagonal matrix

29 / 40

Inference: Learning π0

For drawing π0 we use a Metropolis-Hastings sampling scheme:We propose π∗0k :

π∗0k ∼ LogNormal(log(πt−1k ),C )

and normalize. Then accept according to the acceptance ratio:

A(π∗0|πt−10 ) = min

(1,

P(π∗0|w1,w2)

P(πt−10 |w1,w2)

g(πt−10 |π∗0)

g(π∗0|πt−10 )

)

Results in better mixing.

30 / 40

Initial Results: Simulation

We set up TL-SLFM as regression model.

I Reporting area under ROC curve with standard errors from 10simulations

TL-SLFM TL-LFM LFM Lasso700:2800 0.812 (0.008) 0.788 (0.010) 0.754 (0.012) 0.783 (0.009)500:2500 0.791 (0.010) 0.765 (0.012) 0.694 (0.008) 0.762 (0.006)200:2000 0.795 (0.008) 0.744 (0.010) 0.668 (0.011) 0.698 (0.010)

Table: Prediction on target-only held out set.

31 / 40

Initial Results: Real Data

Until inference is scaled to evaluate the full data, we test onsubsets of the data by surgery.

I Hernia surgeries (5000 in NSQIP to 362 in Duke)

TL-SLFM TL-LFM Lasso

0.876 0.733 0.838

Table: Prediction on Duke-only patients for any-morbidity

I Breast Mastectomy (5000 in NSQIP to 680 in Duke)

TL-SLFM TL-LFM Lasso

0.747 0.698 0.706

Table: Prediction on Duke-only patients for any-morbidity

32 / 40

Final words:

Overview/Takeaways

I Presented a transfer learning framework using latent factormodels

I Extended framework for more complicated relationshipsbetween populations through TL-SLFM

I Created a novel way to use stick-breaking weights in a scalemixture

33 / 40

Next steps

Transfer Learning

I Scale inference method using stochastic variational Bayes orStochastic Gradient Descent MCMC

I Extend SLFM to be nonparametric (infinite number of factors)

I Apply to different problems for multiple populations withvarying types of information

34 / 40

Next steps

Causal InferenceI kelaHealth

I Measure effectiveness of kelaHealth in reducing complicationsI Consider more tuned intervention based on expected individual

treatment effect

I MS MosaicI Learn sequential treatment effect for MS patients for varying

types of treatments

35 / 40

Thank you!

Prelim Committee:

I Katherine Heller, Ph.D.

I Ricardo Henao, Ph.D.

I Fan Li, Ph.D.

I Surya Tokdar, Ph.D

kelaHealth Team:

I Bora Chang, M.D. Candidate

I Erich Huang, M.D./Ph.D.

I Ouwen Huang, M.D./Ph.D. Candidate

I Jeff Sun, M.D.

36 / 40

Works Cited: I

Bhattacharya, A. and Dunson, D. B. (2011).Sparse bayesian infinite factor models.Biometrika, 98(2):291–306.

Carvalho, C. M., Chang, J., Lucas, J. E., Nevins, J. R., Wang, Q., and West, M.(2012).High-dimensional sparse factor modeling: Applications in gene expressiongenomics.Journal of the American Statistical Association.

Carvalho, C. M., Polson, N. G., and Scott, J. G. (2009).Handling sparsity via the horseshoe.In AISTATS, volume 5, pages 73–80.

Dimick, J. B., Chen, S. L., Taheri, P. A., Henderson, W. G., Khuri, S. F., andCampbell, D. A. (2004).Hospital costs associated with surgical complications: a report from theprivate-sector national surgical quality improvement program.Journal of the American College of Surgeons, 199(4):531–537.

George, E. I. and McCulloch, R. E. (1993).Variable selection via gibbs sampling.Journal of the American Statistical Association, 88(423):881–889.

37 / 40

Works Cited: II

Lopes, H. F. and West, M. (2004).Bayesian model assessment in factor analysis.Statistica Sinica, pages 41–67.

Lucas, J., Carvalho, C., Wang, Q., Bild, A., Nevins, J., and West, M. (2006).Sparse statistical modelling in gene expression genomics.Bayesian Inference for Gene Expression and Proteomics, 1:0–1.

Mitchell, T. J. and Beauchamp, J. J. (1988).Bayesian variable selection in linear regression.Journal of the American Statistical Association, 83(404):1023–1032.

Mohamed, S., Heller, K., and Ghahramani, Z. (2011).Bayesian and l1 approaches to sparse unsupervised learning.arXiv preprint arXiv:1106.1157.

Pan, S. J. and Yang, Q. (2010).A survey on transfer learning.IEEE Transactions on knowledge and data engineering, 22(10):1345–1359.

Polson, N. G. and Scott, J. G. (2010).Shrink globally, act locally: Sparse bayesian regularization and prediction.Bayesian Statistics, 9:501–538.

38 / 40

Works Cited: III

Teh, Y. W., Jordan, M. I., Beal, M. J., and Blei, D. M. (2012).Hierarchical dirichlet processes.Journal of the american statistical association.

Tibshirani, R. (1996).Regression shrinkage and selection via the lasso.Journal of the Royal Statistical Society. Series B (Methodological), pages267–288.

39 / 40

Appendix: More on Stick-breaking prior

Alternatively can write prior as product of two random variables:

Λjk :=

√w jkλ

0k

I Results in marginal covariance of Ωj = (λ0′W jλ0) + Σj

I This product results in the following distribution for theelement Λj

hk where h = 1, ...,P.

f (Λjhk) = f (

√w jkλ

0k) =

φ−1/2+απk

21/2−απk−1π1/2Γ(απk)(Λhk)3απkKαπk

(√

1/φ2Λ2hk)

I where K is a modified Bessel function of the second kind.

I Connection: This product distribution is very similar to the marginaldistribution of f (βk) from the generalized double pareto scalemixture (Caron, Doucet, 2008).

40 / 40

Recommended