Skip to content

Commit

Permalink
Groundwork for some new non-linear algorithms (#196)
Browse files Browse the repository at this point in the history
* Moving some shared characteristics of non linear algos into a trait

* Using par_iter for find_similar_n

* Fix clippy lints
  • Loading branch information
deven96 authored Feb 7, 2025
1 parent 7f7987b commit 88fcbab
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 297 deletions.
2 changes: 1 addition & 1 deletion ahnlich/ai/src/server/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ impl AhnlichProtocol for AIProxyTask {
}
AIQuery::GetKey { store, keys } => {
let metadata_values: HashSet<MetadataValue> =
keys.into_iter().map(|value| value.into()).collect();
keys.into_par_iter().map(|value| value.into()).collect();
let get_key_condition = PredicateCondition::Value(Predicate::In {
key: AHNLICH_AI_RESERVED_META_KEY.clone(),
value: metadata_values,
Expand Down
177 changes: 7 additions & 170 deletions ahnlich/db/src/algorithm/heap.rs
Original file line number Diff line number Diff line change
@@ -1,181 +1,18 @@
#![allow(dead_code)]
use super::LinearAlgorithm;
use super::SimilarityVector;
use ahnlich_types::keyval::StoreKey;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::num::NonZeroUsize;

pub(crate) struct MinHeap<'a> {
max_capacity: NonZeroUsize,
heap: BinaryHeap<Reverse<SimilarityVector<'a>>>,
pub enum HeapOrder {
Min,
Max,
}

impl<'a> MinHeap<'a> {
pub(crate) fn new(capacity: NonZeroUsize) -> Self {
Self {
heap: BinaryHeap::new(),
max_capacity: capacity,
}
}
#[tracing::instrument(skip_all)]
pub(crate) fn len(&self) -> usize {
self.heap.len()
}
#[tracing::instrument(skip_all)]
pub(crate) fn push(&mut self, item: SimilarityVector<'a>) {
self.heap.push(Reverse(item));
}
#[tracing::instrument(skip_all)]
pub(crate) fn pop(&mut self) -> Option<SimilarityVector<'a>> {
self.heap.pop().map(|popped_item| popped_item.0)
}

#[tracing::instrument(skip_all)]
pub(crate) fn output(&mut self) -> Vec<(StoreKey, f32)> {
let mut result: Vec<_> = Vec::with_capacity(self.max_capacity.get());

loop {
match self.pop() {
Some(value) if result.len() < self.max_capacity.get() => {
let vector_sim = value.0;
result.push((vector_sim.0.clone(), vector_sim.1));
}
_ => break,
}
}
result
}
}

pub(crate) struct MaxHeap<'a> {
max_capacity: NonZeroUsize,
heap: BinaryHeap<SimilarityVector<'a>>,
}

impl<'a> MaxHeap<'a> {
pub(crate) fn new(capacity: NonZeroUsize) -> Self {
Self {
heap: BinaryHeap::new(),
max_capacity: capacity,
}
}
#[tracing::instrument(skip_all)]
fn push(&mut self, item: SimilarityVector<'a>) {
self.heap.push(item);
}
#[tracing::instrument(skip_all)]
pub(crate) fn pop(&mut self) -> Option<SimilarityVector<'a>> {
self.heap.pop()
}
#[tracing::instrument(skip_all)]
pub(crate) fn len(&self) -> usize {
self.heap.len()
}

#[tracing::instrument(skip_all)]
fn output(&mut self) -> Vec<(StoreKey, f32)> {
let mut result: Vec<_> = Vec::with_capacity(self.max_capacity.get());

loop {
match self.heap.pop() {
Some(value) if result.len() < self.max_capacity.get() => {
let vector_sim = value.0;
result.push((vector_sim.0.clone(), vector_sim.1));
}
_ => break,
}
}
result
}
}

pub(crate) enum AlgorithmHeapType<'a> {
Min(MinHeap<'a>),
Max(MaxHeap<'a>),
}

impl<'a> AlgorithmHeapType<'a> {
#[tracing::instrument(skip_all)]
pub(crate) fn push(&mut self, item: SimilarityVector<'a>) {
match self {
Self::Max(h) => h.push(item),
Self::Min(h) => h.push(item),
}
}
#[tracing::instrument(skip_all)]
pub(crate) fn pop(&mut self) -> Option<SimilarityVector<'a>> {
match self {
Self::Max(h) => h.pop(),
Self::Min(h) => h.pop(),
}
}

#[tracing::instrument(skip_all)]
pub(crate) fn output(&mut self) -> Vec<(StoreKey, f32)> {
match self {
Self::Min(h) => h.output(),
Self::Max(h) => h.output(),
}
}
}

impl From<(&LinearAlgorithm, NonZeroUsize)> for AlgorithmHeapType<'_> {
fn from((value, capacity): (&LinearAlgorithm, NonZeroUsize)) -> Self {
impl From<&LinearAlgorithm> for HeapOrder {
fn from(value: &LinearAlgorithm) -> Self {
match value {
LinearAlgorithm::EuclideanDistance => AlgorithmHeapType::Min(MinHeap::new(capacity)),
LinearAlgorithm::EuclideanDistance => HeapOrder::Min,
LinearAlgorithm::CosineSimilarity | LinearAlgorithm::DotProductSimilarity => {
AlgorithmHeapType::Max(MaxHeap::new(capacity))
HeapOrder::Max
}
}
}
}

#[cfg(test)]
mod tests {

use super::*;

#[test]
fn test_min_heap_ordering_works() {
let mut heap = MinHeap::new(NonZeroUsize::new(3).unwrap());
let mut count = 0.0;
let first_vector = StoreKey(vec![2.0, 2.0]);

// If we pop these scores now, they should come back in the reverse order.
while count < 5.0 {
let similarity: f32 = 1.0 + count;

let item: SimilarityVector = (&first_vector, similarity).into();

heap.push(item);

count += 1.0;
}

assert_eq!(heap.pop(), Some((&first_vector, 1.0).into()));
assert_eq!(heap.pop(), Some((&first_vector, 2.0).into()));
assert_eq!(heap.pop(), Some((&first_vector, 3.0).into()));
}

#[test]
fn test_max_heap_ordering_works() {
let mut heap = MaxHeap::new(NonZeroUsize::new(3).unwrap());
let mut count = 0.0;
let first_vector = StoreKey(vec![2.0, 2.0]);

// If we pop these scores now, they should come back the right order(max first).
while count < 5.0 {
let similarity: f32 = 1.0 + count;
let item: SimilarityVector = (&first_vector, similarity).into();

heap.push(item);

count += 1.0;
}

assert_eq!(heap.pop(), Some((&first_vector, 5.0).into()));
assert_eq!(heap.pop(), Some((&first_vector, 4.0).into()));
assert_eq!(heap.pop(), Some((&first_vector, 3.0).into()));
}
}
61 changes: 44 additions & 17 deletions ahnlich/db/src/algorithm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ mod heap;
pub mod non_linear;
mod similarity;

use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::num::NonZeroUsize;

use ahnlich_types::keyval::StoreKey;
use ahnlich_types::similarity::Algorithm;
use ahnlich_types::similarity::NonLinearAlgorithm;
use heap::HeapOrder;
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;

use self::{heap::AlgorithmHeapType, similarity::SimilarityFunc};
use self::similarity::SimilarityFunc;

#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
pub(crate) enum AlgorithmByType {
Expand Down Expand Up @@ -48,11 +53,6 @@ impl<'a> From<(&'a StoreKey, f32)> for SimilarityVector<'a> {
SimilarityVector((value.0, value.1))
}
}
impl<'a> From<SimilarityVector<'a>> for (&'a StoreKey, f32) {
fn from(value: SimilarityVector<'a>) -> (&'a StoreKey, f32) {
((value.0).0, (value.0).1)
}
}

