Skip to content

Commit

Permalink
chore: add gathers to the python bench script
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <kemingy94@gmail.com>
  • Loading branch information
kemingy committed Dec 23, 2024
1 parent 788408e commit 7dddee2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ description = "Clustering algorithms."
documentation = "https://docs.rs/gathers"
keywords = ["cluster", "kmeans", "rabitq", "machine-learning", "vector-search"]
categories = ["algorithms", "science"]
rust-version = "1.83"

[dependencies]
argh = "0.1.12"
Expand Down
31 changes: 26 additions & 5 deletions benches/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from faiss import Kmeans
from sklearn.cluster import MiniBatchKMeans
import numpy as np
from gathers import Gathers


def build_arg_parser():
Expand All @@ -15,7 +16,11 @@ def build_arg_parser():
parser.add_argument("--input", "-i", type=str, required=True)
parser.add_argument("--output", "-o", type=str, required=True)
parser.add_argument(
"--library", "-l", type=str, default="faiss", choices=["faiss", "sklearn"]
"--library",
"-l",
type=str,
default="faiss",
choices=["faiss", "sklearn", "gathers"],
)
return parser

Expand Down Expand Up @@ -73,10 +78,26 @@ def sklearn_cluster(args):
write_vec(args.output, kmeans.cluster_centers_)


def gathers_cluster(args):
vecs = read_vec(args.input)
gathers = Gathers(verbose=args.verbose)
t_start = perf_counter()
centroids = gathers.fit(
vecs=vecs, n_cluster=args.n_clusters, max_iter=args.max_iter
)
print(f"gathers k-means training time: {perf_counter() - t_start:.6f}s")
write_vec(args.output, centroids)


if __name__ == "__main__":
args = build_arg_parser().parse_args()
print(args)
if args.library == "faiss":
faiss_cluster(args)
else:
sklearn_cluster(args)
match args.library:
case "faiss":
faiss_cluster(args)
case "sklearn":
sklearn_cluster(args)
case "gathers":
gathers_cluster(args)
case _:
raise ValueError(f"Invalid library name: {args.library}")

0 comments on commit 7dddee2

Please sign in to comment.