13
1/13 Stein Point Markov Chain Monte Carlo Wilson Chen Institute of Statistical Mathematics, Japan June 11, 2019 @ ICML, Long Beach

Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

  • Upload
    others

  • View
    0

  • Download
    0

Embed Size (px)

Citation preview

Page 1: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

1/13

Stein Point Markov Chain Monte Carlo

Wilson ChenInstitute of Statistical Mathematics, Japan

June 11, 2019 @ ICML, Long Beach

Page 2: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

2/13

Collaborators

Alessandro Barp Francois-Xavier Briol Jackson Gorham

Mark Girolami Lester Mackey Chris Oates

Page 3: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

3/13

Empirical Approximation Problem

A major problem in machine learning and modern statistics is to approximate somedifficult-to-compute density p defined on some domain X ⊆ Rd where normalisationconstant is unknown. I.e., p(x) = p(x)/Z and Z > 0 is unknown.

We consider an empirical approximation of p with points {xi}ni=1:

pn(x) =1

n

n∑i=1

δ(x− xi),

so that for test function f : X → R:∫Xf(x)p(x)dx ≈ 1

n

n∑i=1

f(xi).

A popular approach is Markov chain Monte Carlo.

Page 4: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

4/13

Discrepancy

Idea – construct a measure of discrepancy

D(pn, p)

with desirable features:

• Detect (non)convergence. I.e., D(pn, p)→ 0 only if pn∗−→ p.

• Efficiently computable with limited access to p.

Unfortunately not the case for many popular discrepancy measures:

• Kullback-Leibler divergence,

• Wasserstein distance,

• Maximum mean discrepancy (MMD).

Page 5: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

5/13

Kernel Embedding and MMD

Kernel embedding of a distribution p

µp(·) =∫k(x, ·)p(x)dx (a function in the RKHS K)

Consider the maximum mean discrepancy (MMD) as an option for D:

D(pn, p) := ‖µpn − µp‖K =: Dk,p({xi}ni=1)

∴ Dk,p({xi}ni=1)2 = ‖µpn − µp‖

2K = 〈µpn − µp, µpn − µp〉

= 〈µpn , µpn〉 − 2〈µpn , µp〉+ 〈µp, µp〉

We are faced with intractable integrals w.r.t. p!

For a Stein kernel k0:

µp(·) =∫k0(x, ·)p(x)dx = 0.

∴ ‖µpn − µp‖2K0

= ‖µpn‖2K0

=: Dk0,p({xi}ni=1)2 =: KSD2!

Page 6: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

6/13

Kernel Stein Discrepancy (KSD)The kernel Stein discrepancy (KSD) is given by

Dk0,p({xi}ni=1) =1

n

√√√√ n∑i=1

n∑j=1

k0(xi, xj),

where k0 is the Stein kernel

k0(x, x′) := TpT ′pk(x, x′)= ∇x · ∇x′k(x, x′) + 〈∇x log p(x),∇x′k(x, x′)〉

+ 〈∇x′ log p(x′),∇xk(x, x′)〉

+ 〈∇x log p(x),∇x′ log p(x′)〉k(x, x′),

with Tpf = ∇(pf)/p. (Tp is a Stein operator.)

• This is computable without the normalisation constant.

• Requires gradient information ∇ log p(xi).

• Detects (non)convergence for an appropriately chosen k.

Page 7: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

7/13

Stein Points (SP)

The main idea of Stein Points is the greedy minimisation of KSD:

xj |x1, . . . , xj−1 ← argminx∈X

Dk0,p({xi}j−1i=1 ∪ {x})

= argminx∈X

k0(x, x) + 2

j−1∑i=1

k0(x, xi).

A global optimisation step is needed for each iteration.

Page 8: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

8/13

Stein Point Markov Chain Monte Carlo (SP-MCMC)

