Skip to content

Commit

Permalink
simplified knnresult*
Browse files Browse the repository at this point in the history
  • Loading branch information
sadit committed Jan 16, 2022
1 parent fd43dbf commit f158f56
Show file tree
Hide file tree
Showing 2 changed files with 372 additions and 0 deletions.
174 changes: 174 additions & 0 deletions src/knnresult.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# This file is a part of SimilaritySearch.jl
using Intersections

export maxlength, maxlength, getpair, getdist, getid, initialstate, idview, distview, reuse!

export KnnResult

"""
KnnResult(ksearch::Integer)
Creates a priority queue with fixed capacity (`ksearch`) representing a knn result set.
It starts with zero items and grows with [`push!(res, id, dist)`](@ref) calls until `ksearch`
size is reached. After this only the smallest items based on distance are preserved.
"""
struct KnnResult # <: AbstractVector{Tuple{IdType,DistType}}
id::Vector{Int32}
dist::Vector{Float32}
k::Int # number of neighbors
end

function KnnResult(k::Integer)
@assert k > 0
res = KnnResult(Vector{Int32}(undef, 0), Vector{Float32}(undef, 0), k)
sizehint!(res.id, k)
sizehint!(res.dist, k)
res
end

"""
_shifted_fixorder!(res, shift=0)
Sorts the result in place; the possible element out of order is on the last entry always.
It implements a kind of insertion sort that it is efficient due to the expected
distribution of the items being inserted (it is expected just a few elements smaller than the current ones)
"""
function _shifted_fixorder!(res, shift=0)
sp = shift + 1
pos = N = lastindex(res.id)
id = res.id
dist = res.dist
id_, dist_ = res.id[end], res.dist[end]

#pos = doublingsearch(dist, dist_, sp, N)
#pos = binarysearch(dist, dist_, sp, N)
#if N > 16
# pos = doublingsearchrev(dist, dist_, sp, N)::Int
#else
@inbounds while pos > sp && dist_ < dist[pos-1]
pos -= 1
end
#end

@inbounds if pos < N
while N > pos
id[N] = id[N-1]
dist[N] = dist[N-1]
N -= 1
end

dist[N] = dist_
id[N] = id_
end

nothing
end


"""
push!(res::KnnResult, item::Pair)
push!(res::KnnResult, id::Integer, dist::Real)
Appends an item into the result set
"""
@inline function Base.push!(res::KnnResult, id::Integer, dist::Real)
if length(res) < maxlength(res)
k = res.k
push!(res.id, id)
push!(res.dist, dist)

_shifted_fixorder!(res)
return true
end

dist >= last(res.dist) && return false

@inbounds res.id[end], res.dist[end] = id, dist
_shifted_fixorder!(res)
#_shifted_fixorder!(res.shift, res.id, res.dist)
true
end

#@inline Base.push!(res::KnnResult, id::Integer, dist::Real) = push!(res, convert(Int32, id), convert(Float32, dist))
@inline Base.push!(res::KnnResult, p::Pair) = push!(res, p.first, p.second)

"""
popfirst!(p::KnnResult)
Removes and returns the nearest neeighboor pair from the pool, an O(length(p.pool)) operation
"""
@inline function Base.popfirst!(res::KnnResult)
popfirst!(res.id) => popfirst!(res.dist)
end

"""
pop!(res::KnnResult)
Removes and returns the last item in the pool, it is an O(1) operation
"""
@inline function Base.pop!(res::KnnResult)
pop!(res.id) => pop!(res.dist)
end

"""
maxlength(res::KnnResult)
The maximum allowed cardinality (the k of knn)
"""
@inline maxlength(res::KnnResult) = res.k
@inline Base.length(res::KnnResult) = length(res.id)

"""
reuse!(res::KnnResult)
reuse!(res::KnnResult, k::Integer)
Returns a result set and a new initial state; reuse the memory buffers
"""
@inline function reuse!(res::KnnResult, k::Integer=res.k)
@assert k > 0
empty!(res.id)
empty!(res.dist)
if k > res.k
sizehint!(res.id, k)
sizehint!(res.dist, k)
end
KnnResult(res.id, res.dist, k)
end

