I am trying to implement the closest pair problem in Python using divide and conquer, everything seems to work fine except that in some input cases, there is a wrong answer. My code is as follows:
def closestSplitPair(Px,Py,d):
X = Px[len(Px)-1][0]
Sy = [item for item in Py if item[0]>=X-d and item[0]<=X+d]
best,p3,q3 = d,None,None
for i in xrange(0,len(Sy)-2):
for j in xrange(1,min(7,len(Sy)-1-i)):
if dist(Sy[i],Sy[i+j]) < best:
best = (Sy[i],Sy[i+j])
p3,q3 = Sy[i],Sy[i+j]
return (p3,q3,best)
I am calling the above function through a recursive function which is as follows:
def closestPair(Px,Py): """Px and Py are input arrays sorted according to
their x and y coordinates respectively"""
if len(Px) <= 3:
return min_dist(Px)
else:
mid = len(Px)/2
Qx = Px[:mid] ### x-sorted left side of P
Qy = Py[:mid] ### y-sorted left side of P
Rx = Px[mid:] ### x-sorted right side of P
Ry = Py[mid:] ### y-sorted right side of P
(p1,q1,d1) = closestPair(Qx,Qy)
(p2,q2,d2) = closestPair(Rx,Ry)
d = min(d1,d2)
(p3,q3,d3) = closestSplitPair(Px,Py,d)
return min((p1,q1,d1),(p2,q2,d2),(p3,q3,d3),key=lambda tup: tup[2])
where min_dist(P)
is the brute force implementation of the closest pair algorithm for a list P having 3 or less elements and returns a tuple containing the pair of closest points and their distance.
If my input is P = [(0,0),(7,6),(2,20),(12,5),(16,16),(5,8),(19,7),(14,22),(8,19),(7,29),(10,11),(1,13)]
, then my output is ((5,8),(7,6),2.8284271)
which is the correct output. But when my input is P = [(94, 5), (96, -79), (20, 73), (8, -50), (78, 2), (100, 63), (-14, -69), (99, -8), (-11, -7), (-78, -46)]
the output I get is ((78, 2), (94, 5), 16.278820596099706)
whereas the correct output should be ((94, 5), (99, -8), 13.92838827718412)
You have two problems, you are forgetting to call dist to update the best distance. But the main problem is there is more than one recursive call happening so you can end up overwriting when you find a closer split pair with the default, best,p3,q3 = d,None,None
. I passed the best pair from closest_pair
as an argument to closest_split_pair
so I would not potentially overwrite the value.
def closest_split_pair(p_x, p_y, delta, best_pair): # <- a parameter
ln_x = len(p_x)
mx_x = p_x[ln_x // 2][0]
s_y = [x for x in p_y if mx_x - delta <= x[0] <= mx_x + delta]
best = delta
for i in range(len(s_y) - 1):
for j in range(1, min(i + 7, (len(s_y) - i))):
p, q = s_y[i], s_y[i + j]
dst = dist(p, q)
if dst < best:
best_pair = p, q
best = dst
return best_pair
The end of closest_pair looks like the following:
p_1, q_1 = closest_pair(srt_q_x, srt_q_y)
p_2, q_2 = closest_pair(srt_r_x, srt_r_y)
closest = min(dist(p_1, q_1), dist(p_2, q_2))
# get min of both and then pass that as an arg to closest_split_pair
mn = min((p_1, q_1), (p_2, q_2), key=lambda x: dist(x[0], x[1]))
p_3, q_3 = closest_split_pair(p_x, p_y, closest,mn)
# either return mn or we have a closer split pair
return min(mn, (p_3, q_3), key=lambda x: dist(x[0], x[1]))
You also have some other logic issues, your slicing logic is not correct, I made some changes to your code where brute is just a simple bruteforce double loop:
def closestPair(Px, Py):
if len(Px) <= 3:
return brute(Px)
mid = len(Px) / 2
# get left and right half of Px
q, r = Px[:mid], Px[mid:]
# sorted versions of q and r by their x and y coordinates
Qx, Qy = [x for x in q if Py and x[0] <= Px[-1][0]], [x for x in q if x[1] <= Py[-1][1]]
Rx, Ry = [x for x in r if Py and x[0] <= Px[-1][0]], [x for x in r if x[1] <= Py[-1][1]]
(p1, q1) = closestPair(Qx, Qy)
(p2, q2) = closestPair(Rx, Ry)
d = min(dist(p1, p2), dist(p2, q2))
mn = min((p1, q1), (p2, q2), key=lambda x: dist(x[0], x[1]))
(p3, q3) = closest_split_pair(Px, Py, d, mn)
return min(mn, (p3, q3), key=lambda x: dist(x[0], x[1]))
I just did the algorithm today so there are no doubt some improvements to be made but this will get you the correct answer.