How I increased the accuracy of MNIST prediction from 84% to 99.41%
I had never done image predictions before and all I knew was categorical or numerical predictions using classification algorithms. So, when I started exploring the kaggle MNIST data I found that images are nothing but numerical data in a slightly different format.
How can images be represented as numeric data?
When i loaded the data I found that the data had 784 pixel columns (each data point was a 28 x 28 image). All the values ranged between 0(black) and 255(white), which actually meant it was a gray scale image. From this I understood that images are no different from numeric data.
My prediction journey
- As usual I just started with one of the classification algorithms. I knew that Support Vector Machines would perform well on complex data and on a multi-class classification problem. So, I just normalized the data and fed it into a SVM model. The accuracy I got was 84%. I thought this was a pretty good accuracy. But when I saw the kaggle results for the MNIST data, so many people had scored 100% accuracy.
- So, I wanted to try out Neural Networks next. I used keras to create the neural network model as below. I was able to see that my accuracy went up to 97.4%. I was actually quite shocked to see the increase in the accuracy by more than 13%. I realized why neural networks performs so much better for image classification problems.
nn_model=Sequential()
nn_model.add(Dense(45,input_dim=784,activation='relu'))
nn_model.add(Dropout(0.3))
nn_model.add(Dense(35,activation='relu'))
nn_model.add(Dense(23,activation='relu'))
nn_model.add(Dense(10,activation='softmax'))
3. I was still at 2700+ position in the kaggle leader board. I wanted to improve my accuracy even more. So, I started using convolution neural networks to train my model. I didn’t know what it meant and how to use it. So, it was a little bit of a challenge to understand how a CNN works, because only after I learn that I can try to create the convolution layers. The biggest difference between CNN and neural network is that CNN uses convolutions to extract important features in the first place and then feeds those into a neural network, thus it helps increase the accuracy. The code structure looks like
model=Sequential()
model.add(Conv2D(28,(5,5),padding='same',input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(Conv2D(28,(5,5)))
model.add(Activation('relu'))
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Dropout(0.25))model.add(Conv2D(32,(5,5),padding='same',input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(Conv2D(32,(5,5)))
model.add(Activation('relu'))
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Dropout(0.25))model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.25))
model.add(Dense(10))
model.add(Activation('softmax'))
Now I achieved 99.3% accuracy and jumped to position 1200+ in the kaggle leader board.
4. I didn’t do any hyper parameter tuning like changing the learning rate or anything. So, I wanted to give that a try as well. So, I tried different optimizers- adam and sgd. I found sgd performed better. Then tried tuning the learning rate. I found that 0.01 as the learning rate gave the best results. Then I achieved 99.41% accuracy and jumped to position 710 in the leader board.
Conclusion
I know that with hyper parameter tuning I can achieve a even higher accuracy. But for people who think image classification is very difficult, I want you to know it is just like any other classification problem, except for a little bit of extra learning related to convolutions. I would also recommend you to look into my code at Github where I have the code for MNIST predictions for different ML algorithms. You can try the Fashion MNIST data for practice as well.
Note: I used Google colab to write and execute my code using GPU. It is so much more faster and google allows us to use it for free.
Thanks for reading. Hope you gain confidence in creating an image classification model using neural networks! Leave your comments below and also let me know if you have any questions. Contact me via LinkedIn.