From 0cc1b91d93959bb1f7fa47449ffdbf685669df94 Mon Sep 17 00:00:00 2001 From: "J. Sebastian Paez" Date: Sun, 7 Jul 2024 12:11:03 -0700 Subject: [PATCH] (refactor) Complete migration of dbscan to runner struct --- src/aggregation/aggregators.rs | 3 +- src/aggregation/dbscan/dbscan.rs | 294 +------------------ src/aggregation/dbscan/denseframe_dbscan.rs | 3 +- src/aggregation/dbscan/runner.rs | 309 +++++++++++++------- src/aggregation/dbscan/utils.rs | 39 ++- src/aggregation/tracing.rs | 2 +- 6 files changed, 234 insertions(+), 416 deletions(-) diff --git a/src/aggregation/aggregators.rs b/src/aggregation/aggregators.rs index 7229df0..95cfa99 100644 --- a/src/aggregation/aggregators.rs +++ b/src/aggregation/aggregators.rs @@ -1,13 +1,12 @@ use crate::ms::frames::TimsPeak; use crate::space::space_generics::HasIntensity; use crate::utils; -use std::ops::Add; use rayon::prelude::*; // I Dont really like having this here but I am not sure where else to // define it ... since its needed by the aggregation functions -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone, Copy)] pub enum ClusterLabel { Unassigned, Noise, diff --git a/src/aggregation/dbscan/dbscan.rs b/src/aggregation/dbscan/dbscan.rs index ca11d34..abbedf1 100644 --- a/src/aggregation/dbscan/dbscan.rs +++ b/src/aggregation/dbscan/dbscan.rs @@ -1,293 +1,12 @@ use crate::aggregation::aggregators::{aggregate_clusters, ClusterAggregator, ClusterLabel}; use crate::space::kdtree::RadiusKDTree; -use crate::space::space_generics::{ - HasIntensity, NDPoint, NDPointConverter, QueriableIndexedPoints, -}; +use crate::space::space_generics::{HasIntensity, NDPointConverter, QueriableIndexedPoints}; use crate::utils; -use indicatif::ProgressIterator; use log::{debug, info, trace}; -use num::cast::AsPrimitive; use rayon::prelude::*; use std::ops::Add; -use crate::aggregation::dbscan::utils::FilterFunCache; - -/// Density-based spatial clustering of applications with noise (DBSCAN) -/// -/// This module implements a variant of dbscan with a couple of modifications -/// with respect to the vanilla implementation. -/// -/// Pseudocode from wikipedia. -/// Donate to wikipedia y'all. :3 -// -/// DBSCAN(DB, distFunc, eps, minPts) { -/// C := 0 /* Cluster counter */ -/// for each point P in database DB { -/// if label(P) ≠ undefined then continue /* Previously processed in inner loop */ -/// Neighbors N := RangeQuery(DB, distFunc, P, eps) /* Find neighbors */ -/// if |N| < minPts then { /* Density check */ -/// label(P) := Noise /* Label as Noise */ -/// continue -/// } -/// C := C + 1 /* next cluster label */ -/// label(P) := C /* Label initial point */ -/// SeedSet S := N \ {P} /* Neighbors to expand */ -/// for each point Q in S { /* Process every seed point Q */ -/// if label(Q) = Noise then label(Q) := C /* Change Noise to border point */ -/// if label(Q) ≠ undefined then continue /* Previously processed (e.g., border point) */ -/// label(Q) := C /* Label neighbor */ -/// Neighbors N := RangeQuery(DB, distFunc, Q, eps) /* Find neighbors */ -/// if |N| ≥ minPts then { /* Density check (if Q is a core point) */ -/// S := S ∪ N /* Add new neighbors to seed set */ -/// } -/// } -/// } -/// } -/// Variations ... -/// 1. Indexing is am implementation detail to find the neighbors (generic indexer) -/// 2. Sort the pointd by decreasing intensity (more intense points adopt first). -/// 3. Use an intensity threshold intead of a minimum number of neighbors. -/// 4. There are ways to define the limits to the extension of a cluster. - -// TODO: rename quad_points, since this no longer uses a quadtree. -// TODO: refactor to take a filter function instead of requiting -// a min intensity and an intensity trait. -// TODO: rename the pre-filtered... -// TODO: reimplement this a two-stage pass, where the first in parallel -// gets the neighbors and the second does the iterative aggregation. -// THERE BE DRAGONS in this function ... I am thinking about sane ways to -// refactor it to make it more readable and maintainable. - -struct DBScanTimers { - main: utils::ContextTimer, - filter_fun_cache_timer: utils::ContextTimer, - outer_loop_nn_timer: utils::ContextTimer, - inner_loop_nn_timer: utils::ContextTimer, - local_neighbor_filter_timer: utils::ContextTimer, - outer_intensity_calculation: utils::ContextTimer, - inner_intensity_calculation: utils::ContextTimer, -} - -impl DBScanTimers { - fn new() -> Self { - let mut timer = utils::ContextTimer::new("internal_dbscan", false, utils::LogLevel::DEBUG); - let mut filter_fun_cache_timer = timer.start_sub_timer("filter_fun_cache"); - let mut outer_loop_nn_timer = timer.start_sub_timer("outer_loop_nn"); - let mut inner_loop_nn_timer = timer.start_sub_timer("inner_loop_nn"); - let mut local_neighbor_filter_timer = timer.start_sub_timer("local_neighbor_filter"); - let mut outer_intensity_calculation = timer.start_sub_timer("outer_intensity_calculation"); - let mut inner_intensity_calculation = timer.start_sub_timer("inner_intensity_calculation"); - Self { - main: timer, - filter_fun_cache_timer, - outer_loop_nn_timer, - inner_loop_nn_timer, - local_neighbor_filter_timer, - outer_intensity_calculation, - inner_intensity_calculation, - } - } - - fn report_if_gt_us(self, min_time: f64) { - if self.timer.cumtime.as_micros() > min_time { - self.main.report(); - self.filter_fun_cache_timer.report(); - self.outer_loop_nn_timer.report(); - self.inner_loop_nn_timer.report(); - self.local_neighbor_filter_timer.report(); - self.outer_intensity_calculation.report(); - self.inner_intensity_calculation.report(); - } - } -} - -// THIS IS A BOTTLENECK FUNCTION -fn _dbscan< - 'a, - const N: usize, - C: NDPointConverter, - E: Sync + HasIntensity, - T: QueriableIndexedPoints<'a, N, usize> + std::marker::Sync, - FF: Fn(&E, &E) -> bool + Send + Sync + Copy, ->( - indexed_points: &'a T, - prefiltered_peaks: &Vec, - quad_points: &[NDPoint], - min_n: usize, - min_intensity: u64, - intensity_sorted_indices: &Vec<(usize, u64)>, - filter_fun: Option, - converter: C, - progress: bool, - max_extension_distances: &[f32; N], -) -> (u64, Vec>) { - let mut initial_candidates_counts = utils::RollingSDCalculator::default(); - let mut final_candidates_counts = utils::RollingSDCalculator::default(); - - let mut cluster_labels = vec![ClusterLabel::Unassigned; prefiltered_peaks.len()]; - let mut cluster_id = 0; - let mut timers = DBScanTimers::new(); - - let usize_filterfun = |a: &usize, b: &usize| { - filter_fun.expect("filter_fun should be Some")( - &prefiltered_peaks[*a], - &prefiltered_peaks[*b], - ) - }; - let mut filterfun_cache = - FilterFunCache::new(Box::new(&usize_filterfun), prefiltered_peaks.len()); - let mut filterfun_with_cache = |elem_idx: usize, reference_idx: usize| { - timers.filter_fun_cache_timer.reset_start(); - let out = filterfun_cache.get(elem_idx, reference_idx); - timers.filter_fun_cache_timer.stop(false); - out - }; - - let my_progbar = if progress { - indicatif::ProgressBar::new(intensity_sorted_indices.len() as u64) - } else { - indicatif::ProgressBar::hidden() - }; - - for (point_index, _intensity) in intensity_sorted_indices.iter().progress_with(my_progbar) { - let point_index = *point_index; - if cluster_labels[point_index] != ClusterLabel::Unassigned { - continue; - } - - timers.outer_loop_nn_timer.reset_start(); - let query_elems = converter.convert_to_bounds_query(&quad_points[point_index]); - let mut neighbors = indexed_points.query_ndrange(&query_elems.0, query_elems.1); - timers.outer_loop_nn_timer.stop(false); - - if neighbors.len() < min_n { - cluster_labels[point_index] = ClusterLabel::Noise; - continue; - } - - if filter_fun.is_some() { - let num_initial_candidates = neighbors.len(); - neighbors.retain(|i| filterfun_with_cache(**i, point_index)); - // .filter(|i| filter_fun.unwrap()(&prefiltered_peaks[**i], &query_peak)) - - let candidates_after_filter = neighbors.len(); - initial_candidates_counts.add(num_initial_candidates as f32, 1); - final_candidates_counts.add(candidates_after_filter as f32, 1); - - if neighbors.len() < min_n { - cluster_labels[point_index] = ClusterLabel::Noise; - continue; - } - } - - // Q: Do I need to care about overflows here? - Sebastian - timers.outer_intensity_calculation.reset_start(); - let neighbor_intensity_total = neighbors - .iter() - .map(|i| prefiltered_peaks[**i].intensity().as_()) - .sum::(); - timers.outer_intensity_calculation.stop(false); - - if neighbor_intensity_total < min_intensity { - cluster_labels[point_index] = ClusterLabel::Noise; - continue; - } - - cluster_id += 1; - cluster_labels[point_index] = ClusterLabel::Cluster(cluster_id); - let mut seed_set: Vec<&usize> = Vec::new(); - seed_set.extend(neighbors); - - while let Some(neighbor) = seed_set.pop() { - let neighbor_index = *neighbor; - if cluster_labels[neighbor_index] == ClusterLabel::Noise { - cluster_labels[neighbor_index] = ClusterLabel::Cluster(cluster_id); - } - - if cluster_labels[neighbor_index] != ClusterLabel::Unassigned { - continue; - } - - cluster_labels[neighbor_index] = ClusterLabel::Cluster(cluster_id); - - timers.inner_loop_nn_timer.reset_start(); - let inner_query_elems = converter.convert_to_bounds_query(&quad_points[*neighbor]); - let mut local_neighbors = - indexed_points.query_ndrange(&inner_query_elems.0, inner_query_elems.1); - timers.inner_loop_nn_timer.stop(false); - - if filter_fun.is_some() { - local_neighbors.retain(|i| filterfun_with_cache(**i, point_index)) - // .filter(|i| filter_fun.unwrap()(&prefiltered_peaks[**i], &query_peak)) - } - - timers.inner_intensity_calculation.reset_start(); - let query_intensity = prefiltered_peaks[neighbor_index].intensity(); - let neighbor_intensity_total = local_neighbors - .iter() - .map(|i| prefiltered_peaks[**i].intensity().as_()) - .sum::(); - timers.inner_intensity_calculation.stop(false); - - if local_neighbors.len() >= min_n && neighbor_intensity_total >= min_intensity { - // Keep only the neighbors that are not already in a cluster - local_neighbors - .retain(|i| !matches!(cluster_labels[**i], ClusterLabel::Cluster(_))); - - // Keep only the neighbors that are within the max extension distance - // It might be worth setting a different max extension distance for the mz and mobility dimensions. - timers.local_neighbor_filter_timer.reset_start(); - local_neighbors.retain(|i| { - let going_downhill = prefiltered_peaks[**i].intensity() <= query_intensity; - - let p = &quad_points[**i]; - let query_point = query_elems.1.unwrap(); - // Using minkowski distance with p = 1, manhattan distance. - let mut within_distance = true; - for ((p, q), max_dist) in p - .values - .iter() - .zip(query_point.values) - .zip(max_extension_distances.iter()) - { - let dist = (p - q).abs(); - within_distance = within_distance && dist <= *max_dist; - if !within_distance { - break; - } - } - - going_downhill && within_distance - }); - timers.local_neighbor_filter_timer.stop(false); - - seed_set.extend(local_neighbors); - } - } - } - - let (tot_queries, cached_queries) = timers.filterfun_cache.get_stats(); - - if tot_queries > 1000 { - let cache_hit_rate = cached_queries as f64 / tot_queries as f64; - info!( - "Cache hit rate: {} / {} = {}", - cached_queries, tot_queries, cache_hit_rate - ); - - let avg_initial_candidates = initial_candidates_counts.get_mean(); - let avg_final_candidates = final_candidates_counts.get_mean(); - debug!( - "Avg initial candidates: {} Avg final candidates: {}", - avg_initial_candidates, avg_final_candidates - ); - } - - timers.main.stop(false); - timers.report_if_gt_us(1000000); - - (cluster_id, cluster_labels) -} +use crate::aggregation::dbscan::runner::_dbscan; // Pretty simple function ... it uses every passed centroid, converts it to a point // and generates a new centroid that aggregates all the points in its range. @@ -348,14 +67,13 @@ pub fn dbscan_generic< T: HasIntensity + Send + Clone + Copy + Sync, F: Fn() -> G + Send + Sync, const N: usize, - FF: Send + Sync + Fn(&T, &T) -> bool, >( converter: C, prefiltered_peaks: Vec, min_n: usize, min_intensity: u64, def_aggregator: F, - extra_filter_fun: Option<&FF>, + extra_filter_fun: Option<&(dyn Fn(&T, &T) -> bool + Send + Sync)>, log_level: Option, keep_unclustered: bool, max_extension_distances: &[f32; N], @@ -392,7 +110,7 @@ pub fn dbscan_generic< i_timer.stop(true); let mut i_timer = timer.start_sub_timer("dbscan"); - let (tot_clusters, cluster_labels) = _dbscan( + let cluster_labels = _dbscan( &tree, &prefiltered_peaks, &ndpoints, @@ -407,8 +125,8 @@ pub fn dbscan_generic< i_timer.stop(true); let centroids = aggregate_clusters( - tot_clusters, - cluster_labels, + cluster_labels.num_clusters, + cluster_labels.cluster_labels, &prefiltered_peaks, &def_aggregator, log_level, diff --git a/src/aggregation/dbscan/denseframe_dbscan.rs b/src/aggregation/dbscan/denseframe_dbscan.rs index af0fa7e..501bcc8 100644 --- a/src/aggregation/dbscan/denseframe_dbscan.rs +++ b/src/aggregation/dbscan/denseframe_dbscan.rs @@ -4,7 +4,6 @@ use crate::aggregation::dbscan::dbscan::dbscan_generic; use crate::ms::frames::{DenseFrame, TimsPeak}; use crate::utils::within_distance_apply; -type FFTimsPeak = fn(&TimsPeak, &TimsPeak) -> bool; // bool> pub fn dbscan_denseframe( mut denseframe: DenseFrame, @@ -51,7 +50,7 @@ pub fn dbscan_denseframe( min_n, min_intensity, TimsPeakAggregator::default, - None::<&FFTimsPeak>, + None::<&(dyn Fn(&TimsPeak, &TimsPeak) -> bool + Send + Sync)>, None, true, &[max_mz_extension as f32, max_ims_extension], diff --git a/src/aggregation/dbscan/runner.rs b/src/aggregation/dbscan/runner.rs index c21889f..b8beb32 100644 --- a/src/aggregation/dbscan/runner.rs +++ b/src/aggregation/dbscan/runner.rs @@ -1,24 +1,54 @@ -use std::process::Output; - use crate::space::space_generics::NDPointConverter; use crate::space::space_generics::{HasIntensity, NDPoint, QueriableIndexedPoints}; use crate::utils; -use crate::utils::within_distance_apply; use indicatif::ProgressIterator; -use log::{debug, info, trace}; use rayon::prelude::*; -use crate::aggregation::aggregators::{ - aggregate_clusters, ClusterAggregator, ClusterLabel, TimsPeakAggregator, -}; -use crate::space::kdtree::RadiusKDTree; - +use crate::aggregation::aggregators::ClusterLabel; use crate::aggregation::dbscan::utils::FilterFunCache; -struct ClusterLabels { - cluster_labels: Vec>, - num_clusters: u64, +/// Density-based spatial clustering of applications with noise (DBSCAN) +/// +/// This module implements a variant of dbscan with a couple of modifications +/// with respect to the vanilla implementation. +/// +/// Pseudocode from wikipedia. +/// Donate to wikipedia y'all. :3 +// +/// DBSCAN(DB, distFunc, eps, minPts) { +/// C := 0 /* Cluster counter */ +/// for each point P in database DB { +/// if label(P) ≠ undefined then continue /* Previously processed in inner loop */ +/// Neighbors N := RangeQuery(DB, distFunc, P, eps) /* Find neighbors */ +/// if |N| < minPts then { /* Density check */ +/// label(P) := Noise /* Label as Noise */ +/// continue +/// } +/// C := C + 1 /* next cluster label */ +/// label(P) := C /* Label initial point */ +/// SeedSet S := N \ {P} /* Neighbors to expand */ +/// for each point Q in S { /* Process every seed point Q */ +/// if label(Q) = Noise then label(Q) := C /* Change Noise to border point */ +/// if label(Q) ≠ undefined then continue /* Previously processed (e.g., border point) */ +/// label(Q) := C /* Label neighbor */ +/// Neighbors N := RangeQuery(DB, distFunc, Q, eps) /* Find neighbors */ +/// if |N| ≥ minPts then { /* Density check (if Q is a core point) */ +/// S := S ∪ N /* Add new neighbors to seed set */ +/// } +/// } +/// } +/// } +/// Variations ... +/// 1. Indexing is am implementation detail to find the neighbors (generic indexer) +/// 2. Sort the pointd by decreasing intensity (more intense points adopt first). +/// 3. Use an intensity threshold intead of a minimum number of neighbors. +/// 4. There are ways to define the limits to the extension of a cluster. + +#[derive(Debug, Clone)] +pub struct ClusterLabels { + pub cluster_labels: Vec>, + pub num_clusters: u64, } impl ClusterLabels { @@ -65,13 +95,13 @@ struct DBScanTimers { impl DBScanTimers { fn new() -> Self { - let mut timer = utils::ContextTimer::new("internal_dbscan", false, utils::LogLevel::DEBUG); - let mut filter_fun_cache_timer = timer.start_sub_timer("filter_fun_cache"); - let mut outer_loop_nn_timer = timer.start_sub_timer("outer_loop_nn"); - let mut inner_loop_nn_timer = timer.start_sub_timer("inner_loop_nn"); - let mut local_neighbor_filter_timer = timer.start_sub_timer("local_neighbor_filter"); - let mut outer_intensity_calculation = timer.start_sub_timer("outer_intensity_calculation"); - let mut inner_intensity_calculation = timer.start_sub_timer("inner_intensity_calculation"); + let timer = utils::ContextTimer::new("internal_dbscan", false, utils::LogLevel::DEBUG); + let filter_fun_cache_timer = timer.start_sub_timer("filter_fun_cache"); + let outer_loop_nn_timer = timer.start_sub_timer("outer_loop_nn"); + let inner_loop_nn_timer = timer.start_sub_timer("inner_loop_nn"); + let local_neighbor_filter_timer = timer.start_sub_timer("local_neighbor_filter"); + let outer_intensity_calculation = timer.start_sub_timer("outer_intensity_calculation"); + let inner_intensity_calculation = timer.start_sub_timer("inner_intensity_calculation"); Self { main: timer, filter_fun_cache_timer, @@ -83,7 +113,7 @@ impl DBScanTimers { } } - fn report_if_gt_us(self, min_time: u128) { + fn report_if_gt_us(&self, min_time: u128) { if self.main.cumtime.as_micros() > min_time { self.main.report(); self.filter_fun_cache_timer.report(); @@ -110,21 +140,25 @@ impl CandidateCountMetrics { } } -struct DBSCANRunnerState<'a> { +struct DBSCANRunnerState { cluster_labels: ClusterLabels, - filter_fun_cache: FilterFunCache<'a>, + filter_fun_cache: Option, timers: DBScanTimers, candidate_metrics: CandidateCountMetrics, } -impl DBSCANRunnerState<'_> { - fn new<'a>( - nlabels: usize, - min_n: usize, - usize_filterfun: &dyn Fn(&usize, &usize) -> bool, - ) -> Self { - let mut cluster_labels = ClusterLabels::new(nlabels); - let filter_fun_cache = FilterFunCache::new(Box::new(&usize_filterfun), nlabels); +impl DBSCANRunnerState { + fn new

