In this article, we’ll be evaluating image classifiers by exploring StyleGAN2 latent space. In our case, we’ll look at the performance of the hairstyle classifier that we use in Ready Player Me.
StyleGAN2 is a generative model architecture demonstrating state-of-the-art image generation. Code of the model and pre-trained weights can be found, for example, in this repo, alongside generated examples across various domains.
First, what are we going to do? The idea is pretty simple. Having a hairstyle classifier, we’ll search through StyleGAN2 latent space to find directions in it corresponding to each of our classes.
It’s not the best way to test your model if you need some good-old metrics like accuracy, F1-score, and others. But, this way, you get to see what kind of pictures actually trigger your model to predict a certain class. So, it’s going to be more visual.
There are, of course, some limitations. The idea to evaluate your classifier by sampling pictures from a GAN model makes sense only if that model can produce pictures with your classes. That means that using StyleGAN pre-trained on human faces to evaluate a classifier that detects cats in images will not produce any adequate results.
A bit of intro to StyleGAN2
Let’s not go too deep into technical details, but here’s what you need to know. If you look at the StyleGAN2 architecture below, you would see that the noise input Z is first mapped to the style space W, which then is fed to the Generator.
Manipulating input to StyleGAN in the W+ space may be used to generate images from a certain domain. For example, it’s done in StyleCLIP:
CNN Classifier + StyleGAN
Let’s take our hairstyle classifier and consider one of its classes. Say, “curly bob with fringe”. The goal is to check whether the classifier is triggered by the right high-level features presented in images.
What we want to do is to make our StyleGAN generate pictures that will maximize the probability of the “curly bob with fringe” class predicted by our hairstyle model. And hopefully, we’re going to see images similar to the ground truth ones. Here’s a code sketch for that part:
We’re optimizing two losses here:
1. Classifier loss, which is just a cross-entropy between what the hairstyle classifier predicts and the label of the desired class;
2. Identity loss, to preserve the person’s identity while changing only their hairstyle. Sort of like a regularization. For that, we can use, for example, a pre-trained InsightFace.
Results
After running tuning for multiple classes, we can now check how well we managed to capture those style directions. For that, we’ll generate a sequence of images adding the learned direction vector to the latent vector W with some weight:
Combining these sequences of images into GIFs, we got something like this for the “curly bob with fringe” class:
Remember, we got these results simply by maximizing our classifier’s probabilities for a certain class. This shows quite visually what kind of high-level features trigger the classifier to predict a particular class.
Let’s see the same kind of visualizations for one more class, long straight hair with fringe:
These are both classes that our hairstyle model seems to have learned correctly. With face features not changing much during the interpolation in the W+ space, we see hairstyles in our examples change exactly how they are supposed to.
But of course, sometimes models overfit or learn to detect something else but the real target features. To illustrate one of those cases, let’s look at interpolation examples for a class of “long African braids”.
Here, we see that maximizing the hairstyle model’s confidence for this hairstyle class also causes changes in skin color. This may suggest, for example, that the dataset is unbalanced for this particular class. Racial disbalance in the dataset can easily cause such behavior and should be dealt with by collecting more images with a more diverse representation.
Conclusion
You could use this method whenever you’ve got a classifier, and want to understand what exactly is driving its decision-making. This is more about making sure that intended high-level features trigger your classifier, and it’s not overfitting to something else.