Skip to content

Commit

Permalink
Update doc for tensor search (#1371)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Update doc for tensor search

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Documentation Update
- [x] Test cases
  • Loading branch information
yangzq50 authored Jun 21, 2024
1 parent 5b1e692 commit 343b4c5
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 9 deletions.
38 changes: 38 additions & 0 deletions docs/references/pysdk_api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,43 @@ Create a full-text search expression.
table_obj.match('body', 'harmful', 'topn=2')
```

## match tensor

**RemoteTable.match_tensor(*vector_column_name, tensor_data, tensor_data_type, method_type, topn, extra_option)**

Build a KNN tensor search expression. Find the top n closet records to the given tensor according to chosen method.

For example, find k most match tensors generated by ColBERT.

### Parameters

- **vector_column_name : str**
- **tensor_data : list/np.ndarray**
- **tensor_data_type : str**
- **method_type : str**
- `'maxsim'`

- **extra_option : str** options seperated by ';'
- `'topn'`
- **EMVB index options**
- `'emvb_centroid_nprobe'`
- `'emvb_threshold_first'`
- `'emvb_n_doc_to_score'`
- `'emvb_n_doc_out_second_stage'`
- `'emvb_threshold_final'`

### Returns

- Success: Self `RemoteTable`
- Failure: `Exception`

### Examples

```python
match_tensor('t', [[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], 'float', 'maxsim', 'topn=2')
match_tensor('t', [[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], 'float', 'maxsim', 'topn=10;emvb_centroid_nprobe=4;emvb_threshold_first=0.4;emvb_threshold_final=0.5')
```

## fusion

**RemoteTable.fusion(*method, options_text = ''*)**
Expand Down Expand Up @@ -743,6 +780,7 @@ Build a fusion expression.
table_obj.fusion('rrf')
table_obj.fusion('rrf', 'topn=10')
table_obj.fusion('weighted_sum', 'weights=1,2,0.5')
table_obj.fusion('match_tensor', 'topn=2', make_match_tensor_expr('t', [[0.0, -10.0, 0.0, 0.7], [9.2, 45.6, -55.8, 3.5]], 'float', 'maxsim'))
```

### Details
Expand Down
10 changes: 6 additions & 4 deletions src/storage/knn_index/emvb/emvb_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,16 @@ std::tuple<u32, std::unique_ptr<f32[]>, std::unique_ptr<u32[]>> EMVBSearch<FIXED
const u32 k,
const f32 thresh_query) const {
assert(n_centroids_ % 8 == 0);
assert(nprobe > 0);
auto query_token_centroids_scores = Get256AlignedF32Array(FIXED_QUERY_TOKEN_NUM * n_centroids_);
matrixA_multiply_transpose_matrixB_output_to_C(query_ptr,
centroids_data_,
FIXED_QUERY_TOKEN_NUM,
n_centroids_,
embedding_dimension_,
query_token_centroids_scores.get());
auto [candidate_docs, centroid_q_token_sim] = find_candidate_docs(query_token_centroids_scores.get(), nprobe, thresh);
const u32 real_nprobe = std::min(n_centroids_, nprobe);
assert(real_nprobe > 0);
auto [candidate_docs, centroid_q_token_sim] = find_candidate_docs(query_token_centroids_scores.get(), real_nprobe, thresh);
auto selected_cnt_and_docs = compute_hit_frequency(std::move(candidate_docs), n_doc_to_score, std::move(centroid_q_token_sim));
auto selected_docs_centroid_scores =
second_stage_filtering(std::move(selected_cnt_and_docs), out_second_stage, std::move(query_token_centroids_scores));
Expand All @@ -357,15 +358,16 @@ Tuple<u32, UniquePtr<f32[]>, UniquePtr<u32[]>> EMVBSearch<FIXED_QUERY_TOKEN_NUM>
const BlockIndex *block_index,
const TxnTimeStamp begin_ts) const {
assert(n_centroids_ % 8 == 0);
assert(nprobe > 0);
auto query_token_centroids_scores = Get256AlignedF32Array(FIXED_QUERY_TOKEN_NUM * n_centroids_);
matrixA_multiply_transpose_matrixB_output_to_C(query_ptr,
centroids_data_,
FIXED_QUERY_TOKEN_NUM,
n_centroids_,
embedding_dimension_,
query_token_centroids_scores.get());
auto [candidate_docs, centroid_q_token_sim] = find_candidate_docs(query_token_centroids_scores.get(), nprobe, thresh);
const u32 real_nprobe = std::min(n_centroids_, nprobe);
assert(real_nprobe > 0);
auto [candidate_docs, centroid_q_token_sim] = find_candidate_docs(query_token_centroids_scores.get(), real_nprobe, thresh);
std::vector<u32> candidate_docs_filtered;
auto filter_doc = [&candidate_docs_filtered, &candidate_docs, start_segment_offset](auto &&filter) {
candidate_docs_filtered.reserve(candidate_docs.size());
Expand Down
32 changes: 28 additions & 4 deletions tools/generate_emvb_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def generate(generate_if_exists: bool, copy_dir: str):
fix_embedding_num_in_tensor = 32
fix_dim = 128
row_n = 1024
pq_subspace_num = 16
pq_subspace_num = 32
pq_subspace_bits = 8
csv_dir = "./test/data/csv"
slt_dir = "./test/sql/dql/knn/tensor"
Expand Down Expand Up @@ -42,9 +42,33 @@ def generate(generate_if_exists: bool, copy_dir: str):
top_slt_file.write("statement ok\n")
top_slt_file.write("COPY {} FROM '{}' WITH ( DELIMITER ',' );\n".format(table_name, copy_path))
top_slt_file.write("\nstatement ok\n")
top_slt_file.write(
"CREATE INDEX idx1 ON {} (c2) USING EMVB WITH (pq_subspace_num = {}, pq_subspace_bits = {});\n".format(
table_name, pq_subspace_num, pq_subspace_bits))
top_slt_file.write("CREATE INDEX idx1 ON {} (c2) USING EMVB WITH ".format(table_name))
top_slt_file.write("(pq_subspace_num = {}, pq_subspace_bits = {});\n".format(pq_subspace_num, pq_subspace_bits))
query_vec = [0] * fix_dim
query_vec[0] = 1
query_vec[1] = 1
query_vec[2] = 1
query_vec[3] = 1
top_slt_file.write("\n# test index search")
top_slt_file.write("\nstatement ok\n")
top_slt_file.write("SELECT c1 FROM {} SEARCH MATCH TENSOR".format(table_name))
top_slt_file.write(" (c2, {}, 'float', 'maxsim', 'topn=10');\n".format(query_vec))
top_slt_file.write("\nstatement ok\n")
top_slt_file.write("SELECT c1 FROM {} SEARCH MATCH TENSOR (c2, {}".format(table_name, query_vec))
top_slt_file.write(", 'float', 'maxsim', 'topn=10;emvb_threshold_first=0.4;emvb_threshold_final=0.5');\n")
top_slt_file.write("\n# test small mem index of exhaustive scan")
insert_vec = [0] * fix_dim
insert_vec[0] = 2 * row_n
insert_vec[1] = 2 * row_n
insert_vec[2] = 2 * row_n
insert_vec[3] = 2 * row_n
top_slt_file.write("\nstatement ok\n")
top_slt_file.write("INSERT INTO {} VALUES ({}, {});\n".format(table_name, row_n, insert_vec))
top_slt_file.write("\nquery I\n")
top_slt_file.write("SELECT c1 FROM {} SEARCH MATCH TENSOR".format(table_name))
top_slt_file.write(" (c2, {}, 'float', 'maxsim', 'topn=1');\n".format(query_vec))
top_slt_file.write("----\n")
top_slt_file.write("{}\n".format(row_n))
top_slt_file.write("\nstatement ok\n")
top_slt_file.write("DROP TABLE {};\n".format(table_name))

Expand Down
2 changes: 1 addition & 1 deletion tools/sqllogictest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def copy_all(data_dir, copy_dir):
generate15(args.generate_if_exists, args.copy)
generate16(args.generate_if_exists, args.copy)
generate17(args.generate_if_exists, args.copy)
#generate18(args.generate_if_exists, args.copy)
generate18(args.generate_if_exists, args.copy)
#generate19(args.generate_if_exists, args.copy)
print("Generate file finshed.")

Expand Down

0 comments on commit 343b4c5

Please sign in to comment.