impl<'a> PartialEq for SimilarityVector<'a> {
fn eq(&self, other: &Self) -> bool {
Expand All @@ -77,11 +77,16 @@ impl Ord for SimilarityVector<'_> {
}
}

// Pop the topmost N from a generic binary tree
fn pop_n<T: Ord>(heap: &mut BinaryHeap<T>, n: NonZeroUsize) -> Vec<T> {
(0..n.get()).filter_map(|_| heap.pop()).collect()
}

pub(crate) trait FindSimilarN {
fn find_similar_n<'a>(
&'a self,
search_vector: &StoreKey,
search_list: impl Iterator<Item = &'a StoreKey>,
search_list: impl ParallelIterator<Item = &'a StoreKey>,
_used_all: bool,
n: NonZeroUsize,
) -> Vec<(StoreKey, f32)>;
Expand All @@ -92,26 +97,48 @@ impl FindSimilarN for LinearAlgorithm {
fn find_similar_n<'a>(
&'a self,
search_vector: &StoreKey,
search_list: impl Iterator<Item = &'a StoreKey>,
search_list: impl ParallelIterator<Item = &'a StoreKey>,
_used_all: bool,
n: NonZeroUsize,
) -> Vec<(StoreKey, f32)> {
let mut heap: AlgorithmHeapType = (self, n).into();

let heap_order: HeapOrder = self.into();
let similarity_function: SimilarityFunc = self.into();

for second_vector in search_list {
let similarity = similarity_function(search_vector, second_vector);

let heap_value: SimilarityVector = (second_vector, similarity).into();
heap.push(heap_value)
match heap_order {
HeapOrder::Min => pop_n(
&mut search_list
.map(|second_vector| {
let similarity = similarity_function(search_vector, second_vector);
let heap_value: SimilarityVector = (second_vector, similarity).into();
Reverse(heap_value)
})
.collect::<BinaryHeap<_>>(),
n,
)
.into_par_iter()
.map(|a| (a.0 .0 .0.clone(), a.0 .0 .1))
.collect(),
HeapOrder::Max => pop_n(
&mut search_list
.map(|second_vector| {
let similarity = similarity_function(search_vector, second_vector);
let heap_value: SimilarityVector = (second_vector, similarity).into();
heap_value
})
.collect::<BinaryHeap<_>>(),
n,
)
.into_par_iter()
.map(|a| (a.0 .0.clone(), a.0 .1))
.collect(),
}
heap.output()
}
}

#[cfg(test)]
mod tests {
use rayon::iter::IntoParallelRefIterator;

use super::*;
use crate::tests::*;

Expand All @@ -135,7 +162,7 @@ mod tests {

let similar_n_search = cosine_algorithm.find_similar_n(
&first_vector,
search_list.iter(),
search_list.par_iter(),
false,
NonZeroUsize::new(no_similar_values).unwrap(),
);
Expand Down
Loading

0 comments on commit 88fcbab

Please sign in to comment.