Week 15 – infoNCE objective in CLIP

TL;DR

During the past two weeks, I have been following up on the hypothesis mentioned at the end of my last blog post.

  • Experimented on small datasets to validate that the infoNCE objective used in training CLIP converges f(x,y)=exp(score(x,y)) to k\frac{p(x,y)}{p(x)p(y)} rather than k'p(x,y), where k, k' are some constants. This suggests that using CLIP directly for classification, (i.e. approximating p(y|x) as \frac{f(x,y)}{\sum_{y' \in support(Y) f(x,y')}}), is indeed not principled. This is the main focus of the this blog post.

Experiments

I ran experiments on both synthetic and natural datasets to observe what f(x,y)=exp(score(x,y)) converges to under three different objectives, including infoNCE.

To start, let’s use a synthetic example to illustrate the problem. We contrive a dataset in which the joint distribution is simply the product of the marginals. That is, p(x,y) = p(x)p(y) for every x,y pair in the joint distribution between random variables X and Y. Here are three visualizations made from the data:

Joint distribution
p(x,y)

Product of marginals
p(x)p(y)

Joint distribution div. by product of marginals
\frac{p(x,y)}{p(x)p(y)}

We construct a support(X) by support(Y) parameter matrix M, each value is denoted as m(x,y). We define f(x,y)=exp(m(x,y)) and use SGD to train the parameters according to each of the following objectives. At convergence (batch size=500, iterations=between 1k and 2k), we visualize the f(x,y) table for each of the objective:

Negative log likelihood
Minimize\ \mathbb{E}_{x,y \sim D}\left [ -log \frac{f(x,y)}{\sum_{x' \in support(X), y' \in support(Y)} f(x',y')} \right ]

Negative pseudo log-likelihood
Minimize\ \mathbb{E}_{x,y \sim D}\left [ - \left ( log \frac{f(x,y)}{\sum_{y' \in support(Y)} f(x,y')} + log \frac{f(x,y)}{\sum_{x' \in support(X)} f(x',y)} \right ) \right ]

infoNCE used in CLIP
Minimize\ \mathbb{E}_{B \sim D}\mathbb{E}_{x,y \sim B} \left[ - \left ( log \frac{f(x,y)}{\sum_{\_,y' \in B} f(x,y')} + log \frac{f(x,y)}{\sum_{x',\_ \in B} f(x',y)} \right ) \right ]

In this contrived dataset, training with the negative log-likelihood objective or the negative pseudo log-likelihood both converges the f(x,y) visualization to the p(x,y) visualization rather than the \frac{p(x,y)}{p(x)p(y)} visualization. However, training with the infoNCE objective results in a noisy patch with no clear signal.

To quantitatively measure how well f(x,y) is approximating k'p(x,y), we measure the KL-divergence between the normalized f(x,y) table and the groundtruth p(x,y) from the data.

Objective
Minimize Negative Log Likelihood5.8801e-06
Minimize Negative Log Pseudo Likelihood3.9604e-06
Minimize Info NCE Loss0.0043
It is quite noticeable that f(x,y) approximates k'p(x,y) poorly.

Now, let’s modify the distribution such that p(x,y) \neq p(x)p(y).

From the data:

joint distribution
product of marginals
joint distribution divided by product of marginals

Visualization of f(x,y) tables at convergence:

negative log likelihood
negative pseudo log-likelihood
infoNCE used in CLIP

The visual observation tells us that training with the infoNCE objective, the f(x,y) actually converge to a pattern much closer to \frac{p(x,y)}{p(x)p(y)} rather than p(x,y).

Objective
Minimize Negative Log Likelihood9.8529e-06
Minimize Negative Log Pseudo Likelihood6.7194e-06
Minimize Info NCE Loss0.0037

Next, let’s look at an example that is more familiar to people who are interested probabilities. Consider the following procedure: we roll a 6-sided dice twice to collect two i.i.d. results, then collect an (X=sum of results, Y=max of results) pair. We repeat this procedure many times to fill up a count table indexed by the support of two random variables (x-axis spans from 1 to 6, y-axis spans from 2 to 12). We look at the same set of visualization and metrics again.

From the data:

joint distribution
product of marginals
joint distribution divided by product of marginals

Visualization of f(x,y) tables at convergence:

negative log likelihood
negative pseudo log-likelihood
infoNCE used in CLIP

Measure the difference between true distribution and normalized f(x,y) table:

Objective
Minimize Negative Log Likelihood5.5292e-07
Minimize Negative Log Pseudo Likelihood0.0025
Minimize Info NCE Loss0.0761

Finally, let’s look at a datasets found from the R datasets repository.

birthwt (188 datapoints) (x-axis: infant’s weight in 100g, y-axis: mother’s age in years)

From the data:

joint distribution
product of the marginals
joint distribution divided by product of marginals

Visualization of f(x,y) tables at convergence:

negative log likelihood
negative pseudo log-likelihood
infoNCE used in CLIP

Measure the difference between true distribution and normalized f(x,y) table:

Objective
Minimize Negative Log Likelihood5.6581e-07
Minimize Negative Log Pseudo Likelihood7.3230e-06
Minimize Info NCE Loss0.0010

These are just three of the set of small experiment results that I have. I am skipping the others in this blog post because their conclusions are all the same. Moving forward, I will be running experiments on datasets and architecture that are progressively more similar to CLIP. The end goal of these experiments, along with efforts to write the proof, is to demonstrate that there’s a principled method to improve CLIP’s classification performance, which is explained at the end of my last blog post.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s