We propose to replace the global minimisation at each iteration j of the SP methodwith a local search based on a p-invariant Markov chain of length mj . The proposedSP-MCMC method proceeds as follows:

1. Fix an initial point x1 ∈ X .

2. For j = 2, . . . , n:

a. Select i∗ ∈ {1, . . . , j − 1} according to criterion crit({xi}j−1i=1 ).b. Generate (yj,i)

mj

i=1 from a p-invariant Markov chain with yj,1 = xi∗ .

c. Set xj ← argminx∈{yj,i}

mji=1

Dk0,p({xi}j−1i=1 ∪ {x}).

For crit, three different approaches are considered:

• LAST selects the point last added: i∗ := j − 1.

• RAND selects i∗ uniformly at random in {1, . . . , j − 1}.• INFL selects i∗ to be the index of the most influential point in {xi}j−1i=1 .

We call x∗i the most influential point if removing it from the point set creates thegreatest increase in KSD.

Page 9: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

9/13

Gaussian Mixture Model Experiment

MCMC

500 1000j

-4

-2

0lo

g K

SD

500 1000j

-2

0

2

0 2 4 6

Jump2

0

0.5

1

1.5

Den

sity

LAST

500 1000-4

-2

0

500 1000

-2

0

2

0 2 4 60

0.5

1

1.5SP-MCMCMCMC

RAND

500 1000-4

-2

0

500 1000

-2

0

2

0 2 4 60

0.5

1

1.5

INFL

500 1000-4

-2

0

500 1000

-2

0

2

0 2 4 60

0.5

1

1.5

Page 10: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

10/13

IGARCH Experiment (d = 2)

2 4 6 8 10 12log n

eval

-11

-10

-9

-8

-7

-6

-5

-4

log

EP

MALARWMSVGDMEDSPSP-MALA LASTSP-MALA INFLSP-RWM LASTSP-RWM INFL

SP-MCMC methods are compared against the original SP (Chen et al., 2018), MED (RoshanJoseph et al., 2015) and SVGD (Liu & Wang, 2016), as well as the Metropolis-adjustedLangevin algorithm (MALA) and random-walk Metropolis (RWM).

Page 11: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

11/13

ODE Experiment (d = 10)

4 6 8 10 12log n

eval

1

2

3

4

5

6

7

8

log

KS

D

MALARWMSVGDMEDSPSP-MALA LASTSP-MALA INFLSP-RWM LASTSP-RWM INFL

SP-MCMC methods are compared against the original SP (Chen et al., 2018), MED (RoshanJoseph et al., 2015) and SVGD (Liu & Wang, 2016), as well as the Metropolis-adjustedLangevin algorithm (MALA) and random-walk Metropolis (RWM).

Page 12: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

12/13

Theoretical Guarantees

The convergence of the proposed SP-MCMC method is established, with an explicitbound provided on the KSD in terms of the V -uniform ergodicity of the Markovtransition kernel.

Example: SP-MALA Convergence

Let (mj)nj=1 ⊂ N be a fixed sequence and let {xi}ni=1 denote the SP-MALA

output, based on Markov chains (Yj,l)mj

l=1, j ∈ N. (Under certain regularity con-ditions) MALA is V -uniformly ergodic for V (x) = 1 + ‖x‖2 and ∃C > 0 suchthat

E[Dk0,p({xi}ni=1)

2]≤ C

n

n∑i=1

log(n ∧mi)

n ∧mi.

Page 13: Stein Point Markov Chain Monte Carlo11-11-00)-11...2000/11/11  · 8/13 Stein Point Markov Chain Monte Carlo (SP-MCMC) We propose to replace the global minimisation at each iteration

13/13

Paper, Code and Poster

• Paper is available at:https://arxiv.org/pdf/1905.03673.pdf

• Code is available at:https://github.com/wilson-ye-chen/sp-mcmc

• Check out the poster at Pacific Ballroom #216 from 6:30pm to 8pm!