Thursday, September 23, 2010

MNIST for ever....

[update]This post is a bit old, but many people still seem interested. So just a short update: Nowadays I would use Python and scikit-learn to do this. Here is an example of how to do cross-validation for SVMs in scikit-learn. Scikit-learn even downloads MNIST for you. [/update]


MNIST
is, for better or worse, one of the standard benchmarks for machine learning and is also widely used in then neural networks community as a toy vision problem.

Just for the unlikely case that anyone is not familiar with it:
It is a dataset of handwritten digits, 0-9, in black on white background.
It looks something like this:



There are 60000 training and 10000 test images, each 28x28 gray scale.
There are roughly the same number of examples of each category in the test and training datasets.
I used it in some papers myself even though there are some reasons why it is a little weird.

Some not-so-obvious (or maybe they are) facts are:
- The images actually contain a 20x20 patch of digit and where padded to 28x28 for use in LeNet in Yann LeCun's classic 1998 paper on convolutional neural networks.
  • Even within the central 20x20 patch, there are some pixels of zero variance
  • The dataset as a whole has very little variance as can be seen from looking at the spectrum using a PCA
  • The dataset is very easy: random guessing is at 10% correct, a naive Bayes classifier scores about 90% correct and K nearest neighbor about 96.9 (I got that with K=3)
  • It is very easy to exploit the special structure of the dataset: a lot of variation is caused by screwed letters and scaling. A relatively recent paper using a blown up training set and an MLP scored excellent 99,6%.

A thread in the kernel-machines forum motivated me to try and reproduce some results listed on the MNIST webpage using support vector machines with rbf kernel.
I am relatively new to that area and I thought that this would be a nice thing to try, since on the website, no source is given for the given performance of 98.6%.

I tried pretty long and hard on this. Since the dataset is quite big, doing a grid search took quite a while. Also I did not remove the padding which could have sped up the process.

I used the excellent LIBSVM implementation with rbf kernel and the provided grid search tool. I distributed the work on 10 local processors which was quite easy.

I tried several scalings of the original dataset:
Scaling between 0 and 1, between -1 and 1 and normalizing to unit Euclidean length.
First I did a coarse grid search using 5 fold cross validation on a 5000 sample subset of the training set. I also tried 10000 samples and got similar results.
The grid search looked something like:


The scaling between -1 and 1 seemed to work best, so I did a finer grid-search using this scaling and arrived at

gamma =  0.00728932024638  and C = 2.82842712475
using

python grid.py -log2c 0.5,2.5,0.2 -log2g -7.5,-6.5,0.2 -v 5 -h

Using these parameters I trained on the whole training set and tested on the test-set to obtain....

98.56%

... which is close enough so that I can finally rest ...

[edit] By the way, the I used the default tolerance of the libsvm command line interface, which is 0.001. [/edit]

10 comments:

  1. Dear Andy,

    i am realy happy that you did all this work. But how did you seperate all ten classes? Did you use a multiclass svm? Or did you train ten different SVMs? Maybe you wrote it (implicit), but I'm not familiar with LIBSVM.

    Thanks in advance.
    Frerk Saxen

    ReplyDelete
  2. Hi Frerek.

    Actually I am not so sure what LibSVM does there. It definitely trains 10 one-vs-rest SVMs. But I think it does this somehow in a smart way... but maybe this is "just" some smart way of doing the optimizations. I guess you would have to look into the original paper on the LibSVM website.
    It definitely uses the same parameters for all 10 SVMs so you could probably get better by searching parameters for each SVM separately... But I was happy to reproduce something close to the result reported on the MNIST website.

    Cheers,
    Andy

    ReplyDelete
  3. Thank you Thank you Thank you Thank you!!!
    I was straggling to reproduce the 98.6 accuracy or at least something close, and now thanks to you I did it.
    Thank you, you're awesome!!!

    ReplyDelete
  4. Thanks a lot for this post! but as I'm new to Matlab, Libsvm and MNIST, I have a couple of questions: first, could you please show me a code that uses cross-validation to find the optimal value for parameter 'c'? and second, I have heard about an algorithm called 'whitening' the images. Do you know how it could be applied to MNIST?

    ReplyDelete
    Replies
    1. By now I would do it with python and scikit-learn.
      I might write a new version of this post some time soon.
      With scikit-learn, the complete code looks something like this:
      https://gist.github.com/2594372
      Be aware, though, this might take a while and will use all available CPUs.

      Delete
    2. Oh and about whitening: I doubt that would help with mnist. Though you can try with scikit-learn with PCA(whiten=True).

      Delete
  5. Sort of old thread here sorry, came across it while googling...

    I actually got worse results with whitening. I think it sort of introduced grey regions where there were prieviously none. As you noted, the dataset has very little variance - which is what whitening tried to maximize.

    I really don't get why this dataset is such a benchmark. Just because its small grey-scale images.. easy to use. *frown*.

    ReplyDelete