25
Bias Correction in Learned Generative Models using Likelihood-free Importance Weighting Aditya Grover Stanford University Joint work with Jiaming Song, Alekh Agarwal, Kenneth Tran, Ashish Kapoor, Eric Horvitz, Stefano Ermon NeurIPS 2019

Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

  • Upload
    others

  • View
    5

  • Download
    0

Embed Size (px)

Citation preview

Page 1: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Bias Correction in Learned Generative Models using Likelihood-free Importance Weighting

Aditya GroverStanford University

Joint work with Jiaming Song, Alekh Agarwal, Kenneth Tran, Ashish Kapoor, Eric Horvitz, Stefano Ermon

NeurIPS 2019

Page 2: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Transforming Science & Society

Hwang et al., 2018, Gómez-Bombarelli et al., 2016

Page 3: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

What You See Is Not What You Always Get

Odena et al., 2016

Page 4: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Learning Generative ModelsGiven: Samples from a data distribution Goal: Choose a model family ! and approximate a data distribution as closely as possible

"#$%$

"& ∈ !

"&(("#$%$, "&)

+,~"#$%$, = /, 0,… , 2 min

"&∈!(("#$%$, "&)

Page 5: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Challenges

• How to define distance !(⋅) between distributions?• Model mismatch: %&'(' ∉ *• Optimization is imperfect• Finite datasets: empirical data distribution +%&'(' is far from true data

distribution %&'('

min%/∈*

!(%&'(', %/)

%&'(' ≠ %/Generative models are biased w.r.t. 3&'('!

Page 6: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Evaluating Generative Models is Hard• Density estimation

§ Not applicable for models with ill-defined/intractable likelihoods e.g., GANs, VAEs§ Not correlated with sample quality (Theis et al., 2015)

• Sample quality metrics e.g., Inception Scores (Salimans et al., 2016), FID (Heusel et al., 2017), KID (Binkowski et al., 2018) etc.

• Downstream task e.g., semi-supervised learning

Page 7: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Identifying bias in generative modeling

• Let !:# → ℝ be some real-valued function of interest• We assume ! is unknown during training of generative model• Evidence of bias: &'()*) ! + ≠ &'- ! +

Page 8: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Motivating Use Cases

• Model-based Off-Policy Evaluation§ How do we safely evaluate target policy given data from a different source

policy?§ Value estimates are an expectation w.r.t. the estimated generative dynamics

model and target policy

• Model-based Data Augmentation§ Classifier trained on a mixture of real and generated data§ Loss is augmented with an expectation w.r.t. generated data

• Fair and Sample efficient generation

!"#$%$ & ' ≠ !") & '

Page 9: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Bias Mitigation!"#$%$ & ' ≠ !") & '

How to correct for bias due to model mismatch?• Option 1: Train deeper models

§ Increases estimation error§ Does not correct for distributional assumptions

• Option 2: Non-zero bias ≡ Instance of covariate shift. Can we use importance weighting?§ Reweight samples ' ∼ ") by the density ratio

, - ≔ "#$%$ -") -

=> !") , - & ' = !"#$%$ & '§ We don’t know "#$%$§ For many generative models, even ") is not known (e.g., VAEs, GANs)

Page 10: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Importance Weighting via ClassificationImportance weights can be estimated via binary classification!Train a classifier to distinguish real (Y=1) and generated data (Y=0).

• For a Bayes optimum classifier !"∗ : %×' → [0, 1],./ 0 = 23454(0)

28(0)= 9:∗ (;<=|?)

9:∗ ;<@ ?)• Practical checklist

üCalibrationüValidation set

• Not the same as GAN training§ Post-hoc bias correction§ Do not throw the discriminator! New generative model is a function of both ./(classifier) + 28

Page 11: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Synthetic Example100 samples

1000 samples

