1 |
package ags.utils.dataStructures.trees.thirdGenKD; |
2 |
|
3 |
import ags.utils.dataStructures.BinaryHeap; |
4 |
import ags.utils.dataStructures.MaxHeap; |
5 |
import ags.utils.dataStructures.MinHeap; |
6 |
|
7 |
/** |
8 |
* |
9 |
*/ |
10 |
public class KdTree<T> extends KdNode<T> { |
11 |
public KdTree(int dimensions) { |
12 |
this(dimensions, 24); |
13 |
} |
14 |
|
15 |
public KdTree(int dimensions, int bucketCapacity) { |
16 |
super(dimensions, bucketCapacity); |
17 |
} |
18 |
|
19 |
public NearestNeighborIterator<T> getNearestNeighborIterator(double[] searchPoint, int maxPointsReturned, DistanceFunction distanceFunction) { |
20 |
return new NearestNeighborIterator<T>(this, searchPoint, maxPointsReturned, distanceFunction); |
21 |
} |
22 |
|
23 |
public MaxHeap<T> findNearestNeighbors(double[] searchPoint, int maxPointsReturned, DistanceFunction distanceFunction) { |
24 |
BinaryHeap.Min<KdNode<T>> pendingPaths = new BinaryHeap.Min<KdNode<T>>(); |
25 |
BinaryHeap.Max<T> evaluatedPoints = new BinaryHeap.Max<T>(); |
26 |
int pointsRemaining = Math.min(maxPointsReturned, size()); |
27 |
pendingPaths.offer(0, this); |
28 |
|
29 |
while (pendingPaths.size() > 0 && (evaluatedPoints.size() < pointsRemaining || (pendingPaths.getMinKey() < evaluatedPoints.getMaxKey()))) { |
30 |
nearestNeighborSearchStep(pendingPaths, evaluatedPoints, pointsRemaining, distanceFunction, searchPoint); |
31 |
} |
32 |
|
33 |
return evaluatedPoints; |
34 |
} |
35 |
|
36 |
@SuppressWarnings("unchecked") |
37 |
protected static <T> void nearestNeighborSearchStep ( |
38 |
MinHeap<KdNode<T>> pendingPaths, MaxHeap<T> evaluatedPoints, int desiredPoints, |
39 |
DistanceFunction distanceFunction, double[] searchPoint) { |
40 |
// If there are pending paths possibly closer than the nearest evaluated point, check it out |
41 |
KdNode<T> cursor = pendingPaths.getMin(); |
42 |
pendingPaths.removeMin(); |
43 |
|
44 |
// Descend the tree, recording paths not taken |
45 |
while (!cursor.isLeaf()) { |
46 |
KdNode<T> pathNotTaken; |
47 |
if (searchPoint[cursor.splitDimension] > cursor.splitValue) { |
48 |
pathNotTaken = cursor.left; |
49 |
cursor = cursor.right; |
50 |
} |
51 |
else { |
52 |
pathNotTaken = cursor.right; |
53 |
cursor = cursor.left; |
54 |
} |
55 |
double otherDistance = distanceFunction.distanceToRect(searchPoint, pathNotTaken.minBound, pathNotTaken.maxBound); |
56 |
// Only add a path if we either need more points or it's closer than furthest point on list so far |
57 |
if (evaluatedPoints.size() < desiredPoints || otherDistance <= evaluatedPoints.getMaxKey()) { |
58 |
pendingPaths.offer(otherDistance, pathNotTaken); |
59 |
} |
60 |
} |
61 |
|
62 |
if (cursor.singlePoint) { |
63 |
double nodeDistance = distanceFunction.distance(cursor.points[0], searchPoint); |
64 |
// Only add a point if either need more points or it's closer than furthest on list so far |
65 |
if (evaluatedPoints.size() < desiredPoints || nodeDistance <= evaluatedPoints.getMaxKey()) { |
66 |
for (int i = 0; i < cursor.size(); i++) { |
67 |
T value = (T) cursor.data[i]; |
68 |
|
69 |
// If we don't need any more, replace max |
70 |
if (evaluatedPoints.size() == desiredPoints) { |
71 |
evaluatedPoints.replaceMax(nodeDistance, value); |
72 |
} else { |
73 |
evaluatedPoints.offer(nodeDistance, value); |
74 |
} |
75 |
} |
76 |
} |
77 |
} else { |
78 |
// Add the points at the cursor |
79 |
for (int i = 0; i < cursor.size(); i++) { |
80 |
double[] point = cursor.points[i]; |
81 |
T value = (T) cursor.data[i]; |
82 |
double distance = distanceFunction.distance(point, searchPoint); |
83 |
// Only add a point if either need more points or it's closer than furthest on list so far |
84 |
if (evaluatedPoints.size() < desiredPoints) { |
85 |
evaluatedPoints.offer(distance, value); |
86 |
} else if (distance < evaluatedPoints.getMaxKey()) { |
87 |
evaluatedPoints.replaceMax(distance, value); |
88 |
} |
89 |
} |
90 |
} |
91 |
} |
92 |
} |