Friday, February 3, 2012

We Need To Talk About Binary Search.

Although the basic idea of binary search is comparatively straightforward, the details can be surprisingly tricky.
-- Donald Knuth

Why is binary search so damn hard to get right? Why is it that 90% of programmers are unable to code up a binary search on the spot, even though it's easily the most intuitive of the standard algorithms?

  • Firstly, binary search has a lot of potential for off-by-one errors. Do you do inclusive bounds or exclusive bounds? What's your break condition: lo=hi+1, lo=hi, or lo=hi-1? Is the midpoint (lo+hi)/2 or (lo+hi)/2 - 1 or (lo+hi)/2 + 1? And what about the comparison, < or ? Certain combinations of these work, but it's easy to pick one that doesn't.
  • Secondly, there are actually two variants of binary search: a lower-bound search and an upper-bound search. Bugs are often caused by a careless programmer accidentally applying a lower-bound search when an upper-bound search was required, or vice versa.
  • Finally, binary search is very easy to underestimate and very hard to debug. You'll get it working on one case, but when you increase the array size by 1 it'll stop working; you'll then fix it for this case, but now it won't work in the original case!

I want to generalise and nail down the binary search, with the goal of introducing a shift in the way the you perceive it. By the end of this post you should be able to code any variant of binary search without hesitation and with complete confidence. But first, back to the start: here is the binary search you were probably taught...

Input: sorted list of elements, query term
Output: the index of the first appearance of the query in the list, or an ERROR value otherwise

I propose an alternative definition: a binary search takes as input a (monotonic) function f(x) and a boolean predicate function p(v), and searches over the finite domain of the function for arguments where the predicate is true for the function's value -- i.e., values x such that p(f(x)) is true.

Based on this definition, here are the definitions of the variants:

Upper-bound: find the maximum argument x such that p(f(x)) is true
Lower-bound: find the minimum argument x such that p(f(x)) is true

The traditional binary search described above is a special case of the general lower-bound search, where f(x) = array[x], the domain of f(x) is the set of integers {0, 1, ..., N-1} (N being the length of array) and p(v) = v ≥ query.

In other words, you're searching for the minimum argument x such that array[x] ≥ query. Hey, this is just what we had before!

Let's face it: this description of binary search isn't very helpful. For example, why use this predicate thing when all we need is a simple array[mid] < query in our binary search?

The advantage of this somewhat convoluted definition comes when either the query is not in the array or it's in the array many times. Say you're searching for the first instance of the number 6 in the following array, using the traditional method:

[1, 1, 2, 4, 5, 5, 5, 6, 6, 6, 6, 8, 10, 10, 11]

It should output 7. Let's try this out...

lo, hi = 0, len(arr) - 1
while lo < hi:
    mid = (lo + hi) / 2
    if arr[mid] >= 6:  hi = mid - 1
    else:              lo = mid + 1
print mid  # should print 7

This looks about right. If mid is greater or equal to 6, it'll keep searching to the left, trying to find smaller values of 6. Otherwise, it'll search to the right of mid. Will mid always be the correct index by the end? I run it... apparently not, it outputs 5. Oh, I know! It's because I used ≥ instead of >! Right now it'll keep searching left after it encounters the first 6. OK, change that to a > and run it again. Now I'm getting 9... oh, it must be because I'm accessing arr[mid] at the end instead of arr[lo]. On the final iteration, arr[mid] would have become too large, but arr[lo] will be just right -- it should be at exactly 7, which is what we want. Hit F5; wtf, 10? Undo that ≥/> change from before but keep the other changes to see if it makes a difference -- nope, now 6? And I haven't even considered the issue of inclusive or exclusive bounds...

No joke, I just messed around with the +'s and -'s and bounds and ≥s and >s and ≤s <s and los and mids and his for about 10 minutes and I can't find a single combination which gets me an answer of 7. This is worse than any other bug because you'll undoubtedly end up in a loop of case-bashing: getting it to work for this case, then finding it fails for another, then fixing it for that case and finding it now fails the original case. This style of debugging never works. Instead, turn off your monitor, grab a pen and paper and plan this out.

