Skip to content

Commit 311b609

Browse files
authored
Let thrift client reconnect on insert failure (#1156)
Let thrift client reconnect on insert failure Updated requirements.txt Issue link:#1125 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
1 parent 848b3a8 commit 311b609

10 files changed

+82
-55
lines changed

python/benchmark/clients/base_client.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import argparse
22
from abc import abstractmethod
3-
from typing import Any, List, Optional, Dict, Union
4-
from enum import Enum
3+
from typing import Any
54
import subprocess
6-
import sys
75
import os
8-
from urllib.parse import urlparse
96
import time
7+
import logging
8+
109

1110
class BaseClient:
1211
"""
@@ -25,14 +24,21 @@ def __init__(self,
2524
"""
2625
pass
2726

27+
@abstractmethod
28+
def upload(self):
29+
"""
30+
Upload data and build indexes (parameters are parsed by __init__).
31+
"""
32+
pass
33+
2834
@abstractmethod
2935
def search(self) -> list[list[Any]]:
3036
"""
3137
Execute the corresponding query tasks (vector search, full-text search, hybrid search) based on the parsed parameters.
3238
The function returns id list.
3339
"""
3440
pass
35-
41+
3642
def download_data(self, url, target_path):
3743
"""
3844
Download dataset and extract it into path.
@@ -59,6 +65,11 @@ def run_experiment(self, args):
5965
"""
6066
run experiment and save results.
6167
"""
68+
if args.import_data:
69+
start_time = time.time()
70+
self.upload()
71+
finish_time = time.time()
72+
logging.info(f"upload finish, cost time = {finish_time - start_time}")
6273
if args.query:
6374
results = self.search()
64-
self.check_and_save_results(results)
75+
self.check_and_save_results(results)

python/benchmark/clients/elasticsearch_client.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from elasticsearch import Elasticsearch, helpers
44
import json
55
import time
6-
from typing import List, Optional
6+
from typing import List
77
import os
88
import h5py
99
import uuid
1010
import numpy as np
11-
import csv
11+
import logging
1212

1313
from .base_client import BaseClient
1414

@@ -74,7 +74,7 @@ def upload(self):
7474
for i, line in enumerate(data_file):
7575
row = line.strip().split('\t')
7676
if len(row) != len(headers):
77-
print(f"row = {i}, row_len = {len(row)}, not equal headers len, skip")
77+
logging.info(f"row = {i}, row_len = {len(row)}, not equal headers len, skip")
7878
continue
7979
row_dict = {header: value for header, value in zip(headers, row)}
8080
current_batch.append({"_index": self.collection_name, "_id": uuid.UUID(int=i).hex, "_source": row_dict})
@@ -133,7 +133,7 @@ def search(self) -> list[list[Any]]:
133133
The function returns id list.
134134
"""
135135
query_path = os.path.join(self.path_prefix, self.data["query_path"])
136-
print(query_path)
136+
logging.info(query_path)
137137
results = []
138138
_, ext = os.path.splitext(query_path)
139139
if ext == '.json' or ext == '.jsonl':
@@ -184,7 +184,7 @@ def search(self) -> list[list[Any]]:
184184
latency = (end - start) * 1000
185185
result = [(uuid.UUID(hex=hit['_id']).int, hit['_score']) for hit in result['hits']['hits']]
186186
result.append(latency)
187-
print(f"{line[:-1]}, {latency}")
187+
logging.info(f"{line[:-1]}, {latency}")
188188
results.append(result)
189189
else:
190190
raise TypeError("Unsupported file type")
@@ -214,7 +214,7 @@ def check_and_save_results(self, results: List[List[Any]]):
214214
precisions.append(precision)
215215
latencies.append(result[-1])
216216

217-
print(
217+
logging.info(
218218
f'''mean_time: {np.mean(latencies)}, mean_precisions: {np.mean(precisions)},
219219
std_time: {np.std(latencies)}, min_time: {np.min(latencies)}, \n
220220
max_time: {np.max(latencies)}, p95_time: {np.percentile(latencies, 95)},
@@ -223,7 +223,7 @@ def check_and_save_results(self, results: List[List[Any]]):
223223
latencies = []
224224
for result in results:
225225
latencies.append(result[-1])
226-
print(
226+
logging.info(
227227
f'''mean_time: {np.mean(latencies)}, std_time: {np.std(latencies)},
228228
max_time: {np.max(latencies)}, min_time: {np.min(latencies)},
229229
p95_time: {np.percentile(latencies, 95)}, p99_time: {np.percentile(latencies, 99)}''')

python/benchmark/clients/infinity_client.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import time
66
import numpy as np
77
from typing import Any, List
8+
import logging
89

910
import infinity
1011
import infinity.index as index
1112
from infinity import NetworkAddress
1213
from .base_client import BaseClient
13-
import infinity.remote_thrift.infinity_thrift_rpc.ttypes as ttypes
14-
import csv
1514

1615
class InfinityClient(BaseClient):
1716
def __init__(self,
@@ -93,9 +92,9 @@ def upload(self):
9392
for i, line in enumerate(data_file):
9493
row = line.strip().split('\t')
9594
if (i % 100000 == 0):
96-
print(f"row {i}")
95+
logging.info(f"row {i}")
9796
if len(row) != len(headers):
98-
print(f"row = {i}, row_len = {len(row)}, not equal headers len, skip")
97+
logging.info(f"row = {i}, row_len = {len(row)}, not equal headers len, skip")
9998
continue
10099
row_dict = {header: value for header, value in zip(headers, row)}
101100
current_batch.append(row_dict)
@@ -166,7 +165,7 @@ def search(self) -> list[list[Any]]:
166165
latency = (time.time() - start) * 1000
167166
result = [(row_id[0], score) for row_id, score in zip(res['ROW_ID'], res['SCORE'])]
168167
result.append(latency)
169-
print(f"{query}, {latency}")
168+
logging.info(f"{query}, {latency}")
170169
results.append(result)
171170
else:
172171
raise TypeError("Unsupported file type")
@@ -197,7 +196,7 @@ def check_and_save_results(self, results: List[List[Any]]):
197196
precisions.append(precision)
198197
latencies.append(result[-1])
199198

200-
print(
199+
logging.info(
201200
f'''mean_time: {np.mean(latencies)}, mean_precisions: {np.mean(precisions)},
202201
std_time: {np.std(latencies)}, min_time: {np.min(latencies)},
203202
max_time: {np.max(latencies)}, p95_time: {np.percentile(latencies, 95)},
@@ -206,7 +205,7 @@ def check_and_save_results(self, results: List[List[Any]]):
206205
latencies = []
207206
for result in results:
208207
latencies.append(result[-1])
209-
print(
208+
logging.info(
210209
f'''mean_time: {np.mean(latencies)}, std_time: {np.std(latencies)},
211210
max_time: {np.max(latencies)}, min_time: {np.min(latencies)},
212211
p95_time: {np.percentile(latencies, 95)}, p99_time: {np.percentile(latencies, 99)}''')

python/benchmark/clients/qdrant_client.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import time
77
import json
88
import h5py
9-
from typing import Any, List, Optional
9+
from typing import Any
10+
import logging
1011

1112
from .base_client import BaseClient
1213

@@ -137,7 +138,7 @@ def search(self) -> list[list[Any]]:
137138
with_payload=False
138139
)
139140
end = time.time()
140-
print(f"latency of search: {(end - start)*1000:.2f} milliseconds")
141+
logging.info(f"latency of search: {(end - start)*1000:.2f} milliseconds")
141142
results.append(result)
142143
elif ext == '.hdf5' and self.data['mode'] == 'vector':
143144
with h5py.File(query_path, 'r') as f:
@@ -150,7 +151,7 @@ def search(self) -> list[list[Any]]:
150151
)
151152
results.append(result)
152153
end = time.time()
153-
print(f"latency of KNN search: {(end - start)*1000/len(f['test']):.2f} milliseconds")
154+
logging.info(f"latency of KNN search: {(end - start)*1000/len(f['test']):.2f} milliseconds")
154155
else:
155156
raise TypeError("Unsupported file type")
156157
return results

python/benchmark/configs/infinity_enwiki.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"query_link": "to_be_set",
1010
"mode": "fulltext",
1111
"topK": 10,
12-
"use_import": true,
12+
"use_import": false,
1313
"schema": {
1414
"doctitle": {"type": "varchar", "default":""},
1515
"docdate": {"type": "varchar", "default":""},

python/benchmark/requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
elasticsearch~=8.13.0
2+
h5py~=3.11.0
3+
qdrant_client~=1.9.0
4+

python/benchmark/run.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import argparse
22
import os
3+
import logging
34

45
from clients.elasticsearch_client import ElasticsearchClient
56
from clients.infinity_client import InfinityClient
67
from clients.qdrant_client import QdrantClient
7-
from generate_query_json import generate_query_json
88
from generate_query_json import generate_query_txt
99

1010
ENGINES = ['infinity', 'qdrant', 'elasticsearch']
@@ -53,6 +53,7 @@ def get_client(engine: str, config: str, options: argparse.Namespace):
5353
raise ValueError(f"Unknown engine: {engine}")
5454

5555
if __name__ == '__main__':
56+
logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(levelname)-8s %(message)s')
5657
args = parse_args()
5758
config_paths = generate_config_paths(args)
5859

@@ -65,9 +66,9 @@ def get_client(engine: str, config: str, options: argparse.Namespace):
6566

6667
for path, engine in config_paths:
6768
if not os.path.exists(path):
68-
print(f"qdrant does not support full text search")
69+
logging.info("qdrant does not support full text search")
6970
continue
70-
print("Running", engine, "with", path)
71+
logging.info("Running {} with {}".format(engine, path))
7172
client = get_client(engine, path, args)
7273
client.run_experiment(args)
73-
print("Finished", engine, "with", path)
74+
logging.info("Finished {} with {}".format(engine, path))

python/infinity/remote_thrift/client.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,13 @@
2727
class ThriftInfinityClient:
2828
def __init__(self, uri: URI):
2929
self.uri = uri
30+
self.transport = None
3031
self.reconnect()
3132

3233
def reconnect(self):
34+
if self.transport is not None:
35+
self.transport.close()
36+
self.transport = None
3337
# self.transport = TTransport.TFramedTransport(TSocket.TSocket(self.uri.ip, self.uri.port)) # async
3438
self.transport = TTransport.TBufferedTransport(
3539
TSocket.TSocket(self.uri.ip, self.uri.port)) # sync
@@ -126,7 +130,8 @@ def list_indexes(self, db_name: str, table_name: str):
126130

127131
def insert(self, db_name: str, table_name: str, column_names: list[str], fields: list[Field]):
128132
retry = 0
129-
while retry <= 10:
133+
inner_ex = None
134+
while retry <= 2:
130135
try:
131136
res = self.client.Insert(InsertRequest(session_id=self.session_id,
132137
db_name=db_name,
@@ -135,12 +140,14 @@ def insert(self, db_name: str, table_name: str, column_names: list[str], fields:
135140
fields=fields))
136141
return res
137142
except TTransportException as ex:
138-
if ex.type == ex.END_OF_FILE:
139-
self.reconnect()
140-
retry += 1
141-
else:
142-
break
143-
return CommonResponse(ErrorCode.TOO_MANY_CONNECTIONS, "retry insert failed")
143+
#import traceback
144+
#traceback.print_exc()
145+
self.reconnect()
146+
inner_ex = ex
147+
retry += 1
148+
except Exception as ex:
149+
inner_ex = ex
150+
return CommonResponse(ErrorCode.TOO_MANY_CONNECTIONS, "insert failed with exception: " + str(inner_ex))
144151

145152
# Can be used in compact mode
146153
# def insert(self, db_name: str, table_name: str, column_names: list[str], fields: list[Field]):
@@ -198,7 +205,11 @@ def update(self, db_name: str, table_name: str, where_expr, update_expr_array):
198205
update_expr_array=update_expr_array))
199206

200207
def disconnect(self):
201-
res = self.client.Disconnect(CommonRequest(session_id=self.session_id))
208+
res = None
209+
try:
210+
res = self.client.Disconnect(CommonRequest(session_id=self.session_id))
211+
except Exception:
212+
pass
202213
self.transport.close()
203214
return res
204215

python/pyproject.toml

+10-10
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
name = "infinity_sdk"
33
version = "0.2.0.dev1"
44
dependencies = [
5-
"sqlglot==11.7.1",
6-
"pydantic",
7-
"thrift",
8-
"setuptools",
9-
"pytest",
10-
"pandas",
11-
"numpy",
12-
"pyarrow",
13-
"openpyxl",
14-
"polars"
5+
"sqlglot~=11.7.1",
6+
"pydantic~=2.7.1",
7+
"thrift~=0.20.0",
8+
"setuptools~=69.5.1",
9+
"pytest~=8.2.0",
10+
"pandas~=2.2.2",
11+
"numpy~=1.26.4",
12+
"pyarrow~=16.0.0",
13+
"polars~=0.20.23",
14+
"openpyxl~=3.1.2"
1515
]
1616
description = "infinity"
1717
readme = "README.md"

python/requirements.txt

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
sqlglot==11.7.1
2-
pydantic~=1.10.12
1+
sqlglot~=11.7.1
2+
pydantic~=2.7.1
33
thrift~=0.20.0
4-
setuptools~=68.0.0
5-
pytest~=7.4.0
6-
pandas~=2.1.1
7-
openpyxl
8-
numpy~=1.26.0
9-
polars~=0.19.0
10-
pyarrow
4+
setuptools~=69.5.1
5+
pytest~=8.2.0
6+
pandas~=2.2.2
7+
numpy~=1.26.4
8+
pyarrow~=16.0.0
9+
polars~=0.20.23
10+
openpyxl~=3.1.2

0 commit comments

Comments
 (0)