-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathkdtree.go
158 lines (145 loc) · 4.13 KB
/
kdtree.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
// Copyright 2012 Sonia Keys
// License MIT: http://www.opensource.org/licenses/MIT
// K-d tree example implementation.
//
// Implmentation follows pseudocode from "An intoductory tutorial on kd-trees"
// by Andrew W. Moore, Carnegie Mellon University, PDF accessed from
// http://www.autonlab.org/autonweb/14665
package kdtree
import (
"math"
"sort"
)
// Point is a k-dimensional point.
type Point []float64
// Sqd returns the square of the euclidean distance.
func (p Point) Sqd(q Point) float64 {
var sum float64
for dim, pCoord := range p {
d := pCoord - q[dim]
sum += d * d
}
return sum
}
// HyperRect is used to represent a k-dimensional bounding box.
type HyperRect struct {
Min, Max Point
}
// Copy performs a deep copy, which is usually what you want.
//
// Go slices (the Point objects in a HyperRect) are reference objects.
// The data must be copied if you want to modify one without modifying
// the original.
func (hr HyperRect) Copy() HyperRect {
return HyperRect{append(Point{}, hr.Min...), append(Point{}, hr.Max...)}
}
// KdTree represents a k-d tree and associated k-d bounding box.
type KdTree struct {
n *kdNode
Bounds HyperRect
}
// kdNode following field names in the paper.
// rangeElt would be whatever data is associated with the point.
// we don't bother with it for this example.
type kdNode struct {
domElt Point
split int
left, right *kdNode
}
// New constructs a KdTree from a list of points and a bounding box.
//
// The bounds could be computed of course, but typically you know them already.
func New(pts []Point, bounds HyperRect) KdTree {
// algorithm is table 6.3 in the paper.
var nk2 func([]Point, int) *kdNode
nk2 = func(exset []Point, split int) *kdNode {
if len(exset) == 0 {
return nil
}
// pivot choosing procedure. we find median, then find largest
// index of points with median value. this satisfies the
// inequalities of steps 6 and 7 in the algorithm.
sort.Sort(part{exset, split})
m := len(exset) / 2
d := exset[m]
for m+1 < len(exset) && exset[m+1][split] == d[split] {
m++
}
// next split
s2 := split + 1
if s2 == len(d) {
s2 = 0
}
return &kdNode{d, split, nk2(exset[:m], s2), nk2(exset[m+1:], s2)}
}
return KdTree{nk2(pts, 0), bounds}
}
// Nearest. find nearest neighbor.
//
// return values:
// - nearest neighbor--the point within the tree that is nearest p.
// - square of the distance to that point.
// - a count of the nodes visited in the search.
func (t KdTree) Nearest(p Point) (best Point, bestSqd float64, nv int) {
return nn(t.n, p, t.Bounds, math.Inf(1))
}
// algorithm is table 6.4 from the paper, with the addition of counting
// the number nodes visited.
func nn(kd *kdNode, target Point, hr HyperRect,
maxDistSqd float64) (nearest Point, distSqd float64, nodesVisited int) {
if kd == nil {
return nil, math.Inf(1), 0
}
nodesVisited++
s := kd.split
pivot := kd.domElt
leftHr := hr.Copy()
rightHr := hr.Copy()
leftHr.Max[s] = pivot[s]
rightHr.Min[s] = pivot[s]
targetInLeft := target[s] <= pivot[s]
var nearerKd, furtherKd *kdNode
var nearerHr, furtherHr HyperRect
if targetInLeft {
nearerKd, nearerHr = kd.left, leftHr
furtherKd, furtherHr = kd.right, rightHr
} else {
nearerKd, nearerHr = kd.right, rightHr
furtherKd, furtherHr = kd.left, leftHr
}
var nv int
nearest, distSqd, nv = nn(nearerKd, target, nearerHr, maxDistSqd)
nodesVisited += nv
if distSqd < maxDistSqd {
maxDistSqd = distSqd
}
d := pivot[s] - target[s]
d *= d
if d > maxDistSqd {
return
}
if d = pivot.Sqd(target); d < distSqd {
nearest = pivot
distSqd = d
maxDistSqd = distSqd
}
tempNearest, tempSqd, nv := nn(furtherKd, target, furtherHr, maxDistSqd)
nodesVisited += nv
if tempSqd < distSqd {
nearest = tempNearest
distSqd = tempSqd
}
return
}
// a container type used for sorting. it holds the points to sort and
// the dimension to use for the sort key.
type part struct {
pts []Point
dPart int
}
// satisfy sort.Interface
func (p part) Len() int { return len(p.pts) }
func (p part) Less(i, j int) bool {
return p.pts[i][p.dPart] < p.pts[j][p.dPart]
}
func (p part) Swap(i, j int) { p.pts[i], p.pts[j] = p.pts[j], p.pts[i] }