Intuitive Classification using KNN and Python

by yhat

Learn More

K-nearest neighbors, or KNN, is a supervised learning algorithm for either classification or regression. It's super intuitive and has been applied to many types of problems.

It's great for many applications, with personalization tasks being among the most common. To make a personalized offer to one customer, you might employ KNN to find similar customers and base your offer on their purchase behaviors. KNN has also been applied to medical diagnosis and credit scoring.

This is a post about the K-nearest neighbors algorithm and Python.

What is K-nearest neighbors?

Conceptually, KNN is very simple. Given a dataset for which class labels are known, you want to predict the class of a new data point.

The strategy is to compare the new observation to those observations already labeled. The predicted class will be based on the known classes of the nearest k neighbors (i.e. based on the class labels of the other data points most similar to the one you're trying to predict).

An example

Imagine a bunch of widgets which belong to either the "Blue" or the "Red" class. Each widget has 2 variables associated with it: x and y. We can plot our widgets in 2D.

Now let's say I get another widget that also has x and y variables. How can we tell whether this widget is blue or red?
New observation represented by the black dot.

The KNN approach to classification calls for comparing this new point to the other nearby points. If we were using KNN with 3 neighbors, we'd grab the 3 nearest dots to our black dot and look at the colors. The nearest dots would then "vote", with the more predominant color being the color we'll assign to our new black dot.

Prediction with 5 Neighbors

If we had chosen 5 neighbors instead of 3 neighbors, things would have turned out differently. Looking at the plot below, we can see that the vote would tally blue: 3, red: 2. So we would classify our new dot as blue.

Choosing the right value of k

KNN requires us to specify a value for k. Logically, the next question is "how do we choose k?" It turns out that choosing the right number of neighbors matters. A lot.

But, choosing the right k is as challenging as it is important.

...choosing the best value for k is "not easy but laborious"

- Phillips and Lee, Mining Positive Associations of Urban Criminal Activities

Generally it's good to try an odd number for k to start out. This helps avoid situations where your classifier "ties" as a result of having the same number of votes for two different classes. This is particularly true if your dataset has only two classes (i.e. if k=4 and an observation has nearest neighbors ['blue', 'blue', 'red', 'red'], you've got a tie on your hands.

Too many chefs in the kitchen

The good news is that scikit-learn does a lot to help you find the best value for k.

Choose the right k example

Let's take a look at the Wine Data Set from the UCI Machine Learn Repo. Each record consists of some metadata about a particular wine, including the color of the wine (red/white). We're going to use density, sulphates, and residual_sugar to predict color.

Using scikit-learn, we can vary the parameter n_neighbors by just looping through a range of values, calculating the accuracy against a holdout set, and then plotting the results.

Looking at plot, you can see that classifier peaks in accuracy somewhere around 23 neighbors gradually deteriorate. Despite the fact that the classifier maxes out with K=23, this doesn't necessarily mean that you should select 23 neighbors.

Take a look at K=13. With 13 neighbors, the classifier performs almost just as good as it does with 23 neighbors. In addition to performing nearly as well, it's also interesting to note that K values 15-21 actually have worse performance than K=13. This indicates that the classifier is likely to be overfitting, or paying too much attention to the noise in the data. Due to both the comparable performance and the indication of possible overfitting, I would select 13 neighbors for this classifier.

Remember: Simpler is often better!

Calculating the Neighbors

Just like there are many options to use for K, you can also modify how you calculate the distance between points. The most natural way to think about calculating distance (and the method that I used in the red/blue dots example) is by weighting each of the K nearest points equally in the voting. This might be egalitarian but in many instances it doesn't really make sense. Why should the 13th closest point get an equal vote to the 1st nearest point.

Using scikit-learn you can change the way you calculate distance using the weights parameter. We have two weighting functions out of the box: uniform, which weights everything equally, and distance, which weights points by the inverse of their distance.

You can also create your own distance function. In the example below I created a distance measure that just takes the log of the measures. When I run it through the same script with 3 neighbors, I actually get the best performance using my custom function.

Other Levers

Of course there are other levers, or hyper-parameters, you can pull the will impact the results of your classifier. But above all, the most critical thing to determining whether or not your classifier will work is the features you put into it. Great features will compensate for lousy parameter tweaking but it's much more difficult to compensate for poor features by optimizing your classifier's parameters.

So just remember, this isn't magic. Garbage in, garbage out.

A Quick Example

One use case for K-Nearest Neighbors is in satellite and topographic imaging. Often times some of the pixels in an image are randomly distorted and you wind up with missing data. To fix this, you can train a KNN classifier on the image and then use the classifier to fill in the missing values. Unfortunately I don't have my own personal satellite, but what I do have is a smiley file:Smiley.svg.

Making a smiley

To make a smiley face, we're going to use numpy, pandas, matplotlib, and middle school math. Remember way back when, when your 7th grade teacher made you learn how how to plot circles and ellipses and you thought to yourself, When am I ever going to use this?

Well today is that day. Feel free to dust off your TI-89 and customize your smiley.


You can see the equation for a circle in our code below (see, Mrs. Lynch was right).

What does this have to do with k-nearest neighbors?

Suppose we have a smiley (or any image for that matter) that's missing some parts or is otherwise damaged. Can we use KNN to "fix" it?

To illustrate, we'll need to simulate a broken version of our smiley image by randomly removing half of the pixels.

Looks like a cross between a smiley and a scene from White Noise (ouch...only 9% on the Tomameter)

Fixing a broken image using KNN

Now for the fun part. We're going to use the remaining data to train a classifier. Again we're going to bust out KNeighborsClassifier from scikit-learn. We then classify the new points and recreate the image.

Looks pretty good right? This technique is part of a large study in statistics and machine learning called imputation (but that's a topic for another day).

Other Resources

Want to learn more? Check out these resources listed below.

Our Products

A Python IDE built for doing data science directly on your desktop.

Download it now!

Harness the power of distributed computing to run computationally intensive tasks on a cluster of servers.

Learn More

A platform for productionizing, scaling, and monitoring predictive models in production applications.

Learn More

Yhat (pronounced Y-hat) provides data science and decision management solutions that let data scientists create, deploy and integrate insights into any business application without IT or custom coding.

With Yhat, data scientists can use their preferred scientific tools (e.g. R and Python) to develop analytical projects in the cloud collaboratively and then deploy them as highly scalable real-time decision making APIs for use in customer- or employee-facing apps.