Why are generative models important and useful?
Within the Machine Learning community, generative models have become very popular over the last few years. These are generally based on Generative Adversarial Networks (Goodfellow et al. 2014), Variational Auto-encoders (Kingma & Welling, 2014), and more recently density models based on invertible flows, e.g. (Dinh et al, 2016) or autoregressive models, e.g. (Oord et al, 2016). These models all attempt the difficult task of estimating the distribution of high-dimensional objects such as images (milions of dimensions for each pixel and RGB channel), audio (each timestamp is a dimension), text (each letter is a dimension), genetic data (each nucleotide is a dimension). By estimating the distribution, we mean either the ability to generate samples (an image, a corpus of text) or to estimate the density of a given image or object – i.e. $p(image | model)$.
In this blog post I will focus on the question in the title: “Why are generative models important?” And why do we need them? For standard ML problems such as classification or regression, which reduce a complex image to a label/number, the answer is obvious: we deploy them in a robot which makes decisions based on these predictions, or provide the answer to a person which makes a decision based on the prediction (e.g. doctor decides to give treatment X to a patient). But why would we like to go from low-dimensional objects (cat label) to high-dimensional objects (image of a specific cat), which are harder to work with? We clarify some of these reasons below.
Communication with users
Generated images, text or audio are very useful for communication with users (Fig 1). For example, the movie industry is very interested to generate potential movie scenes from a given text script, or maybe only parts of the movie such as the special effects. AI-generated text can be used to tell a story, or for educational purposes. In terms of audio, AI speech assistants (OK Google, Alexa) need to generate high-dimensional audio to communicate information to the users.
*Fig 1. Animations (top) and text (bottom) are popular means to communicate a story with an audience. Top: Lion King movie excerpt, Bottom: Text generated by OpenAI *
Communication with users is only useful for stimuli that we can easily percieve (images, text, sound, etc …). For example, one could also generate high-dimensional fake genomic sequences, but these are less useful since it’s hard for us to interpret or visualise them.
Shared cooperation between ML system and experts
Generative models can be used by experts as an assistive tool. In medicine, an ML system based on generative modelling can predict the future disease evolution, while the medical professional uses the prediction to recommend a treatment (Fig 2). This is in contrast to a purely end-to-end ML approach that would take the image as input and recommend the treatment directly as output. This hybrid approach might be neccessary instead of an end-to-end approach due to (i) inability of ML systems to properly solve all sub-tasks, (ii) regulatory approval for remaining components is not yet granted, (iii) data might not be available for those specific subtasks.
Fig 2. Shared cooperation between AI and expert on prognosis and treatment recommendation. ML system could only perform the prognosis (prediction of future disease evolution), while the doctor can recommend the treatment. The ML system might be unable to perform the treatment recommendation because the task might be too complex (treatment decision also needs to be explained and communicated to the user), regulatory approval is needed, or training data is not be available.
Fig 3. Demonstration of a model that predicts the evolution of Alzheimer’s disease. A medical professional can make more informed decisions based on these predictions. From daniravi.wixsite.com and Ravi et al, 2019
Interpretability
One can use generative models to understand what concept has been learned by an individual neuron – see Fig. 4 (Bau et al, 2017). By activating that particular neuron or a set of neurons, the generator creates samples from that concept (Fig 5).
Fig 4. Images are used to show what class of concepts each neuron has learned. (Bau et al, 2017)
Fig 5. GAN-based generator used for painting and image manipulation – see gandissect.csail.mit.edu (Bau et al, 2017). It works by activating a set of neurons that are related to categories of interest: trees, grass, door, etc …
Debugging
Generating images can help interpret what a trained classifier learned, and thus debug potential errors. After classifier training is complete, one can use a generative model to create hard-to-classify images (corner-cases), and evaluate how the classifier works on these images. As exemplified in Fig 6, one can go a step further and generate a range of input images.
Fig 6. Given a pre-trained classifier f(x) to identify whether a person is smiling and an initial query image x (left-most), generative models can modify the query image x to change the classification score f(x) from not-smiling (f(x)=0.0) to smiling (f(x)=1.0) without changing the identity of the subject. If the classifier learned a spurious correlation due to dataset bias (e.g. all smiling images also had white background), the generator will also learn to modify the background color, which would be noticeable in the generated images. Data augmentation or disentangled representations can help fix such an error. Image from (Singla et al, 2020)
Communication with other ML systems - The image is the “interface”
Images, audio or other high-dimensional objects provide a good interface for machine learning systems to communicate with each other. This is especially true with incompatible systems that have different underlying representations.
Imagine a transfer learning scenario where we’d like to transfer knowledge from a deep learning (DL) model to a kernel machine (e.g. Support Vector Machine). Suppose that the DL model learned to classify tree images (input) to specific tree species (categorical). Now suppose we also trained a kernel machine on the same task, but due to it’s different inductive bias, it commonly mislabels Ficus microcarpa (curtain fig - native to India, Southern China and Australia) as F. retusa (native to Malayesia). To be fair, we cannot blame the kernel machine, as people often misidentify the two species – it is mostly the length of the leaf blade that can tell them apart.
Nevertheless, how can the DL model teach the kernel machine that to tell apart the two tree species? Since the internal representations of the models are all completely different, it might be difficult (if at all possible) to transfer the learned weights from the DL model to the parameters of the kernel machine. However, one can build a generator to generate many possible images of the two tree species, which are then passed through the DL model to generate (input:tree, output:$f_{DL}(trees) = species$) pairs of images with labels predicted by the DL model. We can then use these newly generated images to fine-tune the kernel machine. If the kernel machine has a good inductive bias, this form of data augmentation will enable learning of the new concept. In this case, the image is the interface between the two systems.
Generating models can be user for other forms of data augmentation. Once can induce invariances not already present in the model (e.g. rotations, scaling), or to augment corner-cases, such as accidents in self-driving car datasets (Fig. 7).
Fig 7. Self-driving car datasets suffer from a long-tail of corner-cases: accidents are extremely rare. Generative models can be used to augment the dataset with rarely-seen events, such as accidents in rainy weather. Image from Carla simulator, Dosovitskiy et al. 2017
Simulation
Generative models can also be used as a world simulator or to generate parts of it. For example, generative models can generate “fake” humans in a 3D environment that could interact with a robot. Moreover, for each human, one could specify the desired age, gender and other desired demographic attributes. If a robot operates in the simulated world, the generative models can even be used to generate the action responses of the humans after the robot takes certain actions – i.e. the updated world. As opposed to more structured worlds (e.g. room with chair, table or other objects), I believe current models are particularly suitable for modelling organics (humans, animals, plants, brain anatomy) which have complex structures and variations that cannot be easily constructed by composition of simpler elements.
Fig 8. Conditional generative models can be used as a simulator of the world. The model can generate the updated virtual environment or parts of it (e.g. trees) based on the actions of the driver (left-right turns, pedals, etc …). Inspired by Jahanian et al, 2019
Perform Bayesian Posterior Optimisation
Some generative models based on invertible flow (Dinh et al, 2016) or autoregressive models (Oord et al, 2016) are able not only to sample images, but also to estimate densities: given an input image, they can estimate the probability density function $f_{\theta}(image)$ in order to tell how likely it is that this image is under model family $f$ with parameters $\theta$. For convolutional neural networks, $\theta$ can represent the set of weight parameters of each convolutional layer. This allows us to use the generative model as a prior over the space of possible/realistic images. More precisely, we set the prior to be equal to the model density under a pre-optimised parameter set $\theta$:
\[p(color\_img) = f_{\theta}(image)\]Now imagine we’re interested to perform an image colorization task (Fig 9) by estimating p(color_img | gray_img), the distribution of all possible coloured images from a given grayscale image. We apply Bayes’ rule as follows:
\[p(color\_img|gray\_img) = p(gray\_img|color\_img) p(color\_img)\]The first term, called the likelihood term, expresses the forward model and is very easy to compute: we take the colour image, convert it to grayscale, then compute a distance (e.g. L2) between the output and our input gray_img. The second term is the prior term, and is most often assumed to be uniform. However, a uniform prior term would allow unrealistic combinations of colours to be given to the image (e.g. Fig 9, right-most column). For example, a particular bird species could only appear with certain combinations of colors (Fig 9, middle). This can be fixed by using a generative density model to constrain the inputs to only realistic combinations of colors.
Fig 9. A single grayscale image can have multiple colorizations – all 9 colorizations correspond exactly to the grayscale image – but only a subset of these colorizations are realistic (6 images to the left). This might be because, in the wild, the bird never appears with brown or green neck, or with blue forehead (3 images to the right). The generative model trained on a dataset of birds can be used to estimate the prior term $p(color$_$img)$, which constrains the model to only plausible color combinations. Image adapted from Ardizonne et al, 2019
Anomaly detection
Here, we would like to detect if a new sample is out-of-distribution. For example, we have a dataset of healthy brain images, and we would like to detect if a new image is also healthy or not. Since there are a variety of brain pathologies we might not know a-priori, we cannot use a discriminative model which would classify between healthy vs disease X or disease Y. Therefore, we build a density model and then we can set a threshold on the density estimate – if the new sample has density below the threshold, it is considered abnormal and flagged for a doctor to check it.
Sharing the generative model instead of sharing the dataset
There are certain instances where the dataset cannot be shared, but a generative model would be able to be shared, or a dataset of fake samples generated with the model. This can occur due to privacy issues (e.g. medical data) or trade secrets.
Dataset size can also be a consideration – if the size of the generative model is smaller than the size of the dataset, one might prefer to transfer the smaller generative model over a network – in this case, a form of compression has been achieved through the generative model.
Comments powered by Talkyard.