"""
getindex(res::KnnResult, i)
Access the i-th item in `res`
"""
@inline function getpair(res::KnnResult, i)
@inbounds res.id[i] => res.dist[i]
end

@inline getid(res::KnnResult, i) = @inbounds res.id[i]
@inline getdist(res::KnnResult, i) = @inbounds res.dist[i]

@inline Base.last(res::KnnResult) = last(res.id) => last(res.dist)
@inline Base.first(res::KnnResult) = @inbounds res.id[1] => res.dist[1]
@inline Base.maximum(res::KnnResult) = last(res.dist)
@inline Base.minimum(res::KnnResult) = @inbounds res.dist[1]
@inline Base.argmax(res::KnnResult) = last(res.id)
@inline Base.argmin(res::KnnResult) = @inbounds res.id[1]

@inline idview(res::KnnResult) = res.id
@inline distview(res::KnnResult) = res.dist

@inline Base.eachindex(res::KnnResult) = 1:length(res)
Base.eltype(res::KnnResult) = Pair{Int32,Float32}

##### iterator interface
### KnnResult
"""
Base.iterate(res::KnnResult, state::Int=1)
Support for iteration
"""
function Base.iterate(res::KnnResult, i::Int=1)
n = length(res)
(n == 0 || i > n) && return nothing
@inbounds res.id[i] => res.dist[i], i+1
end
198 changes: 198 additions & 0 deletions src/knnresultshift.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# This file is a part of SimilaritySearch.jl
struct KnnResultState
shift::Int
end

export initialstate, KnnResultShift

"""
KnnResultShift(ksearch::Integer)
Creates a priority queue with fixed capacity (`ksearch`) representing a knn result set.
It starts with zero items and grows with [`push!(res, id, dist)`](@ref) calls until `ksearch`
size is reached. After this only the smallest items based on distance are preserved.
"""
struct KnnResultShift
id::Vector{Int32}
dist::Vector{Float32}
k::Int # number of neighbors
end

function KnnResultShift(k::Integer)
@assert k > 0
res = KnnResultShift(Vector{Int32}(undef, 0), Vector{Float32}(undef, 0), k)
sizehint!(res.id, k)
sizehint!(res.dist, k)
res
end

function initialstate(::KnnResultShift)
KnnResultState(0)
end

"""
push!(res::KnnResultShift, item::Pair)
push!(res::KnnResultShift, id::Integer, dist::Real)
Appends an item into the result set
"""
@inline function Base.push!(res::KnnResultShift, st::KnnResultState, id::Integer, dist::Real)
if length(res, st) < maxlength(res)
k = res.k
if length(res.id) >= 2k-1
compact!(res, st, 1)
st = KnnResultState(0)
@inbounds res.id[end], res.dist[end] = id, dist
else
push!(res.id, id)
push!(res.dist, dist)
end

_shifted_fixorder!(res, st.shift)
#_shifted_fixorder!(res.shift, res.id, res.dist)
return st
end

dist >= last(res.dist) && return st

@inbounds res.id[end], res.dist[end] = id, dist
_shifted_fixorder!(res, st.shift)
#_shifted_fixorder!(res.shift, res.id, res.dist)
st
end

#@inline Base.push!(res::KnnResultShift, st::KnnResultState, id::Integer, dist::Real) = push!(res, st, convert(Int32, id), convert(Float32, dist))
@inline Base.push!(res::KnnResultShift, st::KnnResultState, p::Pair) = push!(res, st, p.first, p.second)

function compact!(res::KnnResultShift, st::KnnResultState, resize_extra)
shift = st.shift
if shift > 0
n = length(res, st)
j = shift
@inbounds for i in 1:n
j += 1
res.id[i] = res.id[j]
res.dist[i] = res.dist[j]
end

resize!(res.id, n+resize_extra)
resize!(res.dist, n+resize_extra)
end

res
end