OK, are we doing upper-bound or lower-bound search? We're finding the minimum index with the value 6; lower-bound then. What's the predicate? It's a lower-bound search, so we want it to return true while we're larger than our desired index and return false when it's smaller. Easy! p(v) = (v >= 6).

So let's do this right. I now write my own p(v) function, even though the logic is ridiculously simple, and I translate the predicate-based binary search definition into code. Most importantly, I introduce a new variable, the best_so_far variable.

def p(v): return v >= 6

lo, hi = 0, len(arr) - 1
best_so_far = None

while lo <= hi:
    x = (lo + hi) / 2
    if p(arr[x]):
        # we found a potential minimum x, but we should still check to see if any smaller ones work
        best_so_far  = x
        hi = x - 1
    else:
        # the predicate is false, so we need to go right to find true values
        lo = x + 1

print best_so_far 

And it works first time. *whistles*

But the problem definition changes! Your boss tells you that now, you must search for the last occurrence of 6!

But hey, that's cool. Re-evaluate the problem: it's now an upper-bound search. Our predicate must return true for values smaller or equal to 6, but start returning false after we get to 7, so the maximum x that p(f(x)) = true is the index of the last 6.

def p(v): return v <= 6

lo, hi = 0, len(arr) - 1
best_so_far = None

while lo <= hi:
    x = (lo + hi) / 2
    if p(arr[x]):
        # we found a potential maximum x, but we should still check to see if any larger ones work
        best_so_far = x
        lo = x + 1
    else:
        # the predicate is false, so we need to go left to find true values
        hi = x - 1

print best_so_far 

Again, it works first go. Notice that I only changed two things: the predicate function (reversed the sign) and the direction we head when we find a predicate=true (i.e. I swapped the lines lo = x + 1 and hi = x - 1). I did not have to mess with any +'s or -'s. No off-by-ones were introduced during the making of this function.

Notice also that I use lo = x + 1 and not lo = x. Similarly, hi = x - 1 and not hi = x. This is a foolproof way to avoid the nasty infinite loop binary search bug caused by integer division -- it ensures that you're never considering a value of x more than once, so you're always narrowing down your search space by at least 1 each time, hence ensuring termination. The use of max/min_so_far gives us complete control over how we're approaching the solution, meaning that we don't need to mess around trying to work out whether it's lo, hi or mid that contain the return value at the conclusion of the algorithm. I personally find that inclusive bounds work best with this form of binary search, but your mileage may vary. If you use exclusive bounds, I make no guarantees on this strategy's correctness.

Yet again, the requirements change. The list is now in descending order, and you need to find the index of the first item less than 5.

[11, 10, 10, 8, 6, 6, 6, 6, 5, 5, 5, 4, 2, 1, 1]

As usual, you only need to make two decisions. Upper- or lower-bound? It's clearly lower-bound: you're finding the first item that satisfies the predicate. What's the predicate? p(v) = (v < 5). Expected output is 11.

def p(v): return v < 5

lo, hi = 0, len(arr) - 1
best_so_far  = None

while lo <= hi:
    x = (lo + hi) / 2
    if p(arr[x]):
        # we found a potential minimum x, but we should still check to see if any smaller ones work
        best_so_far  = x
        hi = x - 1
    else:
        # the predicate is false, so we need to go right to find true values
        lo = x + 1

print best_so_far 

It prints 11.

Do I expect every programmer to write out a trivial p(v) function for every binary search they write? Of course not. It might help you think about the problem, but it's not required. If you take one thing away from this post, let it be this: in any binary search you ever write, whether it be over a list of strings, or a multi-dimensional space, or over the domain of a function that uses the inclusion-exclusion principle on O(1) cumulative sums of rectangular regions of a ternary predicate mapped over the integer values in a 2D grid (this has happened before), you just need to worry about whether it is upper- or lower-bound and what your predicate is.

BAM. No more bugs in a binary search, ever. You can thank me later.