Prim’s Algorithm in Python for MST

Minimum Spanning Tree (MST) algorithms find the shortest path that connects all the points in a graph. Tree algorithms that find minimum spanning trees are useful in network design, taxonomies, or cluster analysis. On PythonAlgos, we’ve already covered one MST algorithm, Kruskal’s algorithm. This time we’re going to cover Prim’s algorithm for an MST.

Prim’s algorithm is similar to Kruskal’s algorithm. Whereas Kruskal’s adds to the MST by looping through edges, Prim’s adds to the MST by looping through vertices. In this post, we’ll cover how to implement Prim’s algorithm in Python through the following:

Psuedocode for Prim’s Algorithm for MST

Prim’s Algorithm is a graph algorithm that finds the minimum spanning tree of a graph. Our implementation will assume that the graph is connected, and therefore made up of one minimum spanning tree. Perhaps we will cover a more advanced implementation in the future.

Here’s the pseudocode for implementing Prim’s Algorithm:

  1. Pick a random point to start the graph at, in our example, we’ll start from point 0 (the first point)
  2. Go through the tree and find the point that is the shortest distance away from the current points in the MST
  3. Add the new point and update the current list of known minimum distances to the MST
  4. Repeat steps 2 and 3 until all the vertices have been added to the minimum spanning tree

Prim’s Algorithm in Python

We’re going to be using Python 3, and we won’t need any external libraries for this implementation. We will create a Graph object which will hold three properties and three functions. The properties our graph will have will be a large number representing the max possible value for an edge distance, the number of vertices in the graph, and an adjacency matrix representation of the graph.

Creating a Graph Object

The first thing we’re going to do is create a Graph object. Every other code block in this tutorial belongs inside this object. First, we create a class property called INF which represents the max value that an edge in the graph can have (or infinity). Then we’ll create the init function. 

The init function of the Graph object will take one parameter other than self, the number of vertices in the graph. It sets the instance property V to the number of vertices and creates an empty adjacency list representation of a graph with V vertices.

class Graph():
   INF = 999999
   def __init__(self, num_vertices):
       self.V = num_vertices
       self.graph = [[0 for column in range(num_vertices)] for row in range(num_vertices)]

Print Function for MST Created by Prim’s Algorithm

Now that we’ve finished the initial creation of our graph, let’s write the functions we need for Prim’s algorithm. In this section, we’ll create a function that prints the edges and weights of the MST that we find using Prim’s algorithm.

This function takes one parameter, parent. It expects parent to be a list of indices corresponding to the parent node of each index. Then it prints out a line that shows that we’re going to first print the edge, a tab, then the weight. Finally, we’ll loop through all the vertices from 1 (the root vertex, 0, has no parent node), until the end and print out the edge, a tab, then the weight.

   # pretty print of the minimum spanning tree
   # prints the MST stored in the list var `parent`
   def printMST(self, parent):
       print("Edge     Weight")
       for i in range(1, self.V):
           print(f"{parent[i]} - {i}       {self.graph[i][parent[i]]}")

Helper Function for Finding the Minimum Vertex in Prim’s

Now let’s create the helper function to find the next minimum distance vertex to add in Prim’s algorithm. This function takes two parameters, a list of distances called key, and a list representing whether a vertex, represented by the truth value at an index, is in the MST already.

The first thing we’ll do in this function is set a min value which represents the minimum edge distance to a vertex that we’ve found so far. We will initially set this value to the INF property we declared earlier. 

Next, we’ll loop through the list of vertices and check if the distance to a vertex in the key list is less than the current minimum and the vertex is not in the minimum spanning tree. If it is, then we’ll set the new minimum value to be the distance and the current minimum index to be the vertex being iterated on. After looping through each point, we’ll return the minimum index.

   # finds the vertex with the minimum distance value
   # takes a key and the current set of nodes in the MST
   def minKey(self, key, mstSet):
       min = self.INF
       for v in range(self.V):
           if key[v] < min and mstSet[v] == False:
               min = key[v]
               min_index = v
       return min_index

Prim’s Algorithm Implementation

Alright, let’s get to the fun part, the actual implementation of Prim’s algorithm. This Python implementation doesn’t take any parameters other than the object itself. The first things we’re going to do are initialize the list of distances, parent nodes, MST nodes, and their values.

The initial list of distances to existing nodes should be set to the max value, INF, for each node in the range of vertices. We’ll set the distance to the initial node, 0, to 0 to start the MST. The list of parent nodes should be set to None for each of the nodes, then we’ll set the parent node for the initial node, 0, to -1 to show it does not exist. Finally, we’ll create a truth value list where the truth value of each index represents whether that vertex is in the minimum spanning tree yet.

