|
| 1 | +#include "accumulate_knn_cpu.h" |
| 2 | +#include "helpers.h" |
| 3 | +#include "utils.h" |
| 4 | + |
| 5 | +#include <torch/extension.h> |
| 6 | +#include <string> //size_t, just for helper function |
| 7 | +#include <cmath> |
| 8 | +//#include <iostream> //remove later DEBUG FIXME |
| 9 | + |
| 10 | +static inline float distanceWeight(const float& distsq) { |
| 11 | + return exp(-1. * ACCUMULATE_KNN_EXPONENT * distsq); |
| 12 | +} |
| 13 | + |
| 14 | +void compute(const float_t *d_distances, |
| 15 | + const float_t *d_feat, |
| 16 | + const int32_t *d_idxs, |
| 17 | + |
| 18 | + float_t *d_out_feat, |
| 19 | + int32_t *d_out_maxidxs, |
| 20 | + |
| 21 | + size_t n_vert, |
| 22 | + size_t n_neigh, |
| 23 | + size_t n_feat, |
| 24 | + |
| 25 | + size_t n_out_feat, |
| 26 | + |
| 27 | + size_t n_moments, |
| 28 | + bool mean_and_max) |
| 29 | +{ |
| 30 | + for (size_t i_v = 0; i_v < n_vert; i_v++) { |
| 31 | + |
| 32 | + for (size_t i_f = 0; i_f < n_feat; i_f++) { |
| 33 | + |
| 34 | + float t_mean = 0; |
| 35 | + float t_max = 0; |
| 36 | + int max_i_n_gidx = 0; |
| 37 | + |
| 38 | + for (size_t i_n = 0; i_n < n_neigh; i_n++) { |
| 39 | + |
| 40 | + int nidx = d_idxs[I2D(i_v, i_n, n_neigh)]; |
| 41 | + if (nidx < 0) continue; |
| 42 | + |
| 43 | + float vnf = d_feat[I2D(nidx, i_f, n_feat)]; |
| 44 | + float distsq = d_distances[I2D(i_v, i_n, n_neigh)]; |
| 45 | + float wfeat = vnf * distanceWeight(distsq); |
| 46 | + |
| 47 | + t_mean += wfeat; |
| 48 | + |
| 49 | + if (mean_and_max && (wfeat >= t_max || !i_n)) { |
| 50 | + max_i_n_gidx = nidx; |
| 51 | + t_max = wfeat; |
| 52 | + } |
| 53 | + } |
| 54 | + |
| 55 | + t_mean /= (float)n_neigh; |
| 56 | + d_out_feat[I2D(i_v, i_f, n_out_feat)] = t_mean; |
| 57 | + if (mean_and_max) { |
| 58 | + d_out_maxidxs[I2D(i_v, i_f, n_feat)] = max_i_n_gidx; |
| 59 | + d_out_feat[I2D(i_v, i_f + n_feat, n_out_feat)] = t_max; |
| 60 | + } |
| 61 | + |
| 62 | + } |
| 63 | + } |
| 64 | +} |
| 65 | + |
| 66 | +std::tuple<torch::Tensor, torch::Tensor> |
| 67 | +accumulate_knn_cpu(torch::Tensor distances, |
| 68 | + torch::Tensor features, |
| 69 | + torch::Tensor indices, |
| 70 | + int n_moments, |
| 71 | + bool mean_and_max) |
| 72 | +{ |
| 73 | + const auto n_vert = distances.size(0); |
| 74 | + const auto n_neigh = indices.size(1); |
| 75 | + const auto n_coords = distances.size(1); |
| 76 | + const auto n_feat = features.size(1); |
| 77 | + |
| 78 | + assert(n_vert == indices.size(0) && n_vert == features.size(0)); |
| 79 | + assert(n_neigh == distances.size(1)); |
| 80 | + |
| 81 | + int64_t n_out_feat = n_feat; |
| 82 | + if (mean_and_max) { |
| 83 | + n_out_feat *= 2; } |
| 84 | + |
| 85 | + auto output_feat_tensor = torch::zeros({ n_vert,n_out_feat }, |
| 86 | + torch::TensorOptions().dtype(torch::kFloat32)); |
| 87 | + auto output_max_idxs_tensor = torch::zeros({ n_vert,n_feat }, |
| 88 | + torch::TensorOptions().dtype(torch::kInt32)); |
| 89 | + |
| 90 | + auto distances_data = distances.data_ptr<float_t>(); |
| 91 | + auto features_data = features.data_ptr<float_t>(); |
| 92 | + auto indices_data = indices.data_ptr<int32_t>(); |
| 93 | + |
| 94 | + auto output_feat_tensor_data = output_feat_tensor.data_ptr<float_t>(); |
| 95 | + auto output_max_idxs_data = output_max_idxs_tensor.data_ptr<int32_t>(); |
| 96 | + |
| 97 | + compute(distances_data, |
| 98 | + features_data, |
| 99 | + indices_data, |
| 100 | + output_feat_tensor_data, |
| 101 | + output_max_idxs_data, |
| 102 | + n_vert, |
| 103 | + n_neigh, |
| 104 | + n_feat, |
| 105 | + n_out_feat, |
| 106 | + n_moments, |
| 107 | + mean_and_max); |
| 108 | + |
| 109 | + return std::make_tuple(output_feat_tensor, output_max_idxs_tensor); |
| 110 | +} |
0 commit comments