Contrastive Representation Learning — A Comprehensive Guide (part 2, scoring representations)

Shairoz Sohail
9 min readDec 31, 2021

What’s in a name?

In the last section, we did several things:

  • Provided a definition and background for contrastive representation learning
  • Set up a dataset for a contrastive learning task
  • Modified a model’s architecture to produce embeddings instead of classifications
  • Wrote a trainer for the model using a variety of different contrastive loss functions
  • Visualized the resulting features in 2-dimensions using PCA

For this section we will take a deeper dive into what kind of feature spaces are being learned through contrastive learning and how to evaluate the quality of these features. We will do this by understanding the structure in the features (using distances, clustering, and correlation methods) and comparing that with what we expect given the semantic content of the images. Let’s get started!

What Makes a Good Feature Space?

We cannot reach a state we haven’t defined yet, so first let’s think about what we’re looking for when we create features from images. Here’s a few “positive” attributes of a feature space that we will state first and then elaborate on:

1) Semantically similar images are close together (i.e Distance(cat, kitten) ~ 0)

2) Semantically different images are far apart (i.e Distance(cat, ballon) >> 0)

3) Features are decorrelated (also known as whitened)

4) Semantic classes are linearly separable (i.e we can use a hyperplane to separate the embedded points for cat images and ballon images)

These are simply basic requirements and most “good” models will have feature spaces that follow each of these to some degree. There are also a set of more “advanced” properties that feature spaces can have that are generally deemed positive but usually only attainable through specific training protocols and architectures.

5) Smoothness (i.e a point between the embedding for cat and balloon might be the embedding for a cat holding a balloon)

6) Sampling ability (i.e we can sample a vector from the space near a known embedded point and decode it into a semantically similar image)

7) Disentangled / Interpretable (i.e each feature corresponds to a singular semantically sensible attribute such as color or shape)

There are a number of methods to handle each of these (Mixup, GANs, and VAEs, respectively, for example) but we won’t get into them too deepy. Instead, we explore why these properties are desirable.

(1,2) Ensures we can easily perform tasks like image retrieval, given we have a single image and would like to retrieve other images that have the same content

(3) Preserves memory, as there is no benefit to having large collinear feature vectors for either prediction or retrieval

(4) This allows the features to be plugged into downstream image classification and object detection tasks

(5) This allows for algebra in the feature space and shows the model is learning long-tailed data distributions fairly — this behavior will often also lead to improved generalization

(6) This allows us to use the model in a generative way to create new images with known content

(7) This allows for the possibility of user control over attributes when the model is generative (for example see Nvidia’s brilliant GauGan model) and in general allows for selective pruning of representations for certain tasks.

All of these attributes are interconnected obviously, and have one often leads to having one or more others to differing degrees. Measuring some of these attributes however is tricky, and this article will only cover a very small subset of methods to do so. More advanced methods can often be found under the research area of Feature Visualization.

Testing Semantic Distances

We will start by testing for properties (1) and (2). We can do this by loading one of our models, forward-propagating examples through it to create feature vectors, then calculating the distance between the feature vectors and categorizing by the label of the original example.

Below, we utilize the NT-Xent trained model from the previous section to generate a heatmap of distances between embeddings of MNIST images (to use one of the other models, simply substitute the weight path with your model of choice):

Here’s some heatmaps of the resulting distances using different models from the last section, the ideal case here is that the diagonal is 0 and everything off-diagonal is high (lighter colors off-diagonal the better with this coloring):

Another thing we can do is look at the distribution of normalized distances between dissimilar data points (i.e different digits).

non_similar_distances_normalized = np.tril(avg_dist_array, k=0) / np.max(avg_dist_array)sns.histplot(non_similar_distances_normalized.flatten())

Since these examples are dissimilar, the higher the normalized distances the better.

From these distance distributions it’s looking like NT-Xent and Contrastive Pair loss are doing the best, despite them looking much worse than Triplet Loss when plotted using PCA! Even more surprising is the relatively poor distance distribution of plain Cross Entropy loss with the labels.

For sheer class separation by distance, the ranking is:

NT-Xent > Contrastive Pairs > Triplet > Cross-entropy

Testing Feature Decorrelation

Remember when we ran PCA on our feature vectors in the last section and we got widely different variance explained percentages for each representation? One of the explanations for this might be collinearity — if X is the NxF matrix of features than Rank(X) < N or equivalently, one or more of the features can be written as a linear combination of the others. Even rough collinearity can be a problem, but we will get to it in due time.

