Closest Pair Implemetation Python

maverick93 picture maverick93 · Jan 30, 2015 · Viewed 13.1k times · Source

I am trying to implement the closest pair problem in Python using divide and conquer, everything seems to work fine except that in some input cases, there is a wrong answer. My code is as follows:

def closestSplitPair(Px,Py,d):
    X = Px[len(Px)-1][0]
    Sy = [item for item in Py if item[0]>=X-d and item[0]<=X+d]
    best,p3,q3 = d,None,None
    for i in xrange(0,len(Sy)-2):
        for j in xrange(1,min(7,len(Sy)-1-i)):
            if dist(Sy[i],Sy[i+j]) < best:
                best = (Sy[i],Sy[i+j])
                p3,q3 = Sy[i],Sy[i+j]
    return (p3,q3,best)

I am calling the above function through a recursive function which is as follows:

def closestPair(Px,Py): """Px and Py are input arrays sorted according to
their x and y coordinates respectively"""
    if len(Px) <= 3:
        return min_dist(Px)
    else:
        mid = len(Px)/2
        Qx = Px[:mid] ### x-sorted left side of P
        Qy = Py[:mid] ### y-sorted left side of P
        Rx = Px[mid:] ### x-sorted right side of P
        Ry = Py[mid:] ### y-sorted right side of P
        (p1,q1,d1) = closestPair(Qx,Qy)
        (p2,q2,d2) = closestPair(Rx,Ry)
        d = min(d1,d2)
        (p3,q3,d3) = closestSplitPair(Px,Py,d)
        return min((p1,q1,d1),(p2,q2,d2),(p3,q3,d3),key=lambda tup: tup[2])

where min_dist(P) is the brute force implementation of the closest pair algorithm for a list P having 3 or less elements and returns a tuple containing the pair of closest points and their distance.

If my input is P = [(0,0),(7,6),(2,20),(12,5),(16,16),(5,8),(19,7),(14,22),(8,19),(7,29),(10,11),(1,13)], then my output is ((5,8),(7,6),2.8284271) which is the correct output. But when my input is P = [(94, 5), (96, -79), (20, 73), (8, -50), (78, 2), (100, 63), (-14, -69), (99, -8), (-11, -7), (-78, -46)] the output I get is ((78, 2), (94, 5), 16.278820596099706) whereas the correct output should be ((94, 5), (99, -8), 13.92838827718412)

Answer

Padraic Cunningham picture Padraic Cunningham · Jan 30, 2015

You have two problems, you are forgetting to call dist to update the best distance. But the main problem is there is more than one recursive call happening so you can end up overwriting when you find a closer split pair with the default, best,p3,q3 = d,None,None. I passed the best pair from closest_pair as an argument to closest_split_pair so I would not potentially overwrite the value.

def closest_split_pair(p_x, p_y, delta, best_pair): # <- a parameter
    ln_x = len(p_x)
    mx_x = p_x[ln_x // 2][0]
    s_y = [x for x in p_y if mx_x - delta <= x[0] <= mx_x + delta]
    best = delta
    for i in range(len(s_y) - 1):
        for j in range(1, min(i + 7, (len(s_y) - i))):
            p, q = s_y[i], s_y[i + j]
            dst = dist(p, q)
            if dst < best:
                best_pair = p, q
                best = dst
    return best_pair

The end of closest_pair looks like the following:

    p_1, q_1 = closest_pair(srt_q_x, srt_q_y)
    p_2, q_2 = closest_pair(srt_r_x, srt_r_y)
    closest = min(dist(p_1, q_1), dist(p_2, q_2))
    # get min of both and then pass that as an arg to closest_split_pair
    mn = min((p_1, q_1), (p_2, q_2), key=lambda x: dist(x[0], x[1]))
    p_3, q_3 = closest_split_pair(p_x, p_y, closest,mn)
    # either return mn or we have a closer split pair
    return min(mn, (p_3, q_3), key=lambda x: dist(x[0], x[1]))

You also have some other logic issues, your slicing logic is not correct, I made some changes to your code where brute is just a simple bruteforce double loop:

def closestPair(Px, Py):
    if len(Px) <= 3:
        return brute(Px)

    mid = len(Px) / 2
    # get left and right half of Px 
    q, r = Px[:mid], Px[mid:]
     # sorted versions of q and r by their x and y coordinates 
    Qx, Qy = [x for x in q if Py and  x[0] <= Px[-1][0]], [x for x in q if x[1] <= Py[-1][1]]
    Rx, Ry = [x for x in r if Py and x[0] <= Px[-1][0]], [x for x in r if x[1] <= Py[-1][1]]
    (p1, q1) = closestPair(Qx, Qy)
    (p2, q2) = closestPair(Rx, Ry)
    d = min(dist(p1, p2), dist(p2, q2))
    mn = min((p1, q1), (p2, q2), key=lambda x: dist(x[0], x[1]))
    (p3, q3) = closest_split_pair(Px, Py, d, mn)
    return min(mn, (p3, q3), key=lambda x: dist(x[0], x[1]))

I just did the algorithm today so there are no doubt some improvements to be made but this will get you the correct answer.