diff --git a/src/knnresult.jl b/src/knnresult.jl new file mode 100644 index 0000000..d282531 --- /dev/null +++ b/src/knnresult.jl @@ -0,0 +1,174 @@ +# This file is a part of SimilaritySearch.jl +using Intersections + +export maxlength, maxlength, getpair, getdist, getid, initialstate, idview, distview, reuse! + +export KnnResult + +""" + KnnResult(ksearch::Integer) + +Creates a priority queue with fixed capacity (`ksearch`) representing a knn result set. +It starts with zero items and grows with [`push!(res, id, dist)`](@ref) calls until `ksearch` +size is reached. After this only the smallest items based on distance are preserved. +""" +struct KnnResult # <: AbstractVector{Tuple{IdType,DistType}} + id::Vector{Int32} + dist::Vector{Float32} + k::Int # number of neighbors +end + +function KnnResult(k::Integer) + @assert k > 0 + res = KnnResult(Vector{Int32}(undef, 0), Vector{Float32}(undef, 0), k) + sizehint!(res.id, k) + sizehint!(res.dist, k) + res +end + +""" + _shifted_fixorder!(res, shift=0) + +Sorts the result in place; the possible element out of order is on the last entry always. +It implements a kind of insertion sort that it is efficient due to the expected +distribution of the items being inserted (it is expected just a few elements smaller than the current ones) +""" +function _shifted_fixorder!(res, shift=0) + sp = shift + 1 + pos = N = lastindex(res.id) + id = res.id + dist = res.dist + id_, dist_ = res.id[end], res.dist[end] + + #pos = doublingsearch(dist, dist_, sp, N) + #pos = binarysearch(dist, dist_, sp, N) + #if N > 16 + # pos = doublingsearchrev(dist, dist_, sp, N)::Int + #else + @inbounds while pos > sp && dist_ < dist[pos-1] + pos -= 1 + end + #end + + @inbounds if pos < N + while N > pos + id[N] = id[N-1] + dist[N] = dist[N-1] + N -= 1 + end + + dist[N] = dist_ + id[N] = id_ + end + + nothing +end + + +""" + push!(res::KnnResult, item::Pair) + push!(res::KnnResult, id::Integer, dist::Real) + +Appends an item into the result set +""" +@inline function Base.push!(res::KnnResult, id::Integer, dist::Real) + if length(res) < maxlength(res) + k = res.k + push!(res.id, id) + push!(res.dist, dist) + + _shifted_fixorder!(res) + return true + end + + dist >= last(res.dist) && return false + + @inbounds res.id[end], res.dist[end] = id, dist + _shifted_fixorder!(res) + #_shifted_fixorder!(res.shift, res.id, res.dist) + true +end + +#@inline Base.push!(res::KnnResult, id::Integer, dist::Real) = push!(res, convert(Int32, id), convert(Float32, dist)) +@inline Base.push!(res::KnnResult, p::Pair) = push!(res, p.first, p.second) + +""" + popfirst!(p::KnnResult) + +Removes and returns the nearest neeighboor pair from the pool, an O(length(p.pool)) operation +""" +@inline function Base.popfirst!(res::KnnResult) + popfirst!(res.id) => popfirst!(res.dist) +end + +""" + pop!(res::KnnResult) + +Removes and returns the last item in the pool, it is an O(1) operation +""" +@inline function Base.pop!(res::KnnResult) + pop!(res.id) => pop!(res.dist) +end + +""" + maxlength(res::KnnResult) + +The maximum allowed cardinality (the k of knn) +""" +@inline maxlength(res::KnnResult) = res.k +@inline Base.length(res::KnnResult) = length(res.id) + +""" + reuse!(res::KnnResult) + reuse!(res::KnnResult, k::Integer) + +Returns a result set and a new initial state; reuse the memory buffers +""" +@inline function reuse!(res::KnnResult, k::Integer=res.k) + @assert k > 0 + empty!(res.id) + empty!(res.dist) + if k > res.k + sizehint!(res.id, k) + sizehint!(res.dist, k) + end + KnnResult(res.id, res.dist, k) +end + +""" + getindex(res::KnnResult, i) + +Access the i-th item in `res` +""" +@inline function getpair(res::KnnResult, i) + @inbounds res.id[i] => res.dist[i] +end + +@inline getid(res::KnnResult, i) = @inbounds res.id[i] +@inline getdist(res::KnnResult, i) = @inbounds res.dist[i] + +@inline Base.last(res::KnnResult) = last(res.id) => last(res.dist) +@inline Base.first(res::KnnResult) = @inbounds res.id[1] => res.dist[1] +@inline Base.maximum(res::KnnResult) = last(res.dist) +@inline Base.minimum(res::KnnResult) = @inbounds res.dist[1] +@inline Base.argmax(res::KnnResult) = last(res.id) +@inline Base.argmin(res::KnnResult) = @inbounds res.id[1] + +@inline idview(res::KnnResult) = res.id +@inline distview(res::KnnResult) = res.dist + +@inline Base.eachindex(res::KnnResult) = 1:length(res) +Base.eltype(res::KnnResult) = Pair{Int32,Float32} + +##### iterator interface +### KnnResult +""" + Base.iterate(res::KnnResult, state::Int=1) + +Support for iteration +""" +function Base.iterate(res::KnnResult, i::Int=1) + n = length(res) + (n == 0 || i > n) && return nothing + @inbounds res.id[i] => res.dist[i], i+1 +end \ No newline at end of file diff --git a/src/knnresultshift.jl b/src/knnresultshift.jl new file mode 100644 index 0000000..708435b --- /dev/null +++ b/src/knnresultshift.jl @@ -0,0 +1,198 @@ +# This file is a part of SimilaritySearch.jl +struct KnnResultState + shift::Int +end + +export initialstate, KnnResultShift + +""" + KnnResultShift(ksearch::Integer) + +Creates a priority queue with fixed capacity (`ksearch`) representing a knn result set. +It starts with zero items and grows with [`push!(res, id, dist)`](@ref) calls until `ksearch` +size is reached. After this only the smallest items based on distance are preserved. +""" +struct KnnResultShift + id::Vector{Int32} + dist::Vector{Float32} + k::Int # number of neighbors +end + +function KnnResultShift(k::Integer) + @assert k > 0 + res = KnnResultShift(Vector{Int32}(undef, 0), Vector{Float32}(undef, 0), k) + sizehint!(res.id, k) + sizehint!(res.dist, k) + res +end + +function initialstate(::KnnResultShift) + KnnResultState(0) +end + +""" + push!(res::KnnResultShift, item::Pair) + push!(res::KnnResultShift, id::Integer, dist::Real) + +Appends an item into the result set +""" +@inline function Base.push!(res::KnnResultShift, st::KnnResultState, id::Integer, dist::Real) + if length(res, st) < maxlength(res) + k = res.k + if length(res.id) >= 2k-1 + compact!(res, st, 1) + st = KnnResultState(0) + @inbounds res.id[end], res.dist[end] = id, dist + else + push!(res.id, id) + push!(res.dist, dist) + end + + _shifted_fixorder!(res, st.shift) + #_shifted_fixorder!(res.shift, res.id, res.dist) + return st + end + + dist >= last(res.dist) && return st + + @inbounds res.id[end], res.dist[end] = id, dist + _shifted_fixorder!(res, st.shift) + #_shifted_fixorder!(res.shift, res.id, res.dist) + st +end + +#@inline Base.push!(res::KnnResultShift, st::KnnResultState, id::Integer, dist::Real) = push!(res, st, convert(Int32, id), convert(Float32, dist)) +@inline Base.push!(res::KnnResultShift, st::KnnResultState, p::Pair) = push!(res, st, p.first, p.second) + +function compact!(res::KnnResultShift, st::KnnResultState, resize_extra) + shift = st.shift + if shift > 0 + n = length(res, st) + j = shift + @inbounds for i in 1:n + j += 1 + res.id[i] = res.id[j] + res.dist[i] = res.dist[j] + end + + resize!(res.id, n+resize_extra) + resize!(res.dist, n+resize_extra) + end + + res +end + +""" + popfirst!(p::KnnResultShift, st::KnnResultState) + +Removes and returns the nearest neeighboor pair from the pool, an O(length(p.pool)) operation +""" +@inline function Base.popfirst!(res::KnnResultShift, st::KnnResultState) + p = argmin(res, st) => minimum(res, st) + res.id[1] = 0 # mark as deleted + p, KnnResultState(st.shift+1) +end + +""" + pop!(res::KnnResultShift, st::KnnResultState) + +Removes and returns the last item in the pool, it is an O(1) operation +""" +@inline function Base.pop!(res::KnnResultShift, st::KnnResultState) + pop!(res.id) => pop!(res.dist), st +end + +""" + maxlength(res::KnnResultShift) + +The maximum allowed cardinality (the k of knn) +""" +@inline maxlength(res::KnnResultShift) = res.k +@inline Base.length(res::KnnResultShift, st::KnnResultState) = length(res.id) - st.shift + +function Base.length(res::KnnResultShift) + i = 1 + n = length(res.id) + while i < n && res.id[i] == 0 + i += 1 + end + + n - i + 1 +end + +""" + reuse!(res::KnnResultShift) + reuse!(res::KnnResultShift, k::Integer) + +Returns a result set and a new initial state; reuse the memory buffers +""" +@inline function reuse!(res::KnnResultShift, k::Integer=res.k) + @assert k > 0 + empty!(res.id) + empty!(res.dist) + if k > res.k + sizehint!(res.id, k) + sizehint!(res.dist, k) + end + KnnResultShift(res.id, res.dist, k) +end + +""" + getindex(res::KnnResultShift, st::KnnResultState, i) + +Access the i-th item in `res` +""" +@inline function getpair(res::KnnResultShift, st::KnnResultState, i) + i += st.shift + @inbounds res.id[i] => res.dist[i] +end + +@inline getid(res::KnnResultShift, st::KnnResultState, i) = @inbounds res.id[i+st.shift] +@inline getdist(res::KnnResultShift, st::KnnResultState, i) = @inbounds res.dist[i+st.shift] + +@inline Base.last(res::KnnResultShift, st::KnnResultState) = last(res.id) => last(res.dist) +@inline Base.first(res::KnnResultShift, st::KnnResultState) = res.id[st.shift+1] => res.dist[st.shift+1] +@inline Base.maximum(res::KnnResultShift, st::KnnResultState) = last(res.dist) +@inline Base.minimum(res::KnnResultShift, st::KnnResultState) = res.dist[1+st.shift] +@inline Base.argmax(res::KnnResultShift, st::KnnResultState) = last(res.id) +@inline Base.argmin(res::KnnResultShift, st::KnnResultState) = res.id[1+st.shift] + +Base.maximum(res::KnnResultShift) = last(res.dist) +Base.argmax(res::KnnResultShift) = last(res.id) +Base.minimum(res::KnnResultShift) = res.dist[_find_start_position(res)] +Base.argmin(res::KnnResultShift) = res.id[_find_start_position(res)] + +@inline idview(res::KnnResultShift, st::KnnResultState) = @view res.id[st.shift+1:end] +@inline distview(res::KnnResultShift, st::KnnResultState) = @view res.dist[st.shift+1:end] + +@inline Base.eachindex(res::KnnResultShift, st::KnnResultState) = 1:length(res, st) +Base.eltype(res::KnnResultShift) = Pair{Int32,Float32} + +##### iterator interface +### KnnResultShift +""" + Base.iterate(res::KnnResultShift, state::Int=1) + +Support for iteration +""" +@inline function _find_start_position(res::KnnResultShift) + i = 1 + id = res.id + n = length(id) + @inbounds while i <= n && id[i] == 0 + i += 1 + end + + i +end + +function Base.iterate(res::KnnResultShift, i::Int=-1) + n = length(res.id) + n == 0 && return nothing + if i == -1 + i = _find_start_position(res) + end + + i > n && return nothing + @inbounds res.id[i] => res.dist[i], i+1 +end \ No newline at end of file