forked from infiniflow/infinity
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase_client.py
75 lines (67 loc) · 2.42 KB
/
base_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
from abc import abstractmethod
from typing import Any
import subprocess
import os
import time
import logging
class BaseClient:
"""
Base class for all clients(Qdrant, ES, infinity).
mode is a string that corresponds to a JSON file's address in the configurations.
Each client reads the required parameters from the JSON configuration file.
"""
@abstractmethod
def __init__(self,
mode: str,
options: argparse.Namespace,
drop_old: bool = True) -> None:
"""
The mode configuration file is parsed to extract the needed parameters, which are then all stored for use by other functions.
"""
pass
@abstractmethod
def upload(self):
"""
Upload data and build indexes (parameters are parsed by __init__).
"""
pass
@abstractmethod
def search(self) -> list[list[Any]]:
"""
Execute the corresponding query tasks (vector search, full-text search, hybrid search) based on the parsed parameters.
The function returns id list.
"""
pass
def download_data(self, url, target_path):
"""
Download dataset and extract it into path.
"""
_, ext = os.path.splitext(url)
if ext == '.bz2':
bz2_path = target_path + '.bz2'
subprocess.run(['wget', '-O', bz2_path, url], check=True)
subprocess.run(['bunzip2', bz2_path], check=True)
extracted_path = os.path.splitext(bz2_path)[0]
os.rename(extracted_path, target_path)
else:
subprocess.run(['wget', '-O', target_path, url], check=True)
@abstractmethod
def check_and_save_results(self, results: list[list[Any]]):
"""
The correct results for queries are read from the mode configuration file to compare with the search results and calculate recall.
Record the results (metrics to be measured) and save them in the results folder.
"""
pass
def run_experiment(self, args):
"""
run experiment and save results.
"""
if args.import_data:
start_time = time.time()
self.upload()
finish_time = time.time()
logging.info(f"upload finish, cost time = {finish_time - start_time}")
if args.query:
results = self.search()
self.check_and_save_results(results)