Kelvin Guu Senior Staff Research Scientist, Manager at Google DeepMind

Why naive cross-validation fails at feature selection

This past year, I took two courses taught by Profs. Emmanuel Candes and Trevor Hastie. Both of them described a surprising but simple problem you can run into when doing feature selection. It gave me another perspective on why the folks at Stanford Statistics have been so interested in regularization over the past decade. But in this post, I’ll just focus on the problem.

Selecting features via cross-validation

Let’s say you have 25 features for your prediction algorithm to use, and you’re trying to find the subset of features that produces the highest prediction accuracy (using all 25 features is not necessarily the right answer, due to potential overfitting).

In total, there are 2251=33,554,4312^{25}-1=33,554,431 different subsets (excluding the set with 0 features).

I’ll refer to each one as a model. If your algorithm runs quickly, you might actually be able to test each of the 33,554,43133,554,431 models using cross-validation and pick the best one. This could be a serious mistake.

Bad models way outnumber good models

Suppose that in reality, only 2 features offer any useful signal and the other 23 are useless noise.

Then, there are

2231=8,388,607 2^{23}-1=8,388,607

bad models which do not include either of the 2 useful features. I’ll casually refer to these as “junk models”.

In contrast, there are just 221=32^{2}-1=3 models that use exclusively good features. It’s fair to assume the the algorithm can tolerate one or two extra bad features. In that case, there are still only

3×k=02(23k)=831 3\times\sum_{k=0}^{2}{23 \choose k}=831

good models.

The junk models outnumber the good models by a factor of 10,000.

Quantity beats quality (unfortunately)

Because there are so many more junk models, it is very likely that one of the junk models will perform better than all of the good models during cross-validation.

Here’s an example:

Suppose that the label you’re trying to predict is binary (either 0 or 1), and your test set for cross-validation contains 50 examples.

  • Consider the 831 good models:
    • Let’s assume the best one gets 40 out of 50 test examples correct. We just need some plausible number to work with for this example.
  • Consider the 8,388,607 junk models:
    • We’ll make the simplifying assumption that all their predictions are basically just independent random coin flips, i.e. a 50-50 chance of predicting 0 or 1.
      • (This assumption isn’t completely right, see Notes at the bottom.)
      • (If an algorithm is doing worse than 50-50 guessing on average, you can just reverse its predictions to make it do better. So, this is as bad as it gets.)
    • We can now model the number of answers a junk model gets correct as a binomial random variable BB with parameters n=50n=50 and p=0.5p=0.5.

The probability that a single junk model guesses more than 40 correct is very low:

P(B>40)=0.0000028 P\left(B>40\right)=0.0000028

But the probability that at least one of them guesses more than 40 correct is very high:

1P(B40)83886070.9999999999406324 1-P\left(B\leq40\right)^{8388607}\approx0.9999999999406324

It’s almost guaranteed that this will happen! The resulting model you choose based on cross-validation will only have junk features and none of the good features.

(Stats readers will notice that this is the same problem seen in multiple hypothesis testing.)

How to fix this?

There are a few solutions, some of them very straightforward:

  1. Decrease the total number of features you are selecting from. In this example, we had to search through 25 features to find the 2 good ones. As you add more junk features to your list of candidates, it becomes increasingly hard to find the good ones!
  2. Increase the size of your test set. This makes it harder for the junk models to “brute force” a good solution.
  3. Add a penalty term to larger models, to compensate for the fact that there are more of them. I describe some well-known penalties such as AIC and BIC here.
  4. Don’t consider all possible subsets, even if you have the computation power to do so. Ironically, a greedy step-wise feature selection strategy would have done better than considering all subsets in the above example. Another way to avoid considering all possible subsets is to use regularization!

The last answer leads to lots of interesting ideas being developed at Stanford Statistics. Hopefully I’ll get a chance to cover those in the future!

Notes

In practice, it isn’t always fair to assume that the junk models are guessing the labels independently at random. For example, the model that uses features 1,2,3 will probably give similar predictions as the model that uses features 2,3,4. Exactly how much correlation exists depends on the algorithm.