ŷhat

Content-based image classification in Python

by yhat

Learn More Tweet

Image recognition is a field concerned with the identification of objects and entities within images. It's a sub-field of computer vision, a growing practice area broadly encompasing methods and strategies for analysing digital images via non-visual means.

While computer vision attracts attention from top tech firms (see Instagram's Unshredder challenge and this facebook job post), it's uses and applications are wide-ranging. Image classification is a classical image recognition problem in which the task is to assign labels to images based their content or metadata.

This is a post about image classification using Python.

This stuff is useful in the real-world

Image classification has uses in lots of verticals, not just social networks. As a simple case-study, let's suppose we're a payroll company processing thousands of employment records each day.

Our customers rely on us for several key functions related to HR, record keeping, payroll, and benefits. To provide these services, customers send us their employee records (name, address, DOB, etc.) along with official documents like copies of a gov't IDs, proof of residences, financials/direct deposit info, etc.

Customers provide this data to us by phone, email, the web, or even by fax. The first responsibility of our accounts team is to check these materials for completeness and accuracy.


Documents come in several file formats via several channels (e.g. email, website, etc.)

What's the simplest way to skin the cat?

As with most operational processes, there are very good, low-tech ways to organize and interpret these documents.

You could use a vendor specializing in outsourced document processing. Or, if you have a reliable way to protect sensitive information, you could use Amazon Mechanical Turk, though your compliance team would probably hate you.


Image classification is one of the most common tasks posted on Mechanical Turk

Many times, the manual solution is the right one. But what if you're really big and you can't find a suitable manual method? Below, we'll explore how to use Python and Scikit-Learn to help your team classify documents more efficiently.

Classifying Images with Python

Define the Scope

We're going to write a script to predict whether an image is a check or a drivers license. Then we'll publish the script in a manner suitable for use within your team's software application.

Build a Training Set of Images

Let's start with the quick 'n dirty for collecting data: Google Search.

A brief non sequitur...Google Image search has gotten incredibly cool. The feature to filter results on image content is likely powered by some computer vision algorithms, though I don't know to what extent that's true. Check out this search for Vladimir Putin.


Amusing Google Image search for "vladimir Putin" + badass

UPDATE: some details from Google Research on their blog here.

Search API

So, I searched for "check" and "drivers license" and got a sense for what's out there. Then I wrote a script to make image searching faster and repeatable.

I used the Bing Image API for this, as I've found it to be more generous than Google's w/r/t rate limiting (i.e. so I don't get blocked/banned by Google).

The program grabs the images on the first page of the search, and downloads them. I ran the script 5-6 times for each category with a different search query each time for simplicity and to avoid dealing with pagination.

Transforming the images

So now we have a bunch of images with labels (labels are implied by the search query I used to find them on Bing). Next we need to put these images into a format suitable for analysis.

One way to represent images in a numerical format is to convert each image to a series of RGB pixels. For example, if we had a 300x150 px image, we'd convert it to a series of 45,000 RGB pixels. You could also explore other numerical measures like hue and saturation, but let's keep it simple for now.

Normalize the data

We need to normalize the images so they're in the same shape. To do that, we resize each observation to 300x167 px. I choose 300x167 because most of the images were roughly that size or slightly larger to begin with.

Once resized to 300x167, we flatten the images using the flatten_image function, so each is represented as a row in a 2D array.

We can run this script on our N images and we wind up with an Nx45000 numpy array.

Creating features

45,000 features is a lot to deal with for many algorithms, so we need to reduce the number of dimensions somehow. For this we can use an unsupervised learning technique called Randomized PCA to derive a smaller number of features from the raw pixel data.

I won't go into great detail on what exactly Randomized PCA is doing, but the gist of it is we tell the function many "components" we want (i.e. how many features/columns) and it'll find correlations in our data and create that number of feature columns while preserving the original "look and feel" of the dataset.

As an example, let's transform the dataset into just 2 components which we can easily plot in 2 dimensions.


Reducing the data to 2 features maintains and allows us to visualize separation between the classes.

Training a Classifier

Since the goal of this exercise is to be able to classify new images without requiring visual inspection by a human, we're going to need to train a predictive model that can tag images automatically.

Based on our previous plot, it looks like there's a strong split b/w checks and drivers licenses, so we should be able to fit a good classifier. For this, I chose K-Nearest Neighbors, though there's more than one suitable learning method for this.

We calibrate Randomized PCA using a training set and then perform the same transformation on the test set. We'll then train a KNeighborsClassifier on the transformed PCA features.

Evaluating the model

When we make predictions on our holdout set, we find the the classifier performs well.


The holdout set is pretty small, but we're happy that only one observation was incorrectly labeled by our classifier.

Running your script in production

Let's pause and summarise what we've done. So far we've:

  • Programatically sourced a bunch of image data and corresponding labels
  • Resized all images to a common width and height
  • Converted the images from unstructured data into numerical features based on their RGB pixel makeup
  • Reduced the high dimensional data into 5 components
  • Fitted a KNN classifier to the reduced data

This all seems pretty cool, but how could this be useful to an operations team at our theoretical payroll company?

To perform these steps in a production software application would involve translating, porting, and adapting lots of code: image pre-processing, Randomized PCA, and implementing the nearest neigbhors algorithm.

You can eliminate the overhead involved with custom coding on that scale using Yhat.

Deploying with Yhat

First, if you haven't already you need to install the Yhat Python library and then import it.

$ pip install yhat

Define a sub-class of Yhat.BaseModel. This extends the BaseModel class, giving your code a few extra properties which allow you to deploy to Yhat. Define 2 mandatory methods and 1 optional method:

  1. require (this is optional)
  2. transform
  3. predict

Operations that need to be performed on the raw data (in our case, a raw image file) goes into the transform method.

The predict method takes the output from transform and executes classification and any post-processing you require. For example, you could have Yhat return a message or notes to be displayed to your payroll team.

You probably remember some intermediate variables / utilities we used earlier: pca, knn, and STANDARD_SIZE. We pass those to our class in order to give Yhat access to those during transform and predict steps.

Since we're using the exact same objects from earlier, results in our production environment will match 1 to 1 with results on our local machine.

Create an instance of your class and pass it to yhat.upload along with a name.

That's it! Our classifier is now deployed to Yhat and can be used to make predictions in real-time from your other software applications.

You can call it using the Yhat Python client like so:

Or make predictions from any other programming language via REST calls (check our our node-js client here or just npm install yhat).

Final Thoughts

For reference, here's the IPython Notebook for this post. This is just scratching the surface of image processing in Python. For more info check out these resources:

Interested in ŷhat? Learn More