4
4
5
5
#include <stdint.h>
6
6
#include <stdio.h>
7
+ #include <string.h>
7
8
8
9
#include "point.h"
9
10
#include "balltree.h"
@@ -65,6 +66,7 @@ void inputiterdata_free(InputIterData *);
65
66
static PointBuffer * ptbuf_from_PyObjects (PyObject * xyz_obj , PyObject * weight_obj );
66
67
static PyObject * ptbuf_get_numpy_view (PointBuffer * buffer );
67
68
static PyObject * statvec_get_numpy_array (StatsVector * vec );
69
+ static PyObject * queueitems_get_numpy_array (QueueItem * items , npy_intp size , npy_intp n_items );
68
70
static DistHistogram * disthistogram_from_PyObject (PyObject * edges_obj );
69
71
static PyObject * PyObject_from_disthistogram (DistHistogram * hist );
70
72
static PyObject * PyBallTree_accumulate_radius (PyBallTree * self , count_radius_func accumulator , PyObject * xyz_obj , double radius , PyObject * weight_obj );
@@ -87,6 +89,7 @@ static PyObject *PyBallTree_get_radius(PyBallTree *self, void *closure);
87
89
static PyObject * PyBallTree_str (PyObject * self );
88
90
static PyObject * PyBallTree_to_file (PyBallTree * self , PyObject * args );
89
91
static PyObject * PyBallTree_count_nodes (PyBallTree * self );
92
+ static PyObject * PyBallTree_nearest_neighbours (PyBallTree * self , PyObject * args , PyObject * kwargs );
90
93
static PyObject * PyBallTree_get_node_data (PyBallTree * self );
91
94
static PyObject * PyBallTree_brute_radius (PyBallTree * self , PyObject * args , PyObject * kwargs );
92
95
static PyObject * PyBallTree_count_radius (PyBallTree * self , PyObject * args , PyObject * kwargs );
@@ -230,6 +233,7 @@ PyArrayObject *numpy_array_add_dim(PyArrayObject* array) {
230
233
if (reshaped == NULL ) {
231
234
PyErr_SetString (PyExc_MemoryError , "failed to reshape array" );
232
235
}
236
+ Py_DECREF (array );
233
237
return reshaped ;
234
238
}
235
239
@@ -279,7 +283,6 @@ static PyArrayObject *xyz_ensure_2dim_double(PyObject *xyz_obj) {
279
283
npy_int ndim = PyArray_NDIM (xyz_arr );
280
284
if (ndim == 1 ) {
281
285
xyz_arr_2dim = numpy_array_add_dim (xyz_arr );
282
- Py_DECREF (xyz_arr );
283
286
if (xyz_arr_2dim == NULL ) {
284
287
return NULL ;
285
288
}
@@ -393,7 +396,12 @@ static PointBuffer *ptbuf_from_PyObjects(PyObject *xyz_obj, PyObject *weight_obj
393
396
int64_t idx = 0 ;
394
397
double x , y , z ;
395
398
while (iter_get_next_xyz (data -> xyz_iter , & x , & y , & z )) {
396
- buffer -> points [idx ] = (Point ){x , y , z , data -> weight_buffer [idx ]};
399
+ Point * point = buffer -> points + idx ;
400
+ point -> x = x ;
401
+ point -> y = y ;
402
+ point -> z = z ;
403
+ point -> weight = data -> weight_buffer [idx ];
404
+ // index is already initialised
397
405
++ idx ;
398
406
}
399
407
inputiterdata_free (data );
@@ -406,11 +414,12 @@ static PyObject *ptbuf_get_numpy_view(PointBuffer *buffer) {
406
414
407
415
// construct an appropriate dtype for Point
408
416
PyObject * arr_dtype = Py_BuildValue (
409
- "[(ss)(ss)(ss)(ss)]" ,
417
+ "[(ss)(ss)(ss)(ss)(ss) ]" ,
410
418
"x" , "f8" ,
411
419
"y" , "f8" ,
412
420
"z" , "f8" ,
413
- "weight" , "f8"
421
+ "weight" , "f8" ,
422
+ "index" , "i8"
414
423
);
415
424
if (arr_dtype == NULL ) {
416
425
return NULL ;
@@ -494,15 +503,48 @@ static PyObject *statvec_get_numpy_array(StatsVector *vec) {
494
503
495
504
// create an uninitialised array and copy the data into it
496
505
PyObject * array = PyArray_Empty (ndim , shape , arr_descr , 0 );
497
- Py_DECREF (arr_descr );
498
506
if (array == NULL ) {
507
+ Py_DECREF (arr_descr );
499
508
return NULL ;
500
509
}
501
510
void * ptr = PyArray_DATA (array );
502
511
memcpy (ptr , vec -> stats , sizeof (NodeStats ) * vec -> size );
503
512
return array ;
504
513
}
505
514
515
+ static PyObject * queueitems_get_numpy_array (QueueItem * items , npy_intp size , npy_intp n_items ) {
516
+ const npy_intp ndim = 2 ;
517
+ npy_intp shape [2 ] = {size , n_items };
518
+
519
+ // construct an appropriate dtype for QueueItem buffer
520
+ PyObject * arr_dtype = Py_BuildValue (
521
+ "[(ss)(ss)]" ,
522
+ "index" , "i8" ,
523
+ "distance" , "f8"
524
+ );
525
+ if (arr_dtype == NULL ) {
526
+ return NULL ;
527
+ }
528
+
529
+ // get the numpy API array descriptor
530
+ PyArray_Descr * arr_descr ;
531
+ int result = PyArray_DescrConverter (arr_dtype , & arr_descr ); // PyArray_Descr **
532
+ Py_DECREF (arr_dtype );
533
+ if (result != NPY_SUCCEED ) {
534
+ return NULL ;
535
+ }
536
+
537
+ // create an uninitialised array and copy the data into it
538
+ PyObject * array = PyArray_Empty (ndim , shape , arr_descr , 0 );
539
+ if (array == NULL ) {
540
+ Py_DECREF (arr_descr );
541
+ return NULL ;
542
+ }
543
+ void * ptr = PyArray_DATA (array );
544
+ memcpy (ptr , items , sizeof (QueueItem ) * size * n_items );
545
+ return array ;
546
+ }
547
+
506
548
static PyObject * PyBallTree_accumulate_radius (
507
549
PyBallTree * self ,
508
550
count_radius_func accumulator ,
@@ -517,7 +559,7 @@ static PyObject *PyBallTree_accumulate_radius(
517
559
// count neighbours for all inputs
518
560
double count = 0.0 ;
519
561
int64_t idx = 0 ;
520
- Point point ;
562
+ Point point = { 0.0 , 0.0 , 0.0 , 0.0 , 0 } ;
521
563
while (iter_get_next_xyz (data -> xyz_iter , & point .x , & point .y , & point .z )) {
522
564
point .weight = data -> weight_buffer [idx ];
523
565
count += accumulator (self -> balltree , & point , radius );
@@ -545,7 +587,7 @@ static PyObject *PyBallTree_accumulate_range(
545
587
}
546
588
// count neighbours for all inputs
547
589
int64_t idx = 0 ;
548
- Point point ;
590
+ Point point = { 0.0 , 0.0 , 0.0 , 0 } ;
549
591
while (iter_get_next_xyz (data -> xyz_iter , & point .x , & point .y , & point .z )) {
550
592
point .weight = data -> weight_buffer [idx ];
551
593
accumulator (self -> balltree , & point , hist );
@@ -613,8 +655,8 @@ PyDoc_STRVAR(
613
655
"--\n\n"
614
656
"Build a new BallTree instance from randomly generated points.\n\n"
615
657
"The (x, y, z) coordinates are generated uniformly in the interval\n"
616
- "[`low`, `high`), `size` controlls the number of points generated. The\n"
617
- "optional `leafsize` determines when the tree query algorithms switch from\n"
658
+ "[`` low`` , `` high`` ), `` size` ` controlls the number of points generated. The\n"
659
+ "optional `` leafsize` ` determines when the tree query algorithms switch from\n"
618
660
"traversal to brute force."
619
661
);
620
662
@@ -794,7 +836,7 @@ static PyObject *PyBallTree_str(PyObject *self) {
794
836
int n_bytes = snprintf (
795
837
buffer ,
796
838
sizeof (buffer ),
797
- "BallTree(num_points=%ld , radius=%lf, center=(%lf, %lf, %lf))" ,
839
+ "BallTree(num_points=%lld , radius=%lf, center=(%lf, %lf, %lf))" ,
798
840
tree -> data -> size ,
799
841
node -> ball .radius ,
800
842
node -> ball .x ,
@@ -858,8 +900,9 @@ PyDoc_STRVAR(
858
900
"get_node_data(self) -> NDArray\n"
859
901
"--\n\n"
860
902
"Collect the meta data of all tree nodes in a numpy array.\n\n"
861
- "The array fields record `depth` (starting from the root node),\n"
862
- "`num_points`, `sum_weight`, `x`, `y`, `z` (node center) and node `radius`."
903
+ "The array fields record ``depth`` (starting from the root node),\n"
904
+ "``num_points``, ``sum_weight``, ``x``, ``y``, ``z`` (node center) and node\n"
905
+ "``radius``."
863
906
);
864
907
865
908
static PyObject * PyBallTree_get_node_data (PyBallTree * self ) {
@@ -873,6 +916,87 @@ static PyObject *PyBallTree_get_node_data(PyBallTree *self) {
873
916
return array ;
874
917
}
875
918
919
+ PyDoc_STRVAR (
920
+ // .. py:method::
921
+ nearest_neighbours_doc ,
922
+ "nearest_neighbours(self, xyz: ArrayLike, k: int, max_dist: float = -1.0) -> NDArray\n"
923
+ "--\n\n"
924
+ "Query a fixed number of nearest neighbours.\n\n"
925
+ "The query point(s) ``xyz`` can be a numpy array of shape (3,) or (N, 3),\n"
926
+ "or an equivalent python object. The number of neighbours ``k`` must be a\n"
927
+ "positive integer and the optional ``max_dist`` parameter puts an upper\n"
928
+ "bound on the separation to the neighbours.\n\n"
929
+ "Returns an array with fields ``index``, holding the index to the neighbour\n"
930
+ "in the array from which the tree was constructed, and ``distance``, the\n"
931
+ "separation. The result is sorted by separation, missing neighbours (e.g. if\n"
932
+ "``distance > max_dist``) are indicated by an index of -1 and infinite\n"
933
+ "separation.\n"
934
+ );
935
+
936
+ static PyObject * PyBallTree_nearest_neighbours (
937
+ PyBallTree * self ,
938
+ PyObject * args ,
939
+ PyObject * kwargs
940
+ ) {
941
+ static char * kwlist [] = {"xyz" , "k" , "max_dist" , NULL };
942
+ PyObject * xyz_obj ;
943
+ long num_neighbours ; // screw Windows
944
+ double max_dist = -1.0 ;
945
+ if (!PyArg_ParseTupleAndKeywords (
946
+ args , kwargs , "Ol|d" , kwlist ,
947
+ & xyz_obj , & num_neighbours , & max_dist )
948
+ ) {
949
+ return NULL ;
950
+ }
951
+ if (num_neighbours < 1 ) {
952
+ PyErr_SetString (PyExc_ValueError , "number of neighbours must be positive" );
953
+ return NULL ;
954
+ }
955
+
956
+ InputIterData * data = inputiterdata_new (xyz_obj , Py_None );
957
+ if (data == NULL ) {
958
+ return NULL ;
959
+ }
960
+
961
+ // allocate output buffer
962
+ size_t n_bytes_queue = num_neighbours * sizeof (QueueItem );
963
+ QueueItem * result = malloc (data -> size * n_bytes_queue );
964
+ if (result == NULL ) {
965
+ PyErr_SetString (PyExc_MemoryError , "failed to allocate output array" );
966
+ inputiterdata_free (data );
967
+ return NULL ;
968
+ }
969
+
970
+ // find neighbours for all inputs
971
+ KnnQueue * queue = NULL ;
972
+ PyObject * pyresult = NULL ;
973
+ int64_t idx = 0 ;
974
+ Point point = {0.0 , 0.0 , 0.0 , 0.0 , 0 };
975
+ while (iter_get_next_xyz (data -> xyz_iter , & point .x , & point .y , & point .z )) {
976
+ queue = balltree_nearest_neighbours (
977
+ self -> balltree ,
978
+ & point ,
979
+ num_neighbours ,
980
+ max_dist
981
+ );
982
+ if (queue == NULL ) {
983
+ printf ("oops\n" );
984
+ goto error ;
985
+ }
986
+ // copy result into output buffer
987
+ memcpy (& result [idx ], queue -> items , n_bytes_queue );
988
+ knque_free (queue );
989
+ idx += num_neighbours ;
990
+ }
991
+
992
+ // convert to numpy array
993
+ pyresult = queueitems_get_numpy_array (result , data -> size , num_neighbours );
994
+ error :
995
+ free (result );
996
+ inputiterdata_free (data );
997
+ return pyresult ;
998
+ }
999
+
876
1000
PyDoc_STRVAR (
877
1001
// .. py:method::
878
1002
brute_radius_doc ,
@@ -1139,6 +1263,12 @@ static PyMethodDef PyBallTree_methods[] = {
1139
1263
.ml_flags = METH_NOARGS ,
1140
1264
.ml_doc = get_node_data_doc
1141
1265
},
1266
+ {
1267
+ .ml_name = "nearest_neighbours" ,
1268
+ .ml_meth = (PyCFunctionWithKeywords )PyBallTree_nearest_neighbours ,
1269
+ .ml_flags = METH_VARARGS | METH_KEYWORDS ,
1270
+ .ml_doc = nearest_neighbours_doc
1271
+ },
1142
1272
{
1143
1273
.ml_name = "brute_radius" ,
1144
1274
.ml_meth = (PyCFunctionWithKeywords )PyBallTree_brute_radius ,
0 commit comments