Conditional Variational Auto-encoder
This tutorial implements Learning Structured Output Representation using Deep Conditional Generative Models paper, which introduced Conditional Variational Auto-encoders in 2015, using Pyro PPL and PyTorch.
Introduction
“Never disdain to make a verification when opportunity offers.” Henri Poincaré.
Supervised deep learning has been successfully applied for many recognition problems in machine learning and computer vision. Although it can approximate a complex many-to-one function very well when large number of training data is provided, the lack of probabilistic inference of the current supervised deep learning methods makes it difficult to model a complex structured output representations. In this work, Kihyuk Sohn, Honglak Lee and Xinchen Yan develop a scalable deep conditional generative model for structured output variables using Gaussian latent variables. The model is trained efficiently in the framework of stochastic gradient variational Bayes, and allows a fast prediction using stochastic feed-forward inference. They called the model Conditional Variational Auto-encoder (CVAE).
The CVAE is a conditional directed graphical model whose input observations modulate the prior on Gaussian latent variables that generate the outputs. It is trained to maximize the conditional marginal log-likelihood. The authors formulate the variational learning objective of the CVAE in the framework of stochastic gradient variational Bayes (SGVB). In experiments, they demonstrate the effectiveness of the CVAE in comparison to the deterministic neural network counterparts in generating diverse but realistic output predictions using stochastic inference. Here, we will implement their proof of concept: an artificial experimental setting for structured output prediction using MNIST database.
The problem
Let’s divide each digit image into four quadrants, and take one, two, or three quadrant(s) as an input and the remaining quadrants as an output to be predicted. The image below shows the case where one quadrant is the input:
Our objective is to learn a model that can perform probabilistic inference and make diverse predictions from a single input. This is because we are not simply modeling a many-to-one function as in classification tasks, but we may need to model a mapping from single input to many possible outputs. One of the limitations of deterministic neural networks is that they generate only a single prediction. In the example above, the input shows a small part of a digit that might be a three or a five.
Preparing the data
We use the MNIST dataset; the first step is to prepare it. Depending on how many quadrants we will use as inputs, we will build the datasets and dataloaders, removing the unused pixels with -1:
Baseline: Deterministic Neural Network
Before we dive into the CVAE implementation, let’s code the baseline model. It is a straightforward implementation:
In the paper, the authors compare the baseline NN with the proposed CVAE by comparing the negative (Conditional) Log Likelihood (CLL), averaged by image in the validation set. Thanks to PyTorch, computing the CLL is equivalent to computing the Binary Cross Entropy Loss using as input a signal passed through a Sigmoid layer. The code below does a small adjustment to leverage this: it only computes the loss in the pixels not masked with -1:
The training is very straightforward. We use 500 neurons in each hidden layer, Adam optimizer with 1e-3
learning
rate, and early stopping. Please check the Github repo for
the full implementation.
Deep Conditional Generative Models for Structured Output Prediction
As illustrated in the image below, there are three types of variables in a deep conditional generative model (CGM): input variables $\bf x$, output variables $\bf y$, and latent variables $\bf z$. The conditional generative process of the model is given in (b) as follows: for given observation $\bf x$, $\bf z$ is drawn from the prior distribution $p_{\theta}({\bf z} | {\bf x})$, and the output $\bf y$ is generated from the distribution $p_{\theta}({\bf y} | {\bf x, z})$. Compared to the baseline NN (a), the latent variables $\bf z$ allow for modeling multiple modes in conditional distribution of output variables $\bf y$ given input $\bf x$, making the proposed CGM suitable for modeling one-to-many mapping.
Deep CGMs are trained to maximize the conditional marginal log-likelihood. Often the objective function is intractable, and we apply the SGVB framework to train the model. The empirical lower bound is written as:
\[\tilde{\mathcal{L}}_{\text{CVAE}}(x, y; \theta, \phi) = -KL(q_{\phi}(z | x, y) || p_{\theta}(z | x)) + \frac{1}{L}\sum_{l=1}^{L}\log p_{\theta}(y | x, z^{(l)})\]where $\bf z^{(l)}$ is a Gaussian latent variable, and $L$ is the number of samples (or particles in Pyro nomenclature). We call this model conditional variational auto-encoder (CVAE). The CVAE is composed of multiple MLPs, such as recognition network $q_{\phi}({\bf z} | \bf{x, y})$, (conditional) prior network $p_{\theta}(\bf{z} | \bf{x})$, and generation network $p_{\theta}(\bf{y} | \bf{x, z})$. In designing the network architecture, we build the network components of the CVAE on top of the baseline NN. Specifically, as shown in (d) above, not only the direct input $\bf x$, but also the initial guess $\hat{y}$ made by the NN are fed into the prior network.
Pyro makes it really easy to translate this architecture into code. The recognition network and the (conditional) prior network are encoders from the traditional VAE setting, while the generation network is the decoder:
Training
The training code can be found in the Github repo. Click play in the video below to watch how the CVAE learns throughout approximately 40 epochs.
As we can see, the model learned posterior distribution continuously improves as the training progresses: not only the loss goes down, but also we can see clearly how the predictions get better and better.
Additionally, here we can already observe the key advantage of CVAEs: the model learns to generate multiple predictions from a single input. In the first digit, the input is clearly a piece of a 7. The model learns it and keeps predicting clearer 7’s, but with different writing styles. In the second and third digits, the inputs are pieces of what could be either a 3 or a 5 (truth is 3), and what could be either a 4 or a 9 (truth is 4). During the first epochs, the CVAE predictions are blurred, and they get clearer as time passes, as expected.
However, different from the first digit, it’s hard to determine whether the truth is 3 and 4 for the second and third digits, respectively, by observing only one quarter of the digits as input. By the end of the training, the CVAE generates very clear and realistic predictions, but it doesn’t force either a 3 or a 5 for the second digit, and a 4 or a 9 for the third digit. Sometimes it predicts one option, and sometimes it predicts another.
Evaluating the results
For qualitative analysis, we visualize the generated output samples in the next figure. As we can see, the baseline NNs can only make a single deterministic prediction, and as a result the output looks blurry and doesn’t look realistic in many cases. In contrast, the samples generated by the CVAE models are more realistic and diverse in shape; sometimes they can even change their identity (digit labels), such as from 3 to 5 or from 4 to 9, and vice versa.
We also provide a quantitative evidence by estimating the marginal conditional log-likelihoods (CLLs) in next table (lower is better).
1 quadrant | 2 quadrants | 3 quadrants | |
---|---|---|---|
NN (baseline) | 100.4 | 61.9 | 25.4 |
CVAE (Monte Carlo) | 71.8 | 51.0 | 24.2 |
Performance gap | 28.6 | 10.9 | 1.2 |
We achieved similar results to the ones achieved by the authors in the paper. We trained only for 50 epochs with early stopping patience of 3 epochs; to improve the results, we could leave the algorithm training for longer. Nevertheless, we can observe the same effect shown in the paper: the estimated CLLs of the CVAE significantly outperforms the baseline NN.
See the full code on Github.
References
- Sohn, K., Lee, H., & Yan, X. (2015). Learning Structured Output Representation using Deep Conditional Generative Models. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, & R. Garnett (Eds.), Advances in Neural Information Processing Systems 28 (pp. 3483–3491). Curran Associates, Inc. http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models.pdf