Skip to content

Commit

Permalink
uses objects instead of names to identify function errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sadit committed Jan 15, 2022
1 parent 34f0f0e commit c5cbbde
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimilaritySearch"
uuid = "053f045d-5466-53fd-b400-a066f88fe02a"
authors = ["Eric S. Tellez <donsadit@gmail.com>"]
version = "0.8.6"
version = "0.8.7"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
17 changes: 10 additions & 7 deletions src/graph/opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
using SearchModels, Random
using StatsBase
import SearchModels: combine, mutate
export OptimizeParameters, optimize!, BeamSearchSpace
export OptimizeParameters, optimize!, BeamSearchSpace, MinRecall, ParetoRecall, ParetoRadius

abstract type ErrorFunction end
struct MinRecall <: ErrorFunction end
struct ParetoRecall <: ErrorFunction end
struct ParetoRadius <: ErrorFunction end

@with_kw struct BeamSearchSpace <: AbstractSolutionSpace
bsize = 8:8:64
Expand All @@ -31,12 +35,12 @@ function mutate(space::BeamSearchSpace, c::BeamSearch, iter)
end

@with_kw mutable struct OptimizeParameters <: Callback
kind = :pareto_recall_searchtime # :pareto_distance_searchtime, :pareto_recall_searchtime, :minimum_recall_searchtime
kind::ErrorFunction = ParetoRecall()
initialpopulation = 16
params = SearchParams(maxpopulation=16, bsize=4, mutbsize=16, crossbsize=8, tol=-1.0, maxiters=16)
ksearch::Int32 = 10
numqueries::Int32 = 64
minrecall = 0.9 # used with :minimum_recall_searchtime
minrecall = 0.9 # used with MinRecall()
space::BeamSearchSpace = BeamSearchSpace()
end

Expand Down Expand Up @@ -130,9 +134,8 @@ function optimize!(
queries = SubDatabase(index.db, sample)
end

recall_options = (:pareto_recall_searchtime, :minimum_recall_searchtime)
knnlist = [KnnResult(opt.ksearch) for i in eachindex(queries)]
gold = if opt.kind in recall_options
gold = if opt.kind isa ParetoRecall || opt.kind isa MinRecall
db = @view index.db[1:length(index)]
seq = ExhaustiveSearch(index.dist, db)
searchbatch(seq, queries, knnlist; parallel=true)
Expand All @@ -156,9 +159,9 @@ function optimize!(

function geterr(p)
cost = p.visited[2] / M[]
if opt.kind === :pareto_recall_searchtime
if opt.kind isa ParetoRecall
cost^2 + (1.0 - p.recall)^2
elseif opt.kind === :minimum_recall_searchtime
elseif opt.kind isa MinRecall
p.recall < opt.minrecall ? 3.0 - 2 * p.recall : cost
else
_kfun(cost) + _kfun(p.radius[2] / R[])
Expand Down
22 changes: 11 additions & 11 deletions test/testsearchgraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,32 @@ using Test
@info "===="
end

@info "--- Optimizing parameters :pareto_distance_searchtime ---"
@info "--- Optimizing parameters ParetoRadius ---"
graph = SearchGraph(; dist, search_algo=BeamSearch(bsize=2), verbose=false)
graph.neighborhood.reduce = SatNeighborhood()
append!(graph, db)
@info "---- starting :pareto_distance_searchtime optimization ---"
@info "---- starting ParetoRadius optimization ---"
optimize!(graph, OptimizeParameters())
I, D, searchtime = timedsearchbatch(graph, queries, ksearch)
recall = macrorecall(goldI, I)
@info ":pareto_distance_search_time:> queries per second: ", 1/searchtime, ", recall:", recall
@info "ParetoRadius:> queries per second: ", 1/searchtime, ", recall:", recall
@info graph.search_algo
@test recall >= 0.6


@info "---- starting :pareto_recall_searchtime optimization ---"
optimize!(graph, OptimizeParameters(kind=:pareto_recall_searchtime))
@info "---- starting ParetoRecall optimization ---"
optimize!(graph, OptimizeParameters(kind=ParetoRecall()))
I, D, searchtime = timedsearchbatch(graph, queries, ksearch)
recall = macrorecall(goldI, I)
@info ":pareto_recall_search_time:> queries per second: ", 1/searchtime, ", recall:", recall
@info "ParetoRecall:> queries per second: ", 1/searchtime, ", recall:", recall
@info graph.search_algo
@test recall >= 0.6

@info "========================= Callback optimization ======================"
@info "--- Optimizing parameters :pareto_distance_searchtime ---"
@info "--- Optimizing parameters ParetoRadius ---"
graph = SearchGraph(; db, dist, search_algo=BeamSearch(bsize=2), verbose=false)
graph.neighborhood.reduce = SatNeighborhood()
push!(graph.callbacks, OptimizeParameters(kind=:pareto_distance_searchtime))
push!(graph.callbacks, OptimizeParameters(kind=ParetoRadius()))
index!(graph)
I, D, searchtime = timedsearchbatch(graph, queries, ksearch)
recall = macrorecall(goldI, I)
Expand All @@ -72,17 +72,17 @@ using Test
@test recall >= 0.6

@info "#############=========== Callback optimization 2 ==========###########"
@info "--- Optimizing parameters :pareto_distance_searchtime L2 ---"
@info "--- Optimizing parameters ParetoRadius L2 ---"
dim = 4
db = MatrixDatabase(ceil.(Int32, rand(Float32, dim, n) .* 100))
queries = VectorDatabase(ceil.(Int32, rand(Float32, dim, m) .* 100))
seq = ExhaustiveSearch(dist, db)
goldI, goldD = searchbatch(seq, queries, ksearch)
graph = SearchGraph(; db, dist, search_algo=BeamSearch(bsize=2), verbose=false)
graph.neighborhood.reduce = SatNeighborhood()
push!(graph.callbacks, OptimizeParameters(kind=:pareto_recall_searchtime))
push!(graph.callbacks, OptimizeParameters(kind=ParetoRecall()))
index!(graph)
#optimize!(graph, OptimizeParameters(kind=:minimum_recall_searchtime, minrecall=0.7))
#optimize!(graph, OptimizeParameters(kind=MinRecall(), minrecall=0.7))
I, D, searchtime = timedsearchbatch(graph, queries, ksearch)
recall = macrorecall(goldI, I)
@info "testing without additional optimizations> queries per second:", 1/searchtime, ", recall: ", recall
Expand Down

2 comments on commit c5cbbde

@sadit
Copy link
Owner Author

@sadit sadit commented on c5cbbde Jan 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/52450

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.7 -m "<description of version>" c5cbbde7f43e22df97d62096cc14b427ed04e1b1
git push origin v0.8.7

Please sign in to comment.