Skip to content

Commit

Permalink
adds IdView and DistView
Browse files Browse the repository at this point in the history
  • Loading branch information
sadit committed Feb 18, 2023
1 parent 92a5dfe commit 5dab3d6
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
25 changes: 17 additions & 8 deletions src/knnresult.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is a part of SimilaritySearch.jl
# export AbstractResult
export KnnResult
export KnnResult, IdView, DistView
export covradius, maxlength, reuse!

"""
Expand Down Expand Up @@ -128,14 +128,23 @@ end
@inline Base.eachindex(res::KnnResult) = firstindex(res):lastindex(res)
Base.eltype(res::KnnResult) = IdWeight

##### iterator interface
### KnnResult
"""
Base.iterate(res::KnnResult, state::Int=1)
struct IdView
res::KnnResult
end

Support for iteration
"""
function Base.iterate(res::KnnResult, i::Int=1)
struct DistView
res::KnnResult
end

@inline Base.getindex(v::IdView, i::Integer) = v.res[i].id
@inline Base.getindex(v::DistView, i::Integer) = v.res[i].weight
@inline Base.eachindex(v::IdView) = 1:length(v.res)
@inline Base.eachindex(v::DistView) = 1:length(v.res)
@inline Base.length(v::IdView) = length(v.res)
@inline Base.length(v::DistView) = length(v.res)

##### iterator interface
function Base.iterate(res::Union{KnnResult,IdView,DistView}, i::Int=1)
n = length(res)
(n == 0 || i > n) && return nothing
@inbounds res[i], i+1
Expand Down
4 changes: 1 addition & 3 deletions src/opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ function create_error_function(index::AbstractSearchIndex, gold, knnlist::Vector
recall = if gold !== nothing
for (i, res) in enumerate(knnlist)
empty!(R[i])
for item in res
push!(R[i], item.id)
end
union!(R[i], IdView(res))
end

macrorecall(gold, R)
Expand Down
8 changes: 4 additions & 4 deletions src/perf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ export recallscore, macrorecall
Compute recall and precision scores from the result sets.
"""
function recallscore(gold, res)::Float64
length(intersect(_convert_as_set(gold), _convert_as_set(res))) / length(gold)
length(intersect(as_set(gold), as_set(res))) / length(gold)
end

_convert_as_set(a::Set) = a
_convert_as_set(a::AbstractVector) = Set(a)
_convert_as_set(a::KnnResult) = Set(item.id for item in res)
as_set(a::Set) = a
as_set(a::AbstractVector) = Set(a)
as_set(res::KnnResult) = Set(IdView(res))

"""
macrorecall(goldI::AbstractMatrix, resI::AbstractMatrix, k=size(goldI, 1))::Float64
Expand Down

2 comments on commit 5dab3d6

@sadit
Copy link
Owner Author

@sadit sadit commented on 5dab3d6 Feb 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/77998

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.0 -m "<description of version>" 5dab3d674b93ac35a4c75313a895a04a6499851c
git push origin v0.10.0

Please sign in to comment.