We start with a basic correlation matrix, recall the correlation of two feature vectors X_i and X_j is defined by:

where E[X] denotes the expected value and Sigma denotes the standard deviation. Here are the correlation matrices between features of the embedding vectors for each of our contrastive loss functions (Cross Entropy is not shown because the embeddings are of a smaller dimension).

To push this one step forward, we’re going to also examine another method of correlation analysis: the Variance Inflation Factor. To properly understand this method, you need to first fully grasp the linear regression procedure. Once regression is solidly understood, we can proceed to understanding VIF. The key idea of the VIF is that since we have linear regression in our toolbox to understand how one variable is related to a set of others, we can also turn around and use this directly against the independent variables (the embedding features in our case) to understand their relationship. Thus, the Variance Inflation Factor describes how correlated a particular independent variable is with groups of other independent variables, unlike correlation which simply measures pairwise relationships. Normally, the VIF equation for feature i is presented as:

which is terribly unhelpful. I’m going to introduce some new notation that might make this a little more clear. Let

denote a linear regression performed with Y as the dependent variable and X1, X2 … XP the independent variables. Now we can let

be the correlation coefficient that results from that regression. Under this notation, we can re-write the VIF as follows:

Thus, the Variance Inflation Factor is 1 divided by 1 minus the R² of the linear regression model of the ith feature against all the other features. Is this an abuse of notation? Sure. Does this make the equation a little more clear? I think so.

Luckily for us, instead of having to run P linear regressions, there’s an easy function that calculates this for us. We will normalize the VIFs to be between 0 and 1 since the features generated by each of these methods can lie on different scales.

Here’s a graph of the normalized VIFs for NT-Xent, Triplet, and Contrastive Pairs loss models:

Testing Linear Separability

Separability is the last, and from the standpoint of the literature, one of the most important things we’re going to cover. Linear separability refers to the ability of a set of points in a vector space to be separated along a meaningful attribute by a hyperplane. Replace the word “points” with “embedding vectors” and “meaningful attribute” by “label” and you will arrive at what we need to do here. High linear separability is good because it allows a simple classifier such as a logistic regression to separate classes of interest. We can of course just force such a space to be generated by training our model with an output linear head, but we’re studying contrastive methods here and as we’ll see in future sections, assuming we can train directly against ground truth labels severely limits our options.

Assessing a representation model in this fashion is often called linear probing, called so because we can use it to “probe” different layers of a network to see how good linear separability is at each point. Generally, linear separability will increase with the depth of the model as lower layers are learning simple features such as edges and shapes while later layers are composing these into meaningful combinations.

linearly separable vs. non-linearly separable data (credit)

Since we will need to actually train a linear classifier here, we will first need to separate training and validation sets. We will pass our training data through our trained contrastive models again; except this time, we won’t back propagate and simply use the resulting embeddings to update our Logistic Regression.

Here is the accuracy of each of the different models we trained when assessed with this linear probe method:

Firstly, it’s pretty impressive outright that we can get these kinds of accuracies from training on just 3000 MNIST images (especially for those of you who remember the days of SIFT). Secondly, the fact that plain old cross-entropy does the best here is not surprising, it was trained specifically to develop features that linearly predict the class labels. The differences between the other methods are small enough that it’s hard to make any definite conclusions, but the high values indicate that even training to simply separate class embeddings can lead to feature spaces that are linearly separable.

Conclusions

So far, we’ve observed the following phenomena for our contrastive training protocols:

  • NT-Xent and Pairwise Contrastive losses lead to the highest degrees of separation between classes in the feature space, and all contrastive methods force larger distances than cross-entropy
  • Training with different contrastive losses lead to different amounts of correlation and VIFs within the features
  • All contrastive losses tested generate features spaces that show high degrees of linear separability, allowing for good class discrimination with a simple logistic regression on the features.

For a subset of people, all of this has surely seemed like an overly complicated and only mildly enlightening reformulation — why bother with contrastive losses when discriminative losses like cross-entropy can generate as good or better performance on the thing we care about: accuracy? Well in the next section the punchline of contrastive learning will be delivered: self-supervised learning! Stay tuned.

--

--

Shairoz Sohail

AI scientist and researcher developing methods for automated visual recognition/understanding and detecting patterns in geospatial phenomena.