Page 12: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Model-based Monte Carlo EvaluationGoal: Evaluate !"#$%$ & ' via "(

• Default Monte Carlo estimator!"#$%$ & ' ≈ *

+∑- &('-) where '- ∼ "(• Likelihood Free Importance Weighted (LFIW) estimator

!"#$%$ & ' ≈ *+∑- 1(23)&('-) where '- ∼ "(

• Relative variance in 1 can be high. Self-normalized LFIW!"#$%$ & ' ≈ ∑- 1(23)

∑4 1(24)&('-)

Page 13: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Improved Sample Quality Metrics

Inception Score (↑) FID (↓) KID(↓)Reference 11.09 ± 0.1263 5.20 ± 0.05 0.008 ± 0.0004

PixelCNN++DefaultLFIW

5.16 ± 0.01176.68 ± 0.0773

58.70 ± 0.050655.83 ± 0.9695

0.196 ± 0.00010.126 ± 0.0009

SNGANDefaultLFIW

8.33± 0.02808.57 ± 0.0325

20.40 ± 0.074717.29 ± 0.0698

0.094 ± 0.00020.073 ±0.0004

Standard error around the mean computed over 10 runs.

Dataset: CIFAR10!

"#$%&% ! '

Page 14: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Importance Resampling Distribution • Define the importance resampling distribution as

!",$ ∝ &$ ' !" '• Normalization constant (",$ = *'∼!"[&$(')]• Density estimation and sampling are intractable• Particle-based approximation

• Approximate induced distribution with finite samples from 01 2• Approximated via resampling methods. E.g., Rejection Sampling (Azadi et al., 2019),

MCMC (Turner et al., 2019)

Page 15: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Sampling Importance Resampling• Choose a finite sampling budget ! > 0• Draw for a batch of ! points $%, $', … , $) from *+ and estimate

importance weights , -.

• Define a categorical distribution / 0 ∝ , -.

• Sample 0~/(0) and return $5

0.6 0.1 0.7

0.42 0.08 0.5

/ 0 = 1 = 0.42 / 0 = 2 = 0.08 /(0 = 3) = 0.5

Page 16: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Are we guaranteed to do better? • No• When is !",$ a “better” fit than !"? • Better: Kl-divergence reduces %&'[!)*+*, !",$] ≤ %&'[!)*+*, !"]

• Necessary and sufficient condition: !",$ is a better fit than !" iff:./∼!)*+* log 4$(/) ≥ log 8",$

• Necessary conditions:./∼!)*+* log 4$(/) ≥ ./ ~!" log 4$(/)./∼!)*+* 4$(/) ≥ ./ ~!" 4$(/)

Page 17: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Downstream ApplicationsData AugmentationOff-Policy Policy Evaluation

Page 18: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Data Augmentation• Goal: Augment training dataset for multi-class classification• Dataset: Omniglot (1000+ classes, 20 examples/class)• Procedure

§ Train a conditional generative model on Omniglot§ Use generated data for training downstream classifier

Page 19: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Importance Weighted Data Augmentation

Class 1

Class 2

Class 3

Real (random order)

Generated (sorted)

Real (random order)

Generated (sorted)

Real (random order)

Generated (sorted)

Decreasing Importance Weights

Page 20: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Importance Weighted Data AugmentationDataset AccuracyReal data only 0.6603 ± 0.0012Generated data only 0.4431 ± 0.0054Generated data + LFIW 0.4481 ± 0.0056Real + generated data 0.6600 ± 0.0040Real + generated data + LFIW 0.6818 ± 0.0022

Standard error around the mean computed over 5 runs.

LFIW on the augmented data increases overall test accuracy!

Page 21: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Downstream ApplicationsData AugmentationOff-Policy Policy Evaluation

Page 22: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Off-Policy Policy Evaluation (OPE)

• Easy to obtain logged trajectory data !", $", %&, !', $', %(, …!&~+ !" [+ ⋅ is initial state distribution] $-~./ !0 [./(⋅) is behavioral policy]%- ∼ 4(!0, $0) [4(⋅) is rewards model]!05' ∼ 6 !0, $0 [6(⋅) is transition dynamics model]

Note: ./ ⋅ , 6(⋅) are unknown• Goal: Evaluate the value 789 of a target policy .:

Treatment 1

Treatment 2

Data: Yes J

Data: No L

Page 23: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Debiasing Model-based OPE• Model-based approach

• Estimate !(⋅) as !%(⋅)• Generate target trajectories via !%(⋅) and &'• Estimate ()* by Monte Carlo

• Train classifier to distinguish triplets from logged data and predictions• Debiasing distributions over trajectories +

Page 24: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Debiased Model-Based OPE

Environment HalfCheetah Swimmer HumanoidModel-based 37.7 63.7 5753Model-based w/ LFIW 23.9 11 4798

Mean absolute error. Lower is better.

Page 25: Bias Correction in Learned Generative ... - Aditya Groveraditya-grover.github.io/files/posters/neurips19.pdf · Aditya Grover StanfordUniversity Joint work with Jiaming Song, AlekhAgarwal,

Summary• Generative models are biased• Likelihood-free importance weighting is a simple technique for bias

mitigation that works well for many downstream applications• Bias is a necessary evil for generalization. Key is to be able to control it!

• Future, Ongoing Work: Fair Generative Modeling via Weak Supervision. Aditya Grover*, Kristy Choi*, Rui Shu, Stefano Ermon. https://arxiv.org/abs/1910.12008