Saturday, December 24, 2011

Quick 'n Dirty Disjoint Sets

The disjoint-set data structure is magical. Not only is it flexible and indispensable in a variety of situations, but both the associated algorithms and the implementation are remarkably clean and simple. This makes the disjoint-set data structure (sometimes called the 'union-find data structure', named after its two primary operations) a joy to encounter when programming.

The data structure is used to answer the question "given a set of bidirectional connections between nodes, can I reach node b from node a by walking along these connections?" It can be visualised as a set of disconnected trees where each tree regularly bumps deep nodes up to the top, where they become children of the root node. Each tree represents a set of connected nodes, so to determine whether we can walk between two nodes -- that is, whether the two nodes are part of the same set -- all we have to do is check if they have the same root. It's marginally more complicated than that, but that's the general gist of union-find. The data structure also allows quick merging of trees, meaning you can add new connections in near-constant time.

Being a tree data structure, it's relatively easy to implement using a pointer structure, where each node points to its parent. Here's a C implementation:

struct set {
    int rank;
    struct set *parent;
};

struct set* newSet(void) {
    struct set *s = (struct set*) malloc(sizeof(struct set));
    s->parent = s;
    s->rank = 0;
    return s;
}

struct set* find(struct set *s) {
    if (s->parent == s)  return s;
    else                 return (s->parent = find(s->parent));
}

void join(struct set *a, struct set *b) {
    struct set *aRep = find(a), *bRep = find(b);
    if (aRep->rank > bRep->rank) {
        bRep->parent = aRep;
    } else if (aRep->rank < bRep->rank) {
        aRep->parent = bRep;
    } else {
        aRep->parent = bRep;
        bRep->rank ++;
    }
}

It's short, sweet and understandable (I'm glaring at you, binary index trees): exactly the kind of thing you can code up in a programming competition. The rank property is used to implement union by rank, optimising the data structure's performance. Combined with path compression, this yields a time complexity of O(inverse Ackermann(n)) for all operations, which is below 5 for all practical values of n and so is effectively a constant time complexity.

But hey, that's still pretty long, and it needs, like, pointers and stuff. Why don't we drop union by rank? Although we lose the inverse Ackermann upper bound, path compression by itself keeps operations almost as fast... and who could resist the reduction in code size?

rep = range(0,100)

def find(p):
    if rep[p] != p:
        rep[p] = find(rep[p])
    return rep[p]
 
def union(p,q):
    rep[find(p)] = find(q)

This is almost the same thing, but with less code and more elegance (in my opinion, at least). I've been told that at least one university teaches union-find without union by rank, so it's also an acceptable implementation.

Here, I'm generating a set of random points on the plane and building a Euclidean minimal spanning tree (MST) from them, using Kruskal's algorithm and our very own disjoint-set data structure implementation. This code has not been tested for correctness.

import random, itertools

rep = range(0,100)

def find(p):
    if rep[p] != p:
        rep[p] = find(rep[p])
    return rep[p]
 
def union(p,q):
    rep[find(p)] = find(q)


# generate 100 random points on the plane
pts = [(random.randint(0,1000), random.randint(0,1000)) for i in xrange(100)]

# define Euclidean distance function
def dist(a,b):
    return ((pts[a][0]-pts[b][0])**2 + (pts[a][1]-pts[b][1])**2)**0.5

# generate pairs of points representing the edges of the complete graph
edges = list(itertools.combinations(range(100), 2))

# sort the edges by length, in ascending order
edges.sort(cmp=lambda (u1,v1),(u2,v2): -1 if dist(u1,v1) < dist(u2,v2) else 1)

# Kruskal's algorithm for MST
total_cost = 0
for (u,v) in edges:
    if find(u) != find(v):
        total_cost += dist(u,v)
        union(u,v)

print total_cost

It might not be as theoretically sound as the pointer implementation, especially considering that the max number of nodes is also constrained by the size of the rep array, but it's quick and dirty, and sometimes that's just what you need.

No comments:

Post a Comment