From 305e4a3473b4bffa8e0ed48625a801c7a585ae94 Mon Sep 17 00:00:00 2001 From: "Eric S. Tellez" Date: Sun, 10 Mar 2024 08:35:00 -0600 Subject: [PATCH] adds farthest first traversal fft --- src/SimilaritySearch.jl | 1 + src/fft.jl | 49 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + test/testfft.jl | 12 ++++++++++ 4 files changed, 63 insertions(+) create mode 100644 src/fft.jl create mode 100644 test/testfft.jl diff --git a/src/SimilaritySearch.jl b/src/SimilaritySearch.jl index 21dcb19..b3752ba 100644 --- a/src/SimilaritySearch.jl +++ b/src/SimilaritySearch.jl @@ -94,6 +94,7 @@ include("deprecated.jl") include("allknn.jl") include("neardup.jl") +include("fft.jl") include("closestpair.jl") include("hsp.jl") diff --git a/src/fft.jl b/src/fft.jl new file mode 100644 index 0000000..4739941 --- /dev/null +++ b/src/fft.jl @@ -0,0 +1,49 @@ +# This file is a part of SimilaritySearch.jl + +export fft + +""" + fft(dist::SemiMetric, X::AbstractDatabase, k; verbose=true) + +Selects `k` items far from each other based on Farthest First Traversal algorithm. + +Returns a named tuple with the following fields: +- `centers` contains the list of centers (indexes to ``X``) +- `nn` the id of the nearest center (in ``X`` order, identifiers between 1 to `length(X)) +- `nndists` the distance from each object in the database to its nearest centers (in ``X`` order) +- `dmax` smallest distance among centers + +Based on `enet.jl` from `KCenters.jl` +""" +function fft(dist::SemiMetric, X::AbstractDatabase, k::Integer; verbose=true) + N = length(X) + centers = Int32[] + dmaxlist = Float32[] + nndists = Vector{Float32}(undef, N) + fill!(nndists, typemax(Float32)) + nn = zeros(UInt32, N) + imax::Int = rand(1:N) + dmax::Float32 = typemax(Float32) + N == 0 && return (; centers, nn, dists=nndists, dmax) + + @inbounds for i in 1:N + push!(dmaxlist, dmax) + push!(centers, imax) + verbose && println(stderr, "computing fartest point $(length(centers)), dmax: $dmax, imax: $imax, n: $(length(X))") + + pivot = X[imax] + @batch minbatch=getminbatch(0, N) for i in 1:N + d = evaluate(dist, X[i], pivot) + if d < nndists[i] + nndists[i] = d + nn[i] = imax + end + end + + dmax, imax = findmax(nndists) + length(dmaxlist) < k || break + end + + (; centers, nn, dists=nndists, dmax) +end + diff --git a/test/runtests.jl b/test/runtests.jl index d896863..7dd56c9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,5 +23,6 @@ include("testadj.jl") include("testallknn.jl") include("testhsp.jl") include("testneardup.jl") +include("testfft.jl") include("testclosestpair.jl") include("testsearchgraph.jl") diff --git a/test/testfft.jl b/test/testfft.jl new file mode 100644 index 0000000..cc61493 --- /dev/null +++ b/test/testfft.jl @@ -0,0 +1,12 @@ +# This file is a part of SimilaritySearch.jl + +using Test, SimilaritySearch, LinearAlgebra + +@testset "farthest first traversal" begin + dist = L2Distance() + X = rand(Float32, 4, 300) + res = fft(dist, MatrixDatabase(X), 30) + @test Set(res.centers) == Set(res.nn) + @test all(res.dmax .>= res.dists) +end +