Graph Algorithms: Prim’s Algorithm in Python

Minimum Spanning Tree (MST) algorithms find the shortest path that connects all the points in a graph. Tree algorithms that find minimum spanning trees are useful in network design, taxonomies, or cluster analysis. On PythonAlgos, we’ve already covered one MST algorithm, Kruskal’s algorithm.

Prim’s algorithm is similar to Kruskal’s algorithm. Whereas Kruskal’s adds to the MST by looping through edges, Prim’s adds to the MST by looping through vertices. In this post, we’ll cover how to implement Prim’s algorithm in Python through the following:

  • Pseudocode for Prim’s Algorithm
  • Prim’s Algorithm in Python 
    • Creating a Graph Object
    • Print Function for MST Created by Prim’s Algorithm
    • Helper Function to Find the Minimum Vertex in Prim’s
    • Prim’s Algorithm Implementation
    • Testing Our Python Implementation of Prim’s Algorithm
    • Full Code for Prim’s Algorithm in Python
  • Summary of Prim’s Algorithm

Psuedocode for Prim’s Algorithm

Prim’s Algorithm is a graph algorithm that finds the minimum spanning tree of a graph. Our implementation of Prim’s algorithm will assume that the graph is connected, and therefore made up of one minimum spanning tree. Perhaps we will cover a more advanced implementation in the future.

Here’s the pseudocode for implementing Prim’s Algorithm:

  1. Pick a random point to start the graph at, in our example Python implementation of Prim’s algorithm, we’ll start from point 0 (the first point)
  2. Go through the tree and find the point that is the shortest distance away from the current points in the MST
  3. Add the new point and update the current list of known minimum distances to the MST
  4. Repeat steps 2 and 3 until all the vertices have been added to the minimum spanning tree

Prim’s Algorithm in Python

Let’s go over how to implement Prim’s Algorithm in Python. We’re going to be using Python 3, and we won’t need any external libraries for this implementation. We will create a Graph object which will hold three properties and three functions. The properties our graph will have will be a large number representing the max possible value for an edge distance, the number of vertices in the graph, and an adjacency matrix representation of the graph.

Creating a Graph Object

The first thing we’re going to do is create a Graph object. Every other code block in this tutorial belongs inside this object. First, we create a class property called INF which represents the max value that an edge in the graph can have (or infinity). Then we’ll create the init function. 

The init function of the Graph object will take one parameter other than self, the number of vertices in the graph. It sets the instance property V to the number of vertices and creates an empty adjacency list representation of a graph with V vertices.

class Graph():
   INF = 999999
   def __init__(self, num_vertices):
       self.V = num_vertices
       self.graph = [[0 for column in range(num_vertices)] for row in range(num_vertices)]

Print Function for MST Created by Prim’s Algorithm

Now that we’ve finished the initial creation of our graph, let’s write the functions we need for Prim’s algorithm. In this section, we’ll create a function that prints the edges and weights of the MST that we find using Prim’s algorithm.

This function takes one parameter, parent. It expects parent to be a list of indices corresponding to the parent node of each index. Then it prints out a line that shows that we’re going to first print the edge, a tab, then the weight. Finally, we’ll loop through all the vertices from 1 (the root vertex, 0, has no parent node), until the end and print out the edge, a tab, then the weight.

   # pretty print of the minimum spanning tree
   # prints the MST stored in the list var `parent`
   def printMST(self, parent):
       print("Edge     Weight")
       for i in range(1, self.V):
           print(f"{parent[i]} - {i}       {self.graph[i][parent[i]]}")

Helper Function for Finding the Minimum Vertex in Prim’s

Now let’s create the helper function to find the next minimum distance vertex to add in Prim’s algorithm. This function takes two parameters, a list of distances called key, and a list representing whether a vertex, represented by the truth value at an index, is in the MST already.

The first thing we’ll do in this function is set a min value which represents the minimum edge distance to a vertex that we’ve found so far. We will initially set this value to the INF property we declared earlier. 

Next, we’ll loop through the list of vertices and check if the distance to a vertex in the key list is less than the current minimum and the vertex is not in the minimum spanning tree. If it is, then we’ll set the new minimum value to be the distance and the current minimum index to be the vertex being iterated on. After looping through each point, we’ll return the minimum index.

   # finds the vertex with the minimum distance value
   # takes a key and the current set of nodes in the MST
   def minKey(self, key, mstSet):
       min = self.INF
       for v in range(self.V):
           if key[v] < min and mstSet[v] == False:
               min = key[v]
               min_index = v
       return min_index

Prim’s Algorithm Implementation

Alright, let’s get to the fun part, the actual implementation of Prim’s algorithm. This Python implementation of Prim’s algorithm doesn’t take any parameters other than the object itself. The first things we’re going to do are initialize the list of distances, parent nodes, MST nodes, and their values.

The initial list of distances to existing nodes should be set to the max value, INF, for each node in the range of vertices. We’ll set the distance to the initial node, 0, to 0 to start the MST. The list of parent nodes should be set to None for each of the nodes, then we’ll set the parent node for the initial node, 0, to -1 to show it does not exist. Finally, we’ll create a truth value list where the truth value of each index represents whether that vertex is in the minimum spanning tree yet.

Now we’ll create the outer loop of vertices. In each iteration of this loop we’ll be adding a vertex to the MST. In this loop, we first find the vertex with the minimum distance using the minKey helper function we wrote earlier. Next, we’ll set the index of that vertex in the MST list to True.

Once we’ve selected the next vertex to add to the MST, we’ll sort out the keys which show the distances of the connectable vertices and the parent list. We’ll create another loop that goes through each vertex. This loop will check for all the new connectable vertices and new minimum distances based on the newly added vertex and update everything in the existing lists.

   # prim's algo, graph is represented as an v by v adjacency matrix
   def prims(self):
       # used to pick minimum weight edge
       key = [self.INF for _ in range(self.V)]
       # used to store MST
       parent = [None for _ in range(self.V)]
       # pick a random vertex, ie 0
       key[0] = 0
       # create list for t/f if a node is connected to the MST
       mstSet = [False for _ in range(self.V)]
        # set the first node to the root (ie have a parent of -1)
       parent[0] = -1
 
       for _ in range(self.V):
           # 1) pick the minimum distance vertex from the current key
           # from the set of points not yet in the MST
           u = self.minKey(key, mstSet)
           # 2) add the new vertex to the MST
           mstSet[u] = True
 
           # loop through the vertices to update the ones that are still
           # not in the MST
           for v in range(self.V):
               # if the edge from the newly added vertex (v) exists,
               # the vertex hasn't been added to the MST, and
               # the new vertex's distance to the graph is greater than the distance
               # stored in the initial graph, update the "key" value to the
               # distance initially given and update the parent of
               # of the vertex (v) to the newly added vertex (u)
               if self.graph[u][v] > 0 and mstSet[v] == False and key[v] > self.graph[u][v]:
                   key[v] = self.graph[u][v]
                   parent[v] = u
 
       self.printMST(parent)

Testing Our Python Implementation of Prim’s Algorithm

Let’s test out our Python implementation of Prim’s algorithm via a graph object on a graph of size five. Before we get into testing the code, let’s take a look at a visual representation of our graph and the expected MST.

Visual representation of the graph that we will be applying our Python implementation of Prim’s algorithm on:

example graph
Example Graph to Build an MST from

Expected Minimum Spanning Tree Yielded by Prim’s Algorithm:

expected MST
Expected MST from Prim’s Algorithm

Now let’s take a look at the code and the output from our implementation.

g = Graph(5)
g.graph = [ [0, 2, 0, 6, 0],
           [2, 0, 3, 8, 5],
           [0, 3, 0, 0, 7],
           [6, 8, 0, 0, 9],
           [0, 5, 7, 9, 0]]
 
g.prims()

The expected print output for our Python implementation of Prim’s Algorithm:

MST build by our Python implementation of Prim's algorithm
Prim’s Algorithm’s MST

Full Code for Prim’s Algorithm in Python

Here’s the full code for Prim’s Algorithm in Python.

