Welcome to our fourth installment on Machine Learning. In this module we’re going to cover K-Means. K-Means is a clustering algorithm based on the hyperparameter “K” which dictates how many clusters there will be. A hyperparameter is just a parameter that we can adjust. Each cluster has a “centroid” or a central point that will be the anchor of our cluster. Here are the steps to the K-Means algorithm:
- Plot the data points
- Plot K centroids randomly
- Calculate distances from each point to each centroid
- Assign a label to each point equal to the centroid it’s clustered to
- Calculate the center of each cluster → that becomes the new centroid
- Repeat steps 3 to 5 until either:
- The centroid stops moving
- The data points are not assigned new centroids after an iteration
- We reach the maximum number of iterations
Video Tutorial:
Here’s a visualization of K-Means:
We’re not going to manually implement K-Means in this introductory module, we’re just going to use Python’s SKLearn module which already has an implementation for us. Let’s get into it. To start using K-Means we’ll have to install some libraries. We’ll need the sklearn
library which has an implementation of K-Means that we can just use, numpy
which contains numerical operators, pandas
which is the de facto data organization library for Python, and matplotlib
which we’ve already used many times and is the best plotting library for Python. We can install these with just one line in the command line:
pip install sklearn numpy pandas matplotlib
Randomly Generated Sample Data K Means
Alright now that we’ve got our libraries installed, we’re ready to go. We’ll cover two different K-Means examples here. Example number 1 is going to be on a contrived example of randomly generated data with two centroids. Example number 2 is going to be on the digits dataset provided by sklearn
. As always, the first thing we’ll do is handle our imports.
import random
import pandas as pd
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
Generate Random Sample Data
Next we’re going to randomly generate 100 data points. The points will be two clusters, one around (0,0)
and the other around (5,5)
. We’ll do this with a for loop that loops 100 times and generates a cluster around (0,0)
on even iterations and around (5,5)
on odd iterations. To randomly generate data points, we’ll create two values, an x
and a y
value. Each value will be made by generating a uniform random number between -1 and 1 and adding it to either 0 or 5. Once we’ve generated all 100 samples we’ll convert our data into a pandas
DataFrame
object for further processing.
samples = 100
data = []
for i in range(100):
if i%2 == 0:
base = 0
else:
base = 5
x = random.uniform(-1,1) + base
y = random.uniform(-1,1) + base
data.append([x,y])
df_rand = pd.DataFrame(data)
Now we can just use sklearn
to implement K-Means with 2 clusters. First we’ll create a K-Means object and then call its fit_predict
module on the DataFrame
we made earlier. Once we have our labels, we’ll separate out the labeled data into two separate dataframes to graph.
k2means = KMeans(n_clusters=2)
label_rand=k2means.fit_predict(df_rand)
flabels1 = df_rand[label_rand==1]
flabels0 = df_rand[label_rand==0]
Plot Randomized Sample Data
Now all we have to do is scatter plot these with matplotlib
. We’ll also get the centroids using the cluster_centers_
attribute of the K Means object we created earlier.
plt.scatter(flabels1[0], flabels1[1], label=0)
plt.scatter(flabels0[0], flabels0[1], label=1)
centroids_rand = k2means.cluster_centers_
plt.scatter(centroids_rand[:,0], centroids_rand[:,1], s=80, color="black")
plt.legend()
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Randomly Generated Two Centroid K Means")
plt.show()
Once we plot these, we should see something like the image below.
Digits Dataset K Means
Alright now that we’ve seen a contrived example, let’s take a look at what a more real-life like example will be. For this example, we’ll be running K Means on the digits dataset. As always, we’ll start off by importing our libraries. We’ll import the load_digits
module from sklearn.datasets
to load the digits dataset. We’ll import PCA
from decomposition to turn this dataset with 64 features into a dataset with 2 features. PCA is Principal Component Analysis, in this example we’ll be using it for dimensionality reduction. We imported KMeans
and matplotlib.pyplot
already above but I just put them here to show you that we need those libraries for this example. We’ll also need numpy
for this example.
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import numpy as np
import matplotlib.pyplot as plt
We’ll start out by loading up the digits and applying PCA to turn our 64 feature dataset into a 2 feature dataset.
data = load_digits().data
pca = PCA(n_components=2)
df = pca.fit_transform(data)
Load and Examine Data
Once we’ve transformed our data let’s apply KMeans
. We’ll give it 10 clusters since there are 10 digits.
kmeans = KMeans(n_clusters=10)
label = kmeans.fit_predict(df)
print(label)
The labels will be a list of numbers where each number is KMeans
prediction of each dataset.
This is all there is to KMeans
, let’s get into what it looks like when we plot it out. We’ll use np
, our alias for the numpy
library to create a set of unique labels from the labels we made earlier. Then we’ll use the cluster_centers_
from the K Means object we created to get the centroids. Now, for each of our unique labels, we’ll plot the data points that correspond to that label on a scatter plot. Once we’ve plotted all the labels, we’ll create a scatter plot of the centroids. For the centroids you’ll notice that I passed in an s
parameter. This parameter accounts for the size of the point, a regular point is size 72.
Plot Data
Once we’ve plotted our data, we simply label our graph, add a legend, and then print it out.
unique_labels = np.unique(label)
centroids = kmeans.cluster_centers_
for i in unique_labels:
plt.scatter(df[label==i, 0], df[label==i, 1], label=i)
# s is a size indicator
plt.scatter(centroids[:,0], centroids[:,1], s=80, color="black")
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Digits K Means")
plt.legend()
plt.show()
Our plot should look something like this:
That’s it, that’s all there is to K-Means. Pretty simple, you can implement it in just a few lines.
I run this site to help you and others like you find cool projects and practice software skills. If this is helpful for you and you enjoy your ad free site, please help fund this site by donating below! If you can’t donate right now, please think of us next time.

One thought on “Introduction to Machine Learning: K Means”