Variational Bayes and Variational Message Passing Variational Bayes and Variational Message Passing

  • View
    1

  • Download
    0

Embed Size (px)

Text of Variational Bayes and Variational Message Passing Variational Bayes and Variational Message Passing

  • Variational Bayes and Variational Message Passing

    Mohammad Emtiyaz Khan

    CS,UBC

    Variational Bayes and Variational Message Passing – p.1/16

  • Variational Inference

    Find a tractable distribution Q(H) that closely approximates the true posterior distribution P (H|V ).

    log P (V ) = ∑

    H

    Q(H) log P (V )

    = ∑

    H

    Q(H) log P (H,V )

    P (H|V )

    = ∑

    H

    Q(H) log

    [ P (H,V )

    Q(H)

    Q(H)

    P (H|V )

    ]

    = ∑

    H

    Q(H) log P (H,V )

    Q(H) ︸ ︷︷ ︸

    L(Q)

    + ∑

    H

    −Q(H) log P (H|V )

    Q(H) ︸ ︷︷ ︸

    KL(Q||P )

    Variational Bayes and Variational Message Passing – p.2/16

  • Variational Inference

    log P (V ) = L(Q) + KL(Q||P ) (1)

    L(Q) = ∑

    H

    Q(H) log P (H,V )

    Q(H) (2)

    KL(Q||P ) = − ∑

    H

    Q(H) log P (H|V )

    Q(H) (3)

    Find Q(H) that maximizes lower bound L(Q) (and hence minimizes KL divergence).

    For Q(H) = P (H|V ), KL vanishes to zero, but P (H|V ) is intractable (that’s why variational approach).

    Trick : Consider a restricted class of Q(H), and then find the member which minimizes the KL divergence.

    Variational Bayes and Variational Message Passing – p.3/16

  • Factorized Distributions

    Q(H) = ∏

    i

    Qi(Hi) (4)

    Substituting this in the expression for lower bound,

    L(Q) = ∑

    H

    i

    Qi(Hi) log P (H,V )

    i Qi(Hi) (Outline)

    = ∑

    H

    i

    Qi(Hi) log P (H,V )− ∑

    H

    i

    Qi(Hi) ∑

    i

    log Qi(Hi)

    = ∑

    H

    i

    Qi(Hi) log P (H,V )− ∑

    i

    Hi

    i

    Qi(Hi) log Qi(Hi)

    = ∑

    H

    i

    Qi(Hi) log P (H,V ) + ∑

    i

    H(Qi)

    Variational Bayes and Variational Message Passing – p.4/16

  • Factorized Distributions

    Now separate out all the terms in one factor Qj.

    L(Q) = ∑

    Hj

    Qj(Hj)〈log P (H,V )〉∼Qj(Hj) ︸ ︷︷ ︸

    log Q∗j (Hj)

    + H(Qi) + ∑

    i6=j

    H(Qi)

    = −KL(Qj ||Q ∗ j) + terms not in Qj (5)

    This bound is maximized wrt Qj when

    log Qj(Hj) = log Q ∗ j(Hj) = 〈log P (H,V )〉∼Qj(Hj) + c (6)

    Now iterate, guaranteed convergence ...

    Variational Bayes and Variational Message Passing – p.5/16

  • Variational Bayes for Bayesian Networks

    log Q∗j(Hj) = 〈log P (H,V )〉∼Qj(Hj) + c

    = ∑

    i

    〈log P (Xi|pai)〉∼Qj(Hj) + c

    = 〈log P (Hj |paj)〉∼Qj(Hj)

    + ∑

    k∈chj

    〈log P (Xk|paj)〉∼Qj(Hj) + c

    Variational Bayes and Variational Message Passing – p.6/16

  • Exponential-Conjugate Models

    P (Y |θ) = exp[φTY (θ)u(Y ) + f(Y ) + g(θ)] (7)

    u(Y ) = Natural statistics (8)

    φY (θ) = Natural Parameter vector (9)

    g(θ) = Constant of integration (10)

    Example I: Bernoulli Distribution

    p(x|µ) = µx(1− µ)1−x (11)

    log p(x|µ) = x log µ + (1− x) log(1− µ) (12)

    = log µ

    (1− µ) ︸ ︷︷ ︸

    φ(µ)

    x ︸︷︷︸

    u(x)

    + log(1− µ) ︸ ︷︷ ︸

    g(µ)

    (13)

    Variational Bayes and Variational Message Passing – p.7/16

  • Exponential-Conjugate Models

    P (Y |θ) = exp[φTY (θ)u(Y ) + f(Y ) + g(θ)] (14)

    P (Y |φ) = exp[φT u(Y ) + f(Y ) + g̃(φ)](Re-parametrization)

    Property I: 〈u(Y )〉P (Y |θ) = − dg̃(φ) dφ

    log p(x|µ) = log µ

    (1− µ) ︸ ︷︷ ︸

    φ(µ)

    x ︸︷︷︸

    u(x)

    + log(1− µ) ︸ ︷︷ ︸

    g(µ)

    (15)

    φ = log µ

    (1− µ) ⇒ µ =

    1 + eφ (16)

    g(µ) = log(1− µ) = − log(1 + eφ) = g̃(φ) (17)

    E(x) = 〈u(Y )〉 = eφ(1 + eφ)−1 = µ (18)

    Variational Bayes and Variational Message Passing – p.8/16

  • Exponential-Conjugate Models

    P (Y |θ) = exp[φTY (θ)u(Y ) + f(Y ) + g(θ)] (19)

    Example II: Gaussian Distribution θ → Y → X ← β

    p(Y |θ) = (2π)−1/2 exp− 1

    2 (Y −θ)2

    log p(Y |θ) = [θ,−1/2] ︸ ︷︷ ︸

    φY (θ)

    [

    Y

    Y 2

    ]

    ︸ ︷︷ ︸

    uY (Y )

    − 1

    2 θ2

    ︸︷︷︸

    gY (θ)

    − 1

    2 log(2π)

    ︸ ︷︷ ︸

    fY (Y )

    p(X|Y, β) = (2π)−1/2β1/2 exp− β

    2 (X−Y )2

    log p(X|Y, β) = [βY,−β/2] ︸ ︷︷ ︸

    φX(Y,β)

    [

    X

    X2

    ]

    ︸ ︷︷ ︸

    uX(X)

    + −1

    2 (βY 2 + log β)

    ︸ ︷︷ ︸

    gX(Y,β)

    − 1

    2 log(2π)

    ︸ ︷︷ ︸

    fX(X)

    Variational Bayes and Variational Message Passing – p.9/16

  • Exponential-Conjugate Models

    Property II: Multi-linearity θ → Y → X ← β

    log p(X|Y, β) = [βY,−β/2] ︸ ︷︷ ︸

    φX(Y,β)

    [

    X

    X2

    ]

    ︸ ︷︷ ︸

    uX(X)

    + −1

    2 (βY 2 + log β)

    ︸ ︷︷ ︸

    gX(Y,β)

    − 1

    2 log(2π)

    ︸ ︷︷ ︸

    fX(X)

    = [βX,−β/2] ︸ ︷︷ ︸

    φXY (X,β)

    [

    Y

    Y 2

    ]

    ︸ ︷︷ ︸

    uY (Y )

    + −1

    2 (βX2 + log β)

    ︸ ︷︷ ︸

    gXY (X,β)

    − 1

    2 log(2π)

    ︸ ︷︷ ︸

    fY (Y )

    log p(Y |θ) = [θ,−1/2] ︸ ︷︷ ︸

    φY (θ)

    [

    Y

    Y 2

    ]

    ︸ ︷︷ ︸

    uY (Y )

    − 1

    2 θ2

    ︸︷︷︸

    gY (θ)

    − 1

    2 log(2π)

    ︸ ︷︷ ︸

    fY (Y )

    Variational Bayes and Variational Message Passing – p.10/16

  • Exponential-Conjugate Models

    Consider Y node and it’s children in θ → Y → X ← β,

    log P (Y |θ) = φTY (θ)uY (Y ) + fY (Y ) + gY (θ)

    log P (X|Y, β) = φTX(Y, β)uX(X) + fX(X) + gX(Y, β)

    = φTXY (X, β)uY (Y ) + gXY (Y, β)

    Recall that,

    log Q∗Y (Y ) = 〈log P (Y |θ)〉∼QY (Y ) + 〈log P (X|Y, β)〉∼QY (Y ) + c

    = 〈φTY (θ)uY (Y ) + fY (Y ) + gY (θ)〉∼QY (Y )

    +〈φTXY (X, β)uY (Y ) + gXY (Y, β)〉∼QY (Y ) + c

    = 〈φTY (θ) + φ T XY (X, β)〉∼QY (Y )uY (Y ) + fY (Y ) + c1

    Variational Bayes and Variational Message Passing – p.11/16

  • Exponential-Conjugate Models

    log Q∗Y (Y ) = 〈φ T Y (θ) + φ

    T XY (X, β)〉∼QY (Y )uY (Y ) + fY (Y ) + c1

    Finally,

    〈φTY (θ)〉 = [θ,−1/2]

    〈φTXY (X, β)〉 = 〈[βX,−β/2]〉

    Later is found using the property I (explain).

    Variational Bayes and Variational Message Passing – p.12/16

  • Back to Bayesian Networks

    Take each node, write the expression as a function of natural statistics of that node.

    log Q∗Y (Y )

    = 〈log P (Y |paY )〉∼QY (Y ) + ∑

    k∈chj

    〈log P (Xk|paj)〉∼QY (Y ) + c

    =

    〈φTY (θ) + ∑

    k∈chj

    φTXY (X, β)〉∼QY (Y )

    uY (Y ) + fY (Y ) + c1

    The compute the expectation of natural statistics of each children node, and use that to find the quantity in bracket.

    Variational Bayes and Variational Message Passing – p.13/16

  • Variational Message Passing

    Message from a parent node Y to a child node X:

    mY →X = 〈uY 〉 (20)

    Message from a child node X to a parent node Y:

    mX→Y = φ̃XY (〈uX〉, {mi→X}i∈cpY ) (21)

    Node Y update it’s posterior Q∗Y :

    φ∗Y = φ̃Y ({mi→Y }i∈paY ) + ∑

    j∈chY

    mj→Y (22)

    Variational Bayes and Variational Message Passing – p.14/16

  • Variational Message Passing

    Variational Bayes and Variational Message Passing – p.15/16

  • Discussion

    Initialization and message passing schedule.

    Calculation of Lower Bound

    Allowable Model

    VIBES

    Variational Bayes and Variational Message Passing – p.16/16

    Variational Inference Variational Inference Factorized Distributions Factorized Distributions Variational Bayes for Bayesian Networks Exponential-Conjugate Models Exponential-Conjugate Models Exponential-Conjugate Models Exponential-Conjugate Models Exponential-Conjugate Models Exponential-Conjugate Models Back to Bayesian Networks Variational Message Passing Variational Message Passing Discussion