2 min read

Prim's Algorithm

Prim's Algorithm

Prim's algorithm solves problems such as finding the Minimum Spanning Tree (MST) of a graph. With a MST we mean the solution set that connects every node of a graph together with the least weight.

let's say for example that you want to every major city of Mexico but by driving the least kilometers, this particular problem can be solved by using prim's algorithm to determining the path.

Used Graph

For this example I will be using the following graph:

graph

The first thing we have to do when working with graphs is to convert them into their list or matrix view depending on the implementation. I have chosen to solve this problem by using a Matrix so let's use this.

Matrix View:

[ 0, 2, 4, 0 ]
[ 2, 0, 1, 5 ]
[ 4, 1, 0, 3 ]
[ 0, 5, 3, 0 ]

This matrix view shows the weights from one node to the other.

Prim's Algorithm

General

Prim's Algorithm starts by picking the smallest node key value. Once we have that one we will look at the neighbours that connects to the current subgraph. Now we will pick the smallest neighbour and add this to our subgraph.

We keep performing this process until everything is connected, also bear in mind that we can not have loops in the graph!

After performing this on a sheet of paper we get the following result:

1 -> 2 -> 3 -> 4

Pseudocode

// Algorithm:
// 1) Init a priority queue, visited list and solution vector
// 2) Put every element on visited = false
// 3) Pick a start element and put it on the priority queue
// 4) while !Q.empty
//  4a) Get next unvisited node

while (visited[Q.top()] && !Q.empty()) {
  Q.pop();

  // Set unvisited node as visited and add it to the solution
  if (!pq.empty())
    p = Q.pop();
    visited[p] = true;
    mst.push_back(p);

  for (neighbour r in grap) {
    if (graph.weight[r] != 0 && !visited[r]) {
      Q.push(r);
    }
  }
}

// Return solution
return mst;

C++ Implementation

#include
#include
#include

using namespace std;

struct node {
  int key;
  int weight;
};

bool operator<(const node& a, const node& b) { return a.weight > b.weight; }

template void print_list(vector numbers) {
  for (auto i = 0; i < numbers.size(); i++) {
    cout << numbers[i] << " ";
  }

  cout << endl;
}

vector prim(vector<vector> &graph) {
  // Init priority queue and parents
  vector visited(graph.size(), false);
  vector mst; // solution
  priority_queue pq;

  struct node start;
  start.key = 0;
  start.weight = 0;

  pq.push(start);

  while (!pq.empty()) {
    // Get next unvisited node
    while (visited[pq.top().key] && !pq.empty()) {
      pq.pop();
    }

    if (!pq.empty()) {
      struct node curr_node = pq.top();
      visited[curr_node.key] = true; // set it visited (black)
      mst.push_back(curr_node.key); // Add it to the solution
      pq.pop();

      // Get the neighbours of the current node
      // Neighbours are all X values that are not 0 in graph[curr_node.key][X] and not visited
      for (int i = 0; i < graph.size(); i++) {
        if (graph[curr_node.key][i] != 0 && !visited[i]) {
          struct node node;
          node.key = i;
          node.weight = graph[curr_node.key][i];

          pq.push(node);
        }
      }
    }
  }

  return mst;
}

int main(int argc, const char * argv[]) {
  vector<vector> graph;

  graph = {
    { 0, 2, 4, 0 },
    { 2, 0, 1, 5 },
    { 4, 1, 0, 3 },
    { 0, 5, 3, 0 }
  };

  // Print solution
  print_list(prim(graph));

  return 0;
}