"""
popfirst!(p::KnnResultShift, st::KnnResultState)
Removes and returns the nearest neeighboor pair from the pool, an O(length(p.pool)) operation
"""
@inline function Base.popfirst!(res::KnnResultShift, st::KnnResultState)
p = argmin(res, st) => minimum(res, st)
res.id[1] = 0 # mark as deleted
p, KnnResultState(st.shift+1)
end

"""
pop!(res::KnnResultShift, st::KnnResultState)
Removes and returns the last item in the pool, it is an O(1) operation
"""
@inline function Base.pop!(res::KnnResultShift, st::KnnResultState)
pop!(res.id) => pop!(res.dist), st
end

"""
maxlength(res::KnnResultShift)
The maximum allowed cardinality (the k of knn)
"""
@inline maxlength(res::KnnResultShift) = res.k
@inline Base.length(res::KnnResultShift, st::KnnResultState) = length(res.id) - st.shift

function Base.length(res::KnnResultShift)
i = 1
n = length(res.id)
while i < n && res.id[i] == 0
i += 1
end

n - i + 1
end

"""
reuse!(res::KnnResultShift)
reuse!(res::KnnResultShift, k::Integer)
Returns a result set and a new initial state; reuse the memory buffers
"""
@inline function reuse!(res::KnnResultShift, k::Integer=res.k)
@assert k > 0
empty!(res.id)
empty!(res.dist)
if k > res.k
sizehint!(res.id, k)
sizehint!(res.dist, k)
end
KnnResultShift(res.id, res.dist, k)
end

"""
getindex(res::KnnResultShift, st::KnnResultState, i)
Access the i-th item in `res`
"""
@inline function getpair(res::KnnResultShift, st::KnnResultState, i)
i += st.shift
@inbounds res.id[i] => res.dist[i]
end

@inline getid(res::KnnResultShift, st::KnnResultState, i) = @inbounds res.id[i+st.shift]
@inline getdist(res::KnnResultShift, st::KnnResultState, i) = @inbounds res.dist[i+st.shift]

@inline Base.last(res::KnnResultShift, st::KnnResultState) = last(res.id) => last(res.dist)
@inline Base.first(res::KnnResultShift, st::KnnResultState) = res.id[st.shift+1] => res.dist[st.shift+1]
@inline Base.maximum(res::KnnResultShift, st::KnnResultState) = last(res.dist)
@inline Base.minimum(res::KnnResultShift, st::KnnResultState) = res.dist[1+st.shift]
@inline Base.argmax(res::KnnResultShift, st::KnnResultState) = last(res.id)
@inline Base.argmin(res::KnnResultShift, st::KnnResultState) = res.id[1+st.shift]

Base.maximum(res::KnnResultShift) = last(res.dist)
Base.argmax(res::KnnResultShift) = last(res.id)
Base.minimum(res::KnnResultShift) = res.dist[_find_start_position(res)]
Base.argmin(res::KnnResultShift) = res.id[_find_start_position(res)]

@inline idview(res::KnnResultShift, st::KnnResultState) = @view res.id[st.shift+1:end]
@inline distview(res::KnnResultShift, st::KnnResultState) = @view res.dist[st.shift+1:end]

@inline Base.eachindex(res::KnnResultShift, st::KnnResultState) = 1:length(res, st)
Base.eltype(res::KnnResultShift) = Pair{Int32,Float32}

##### iterator interface
### KnnResultShift
"""
Base.iterate(res::KnnResultShift, state::Int=1)
Support for iteration
"""
@inline function _find_start_position(res::KnnResultShift)
i = 1
id = res.id
n = length(id)
@inbounds while i <= n && id[i] == 0
i += 1
end

i
end

function Base.iterate(res::KnnResultShift, i::Int=-1)
n = length(res.id)
n == 0 && return nothing
if i == -1
i = _find_start_position(res)
end

i > n && return nothing
@inbounds res.id[i] => res.dist[i], i+1
end

2 comments on commit f158f56

@sadit
Copy link
Owner Author

@sadit sadit commented on f158f56 Jan 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/52520

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.8 -m "<description of version>" f158f56cea19a6f18784e8127aaef85d6e4d84ae
git push origin v0.8.8

Please sign in to comment.