(nlabels: usize, usize_filterfun: Option

) -> Self + where + P: Fn(&usize, &usize) -> bool + Send + Sync, + { + let cluster_labels = ClusterLabels::new(nlabels); + + let filter_fun_cache = match usize_filterfun { + Some(_) => Some(FilterFunCache::new(nlabels)), + None => None, + }; + //FilterFunCache::new(Box::new(&usize_filterfun), nlabels); let timers = DBScanTimers::new(); let candidate_metrics = CandidateCountMetrics::new(); @@ -150,11 +184,17 @@ impl DBSCANRunnerState<'_> { struct DBSCANRunner<'a, const N: usize, C, E> { min_n: usize, min_intensity: u64, - filter_fun: &'a (dyn Fn(&E, &E) -> bool + Send + Sync), + filter_fun: Option<&'a (dyn Fn(&E, &E) -> bool + Send + Sync)>, converter: C, progress: bool, max_extension_distances: &'a [f32; N], - state: Option>, +} + +struct DBSCANPoints<'a, const N: usize, E> { + prefiltered_peaks: &'a Vec, + intensity_sorted_indices: &'a Vec<(usize, u64)>, + indexed_points: &'a (dyn QueriableIndexedPoints<'a, N, usize> + std::marker::Sync), + quad_points: &'a [NDPoint], } // C: NDPointConverter, @@ -166,61 +206,87 @@ struct DBSCANRunner<'a, const N: usize, C, E> { // const N: usize, // FF: Send + Sync + Fn(&T, &T) -> bool, -impl<'a, const N: usize, C, E> DBSCANRunner<'a, N, C, E> +impl<'a, 'b: 'a, const N: usize, C, E> DBSCANRunner<'a, N, C, E> where C: NDPointConverter, E: Sync + HasIntensity, - //T: QueriableIndexedPoints<'a, N, usize> + std::marker::Sync, { fn run( &self, - prefiltered_peaks: &'a Vec, - intensity_sorted_indices: &'a Vec<(usize, f64)>, + prefiltered_peaks: &'b Vec, + intensity_sorted_indices: &'b Vec<(usize, u64)>, + indexed_points: &'b (dyn QueriableIndexedPoints<'a, N, usize> + std::marker::Sync), + quad_points: &'b [NDPoint], ) -> ClusterLabels { - let usize_filterfun = |a: &usize, b: &usize| { - (self.filter_fun)(&prefiltered_peaks[*a], &prefiltered_peaks[*b]) + let usize_filterfun = match self.filter_fun { + Some(filterfun) => { + let cl = |a: &usize, b: &usize| { + filterfun(&prefiltered_peaks[*a], &prefiltered_peaks[*b]) + }; + let bind = Some(cl); + bind + } + None => None, }; - self.state = Some(DBSCANRunnerState::new( - intensity_sorted_indices.len(), - self.min_n, - &usize_filterfun, - )); + // |a: &usize, b: &usize| { + // (self.filter_fun)(&prefiltered_peaks[*a], &prefiltered_peaks[*b]) + // }; + let mut state = DBSCANRunnerState::new(intensity_sorted_indices.len(), usize_filterfun); - let mut state = self.state.expect("State is created in this function."); + let points: DBSCANPoints = DBSCANPoints { + prefiltered_peaks, + intensity_sorted_indices, + indexed_points, + quad_points, + }; // Q: if filter fun is required ... why is it an option? - self.process_points(state, prefiltered_peaks, intensity_sorted_indices); + state = self.process_points(state, &points); + state = self.report_timers(state); + self.take_cluster_labels(state) + } + + fn report_timers(&self, mut state: DBSCANRunnerState) -> DBSCANRunnerState { state.timers.main.stop(false); state.timers.report_if_gt_us(1000000); + state + } + + fn take_cluster_labels(&self, state: DBSCANRunnerState) -> ClusterLabels { state.cluster_labels } fn process_points( &self, - mut state: DBSCANRunnerState<'a>, - prefiltered_peaks: &'a Vec, - intensity_sorted_indices: &'a Vec<(usize, f64)>, - ) { - let my_progbar = state.create_progress_bar(intensity_sorted_indices.len(), self.progress); - - for (point_index, _intensity) in intensity_sorted_indices.iter().progress_with(my_progbar) { + mut state: DBSCANRunnerState, + points: &DBSCANPoints<'a, N, E>, + ) -> DBSCANRunnerState { + let my_progbar = + state.create_progress_bar(points.intensity_sorted_indices.len(), self.progress); + + for (point_index, _intensity) in points + .intensity_sorted_indices + .iter() + .progress_with(my_progbar) + { self.process_single_point( *point_index, - prefiltered_peaks, + &points, &mut state.cluster_labels, &mut state.filter_fun_cache, &mut state.timers, &mut state.candidate_metrics, ); } + state } fn process_single_point( &self, point_index: usize, - prefiltered_peaks: &'a Vec, + points: &DBSCANPoints<'a, N, E>, cluster_labels: &mut ClusterLabels, - filter_fun_cache: &mut FilterFunCache<'a>, + filter_fun_cache: &mut Option, timers: &mut DBScanTimers, cc_metrics: &mut CandidateCountMetrics, ) { @@ -228,22 +294,18 @@ where return; } - let neighbors = self.find_neighbors( - point_index, - prefiltered_peaks, - filter_fun_cache, - timers, - cc_metrics, - ); - if !self.is_core_point(&neighbors, prefiltered_peaks, timers) { + let (neighbors, ref_point) = + self.find_neighbors(point_index, points, filter_fun_cache, timers, cc_metrics); + if !self.is_core_point(&neighbors, points.prefiltered_peaks, timers) { cluster_labels.set_noise(point_index); return; } self.expand_cluster( point_index, + ref_point.unwrap(), neighbors, - prefiltered_peaks, + points, cluster_labels, filter_fun_cache, timers, @@ -253,22 +315,43 @@ where fn find_neighbors( &self, point_index: usize, - prefiltered_peaks: &'a Vec, - filter_fun_cache: &mut FilterFunCache<'a>, + points: &DBSCANPoints<'a, N, E>, + filter_fun_cache: &mut Option, timers: &mut DBScanTimers, cc_metrics: &mut CandidateCountMetrics, - ) -> Vec { + ) -> (Vec, Option<&NDPoint>) { timers.outer_loop_nn_timer.reset_start(); let query_elems = self .converter - .convert_to_bounds_query(&quad_points[point_index]); - let mut candidate_neighbors = self + .convert_to_bounds_query(&points.quad_points[point_index]); + let mut candidate_neighbors = points .indexed_points - .query_ndrange(&query_elems.0, query_elems.1); + .query_ndrange(&query_elems.0, query_elems.1) + .iter() + .map(|x| **x) + .collect::>(); timers.outer_loop_nn_timer.stop(false); + if filter_fun_cache.is_none() { + return (candidate_neighbors, query_elems.1); + } + let num_initial_candidates = candidate_neighbors.len(); - candidate_neighbors.retain(|i| filter_fun_cache(**i, point_index)); + candidate_neighbors.retain(|i| { + let tmp = filter_fun_cache.as_mut().unwrap(); + let res_in_cache = tmp.get(*i, point_index); + match res_in_cache { + Some(res) => res, + None => { + let res = (self.filter_fun.unwrap())( + &points.prefiltered_peaks[*i], + &points.prefiltered_peaks[point_index], + ); + tmp.set(*i, point_index, res); + res + } + } + }); let neighbors = candidate_neighbors; let candidates_after_filter = neighbors.len(); @@ -279,7 +362,7 @@ where .final_candidates_counts .add(candidates_after_filter as f32, 1); - neighbors + (neighbors, query_elems.1) } fn is_core_point( @@ -291,7 +374,7 @@ where timers.outer_intensity_calculation.reset_start(); let neighbor_intensity_total = neighbors .iter() - .map(|i| prefiltered_peaks[**i].intensity().as_()) + .map(|i| prefiltered_peaks[*i].intensity()) .sum::(); timers.outer_intensity_calculation.stop(false); return neighbor_intensity_total >= self.min_intensity; @@ -300,10 +383,11 @@ where fn expand_cluster( &self, point_index: usize, - mut neighbors: Vec, - prefiltered_peaks: &'a Vec, + query_point: &NDPoint, + neighbors: Vec, + points: &DBSCANPoints<'a, N, E>, cluster_labels: &mut ClusterLabels, - filter_fun_cache: &mut FilterFunCache<'a>, + filter_fun_cache: &mut Option, timers: &mut DBScanTimers, ) { cluster_labels.set_new_cluster(point_index); @@ -324,31 +408,51 @@ where cluster_labels.set_current_cluster(neighbor_index); timers.inner_loop_nn_timer.reset_start(); - let inner_query_elems = converter.convert_to_bounds_query(&quad_points[*neighbor]); - let mut local_neighbors = - indexed_points.query_ndrange(&inner_query_elems.0, inner_query_elems.1); + let inner_query_elems = self + .converter + .convert_to_bounds_query(&points.quad_points[neighbor]); + let mut local_neighbors = points + .indexed_points + .query_ndrange(&inner_query_elems.0, inner_query_elems.1); timers.inner_loop_nn_timer.stop(false); - local_neighbors.retain(|i| filterfun_with_cache(**i, point_index)); + if filter_fun_cache.is_some() { + local_neighbors.retain(|i| { + let cache = filter_fun_cache.as_mut().unwrap(); + let res = cache.get(**i, point_index); + match res { + Some(res) => res, + None => { + let res = (self.filter_fun.unwrap())( + &points.prefiltered_peaks[**i], + &points.prefiltered_peaks[point_index], + ); + cache.set(**i, point_index, res); + res + } + } + }); + } timers.inner_intensity_calculation.reset_start(); - let query_intensity = prefiltered_peaks[neighbor_index].intensity(); + let query_intensity = points.prefiltered_peaks[neighbor_index].intensity(); let neighbor_intensity_total = local_neighbors .iter() - .map(|i| prefiltered_peaks[**i].intensity().as_()) + .map(|i| points.prefiltered_peaks[**i].intensity()) .sum::(); timers.inner_intensity_calculation.stop(false); - if local_neighbors.len() >= min_n && neighbor_intensity_total >= min_intensity { + if local_neighbors.len() >= self.min_n && neighbor_intensity_total >= self.min_intensity + { local_neighbors - .retain(|i| !matches!(cluster_labels[**i], ClusterLabel::Cluster(_))); + .retain(|i| !matches!(cluster_labels.get(**i), ClusterLabel::Cluster(_))); timers.local_neighbor_filter_timer.reset_start(); local_neighbors.retain(|i| { - let going_downhill = prefiltered_peaks[**i].intensity() <= query_intensity; + let going_downhill = + points.prefiltered_peaks[**i].intensity() <= query_intensity; - let p = &quad_points[**i]; - let query_point = query_elems.1.unwrap(); + let p: &NDPoint = &points.quad_points[**i]; let mut within_distance = true; for ((p, q), max_dist) in p .values @@ -373,36 +477,39 @@ where } } -fn _dbscan<'a, const N: usize, C, I, E, T, FF>( +pub fn _dbscan< + 'a, + const N: usize, + C: NDPointConverter, + E: Sync + HasIntensity, + T: QueriableIndexedPoints<'a, N, usize> + std::marker::Sync, +>( indexed_points: &'a T, prefiltered_peaks: &'a Vec, quad_points: &'a [NDPoint], min_n: usize, min_intensity: u64, - intensity_sorted_indices: &'a Vec<(usize, I)>, - filter_fun: Option, + intensity_sorted_indices: &'a Vec<(usize, u64)>, + filter_fun: Option<&'a (dyn Fn(&E, &E) -> bool + Send + Sync)>, converter: C, progress: bool, max_extension_distances: &'a [f32; N], -) -> (u64, Vec>) { - let runner = DBSCANRunner::new( - indexed_points, - quad_points, +) -> ClusterLabels { + let runner = DBSCANRunner { min_n, min_intensity, - filter_fun, converter, progress, + filter_fun: filter_fun, max_extension_distances, - ); - - let mut cluster_labels = vec![ClusterLabel::Unassigned; prefiltered_peaks.len()]; + }; - let cluster_id = runner.run( + let cluster_labels = runner.run( prefiltered_peaks, intensity_sorted_indices, - &mut cluster_labels, + indexed_points, + quad_points, ); - (cluster_id, cluster_labels) + cluster_labels } diff --git a/src/aggregation/dbscan/utils.rs b/src/aggregation/dbscan/utils.rs index e4808d3..5886d35 100644 --- a/src/aggregation/dbscan/utils.rs +++ b/src/aggregation/dbscan/utils.rs @@ -1,50 +1,45 @@ use std::collections::BTreeMap; -pub struct FilterFunCache<'a> { +pub struct FilterFunCache { cache: Vec>>, - filter_fun: Box<&'a dyn Fn(&usize, &usize) -> bool>, tot_queries: u64, cached_queries: u64, } -impl<'a> FilterFunCache<'a> { - pub fn new(filter_fun: Box<&'a dyn Fn(&usize, &usize) -> bool>, capacity: usize) -> Self { +impl FilterFunCache { + pub fn new(capacity: usize) -> Self { Self { cache: vec![None; capacity], - filter_fun, tot_queries: 0, cached_queries: 0, } } - pub fn get(&mut self, elem_idx: usize, reference_idx: usize) -> bool { - // Get the value if it exists, call the functon, insert it and - // return it if it doesn't. + pub fn get(&mut self, elem_idx: usize, reference_idx: usize) -> Option { self.tot_queries += 1; - let out: bool = match self.cache[elem_idx] { + let out: Option = match self.cache[elem_idx] { Some(ref map) => match map.get(&reference_idx) { Some(x) => { self.cached_queries += 1; - *x - } - None => { - let out: bool = (self.filter_fun)(&elem_idx, &reference_idx); - self.insert(elem_idx, reference_idx, out); - self.insert(reference_idx, elem_idx, out); - out + Some(*x) } + None => None, }, - None => { - let out = (self.filter_fun)(&elem_idx, &reference_idx); - self.insert(elem_idx, reference_idx, out); - self.insert(reference_idx, elem_idx, out); - out - } + None => None, }; out } + pub fn set(&mut self, elem_idx: usize, reference_idx: usize, value: bool) { + self.insert_both_ways(elem_idx, reference_idx, value); + } + + fn insert_both_ways(&mut self, elem_idx: usize, reference_idx: usize, value: bool) { + self.insert(elem_idx, reference_idx, value); + self.insert(reference_idx, elem_idx, value); + } + fn insert(&mut self, elem_idx: usize, reference_idx: usize, value: bool) { match self.cache[elem_idx] { Some(ref mut map) => { diff --git a/src/aggregation/tracing.rs b/src/aggregation/tracing.rs index 04b07da..ae7d23c 100644 --- a/src/aggregation/tracing.rs +++ b/src/aggregation/tracing.rs @@ -457,7 +457,7 @@ fn _combine_single_window_traces( quad_low_high: window_quad_low_high, btree_chromatogram: BTreeChromatogram::new_lazy(rt_binsize), }, - None::<&FFTimeTimsPeak>, + None::<&(dyn Fn(&TimeTimsPeak, &TimeTimsPeak) -> bool + Send + Sync)>, None, false, &max_extension_distances,