ŷhat

Recognizing Handwritten Digits in Python

by yhat

Handwriting recognition is a classic machine learning problem with roots at least as far as the early 1900s. It's a fascinating problem and one that sits at the center of some magical product experiences--Evernote's Penultimate handwriting app for iPhone and the Apple Newton PDA from the 1990s to name just two.

This is a post about handwriting recognition and Python.


Problem & Applications

The principal task in handwriting recognition is to convert handwritten text into text that a computer can understand.

Text recognition has been employed regularly by the post office, of all organizations, since the 1960s for the purposes of classifying street addresses using Optical Character Recognition (OCR).

But handwriting recognition was an area of interest and research long before the sixties. Emanuel Goldberg (1881-1970) is an inventor remembered for his contributions to text recognition and information retrieval. He was working on this stuff back in the 1920s and 30s.

File: EmanuelGoldberg.jpg

Emanuel Goldberg (left)

"Goldberg...suggested in the late 1920s that statistical equipment could productively apply optical scanning technology for a form of "data entry."


- Michael Buckland
Emanuel Goldberg And His Knowledge Machine, 2006, p. 161

NOTE: Did you know that the USPS employs over 574,000 people and delivers approximately 700 million pieces of mail per day? Checkout these slides for some really interesting facts and figures.


Case Study

We'll explore the topic of handwriting recognition through a case-study using open source resources.

By the end of this tutorial, you'll know how to classify handwritten numbers and make predictions on the web using HTML5 and node.js.


Sourcing the Data

To build the classifier we need training data. For this example, we're actually going to collect the data ourselves. To collect the data I build a really small node.js app using Express. The app generates a random number between 0-9 and tells the user (me) to draw it in the canvas.

When the user finishes writing the number, he/she clicks Submit and the canvas get's sent to the server as an image.


On the backend, the server handles saving the incoming image to a directory called numbers. We name each image [unique_id]_[digit].png. This ensures we won't name 2 images the same thing, and it makes it super easy to tell what digit is in each image.


Ok, so now that you've got your data collection app, just draw about 1000 images so you have a large enough dataset for training! Or if you're lazy, you can download our dataset here.


Building the Classifier

We're going to steal most of the code from our previous post Content-based image classification in Python. As we did before we're going to load all of our images into a numpy array, using PIL to extract the pixel data from each image.


We shrink the images down from 500x500 to 50x50 to make them easier to work with (faster to process).

Then we run a RandomizedPCA on the data to reduce the dimensionality.

By reducing dimensionality, we mean that we're taking an image with 2500 cells (50x50) and aggregating it so that it becomes a vector of length 10.

Next we scale the data using the StandardScaler function in scikit-learn. This ensures, that each column has a mean of 0.

This makes it easier for our classifier to measure the distance between 2 points.


What are the PCA and StandardScaler Doing?

Before we go any further is might be useful to explain exactly what the PCA and StandardScaler are doing.

In the first plot, images have been transformed using a PCA with 2 components (for the model we use 10 but 2 is easier for visualizing). All 2500 pixel values have been "projected" into a 2D plot.

But looking at the values of the x and y axis, you can see that the magnitude of each varies wildly. Enter the StandardScaler.

In the second plot you see the exact same image, but the values of the x and y axis have been scaled tso that the means are 0. Scaling data is very important for many machine learning algorithms because it makes computing distance metrics much more consistent.


Training the Classifier

Now that all of our data is formatted, it's time to train and evaluate the classifier. We're going to use a K-Nearest Neighbors classifier. We can then us the confusion_matrix function to evaluate the classifier. Note that this is a much more difficult problem than the previous post. Instead of just 2 categories, we have 10 which makes the margin for error much lower.

Deploying to Yhat

Deploying to Yhat is pretty straightforward. Just extend the Yhat BaseModel class and copy any relevant code snippets into transform and predict methods.

Integration with our node.js app

Now let's integrate our classifier back into our webapp. The first thing you should do is install the Yhat node module using npm and then import Yhat into your app.

$ npm install yhat

Making Predictions

Next we're going to add a button called Predict and some accompanying AJAX that will send images to our server to make predictions. The resulting output will be written to an table.

On the server side we're going to handle the incoming request from within app.post. We're going to do nearly the same thing as we did earlier except that instead of saving the image to the file system, we're going to send it to our Yhat model to generate a prediction. We then take the resulting data and send it back to the client as JSON.

Play with the App! (Standalone App)

That's it, we're done! As of the publication of this post, we only had about 1000 training samples so the classifier is still struggling with some numbers. Hopefully as we add more data to the training set, we should see improved performance.

Other Resources

Interested in ŷhat? Learn More