Now we’ll create the outer loop of vertices. In each iteration of this loop we’ll be adding a vertex to the MST. In this loop, we first find the vertex with the minimum distance using the minKey helper function we wrote earlier. Next, we’ll set the index of that vertex in the MST list to True.

Once we’ve selected the next vertex to add to the MST, we’ll sort out the keys which show the distances of the connectable vertices and the parent list. We’ll create another loop that goes through each vertex. This loop will check for all the new connectable vertices and new minimum distances based on the newly added vertex and update everything in the existing lists.

   # prim's algo, graph is represented as an v by v adjacency list
   def prims(self):
       # used to pick minimum weight edge
       key = [self.INF for _ in range(self.V)]
       # used to store MST
       parent = [None for _ in range(self.V)]
       # pick a random vertex, ie 0
       key[0] = 0
       # create list for t/f if a node is connected to the MST
       mstSet = [False for _ in range(self.V)]
        # set the first node to the root (ie have a parent of -1)
       parent[0] = -1
 
       for _ in range(self.V):
           # 1) pick the minimum distance vertex from the current key
           # from the set of points not yet in the MST
           u = self.minKey(key, mstSet)
           # 2) add the new vertex to the MST
           mstSet[u] = True
 
           # loop through the vertices to update the ones that are still
           # not in the MST
           for v in range(self.V):
               # if the edge from the newly added vertex (v) exists,
               # the vertex hasn't been added to the MST, and
               # the new vertex's distance to the graph is greater than the distance
               # stored in the initial graph, update the "key" value to the
               # distance initially given and update the parent of
               # of the vertex (v) to the newly added vertex (u)
               if self.graph[u][v] > 0 and mstSet[v] == False and key[v] > self.graph[u][v]:
                   key[v] = self.graph[u][v]
                   parent[v] = u
 
       self.printMST(parent)

Testing Our Python Implementation of Prim’s Algorithm

Let’s test out our Python implementation of Prim’s algorithm via a graph object on a graph of size five. Before we get into testing the code, let’s take a look at a visual representation of our graph and the expected MST.

Visual representation of the graph that we will be applying our Python implementation of Prim’s algorithm on:

example graph to use prim's algorithm for MST on
Example Graph to Build an MST from Prims Algorithm

Expected Minimum Spanning Tree Yielded by Prims Algorithm:

expected prim's algorithm for MST
Expected MST from Prim’s Algorithm

Now let’s take a look at the code and the output from our implementation.

g = Graph(5)
g.graph = [ [0, 2, 0, 6, 0],
           [2, 0, 3, 8, 5],
           [0, 3, 0, 0, 7],
           [6, 8, 0, 0, 9],
           [0, 5, 7, 9, 0]]
 
g.prims()

The expected print output for our Python implementation of Prim’s Algorithm:

MST build by our Python implementation of Prim's algorithm
Prim’s Algorithm’s MST

Full Code for Prim’s Algorithm in Python

Here’s the full code for Prim’s Algorithm in Python.

class Graph():
   INF = 999999
   def __init__(self, num_vertices):
       self.V = num_vertices
       self.graph = [[0 for column in range(num_vertices)] for row in range(num_vertices)]
      
   # pretty print of the minimum spanning tree
   # prints the MST stored in the list var `parent`
   def printMST(self, parent):
       print("Edge     Weight")
       for i in range(1, self.V):
           print(f"{parent[i]} - {i}       {self.graph[i][parent[i]]}")
  
   # finds the vertex with the minimum distance value
   # takes a key and the current set of nodes in the MST
   def minKey(self, key, mstSet):
       min = self.INF
       for v in range(self.V):
           if key[v] < min and mstSet[v] == False:
               min = key[v]
               min_index = v
       return min_index
  
   # prim's algo, graph is represented as an v by v adjacency list
   def prims(self):
       # used to pick minimum weight edge
       key = [self.INF for _ in range(self.V)]
       # used to store MST
       parent = [None for _ in range(self.V)]
       # pick a random vertex, ie 0
       key[0] = 0
       # create list for t/f if a node is connected to the MST
       mstSet = [False for _ in range(self.V)]
        # set the first node to the root (ie have a parent of -1)
       parent[0] = -1
 
       for _ in range(self.V):
           # 1) pick the minimum distance vertex from the current key
           # from the set of points not yet in the MST
           u = self.minKey(key, mstSet)
           # 2) add the new vertex to the MST
           mstSet[u] = True
 
           # loop through the vertices to update the ones that are still
           # not in the MST
           for v in range(self.V):
               # if the edge from the newly added vertex (v) exists,
               # the vertex hasn't been added to the MST, and
               # the new vertex's distance to the graph is greater than the distance
               # stored in the initial graph, update the "key" value to the
               # distance initially given and update the parent of
               # of the vertex (v) to the newly added vertex (u)
               if self.graph[u][v] > 0 and mstSet[v] == False and key[v] > self.graph[u][v]:
                   key[v] = self.graph[u][v]
                   parent[v] = u
 
       self.printMST(parent)
 
