-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
372 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# This file is a part of SimilaritySearch.jl | ||
using Intersections | ||
|
||
export maxlength, maxlength, getpair, getdist, getid, initialstate, idview, distview, reuse! | ||
|
||
export KnnResult | ||
|
||
""" | ||
KnnResult(ksearch::Integer) | ||
Creates a priority queue with fixed capacity (`ksearch`) representing a knn result set. | ||
It starts with zero items and grows with [`push!(res, id, dist)`](@ref) calls until `ksearch` | ||
size is reached. After this only the smallest items based on distance are preserved. | ||
""" | ||
struct KnnResult # <: AbstractVector{Tuple{IdType,DistType}} | ||
id::Vector{Int32} | ||
dist::Vector{Float32} | ||
k::Int # number of neighbors | ||
end | ||
|
||
function KnnResult(k::Integer) | ||
@assert k > 0 | ||
res = KnnResult(Vector{Int32}(undef, 0), Vector{Float32}(undef, 0), k) | ||
sizehint!(res.id, k) | ||
sizehint!(res.dist, k) | ||
res | ||
end | ||
|
||
""" | ||
_shifted_fixorder!(res, shift=0) | ||
Sorts the result in place; the possible element out of order is on the last entry always. | ||
It implements a kind of insertion sort that it is efficient due to the expected | ||
distribution of the items being inserted (it is expected just a few elements smaller than the current ones) | ||
""" | ||
function _shifted_fixorder!(res, shift=0) | ||
sp = shift + 1 | ||
pos = N = lastindex(res.id) | ||
id = res.id | ||
dist = res.dist | ||
id_, dist_ = res.id[end], res.dist[end] | ||
|
||
#pos = doublingsearch(dist, dist_, sp, N) | ||
#pos = binarysearch(dist, dist_, sp, N) | ||
#if N > 16 | ||
# pos = doublingsearchrev(dist, dist_, sp, N)::Int | ||
#else | ||
@inbounds while pos > sp && dist_ < dist[pos-1] | ||
pos -= 1 | ||
end | ||
#end | ||
|
||
@inbounds if pos < N | ||
while N > pos | ||
id[N] = id[N-1] | ||
dist[N] = dist[N-1] | ||
N -= 1 | ||
end | ||
|
||
dist[N] = dist_ | ||
id[N] = id_ | ||
end | ||
|
||
nothing | ||
end | ||
|
||
|
||
""" | ||
push!(res::KnnResult, item::Pair) | ||
push!(res::KnnResult, id::Integer, dist::Real) | ||
Appends an item into the result set | ||
""" | ||
@inline function Base.push!(res::KnnResult, id::Integer, dist::Real) | ||
if length(res) < maxlength(res) | ||
k = res.k | ||
push!(res.id, id) | ||
push!(res.dist, dist) | ||
|
||
_shifted_fixorder!(res) | ||
return true | ||
end | ||
|
||
dist >= last(res.dist) && return false | ||
|
||
@inbounds res.id[end], res.dist[end] = id, dist | ||
_shifted_fixorder!(res) | ||
#_shifted_fixorder!(res.shift, res.id, res.dist) | ||
true | ||
end | ||
|
||
#@inline Base.push!(res::KnnResult, id::Integer, dist::Real) = push!(res, convert(Int32, id), convert(Float32, dist)) | ||
@inline Base.push!(res::KnnResult, p::Pair) = push!(res, p.first, p.second) | ||
|
||
""" | ||
popfirst!(p::KnnResult) | ||
Removes and returns the nearest neeighboor pair from the pool, an O(length(p.pool)) operation | ||
""" | ||
@inline function Base.popfirst!(res::KnnResult) | ||
popfirst!(res.id) => popfirst!(res.dist) | ||
end | ||
|
||
""" | ||
pop!(res::KnnResult) | ||
Removes and returns the last item in the pool, it is an O(1) operation | ||
""" | ||
@inline function Base.pop!(res::KnnResult) | ||
pop!(res.id) => pop!(res.dist) | ||
end | ||
|
||
""" | ||
maxlength(res::KnnResult) | ||
The maximum allowed cardinality (the k of knn) | ||
""" | ||
@inline maxlength(res::KnnResult) = res.k | ||
@inline Base.length(res::KnnResult) = length(res.id) | ||
|
||
""" | ||
reuse!(res::KnnResult) | ||
reuse!(res::KnnResult, k::Integer) | ||
Returns a result set and a new initial state; reuse the memory buffers | ||
""" | ||
@inline function reuse!(res::KnnResult, k::Integer=res.k) | ||
@assert k > 0 | ||
empty!(res.id) | ||
empty!(res.dist) | ||
if k > res.k | ||
sizehint!(res.id, k) | ||
sizehint!(res.dist, k) | ||
end | ||
KnnResult(res.id, res.dist, k) | ||
end | ||
|
||
""" | ||
getindex(res::KnnResult, i) | ||
Access the i-th item in `res` | ||
""" | ||
@inline function getpair(res::KnnResult, i) | ||
@inbounds res.id[i] => res.dist[i] | ||
end | ||
|
||
@inline getid(res::KnnResult, i) = @inbounds res.id[i] | ||
@inline getdist(res::KnnResult, i) = @inbounds res.dist[i] | ||
|
||
@inline Base.last(res::KnnResult) = last(res.id) => last(res.dist) | ||
@inline Base.first(res::KnnResult) = @inbounds res.id[1] => res.dist[1] | ||
@inline Base.maximum(res::KnnResult) = last(res.dist) | ||
@inline Base.minimum(res::KnnResult) = @inbounds res.dist[1] | ||
@inline Base.argmax(res::KnnResult) = last(res.id) | ||
@inline Base.argmin(res::KnnResult) = @inbounds res.id[1] | ||
|
||
@inline idview(res::KnnResult) = res.id | ||
@inline distview(res::KnnResult) = res.dist | ||
|
||
@inline Base.eachindex(res::KnnResult) = 1:length(res) | ||
Base.eltype(res::KnnResult) = Pair{Int32,Float32} | ||
|
||
##### iterator interface | ||
### KnnResult | ||
""" | ||
Base.iterate(res::KnnResult, state::Int=1) | ||
Support for iteration | ||
""" | ||
function Base.iterate(res::KnnResult, i::Int=1) | ||
n = length(res) | ||
(n == 0 || i > n) && return nothing | ||
@inbounds res.id[i] => res.dist[i], i+1 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
# This file is a part of SimilaritySearch.jl | ||
struct KnnResultState | ||
shift::Int | ||
end | ||
|
||
export initialstate, KnnResultShift | ||
|
||
""" | ||
KnnResultShift(ksearch::Integer) | ||
Creates a priority queue with fixed capacity (`ksearch`) representing a knn result set. | ||
It starts with zero items and grows with [`push!(res, id, dist)`](@ref) calls until `ksearch` | ||
size is reached. After this only the smallest items based on distance are preserved. | ||
""" | ||
struct KnnResultShift | ||
id::Vector{Int32} | ||
dist::Vector{Float32} | ||
k::Int # number of neighbors | ||
end | ||
|
||
function KnnResultShift(k::Integer) | ||
@assert k > 0 | ||
res = KnnResultShift(Vector{Int32}(undef, 0), Vector{Float32}(undef, 0), k) | ||
sizehint!(res.id, k) | ||
sizehint!(res.dist, k) | ||
res | ||
end | ||
|
||
function initialstate(::KnnResultShift) | ||
KnnResultState(0) | ||
end | ||
|
||
""" | ||
push!(res::KnnResultShift, item::Pair) | ||
push!(res::KnnResultShift, id::Integer, dist::Real) | ||
Appends an item into the result set | ||
""" | ||
@inline function Base.push!(res::KnnResultShift, st::KnnResultState, id::Integer, dist::Real) | ||
if length(res, st) < maxlength(res) | ||
k = res.k | ||
if length(res.id) >= 2k-1 | ||
compact!(res, st, 1) | ||
st = KnnResultState(0) | ||
@inbounds res.id[end], res.dist[end] = id, dist | ||
else | ||
push!(res.id, id) | ||
push!(res.dist, dist) | ||
end | ||
|
||
_shifted_fixorder!(res, st.shift) | ||
#_shifted_fixorder!(res.shift, res.id, res.dist) | ||
return st | ||
end | ||
|
||
dist >= last(res.dist) && return st | ||
|
||
@inbounds res.id[end], res.dist[end] = id, dist | ||
_shifted_fixorder!(res, st.shift) | ||
#_shifted_fixorder!(res.shift, res.id, res.dist) | ||
st | ||
end | ||
|
||
#@inline Base.push!(res::KnnResultShift, st::KnnResultState, id::Integer, dist::Real) = push!(res, st, convert(Int32, id), convert(Float32, dist)) | ||
@inline Base.push!(res::KnnResultShift, st::KnnResultState, p::Pair) = push!(res, st, p.first, p.second) | ||
|
||
function compact!(res::KnnResultShift, st::KnnResultState, resize_extra) | ||
shift = st.shift | ||
if shift > 0 | ||
n = length(res, st) | ||
j = shift | ||
@inbounds for i in 1:n | ||
j += 1 | ||
res.id[i] = res.id[j] | ||
res.dist[i] = res.dist[j] | ||
end | ||
|
||
resize!(res.id, n+resize_extra) | ||
resize!(res.dist, n+resize_extra) | ||
end | ||
|
||
res | ||
end | ||
|
||
""" | ||
popfirst!(p::KnnResultShift, st::KnnResultState) | ||
Removes and returns the nearest neeighboor pair from the pool, an O(length(p.pool)) operation | ||
""" | ||
@inline function Base.popfirst!(res::KnnResultShift, st::KnnResultState) | ||
p = argmin(res, st) => minimum(res, st) | ||
res.id[1] = 0 # mark as deleted | ||
p, KnnResultState(st.shift+1) | ||
end | ||
|
||
""" | ||
pop!(res::KnnResultShift, st::KnnResultState) | ||
Removes and returns the last item in the pool, it is an O(1) operation | ||
""" | ||
@inline function Base.pop!(res::KnnResultShift, st::KnnResultState) | ||
pop!(res.id) => pop!(res.dist), st | ||
end | ||
|
||
""" | ||
maxlength(res::KnnResultShift) | ||
The maximum allowed cardinality (the k of knn) | ||
""" | ||
@inline maxlength(res::KnnResultShift) = res.k | ||
@inline Base.length(res::KnnResultShift, st::KnnResultState) = length(res.id) - st.shift | ||
|
||
function Base.length(res::KnnResultShift) | ||
i = 1 | ||
n = length(res.id) | ||
while i < n && res.id[i] == 0 | ||
i += 1 | ||
end | ||
|
||
n - i + 1 | ||
end | ||
|
||
""" | ||
reuse!(res::KnnResultShift) | ||
reuse!(res::KnnResultShift, k::Integer) | ||
Returns a result set and a new initial state; reuse the memory buffers | ||
""" | ||
@inline function reuse!(res::KnnResultShift, k::Integer=res.k) | ||
@assert k > 0 | ||
empty!(res.id) | ||
empty!(res.dist) | ||
if k > res.k | ||
sizehint!(res.id, k) | ||
sizehint!(res.dist, k) | ||
end | ||
KnnResultShift(res.id, res.dist, k) | ||
end | ||
|
||
""" | ||
getindex(res::KnnResultShift, st::KnnResultState, i) | ||
Access the i-th item in `res` | ||
""" | ||
@inline function getpair(res::KnnResultShift, st::KnnResultState, i) | ||
i += st.shift | ||
@inbounds res.id[i] => res.dist[i] | ||
end | ||
|
||
@inline getid(res::KnnResultShift, st::KnnResultState, i) = @inbounds res.id[i+st.shift] | ||
@inline getdist(res::KnnResultShift, st::KnnResultState, i) = @inbounds res.dist[i+st.shift] | ||
|
||
@inline Base.last(res::KnnResultShift, st::KnnResultState) = last(res.id) => last(res.dist) | ||
@inline Base.first(res::KnnResultShift, st::KnnResultState) = res.id[st.shift+1] => res.dist[st.shift+1] | ||
@inline Base.maximum(res::KnnResultShift, st::KnnResultState) = last(res.dist) | ||
@inline Base.minimum(res::KnnResultShift, st::KnnResultState) = res.dist[1+st.shift] | ||
@inline Base.argmax(res::KnnResultShift, st::KnnResultState) = last(res.id) | ||
@inline Base.argmin(res::KnnResultShift, st::KnnResultState) = res.id[1+st.shift] | ||
|
||
Base.maximum(res::KnnResultShift) = last(res.dist) | ||
Base.argmax(res::KnnResultShift) = last(res.id) | ||
Base.minimum(res::KnnResultShift) = res.dist[_find_start_position(res)] | ||
Base.argmin(res::KnnResultShift) = res.id[_find_start_position(res)] | ||
|
||
@inline idview(res::KnnResultShift, st::KnnResultState) = @view res.id[st.shift+1:end] | ||
@inline distview(res::KnnResultShift, st::KnnResultState) = @view res.dist[st.shift+1:end] | ||
|
||
@inline Base.eachindex(res::KnnResultShift, st::KnnResultState) = 1:length(res, st) | ||
Base.eltype(res::KnnResultShift) = Pair{Int32,Float32} | ||
|
||
##### iterator interface | ||
### KnnResultShift | ||
""" | ||
Base.iterate(res::KnnResultShift, state::Int=1) | ||
Support for iteration | ||
""" | ||
@inline function _find_start_position(res::KnnResultShift) | ||
i = 1 | ||
id = res.id | ||
n = length(id) | ||
@inbounds while i <= n && id[i] == 0 | ||
i += 1 | ||
end | ||
|
||
i | ||
end | ||
|
||
function Base.iterate(res::KnnResultShift, i::Int=-1) | ||
n = length(res.id) | ||
n == 0 && return nothing | ||
if i == -1 | ||
i = _find_start_position(res) | ||
end | ||
|
||
i > n && return nothing | ||
@inbounds res.id[i] => res.dist[i], i+1 | ||
end |
f158f56
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register()
f158f56
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/52520
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: