Note: This article was originally published on Oct 10, 2014 and updated on Mar 27th, 2018
- Understand k nearest neighbor (KNN) – one of the most popular machine learning algorithms
- Learn the working of kNN in python
- Choose the right value of k in simple terms
In the four years of my data science career, I have built more than 80% classification models and just 15-20% regression models. These ratios can be more or less generalized throughout the industry. The reason behind this bias towards classification models is that most analytical problems involve making a decision.
For instance, will a customer attrite or not, should we target customer X for digital campaigns, whether customer has a high potential or not etc. These analysis are more insightful and directly linked to an implementation roadmap.
In this article, we will talk about another widely used machine learning classification technique called K-nearest neighbors (KNN) . Our focus will be primarily on how does the algorithm work and how does the input parameter affect the output/prediction.
Table of Contents
- When do we use KNN algorithm?
- How does the KNN algorithm work?
- How do we choose the factor K?
- Breaking it Down – Pseudo Code of KNN
- Implementation in Python from scratch
- Comparing our model with scikit-learn
When do we use KNN algorithm?
KNN can be used for both classification and regression predictive problems. However, it is more widely used in classification problems in the industry. To evaluate any technique we generally look at 3 important aspects:
1. Ease to interpret output
2. Calculation time
3. Predictive Power
Let us take a few examples to place KNN in the scale :
How does the KNN algorithm work?
Let’s take a simple case to understand this algorithm. Following is a spread of red circles (RC) and green squares (GS) :
You intend to find out the class of the blue star (BS) . BS can either be RC or GS and nothing else. The “K” is KNN algorithm is the nearest neighbors we wish to take vote from. Let’s say K = 3. Hence, we will now make a circle with BS as center just as big as to enclose only three datapoints on the plane. Refer to following diagram for more details:
The three closest points to BS is all RC. Hence, with good confidence level we can say that the BS should belong to the class RC. Here, the choice became very obvious as all three votes from the closest neighbor went to RC. The choice of the parameter K is very crucial in this algorithm. Next we will understand what are the factors to be considered to conclude the best K.
How do we choose the factor K?
First let us try to understand what exactly does K influence in the algorithm. If we see the last example, given that all the 6 training observation remain constant, with a given K value we can make boundaries of each class. These boundaries will segregate RC from GS. The same way, let’s try to see the effect of value “K” on the class boundaries. Following are the different boundaries separating the two classes with different values of K.
If you watch carefully, you can see that the boundary becomes smoother with increasing value of K. With K increasing to infinity it finally becomes all blue or all red depending on the total majority. The training error rate and the validation error rate are two parameters we need to access on different K-value. Following is the curve for the training error rate with varying value of K :
As you can see, the error rate at K=1 is always zero for the training sample. This is because the closest point to any training data point is itself.Hence the prediction is always accurate with K=1. If validation error curve would have been similar, our choice of K would have been 1. Following is the validation error curve with varying value of K:
This makes the story more clear. At K=1, we were overfitting the boundaries. Hence, error rate initially decreases and reaches a minima. After the minima point, it then increase with increasing K. To get the optimal value of K, you can segregate the training and validation from the initial dataset. Now plot the validation error curve to get the optimal value of K. This value of K should be used for all predictions.
Breaking it Down – Pseudo Code of KNN
We can implement a KNN model by following the below steps:
- Load the data
- Initialise the value of k
- For getting the predicted class, iterate from 1 to total number of training data points
- Calculate the distance between test data and each row of training data. Here we will use Euclidean distance as our distance metric since it’s the most popular method. The other metrics that can be used are Chebyshev, cosine, etc.
- Sort the calculated distances in ascending order based on distance values
- Get top k rows from the sorted array
- Get the most frequent class of these rows
- Return the predicted class
Implementation in Python from scratch
We will be using the popular Iris dataset for building our KNN model. You can download it from here.
Comparing our model with scikit-learn
from sklearn.neighbors import KNeighborsClassifier neigh = KNeighborsClassifier(n_neighbors=3) neigh.fit(data.iloc[:,0:4], data['Name']) # Predicted class print(neigh.predict(test)) -> ['Iris-virginica'] # 3 nearest neighbors print(neigh.kneighbors(test)) -> [[141 139 120]]
We can see that both the models predicted the same class (‘Iris-virginica’) and the same nearest neighbors ( [141 139 120] ). Hence we can conclude that our model runs as expected.
Implementation of kNN in R
Step 1: Importing the data
Step 2: Checking the data and calculating the data summary
#Top observations present in the data SepalLength SepalWidth PetalLength PetalWidth Name 1 5.1 3.5 1.4 0.2 Iris-setosa 2 4.9 3.0 1.4 0.2 Iris-setosa 3 4.7 3.2 1.3 0.2 Iris-setosa 4 4.6 3.1 1.5 0.2 Iris-setosa 5 5.0 3.6 1.4 0.2 Iris-setosa 6 5.4 3.9 1.7 0.4 Iris-setosa #Check the dimensions of the data  150 5 #Summarise the data SepalLength SepalWidth PetalLength PetalWidth Name Min. :4.300 Min. :2.000 Min. :1.000 Min. :0.100 Iris-setosa :50 1st Qu.:5.100 1st Qu.:2.800 1st Qu.:1.600 1st Qu.:0.300 Iris-versicolor:50 Median :5.800 Median :3.000 Median :4.350 Median :1.300 Iris-virginica :50 Mean :5.843 Mean :3.054 Mean :3.759 Mean :1.199 3rd Qu.:6.400 3rd Qu.:3.300 3rd Qu.:5.100 3rd Qu.:1.800 Max. :7.900 Max. :4.400 Max. :6.900 Max. :2.500
Step 3: Splitting the Data
Step 4: Calculating the Euclidean Distance
Step 5: Writing the function to predict kNN
Step 6: Calculating the label(Name) for K=1
For K=1  "Iris-virginica"
In the same way, you can compute for other values of K.
Comparing our kNN predictor function with “Class” library
For K=1  "Iris-virginica"
We can see that both models predicted the same class (‘Iris-virginica’).
KNN algorithm is one of the simplest classification algorithm. Even with such simplicity, it can give highly competitive results. KNN algorithm can also be used for regression problems. The only difference from the discussed methodology will be using averages of nearest neighbors rather than voting from nearest neighbors. KNN can be coded in a single line on R. I am yet to explore how can we use KNN algorithm on SAS.
Did you find the article useful? Have you used any other machine learning tool recently? Do you plan to use KNN in any of your business problems? If yes, share with us how you plan to go about it.
If you like what you just read & want to continue your analytics learning, subscribe to our emails, follow us on twitter or like our facebook page.
You can also read this article on Analytics Vidhya's Android APP