g = Graph(5)
g.graph = [ [0, 2, 0, 6, 0],
           [2, 0, 3, 8, 5],
           [0, 3, 0, 0, 7],
           [6, 8, 0, 0, 9],
           [0, 5, 7, 9, 0]]
 
g.prims()

Prims Algorithm Using Priority Queue in Python

Now that we’ve walked through one solution to Prims Algorithm, let’s look at a second. This solution uses a priority queue implemented with Python heapq. One of the main differences between these two solutions is that this one uses an adjacency list to represent the graph. From the code above you can see that we used a matrix to represent the graph before.

The way that we do Prim’s algorithm with a priority queue requires us to use tuples to represent edges and vertices. The adjacency list representation below is the same graph as above. Each entry in each row of the adjacency list represents a vertex-edge combination. For example, (1, 2) in the first row means that vertex 0 (the first vertex) connects to vertex 1 with edge length/width 2.

import heapq

# adjacency list for Prims Algorithm with Priority Queue
adj_list_graph=[[(1, 2), (3, 6)],
                [(0, 2), (2, 3), (3, 8), (4, 5)],
                [(1, 3), (4, 7)],
                [(0, 6), (1, 8), (4, 9)],
                [(1, 5), (2, 7), (3, 9)]]

Prims Algorithm Implementation with heapq Priority Queue

Now that we’ve set up the graph that our function can read, let’s implement the logic. This implementation of Prim’s algorithm takes two parameters. The input graph as an adjacency matrix, and the starting vertex.

The first thing we do in our function is establish the data structures we need. We will keep track of the edges, the weights, and the visited vertices. The edges and weights are initially empty and the visited_vertices list is populated with the starting vertex.

Then, while the length of the visited vertices is less than the length of the graph, we use the priority queue to figure out which vertex to add next. In our while loop, we create a heap out of the possible moves from the existing MST. Note that we need to weigh the priority queue by the weight of the edge, not the vertex number.

Once we’ve created that priority queue out of the available moves, we pop the first move available to us with heapq.heappop. We add that vertex and the edge and keep going until we have added all possible vertices to our Prims algorithm MST.

"""Prims Algorithm with a Priority Queue 
Implemented with HeapQ
@parameter graph: Graph (adjacency list rep)
@parameter start: integer within graph's # of vertices
@return list type representing the MST as (vertex, weight)"""
def prims_priority_q(graph, start):
    # establish the necessary data structures 
    edges = []
    weights = []
    visited_vertices = [start]

    while len(visited_vertices) < len(graph):
        moves = []
        for x in visited_vertices:
            for node in graph[x]:
                # push weight, cur vertex, next vertex
                # could be prettier if we used objects instead
                # of tuples
                if node[0] not in visited_vertices:
                    heapq.heappush(moves, (node[1], x, node[0]))
        
        # get the next move based on the weight
        next_move = heapq.heappop(moves)
        print(f"next move: {next_move}")
        # add the next vertex, total weight, and append the edge
        visited_vertices.append(next_move[2])
        weights.append(next_move[0])
        edges.append((next_move[1], next_move[2]))

    return edges, weights

Testing the Priority Queue Version of Prim’s Algorithm

Finally, we can test the priority queue implementation of Prim’s algorithm. We run it on the adjacency list graph we created above and expect to see the same MST as the one we created before. The code below executes these tests and pretty prints the edges and weights.

edges, weights = prims_priority_q(adj_list_graph, 0)
print("edges    weights")
for edge, weight in zip(edges, weights):
    print(f"{edge}      {weight}")

Running the program made of all the code compiled in these sections should result in the output like the one below. In this implementation, we’re even showing the next move available in Prim’s algorithm as we run it.

Prim's Algorithm Priority Queue Implementation Test Results
Prims Algorithm Priority Queue Implementation Test Results

Summary of Prims Algorithm

In this post on Prim’s algorithm in Python, we learned that Prim’s is a minimum spanning tree graph algorithm. Prim’s algorithm creates a MST by adding the nearest vertex one after another. We looked at the psuedocode for Prim’s algorithm, and then created a Python graph class that implements Prim’s. Finally, we tested our implementation on a five vertex graph.

Further Reading

I run this site to help you and others like you find cool projects and practice software skills. If this is helpful for you and you enjoy your ad free site, please help fund this site by donating below! If you can’t donate right now, please think of us next time.

Yujian Tang

3 thoughts on “Prim’s Algorithm in Python for MST

Leave a Reply

%d bloggers like this: