Upload
others
View
5
Download
0
Embed Size (px)
Citation preview
Bias Correction in Learned Generative Models using Likelihood-free Importance Weighting
Aditya GroverStanford University
Joint work with Jiaming Song, Alekh Agarwal, Kenneth Tran, Ashish Kapoor, Eric Horvitz, Stefano Ermon
NeurIPS 2019
Transforming Science & Society
Hwang et al., 2018, Gómez-Bombarelli et al., 2016
What You See Is Not What You Always Get
Odena et al., 2016
Learning Generative ModelsGiven: Samples from a data distribution Goal: Choose a model family ! and approximate a data distribution as closely as possible
"#$%$
"& ∈ !
"&(("#$%$, "&)
+,~"#$%$, = /, 0,… , 2 min
"&∈!(("#$%$, "&)
Challenges
• How to define distance !(⋅) between distributions?• Model mismatch: %&'(' ∉ *• Optimization is imperfect• Finite datasets: empirical data distribution +%&'(' is far from true data
distribution %&'('
min%/∈*
!(%&'(', %/)
%&'(' ≠ %/Generative models are biased w.r.t. 3&'('!
Evaluating Generative Models is Hard• Density estimation
§ Not applicable for models with ill-defined/intractable likelihoods e.g., GANs, VAEs§ Not correlated with sample quality (Theis et al., 2015)
• Sample quality metrics e.g., Inception Scores (Salimans et al., 2016), FID (Heusel et al., 2017), KID (Binkowski et al., 2018) etc.
• Downstream task e.g., semi-supervised learning
Identifying bias in generative modeling
• Let !:# → ℝ be some real-valued function of interest• We assume ! is unknown during training of generative model• Evidence of bias: &'()*) ! + ≠ &'- ! +
Motivating Use Cases
• Model-based Off-Policy Evaluation§ How do we safely evaluate target policy given data from a different source
policy?§ Value estimates are an expectation w.r.t. the estimated generative dynamics
model and target policy
• Model-based Data Augmentation§ Classifier trained on a mixture of real and generated data§ Loss is augmented with an expectation w.r.t. generated data
• Fair and Sample efficient generation
!"#$%$ & ' ≠ !") & '
Bias Mitigation!"#$%$ & ' ≠ !") & '
How to correct for bias due to model mismatch?• Option 1: Train deeper models
§ Increases estimation error§ Does not correct for distributional assumptions
• Option 2: Non-zero bias ≡ Instance of covariate shift. Can we use importance weighting?§ Reweight samples ' ∼ ") by the density ratio
, - ≔ "#$%$ -") -
=> !") , - & ' = !"#$%$ & '§ We don’t know "#$%$§ For many generative models, even ") is not known (e.g., VAEs, GANs)
Importance Weighting via ClassificationImportance weights can be estimated via binary classification!Train a classifier to distinguish real (Y=1) and generated data (Y=0).
• For a Bayes optimum classifier !"∗ : %×' → [0, 1],./ 0 = 23454(0)
28(0)= 9:∗ (;<=|?)
9:∗ ;<@ ?)• Practical checklist
üCalibrationüValidation set
• Not the same as GAN training§ Post-hoc bias correction§ Do not throw the discriminator! New generative model is a function of both ./(classifier) + 28
Synthetic Example100 samples
1000 samples
Model-based Monte Carlo EvaluationGoal: Evaluate !"#$%$ & ' via "(
• Default Monte Carlo estimator!"#$%$ & ' ≈ *
+∑- &('-) where '- ∼ "(• Likelihood Free Importance Weighted (LFIW) estimator
!"#$%$ & ' ≈ *+∑- 1(23)&('-) where '- ∼ "(
• Relative variance in 1 can be high. Self-normalized LFIW!"#$%$ & ' ≈ ∑- 1(23)
∑4 1(24)&('-)
Improved Sample Quality Metrics
Inception Score (↑) FID (↓) KID(↓)Reference 11.09 ± 0.1263 5.20 ± 0.05 0.008 ± 0.0004
PixelCNN++DefaultLFIW
5.16 ± 0.01176.68 ± 0.0773
58.70 ± 0.050655.83 ± 0.9695
0.196 ± 0.00010.126 ± 0.0009
SNGANDefaultLFIW
8.33± 0.02808.57 ± 0.0325
20.40 ± 0.074717.29 ± 0.0698
0.094 ± 0.00020.073 ±0.0004
Standard error around the mean computed over 10 runs.
Dataset: CIFAR10!
"#$%&% ! '
Importance Resampling Distribution • Define the importance resampling distribution as
!",$ ∝ &$ ' !" '• Normalization constant (",$ = *'∼!"[&$(')]• Density estimation and sampling are intractable• Particle-based approximation
• Approximate induced distribution with finite samples from 01 2• Approximated via resampling methods. E.g., Rejection Sampling (Azadi et al., 2019),
MCMC (Turner et al., 2019)
Sampling Importance Resampling• Choose a finite sampling budget ! > 0• Draw for a batch of ! points $%, $', … , $) from *+ and estimate
importance weights , -.
• Define a categorical distribution / 0 ∝ , -.
• Sample 0~/(0) and return $5
0.6 0.1 0.7
0.42 0.08 0.5
/ 0 = 1 = 0.42 / 0 = 2 = 0.08 /(0 = 3) = 0.5
Are we guaranteed to do better? • No• When is !",$ a “better” fit than !"? • Better: Kl-divergence reduces %&'[!)*+*, !",$] ≤ %&'[!)*+*, !"]
• Necessary and sufficient condition: !",$ is a better fit than !" iff:./∼!)*+* log 4$(/) ≥ log 8",$
• Necessary conditions:./∼!)*+* log 4$(/) ≥ ./ ~!" log 4$(/)./∼!)*+* 4$(/) ≥ ./ ~!" 4$(/)
Downstream ApplicationsData AugmentationOff-Policy Policy Evaluation
Data Augmentation• Goal: Augment training dataset for multi-class classification• Dataset: Omniglot (1000+ classes, 20 examples/class)• Procedure
§ Train a conditional generative model on Omniglot§ Use generated data for training downstream classifier
Importance Weighted Data Augmentation
Class 1
Class 2
Class 3
Real (random order)
Generated (sorted)
Real (random order)
Generated (sorted)
Real (random order)
Generated (sorted)
Decreasing Importance Weights
Importance Weighted Data AugmentationDataset AccuracyReal data only 0.6603 ± 0.0012Generated data only 0.4431 ± 0.0054Generated data + LFIW 0.4481 ± 0.0056Real + generated data 0.6600 ± 0.0040Real + generated data + LFIW 0.6818 ± 0.0022
Standard error around the mean computed over 5 runs.
LFIW on the augmented data increases overall test accuracy!
Downstream ApplicationsData AugmentationOff-Policy Policy Evaluation
Off-Policy Policy Evaluation (OPE)
• Easy to obtain logged trajectory data !", $", %&, !', $', %(, …!&~+ !" [+ ⋅ is initial state distribution] $-~./ !0 [./(⋅) is behavioral policy]%- ∼ 4(!0, $0) [4(⋅) is rewards model]!05' ∼ 6 !0, $0 [6(⋅) is transition dynamics model]
Note: ./ ⋅ , 6(⋅) are unknown• Goal: Evaluate the value 789 of a target policy .:
Treatment 1
Treatment 2
Data: Yes J
Data: No L
Debiasing Model-based OPE• Model-based approach
• Estimate !(⋅) as !%(⋅)• Generate target trajectories via !%(⋅) and &'• Estimate ()* by Monte Carlo
• Train classifier to distinguish triplets from logged data and predictions• Debiasing distributions over trajectories +
Debiased Model-Based OPE
Environment HalfCheetah Swimmer HumanoidModel-based 37.7 63.7 5753Model-based w/ LFIW 23.9 11 4798
Mean absolute error. Lower is better.
Summary• Generative models are biased• Likelihood-free importance weighting is a simple technique for bias
mitigation that works well for many downstream applications• Bias is a necessary evil for generalization. Key is to be able to control it!
• Future, Ongoing Work: Fair Generative Modeling via Weak Supervision. Aditya Grover*, Kristy Choi*, Rui Shu, Stefano Ermon. https://arxiv.org/abs/1910.12008