class Graph():
   INF = 999999
   def __init__(self, num_vertices):
       self.V = num_vertices
       self.graph = [[0 for column in range(num_vertices)] for row in range(num_vertices)]
      
   # pretty print of the minimum spanning tree
   # prints the MST stored in the list var `parent`
   def printMST(self, parent):
       print("Edge     Weight")
       for i in range(1, self.V):
           print(f"{parent[i]} - {i}       {self.graph[i][parent[i]]}")
  
   # finds the vertex with the minimum distance value
   # takes a key and the current set of nodes in the MST
   def minKey(self, key, mstSet):
       min = self.INF
       for v in range(self.V):
           if key[v] < min and mstSet[v] == False:
               min = key[v]
               min_index = v
       return min_index
  
   # prim's algo, graph is represented as an v by v adjacency matrix
   def prims(self):
       # used to pick minimum weight edge
       key = [self.INF for _ in range(self.V)]
       # used to store MST
       parent = [None for _ in range(self.V)]
       # pick a random vertex, ie 0
       key[0] = 0
       # create list for t/f if a node is connected to the MST
       mstSet = [False for _ in range(self.V)]
        # set the first node to the root (ie have a parent of -1)
       parent[0] = -1
 
       for _ in range(self.V):
           # 1) pick the minimum distance vertex from the current key
           # from the set of points not yet in the MST
           u = self.minKey(key, mstSet)
           # 2) add the new vertex to the MST
           mstSet[u] = True
 
           # loop through the vertices to update the ones that are still
           # not in the MST
           for v in range(self.V):
               # if the edge from the newly added vertex (v) exists,
               # the vertex hasn't been added to the MST, and
               # the new vertex's distance to the graph is greater than the distance
               # stored in the initial graph, update the "key" value to the
               # distance initially given and update the parent of
               # of the vertex (v) to the newly added vertex (u)
               if self.graph[u][v] > 0 and mstSet[v] == False and key[v] > self.graph[u][v]:
                   key[v] = self.graph[u][v]
                   parent[v] = u
 
       self.printMST(parent)
 
g = Graph(5)
g.graph = [ [0, 2, 0, 6, 0],
           [2, 0, 3, 8, 5],
           [0, 3, 0, 0, 7],
           [6, 8, 0, 0, 9],
           [0, 5, 7, 9, 0]]
 
g.prims()

Summary of Prim’s Algorithm

In this post on Prim’s algorithm and how to implement it in Python, we learned that Prim’s algorithm is a minimum spanning tree graph algorithm. Prim’s algorithm creates a MST by adding the nearest vertex one after another. We looked at the psuedocode for Prim’s algorithm, and then created a Python graph class that implements Prim’s. Finally, we tested our implementation on a five vertex graph.

Further Reading

Learn More

To learn more, feel free to reach out to me @yujian_tang on Twitter, connect with me on LinkedIn, and join our Discord. Remember to follow the blog to stay updated with cool Python projects and ways to level up your Software and Python skills! If you liked this article, please Tweet it, share it on LinkedIn, or tell your friends!

I run this site to help you and others like you find cool projects and practice software skills. If this is helpful for you and you enjoy your ad free site, please help fund this site by donating below! If you can’t donate right now, please think of us next time.

Yujian Tang
Yujian Tang

I started my professional software career interning for IBM in high school after winning ACSL two years in a row. I got into AI/ML in college where I published a first author paper to IEEE Big Data. After college I worked on the AutoML infrastructure at Amazon before leaving to work in startups. I believe I create the highest quality software content so that’s what I’m doing now. Drop a comment to let me know!

One-Time
Monthly
Yearly

Make a one-time donation

Make a monthly donation

Make a yearly donation

Choose an amount

$5.00
$15.00
$100.00
$5.00
$15.00
$100.00
$5.00
$15.00
$100.00

Or enter a custom amount

$

Your contribution is appreciated.

Your contribution is appreciated.

Your contribution is appreciated.

DonateDonate monthlyDonate yearly

2 thoughts on “Graph Algorithms: Prim’s Algorithm in Python

Leave a Reply

%d bloggers like this: