Skip to content

Commit 73b870f

Browse files
authored
Merge pull request #2 from jlvdb/nearest-neighbours
Added fixed-number nearest neighbour search
2 parents 15d136a + 9bc0408 commit 73b870f

20 files changed

+518
-118
lines changed

README.rst

+3-6
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,9 @@
1515
1616
A fast ball tree implementation for three dimensional (weighted) data with an
1717
Euclidean distance norm. The base implementation is in `C` and there is a
18-
wrapper for `Python`.
19-
20-
The tree is optimised towards spatial correlation function calculations since
21-
the query routines are geared towards range queries, i.e. counting pairs with a
22-
given (range of) separations. Fixed number nearest neighbour search is currently
23-
not implemented.
18+
wrapper for `Python`. The tree is optimised towards spatial correlation function
19+
calculations since it provides fast counting routinge, e.g. by implementing a
20+
dualtree query algorithm.
2421

2522
- Code: https://github.com/jlvdb/balltree.git
2623
- Docs: https://balltree.readthedocs.io/

balltree/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.0.1"
1+
__version__ = "1.1.0"
22

33
from .angulartree import AngularTree
44
from .balltree import BallTree, default_leafsize as _df

balltree/angulartree.py

+39-8
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ def from_random(
5151
Build a new AngularTree instance from randomly generated points.
5252
5353
The (ra, dec) coordinates are generated uniformly in the interval
54-
[`ra_min`, `ra_max`) and [`dec_min`, `dec_max`), respectively. `size`
55-
controlls the number of points generated. The optional `leafsize`
56-
determines when the tree query algorithms switch from traversal to brute
57-
force.
54+
[``ra_min``, ``ra_max``) and [``dec_min``, ``dec_max``), respectively.
55+
``size`` controlls the number of points generated. The optional
56+
``leafsize`` determines when the tree query algorithms switch from
57+
traversal to brute force.
5858
"""
5959
((x_min, y_min),) = coord.angular_to_cylinder([ra_min, dec_min])
6060
((x_max, y_max),) = coord.angular_to_cylinder([ra_max, dec_max])
@@ -74,11 +74,12 @@ def data(self) -> NDArray:
7474
np.transpose([data["x"], data["y"], data["z"]])
7575
)
7676

77-
dtype = [("ra", "f8"), ("dec", "f8"), ("weight", "f8")]
77+
dtype = [("ra", "f8"), ("dec", "f8"), ("weight", "f8"), ("index", "i8")]
7878
array = np.empty(len(data), dtype=dtype)
7979
array["ra"] = radec[:, 0]
8080
array["dec"] = radec[:, 1]
8181
array["weight"] = data["weight"]
82+
array["index"] = data["index"]
8283
return array
8384

8485
@property
@@ -102,7 +103,7 @@ def radius(self) -> float:
102103
center = coord.angular_to_euclidean(self.center)[0]
103104
radec_flat = self.data.view("f8")
104105
shape = (self.num_points, -1)
105-
xyz = coord.angular_to_euclidean(radec_flat.reshape(shape)[:, :-1])
106+
xyz = coord.angular_to_euclidean(radec_flat.reshape(shape)[:, :-2])
106107
# compute the maximum distance from the center project one the sphere
107108
diff = xyz - center[np.newaxis, :]
108109
dist = np.sqrt(np.sum(diff**2, axis=1))
@@ -120,15 +121,45 @@ def get_node_data() -> NDArray:
120121
"""
121122
Collect the meta data of all tree nodes in a numpy array.
122123
123-
The array fields record `depth` (starting from the root node),
124-
`num_points`, `sum_weight`, `x`, `y`, `z` (node center) and node `radius`.
124+
The array fields record ``depth`` (starting from the root node),
125+
``num_points``, ``sum_weight``, ``x``, ``y``, ``z`` (node center) and
126+
node ``radius``.
125127
126128
.. Note::
127129
The node coordinates and radius are currently not converted to
128130
anlges.
129131
"""
130132
return super().get_node_data()
131133

134+
def nearest_neighbours(self, radec, k, max_ang=-1.0) -> NDArray:
135+
"""
136+
Query a fixed number of nearest neighbours.
137+
138+
The query point(s) ``radec`` can be a numpy array of shape (2,) or (N, 2),
139+
or an equivalent python object. The number of neighbours ``k`` must be a
140+
positive integer and the optional ``max_ang`` parameter puts an upper
141+
bound on the angular separation (in radian) to the neighbours.
142+
143+
Returns an array with fields ``index``, holding the index to the neighbour
144+
in the array from which the tree was constructed, and ``angle``, the
145+
angular separation in radian. The result is sorted by separation,
146+
missing neighbours (e.g. if ``angle > max_ang``) are indicated by an
147+
index of -1 and infinite separation.
148+
"""
149+
xyz = coord.angular_to_euclidean(radec)
150+
if max_ang > 0:
151+
max_dist = coord.angle_to_chorddist(max_ang)
152+
else:
153+
max_dist = -1.0
154+
raw = super().nearest_neighbours(xyz, k, max_dist=max_dist)
155+
good = raw["index"] >= 0
156+
157+
result = np.empty(raw.shape, dtype=[("index", "i8"), ("angle", "f8")])
158+
result["index"] = raw["index"]
159+
result["angle"][~good] = np.inf
160+
result["angle"][good] = coord.chorddist_to_angle(raw["distance"][good])
161+
return result
162+
132163
def brute_radius(
133164
self,
134165
radec: ArrayLike,

balltree/balltree.c

+142-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <stdint.h>
66
#include <stdio.h>
7+
#include <string.h>
78

89
#include "point.h"
910
#include "balltree.h"
@@ -65,6 +66,7 @@ void inputiterdata_free(InputIterData *);
6566
static PointBuffer *ptbuf_from_PyObjects(PyObject *xyz_obj, PyObject *weight_obj);
6667
static PyObject *ptbuf_get_numpy_view(PointBuffer *buffer);
6768
static PyObject *statvec_get_numpy_array(StatsVector *vec);
69+
static PyObject *queueitems_get_numpy_array(QueueItem *items, npy_intp size, npy_intp n_items);
6870
static DistHistogram *disthistogram_from_PyObject(PyObject *edges_obj);
6971
static PyObject *PyObject_from_disthistogram(DistHistogram *hist);
7072
static PyObject *PyBallTree_accumulate_radius(PyBallTree *self, count_radius_func accumulator, PyObject *xyz_obj, double radius, PyObject *weight_obj);
@@ -87,6 +89,7 @@ static PyObject *PyBallTree_get_radius(PyBallTree *self, void *closure);
8789
static PyObject *PyBallTree_str(PyObject *self);
8890
static PyObject *PyBallTree_to_file(PyBallTree *self, PyObject *args);
8991
static PyObject *PyBallTree_count_nodes(PyBallTree *self);
92+
static PyObject *PyBallTree_nearest_neighbours(PyBallTree *self, PyObject *args, PyObject *kwargs);
9093
static PyObject *PyBallTree_get_node_data(PyBallTree *self);
9194
static PyObject *PyBallTree_brute_radius(PyBallTree *self, PyObject *args, PyObject *kwargs);
9295
static PyObject *PyBallTree_count_radius(PyBallTree *self, PyObject *args, PyObject *kwargs);
@@ -230,6 +233,7 @@ PyArrayObject *numpy_array_add_dim(PyArrayObject* array) {
230233
if (reshaped == NULL) {
231234
PyErr_SetString(PyExc_MemoryError, "failed to reshape array");
232235
}
236+
Py_DECREF(array);
233237
return reshaped;
234238
}
235239

@@ -279,7 +283,6 @@ static PyArrayObject *xyz_ensure_2dim_double(PyObject *xyz_obj) {
279283
npy_int ndim = PyArray_NDIM(xyz_arr);
280284
if (ndim == 1) {
281285
xyz_arr_2dim = numpy_array_add_dim(xyz_arr);
282-
Py_DECREF(xyz_arr);
283286
if (xyz_arr_2dim == NULL) {
284287
return NULL;
285288
}
@@ -393,7 +396,12 @@ static PointBuffer *ptbuf_from_PyObjects(PyObject *xyz_obj, PyObject *weight_obj
393396
int64_t idx = 0;
394397
double x, y, z;
395398
while (iter_get_next_xyz(data->xyz_iter, &x, &y, &z)) {
396-
buffer->points[idx] = (Point){x, y, z, data->weight_buffer[idx]};
399+
Point *point = buffer->points + idx;
400+
point->x = x;
401+
point->y = y;
402+
point->z = z;
403+
point->weight = data->weight_buffer[idx];
404+
// index is already initialised
397405
++idx;
398406
}
399407
inputiterdata_free(data);
@@ -406,11 +414,12 @@ static PyObject *ptbuf_get_numpy_view(PointBuffer *buffer) {
406414

407415
// construct an appropriate dtype for Point
408416
PyObject *arr_dtype = Py_BuildValue(
409-
"[(ss)(ss)(ss)(ss)]",
417+
"[(ss)(ss)(ss)(ss)(ss)]",
410418
"x", "f8",
411419
"y", "f8",
412420
"z", "f8",
413-
"weight", "f8"
421+
"weight", "f8",
422+
"index", "i8"
414423
);
415424
if (arr_dtype == NULL) {
416425
return NULL;
@@ -494,15 +503,48 @@ static PyObject *statvec_get_numpy_array(StatsVector *vec) {
494503

495504
// create an uninitialised array and copy the data into it
496505
PyObject *array = PyArray_Empty(ndim, shape, arr_descr, 0);
497-
Py_DECREF(arr_descr);
498506
if (array == NULL) {
507+
Py_DECREF(arr_descr);
499508
return NULL;
500509
}
501510
void *ptr = PyArray_DATA(array);
502511
memcpy(ptr, vec->stats, sizeof(NodeStats) * vec->size);
503512
return array;
504513
}
505514

515+
static PyObject *queueitems_get_numpy_array(QueueItem *items, npy_intp size, npy_intp n_items) {
516+
const npy_intp ndim = 2;
517+
npy_intp shape[2] = {size, n_items};
518+
519+
// construct an appropriate dtype for QueueItem buffer
520+
PyObject *arr_dtype = Py_BuildValue(
521+
"[(ss)(ss)]",
522+
"index", "i8",
523+
"distance", "f8"
524+
);
525+
if (arr_dtype == NULL) {
526+
return NULL;
527+
}
528+
529+
// get the numpy API array descriptor
530+
PyArray_Descr *arr_descr;
531+
int result = PyArray_DescrConverter(arr_dtype, &arr_descr); // PyArray_Descr **
532+
Py_DECREF(arr_dtype);
533+
if (result != NPY_SUCCEED) {
534+
return NULL;
535+
}
536+
537+
// create an uninitialised array and copy the data into it
538+
PyObject *array = PyArray_Empty(ndim, shape, arr_descr, 0);
539+
if (array == NULL) {
540+
Py_DECREF(arr_descr);
541+
return NULL;
542+
}
543+
void *ptr = PyArray_DATA(array);
544+
memcpy(ptr, items, sizeof(QueueItem) * size * n_items);
545+
return array;
546+
}
547+
506548
static PyObject *PyBallTree_accumulate_radius(
507549
PyBallTree *self,
508550
count_radius_func accumulator,
@@ -517,7 +559,7 @@ static PyObject *PyBallTree_accumulate_radius(
517559
// count neighbours for all inputs
518560
double count = 0.0;
519561
int64_t idx = 0;
520-
Point point;
562+
Point point = {0.0, 0.0, 0.0, 0.0, 0};
521563
while (iter_get_next_xyz(data->xyz_iter, &point.x, &point.y, &point.z)) {
522564
point.weight = data->weight_buffer[idx];
523565
count += accumulator(self->balltree, &point, radius);
@@ -545,7 +587,7 @@ static PyObject *PyBallTree_accumulate_range(
545587
}
546588
// count neighbours for all inputs
547589
int64_t idx = 0;
548-
Point point;
590+
Point point = {0.0, 0.0, 0.0, 0};
549591
while (iter_get_next_xyz(data->xyz_iter, &point.x, &point.y, &point.z)) {
550592
point.weight = data->weight_buffer[idx];
551593
accumulator(self->balltree, &point, hist);
@@ -613,8 +655,8 @@ PyDoc_STRVAR(
613655
"--\n\n"
614656
"Build a new BallTree instance from randomly generated points.\n\n"
615657
"The (x, y, z) coordinates are generated uniformly in the interval\n"
616-
"[`low`, `high`), `size` controlls the number of points generated. The\n"
617-
"optional `leafsize` determines when the tree query algorithms switch from\n"
658+
"[``low``, ``high``), ``size`` controlls the number of points generated. The\n"
659+
"optional ``leafsize`` determines when the tree query algorithms switch from\n"
618660
"traversal to brute force."
619661
);
620662

@@ -794,7 +836,7 @@ static PyObject *PyBallTree_str(PyObject *self) {
794836
int n_bytes = snprintf(
795837
buffer,
796838
sizeof(buffer),
797-
"BallTree(num_points=%ld, radius=%lf, center=(%lf, %lf, %lf))",
839+
"BallTree(num_points=%lld, radius=%lf, center=(%lf, %lf, %lf))",
798840
tree->data->size,
799841
node->ball.radius,
800842
node->ball.x,
@@ -858,8 +900,9 @@ PyDoc_STRVAR(
858900
"get_node_data(self) -> NDArray\n"
859901
"--\n\n"
860902
"Collect the meta data of all tree nodes in a numpy array.\n\n"
861-
"The array fields record `depth` (starting from the root node),\n"
862-
"`num_points`, `sum_weight`, `x`, `y`, `z` (node center) and node `radius`."
903+
"The array fields record ``depth`` (starting from the root node),\n"
904+
"``num_points``, ``sum_weight``, ``x``, ``y``, ``z`` (node center) and node\n"
905+
"``radius``."
863906
);
864907

865908
static PyObject *PyBallTree_get_node_data(PyBallTree *self) {
@@ -873,6 +916,87 @@ static PyObject *PyBallTree_get_node_data(PyBallTree *self) {
873916
return array;
874917
}
875918

919+
PyDoc_STRVAR(
920+
// .. py:method::
921+
nearest_neighbours_doc,
922+
"nearest_neighbours(self, xyz: ArrayLike, k: int, max_dist: float = -1.0) -> NDArray\n"
923+
"--\n\n"
924+
"Query a fixed number of nearest neighbours.\n\n"
925+
"The query point(s) ``xyz`` can be a numpy array of shape (3,) or (N, 3),\n"
926+
"or an equivalent python object. The number of neighbours ``k`` must be a\n"
927+
"positive integer and the optional ``max_dist`` parameter puts an upper\n"
928+
"bound on the separation to the neighbours.\n\n"
929+
"Returns an array with fields ``index``, holding the index to the neighbour\n"
930+
"in the array from which the tree was constructed, and ``distance``, the\n"
931+
"separation. The result is sorted by separation, missing neighbours (e.g. if\n"
932+
"``distance > max_dist``) are indicated by an index of -1 and infinite\n"
933+
"separation.\n"
934+
);
935+
936+
static PyObject *PyBallTree_nearest_neighbours(
937+
PyBallTree *self,
938+
PyObject *args,
939+
PyObject *kwargs
940+
) {
941+
static char *kwlist[] = {"xyz", "k", "max_dist", NULL};
942+
PyObject *xyz_obj;
943+
long num_neighbours; // screw Windows
944+
double max_dist = -1.0;
945+
if (!PyArg_ParseTupleAndKeywords(
946+
args, kwargs, "Ol|d", kwlist,
947+
&xyz_obj, &num_neighbours, &max_dist)
948+
) {
949+
return NULL;
950+
}
951+
if (num_neighbours < 1) {
952+
PyErr_SetString(PyExc_ValueError, "number of neighbours must be positive");
953+
return NULL;
954+
}
955+
956+
InputIterData *data = inputiterdata_new(xyz_obj, Py_None);
957+
if (data == NULL) {
958+
return NULL;
959+
}
960+
961+
// allocate output buffer
962+
size_t n_bytes_queue = num_neighbours * sizeof(QueueItem);
963+
QueueItem *result = malloc(data->size * n_bytes_queue);
964+
if (result == NULL) {
965+
PyErr_SetString(PyExc_MemoryError, "failed to allocate output array");
966+
inputiterdata_free(data);
967+
return NULL;
968+
}
969+
970+
// find neighbours for all inputs
971+
KnnQueue *queue = NULL;
972+
PyObject *pyresult = NULL;
973+
int64_t idx = 0;
974+
Point point = {0.0, 0.0, 0.0, 0.0, 0};
975+
while (iter_get_next_xyz(data->xyz_iter, &point.x, &point.y, &point.z)) {
976+
queue = balltree_nearest_neighbours(
977+
self->balltree,
978+
&point,
979+
num_neighbours,
980+
max_dist
981+
);
982+
if (queue == NULL) {
983+
printf("oops\n");
984+
goto error;
985+
}
986+
// copy result into output buffer
987+
memcpy(&result[idx], queue->items, n_bytes_queue);
988+
knque_free(queue);
989+
idx += num_neighbours;
990+
}
991+
992+
// convert to numpy array
993+
pyresult = queueitems_get_numpy_array(result, data->size, num_neighbours);
994+
error:
995+
free(result);
996+
inputiterdata_free(data);
997+
return pyresult;
998+
}
999+
8761000
PyDoc_STRVAR(
8771001
// .. py:method::
8781002
brute_radius_doc,
@@ -1139,6 +1263,12 @@ static PyMethodDef PyBallTree_methods[] = {
11391263
.ml_flags = METH_NOARGS,
11401264
.ml_doc = get_node_data_doc
11411265
},
1266+
{
1267+
.ml_name = "nearest_neighbours",
1268+
.ml_meth = (PyCFunctionWithKeywords)PyBallTree_nearest_neighbours,
1269+
.ml_flags = METH_VARARGS | METH_KEYWORDS,
1270+
.ml_doc = nearest_neighbours_doc
1271+
},
11421272
{
11431273
.ml_name = "brute_radius",
11441274
.ml_meth = (PyCFunctionWithKeywords)PyBallTree_brute_radius,

include/ballnode.h

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "point.h"
77
#include "histogram.h"
8+
#include "queue.h"
89

910
#define BALLNODE_IS_LEAF(node) (node)->is_leaf
1011

@@ -57,6 +58,7 @@ int bnode_is_leaf(const BallNode *);
5758
PointSlice bnode_get_ptslc(const BallNode *);
5859

5960
// from ballnode_query.c
61+
void bnode_nearest_neighbours(const BallNode *, const Point *, KnnQueue *);
6062
double bnode_count_radius(const BallNode *, const Point *, double);
6163
void bnode_count_range(const BallNode *, const Point *, DistHistogram *);
6264
double bnode_dualcount_radius(const BallNode *, const BallNode *, double);

0 commit comments

Comments
 (0)