From 9bf2e600962add1462fef7472c2b6269d6898686 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Wed, 3 Jan 2024 21:29:07 -0800 Subject: [PATCH] Enabled Filtering on Nested Vector fields with top level filters Signed-off-by: Navneet Verma --- CHANGELOG.md | 3 +- .../knn/index/query/KNNQueryFactory.java | 20 +- .../index/AdvancedFilteringUseCasesIT.java | 550 ++++++++++++++++++ .../opensearch/knn/index/NestedSearchIT.java | 62 +- .../knn/index/query/KNNQueryFactoryTests.java | 60 ++ .../org/opensearch/knn/KNNRestTestCase.java | 26 + .../opensearch/knn/NestedKnnDocBuilder.java | 90 +++ 7 files changed, 754 insertions(+), 57 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java create mode 100644 src/testFixtures/java/org/opensearch/knn/NestedKnnDocBuilder.java diff --git a/CHANGELOG.md b/CHANGELOG.md index b2ab6b446..aafc2e585 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,12 +12,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Maintenance ### Refactoring -## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.11...2.x) +## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.12...2.x) ### Features * Add parent join support for lucene knn [#1182](https://github.com/opensearch-project/k-NN/pull/1182) ### Enhancements * Increase Lucene max dimension limit to 16,000 [#1346](https://github.com/opensearch-project/k-NN/pull/1346) * Tuned default values for ef_search and ef_construction for better indexing and search performance for vector search [#1353](https://github.com/opensearch-project/k-NN/pull/1353) +* Enabled Filtering on Nested Vector fields with top level filters [#1372](https://github.com/opensearch-project/k-NN/pull/1372) ### Bug Fixes * Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305) * Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index c073450af..7772ab582 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -17,8 +17,10 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.search.NestedHelper; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; @@ -155,11 +157,27 @@ private static Query getFilterQuery(CreateQueryRequest createQueryRequest) { createQueryRequest.k ) ); + final Query filterQuery; try { - return createQueryRequest.getFilter().get().toQuery(queryShardContext); + filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); } catch (IOException e) { throw new RuntimeException("Cannot create knn query with filter", e); } + // If k-NN Field is nested field then parentFilter will not be null. This parentFilter is set by the + // Opensearch core. Ref PR: https://github.com/opensearch-project/OpenSearch/pull/10246 + if (queryShardContext.getParentFilter() != null) { + // if the filter is also a nested query clause then we should just return the same query without + // considering it to join with the parent documents. + if (new NestedHelper(queryShardContext.getMapperService()).mightMatchNestedDocs(filterQuery)) { + return filterQuery; + } + // This condition will be hit when filters are getting applied on the top level fields and k-nn + // query field is a nested field. In this case we need to wrap the filter query with + // ToChildBlockJoinQuery to ensure parent documents which will be retrieved from filters can be + // joined with the child documents containing vector field. + return new ToChildBlockJoinQuery(filterQuery, queryShardContext.getParentFilter()); + } + return filterQuery; } return null; } diff --git a/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java b/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java new file mode 100644 index 000000000..17625fcda --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java @@ -0,0 +1,550 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index; + +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.Assert; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.NestedKnnDocBuilder; +import org.opensearch.knn.index.util.KNNEngine; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.K; +import static org.opensearch.knn.common.KNNConstants.KNN; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PATH; +import static org.opensearch.knn.common.KNNConstants.QUERY; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; +import static org.opensearch.knn.common.KNNConstants.TYPE_NESTED; +import static org.opensearch.knn.common.KNNConstants.VECTOR; + +/** + * This class contains the IT for some advanced and tricky use-case of filters. + * Github issue + */ +public class AdvancedFilteringUseCasesIT extends KNNRestTestCase { + + private static final String INDEX_NAME = "advanced_filtering_test_index"; + + private static final String FIELD_NAME_NESTED = "test_nested"; + + private static final String FIELD_NAME_VECTOR = "test_vector"; + + private static final String PROPERTIES_FIELD = "properties"; + + private static final String FILTER_FIELD = "filter"; + + private static final String TERM_FIELD = "term"; + + private static final int k = 20; + + private static final String FIELD_NAME_METADATA = "parking"; + + private static final int NUM_DOCS = 50; + + private static final int DOCUMENT_IN_RESPONSE = 10; + + private static final Float[] QUERY_VECTOR = { 5f }; + + private static final List enginesToTest = KNNEngine.getEnginesThatSupportsFilters() + .stream() + .map(KNNEngine::getName) + .collect(Collectors.toList()); + + /** + * { + * "query": { + * "nested": { + * "path": "test_nested", + * "query": { + * "knn": { + * "test_nested.test_vector": { + * "vector": [ + * 3 + * ], + * "k": 20, + * "filter": { + * "nested": { + * "path": "test_nested", + * "query": { + * "term": { + * "test_nested.parking": "false" + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testFiltering_whenNestedKNNAndFilterFieldWithNestedQueries_thenSuccess() { + for (final String engine : enginesToTest) { + // Set up the index with nested k-nn and metadata fields + createKnnIndex(INDEX_NAME, createNestedMappings(1, engine)); + for (int i = 1; i <= NUM_DOCS; i++) { + // making sure that only 2 documents have valid filters + final String metadataFieldValue = i % 2 == 0 ? "false" : "true"; + String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .addVectors(FIELD_NAME_VECTOR, new Float[] { (float) i + 1 }) + .addVectorWithMetadata(FIELD_NAME_VECTOR, new Float[] { (float) i }, FIELD_NAME_METADATA, metadataFieldValue) + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + // Build the query with both k-nn and filters as nested fields. The filter should also have a nested context + final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); + builder.startObject(TYPE_NESTED); + builder.field(PATH, FIELD_NAME_NESTED); + builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR); + builder.field(VECTOR, QUERY_VECTOR); + builder.field(K, k); + builder.startObject(FILTER_FIELD); + builder.startObject(TYPE_NESTED); + builder.field(PATH, FIELD_NAME_NESTED); + builder.startObject(QUERY); + builder.startObject(TERM_FIELD); + builder.field(FIELD_NAME_NESTED + "." + FIELD_NAME_METADATA, "false"); + builder.endObject(); + builder.endObject(); + builder.endObject(); + builder.endObject(); + builder.endObject().endObject().endObject().endObject().endObject().endObject(); + + validateFilterSearch(builder.toString(), engine); + // cleanup + deleteKNNIndex(INDEX_NAME); + } + } + + /** + * { + * "query": { + * "nested": { + * "path": "test_nested", + * "query": { + * "knn": { + * "test_nested.test_vector": { + * "vector": [ + * 3 + * ], + * "k": 20, + * "filter": { + * "term": { + * "test_nested.parking": "false" + * } + * } + * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testFiltering_whenNestedKNNAndFilterFieldWithNoNestedContextInFilterQuery_thenSuccess() { + for (final String engine : enginesToTest) { + // Set up the index with nested k-nn and metadata fields + createKnnIndex(INDEX_NAME, createNestedMappings(1, engine)); + for (int i = 1; i <= NUM_DOCS; i++) { + // making sure that only 2 documents have valid filters + final String metadataFieldValue = i % 2 == 0 ? "false" : "true"; + String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .addVectors(FIELD_NAME_VECTOR, new Float[] { (float) i + 1 }) + .addVectorWithMetadata(FIELD_NAME_VECTOR, new Float[] { (float) i }, FIELD_NAME_METADATA, metadataFieldValue) + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + // Build the query with both k-nn and filters as nested fields but a single nested context + final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); + builder.startObject(TYPE_NESTED); + builder.field(PATH, FIELD_NAME_NESTED); + builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR); + builder.field(VECTOR, QUERY_VECTOR); + builder.field(K, k); + builder.startObject(FILTER_FIELD); + builder.startObject(TERM_FIELD); + builder.field(FIELD_NAME_NESTED + "." + FIELD_NAME_METADATA, "false"); + builder.endObject(); + builder.endObject(); + builder.endObject().endObject().endObject().endObject().endObject().endObject(); + + validateFilterSearch(builder.toString(), engine); + + // cleanup + deleteKNNIndex(INDEX_NAME); + } + } + + /** + * { + * "query": { + * "nested": { + * "path": "test_nested", + * "query": { + * "knn": { + * "test_nested.test_vector": { + * "vector": [ + * 3 + * ], + * "k": 20, + * "filter": { + * "term": { + * "parking": "false" + * } + * } + * } + * } + * } + * } + * } + * } + * + */ + @SneakyThrows + public void testFiltering_whenNestedKNNAndNonNestedFilterFieldWithNonNestedFilterQuery_thenSuccess() { + for (final String engine : enginesToTest) { + // Set up the index with nested k-nn and metadata fields + createKnnIndex(INDEX_NAME, createNestedMappings(1, engine)); + for (int i = 1; i <= NUM_DOCS; i++) { + final String metadataFieldValue = i % 2 == 0 ? "false" : "true"; + String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .addVectors(FIELD_NAME_VECTOR, new Float[] { (float) i + 1 }, new Float[] { (float) i }) + .addTopLevelField(FIELD_NAME_METADATA, metadataFieldValue) + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + // Build the query with k-nn field as nested query and filter on the top level fields + final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); + builder.startObject(TYPE_NESTED); + builder.field(PATH, FIELD_NAME_NESTED); + builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR); + builder.field(VECTOR, QUERY_VECTOR); + builder.field(K, k); + builder.startObject(FILTER_FIELD); + builder.startObject(TERM_FIELD); + builder.field(FIELD_NAME_METADATA, "false"); + builder.endObject(); + builder.endObject(); + builder.endObject().endObject().endObject().endObject().endObject().endObject(); + + validateFilterSearch(builder.toString(), engine); + + // cleanup + deleteKNNIndex(INDEX_NAME); + } + } + + /** + * { + * "query": { + * "knn": { + * "test_vector": { + * "vector": [ + * 3 + * ], + * "k": 20, + * "filter": { + * "bool": { + * "should": [ + * { + * "nested": { + * "path": "test_nested", + * "query": { + * "term": { + * "test_nested.parking": "false" + * } + * } + * } + * } + * ] + * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testFiltering_whenNonNestedKNNAndNestedFilterFieldWithNestedFilterQuery_thenSuccess() { + for (final String engine : enginesToTest) { + // Set up the index with nested k-nn and metadata fields + createKnnIndex(INDEX_NAME, createVectorNonNestedMappings(1, engine)); + for (int i = 1; i <= NUM_DOCS; i++) { + final String metadataFieldValue = i % 2 == 0 ? "false" : "true"; + String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .addMetadata(ImmutableMap.of(FIELD_NAME_METADATA, metadataFieldValue)) + .addTopLevelField(FIELD_NAME_VECTOR, new Float[] { (float) i }) + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + // Build the query when filters are nested with nested path and k-NN field is non nested. + final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); + builder.startObject(KNN).startObject(FIELD_NAME_VECTOR); + builder.field(VECTOR, QUERY_VECTOR); + builder.field(K, k); + builder.startObject(FILTER_FIELD); + builder.startObject("bool"); + builder.startArray("should"); + builder.startObject(); + + builder.startObject(TYPE_NESTED); + builder.field(PATH, FIELD_NAME_NESTED); + + builder.startObject(QUERY); + builder.startObject(TERM_FIELD); + builder.field(FIELD_NAME_NESTED + "." + FIELD_NAME_METADATA, "false"); + builder.endObject(); + builder.endObject(); + + builder.endObject(); + + builder.endObject(); + builder.endArray(); + builder.endObject(); + builder.endObject(); + builder.endObject(); + builder.endObject(); + builder.endObject(); + builder.endObject(); + + validateFilterSearch(builder.toString(), engine); + // cleanup + deleteKNNIndex(INDEX_NAME); + } + } + + /** + * { + * "query": { + * "knn": { + * "test_vector": { + * "vector": [ + * 5 + * ], + * "k": 20, + * "filter": { + * "bool": { + * "must": [ + * { + * "nested": { + * "path": "test_nested", + * "query": { + * "term": { + * "test_nested.parking": "false" + * } + * } + * } + * }, + * { + * "term": { + * "parking": "false" + * } + * } + * ] + * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testFiltering_whenNonNestedKNNAndNestedFilterAndNonNestedFieldWithNestedAndNonNestedFilterQuery_thenSuccess() { + for (final String engine : enginesToTest) { + // Set up the index with nested k-nn and metadata fields + createKnnIndex(INDEX_NAME, createVectorNonNestedMappings(1, engine)); + for (int i = 1; i <= NUM_DOCS; i++) { + final String metadataFieldValue = i % 2 == 0 ? "false" : "true"; + String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .addMetadata(ImmutableMap.of(FIELD_NAME_METADATA, metadataFieldValue)) + .addTopLevelField(FIELD_NAME_VECTOR, new Float[] { (float) i }) + .addTopLevelField(FIELD_NAME_METADATA, metadataFieldValue) + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + // Build the query when filters are nested with nested path and k-NN field is non nested. + final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); + builder.startObject(KNN).startObject(FIELD_NAME_VECTOR); + builder.field(VECTOR, QUERY_VECTOR); + builder.field(K, k); + builder.startObject(FILTER_FIELD); + builder.startObject("bool"); + builder.startArray("must"); + builder.startObject(); + builder.startObject(TERM_FIELD); + builder.field(FIELD_NAME_METADATA, "false"); + builder.endObject(); + builder.endObject(); + + builder.startObject(); + + builder.startObject(TYPE_NESTED); + builder.field(PATH, FIELD_NAME_NESTED); + + builder.startObject(QUERY); + builder.startObject(TERM_FIELD); + builder.field(FIELD_NAME_NESTED + "." + FIELD_NAME_METADATA, "false"); + builder.endObject(); + builder.endObject(); + + builder.endObject(); + + builder.endObject(); + builder.endArray(); + builder.endObject(); + builder.endObject(); + builder.endObject(); + builder.endObject(); + builder.endObject(); + builder.endObject(); + + validateFilterSearch(builder.toString(), engine); + // cleanup + deleteKNNIndex(INDEX_NAME); + } + } + + private void validateFilterSearch(final String query, final String engine) throws IOException, ParseException { + String response = EntityUtils.toString(performSearch(INDEX_NAME, query).getEntity()); + // Validate number of documents returned as the expected number of documents + Assert.assertEquals("For engine " + engine + " : ", DOCUMENT_IN_RESPONSE, parseHits(response)); + if (KNNEngine.getEngine(engine) == KNNEngine.FAISS) { + // Update the filter threshold to 0 to ensure that we are hitting ANN Search use case for FAISS + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 0)); + response = EntityUtils.toString(performSearch(INDEX_NAME, query).getEntity()); + + // Validate number of documents returned as the expected number of documents + Assert.assertEquals("For engine " + engine + " with ANN search :", DOCUMENT_IN_RESPONSE, parseHits(response)); + } + } + + /** + * Sample return + * { + * "properties": { + * "test_nested": { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 1, + * "method": { + * "name": "hnsw", + * "space_type": "l2", + * "engine": "lucene" + * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + private String createNestedMappings(final int dimension, final String engine) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME_NESTED) + .field(TYPE, TYPE_NESTED) + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME_VECTOR) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNN_ENGINE, engine) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + return builder.toString(); + } + + /** + * Sample return + * { + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 1, + * "method": { + * "name": "hnsw", + * "space_type": "l2", + * "engine": "lucene" + * } + * }, + * "test_nested": { + * "type": "nested" + * } + * } + * } + */ + @SneakyThrows + private String createVectorNonNestedMappings(final int dimension, final String engine) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME_NESTED) + .field(TYPE, TYPE_NESTED) + .endObject() + .startObject(FIELD_NAME_VECTOR) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNN_ENGINE, engine) + .endObject() + .endObject() + .endObject() + .endObject(); + + return builder.toString(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java index 4f39c8bba..9cd446195 100644 --- a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java @@ -15,6 +15,7 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.NestedKnnDocBuilder; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; @@ -59,14 +60,16 @@ public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() { createKnnIndex(2, KNNEngine.LUCENE.getName()); String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) - .add(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f }) + .addVectors(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f }) .build(); - addNestedKnnDoc(INDEX_NAME, "1", doc1); + addKnnDoc(INDEX_NAME, "1", doc1); String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) - .add(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f }) + .addVectors(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f }) .build(); - addNestedKnnDoc(INDEX_NAME, "2", doc2); + addKnnDoc(INDEX_NAME, "2", doc2); + + refreshIndex(INDEX_NAME); Float[] queryVector = { 1f, 1f }; Response response = queryNestedField(INDEX_NAME, 2, queryVector); @@ -131,30 +134,6 @@ private void createKnnIndex(final int dimension, final String engine) throws Exc createKnnIndex(INDEX_NAME, mapping); } - @SneakyThrows - private void ingestTestData() { - String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) - .add(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f }) - .build(); - addNestedKnnDoc(INDEX_NAME, "1", doc1); - - String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) - .add(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f }) - .build(); - addNestedKnnDoc(INDEX_NAME, "2", doc2); - } - - private void addNestedKnnDoc(final String index, final String docId, final String document) throws IOException { - Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); - - request.setJsonEntity(document); - client().performRequest(request); - - request = new Request("POST", "/" + index + "/_refresh"); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - } - private Response queryNestedField(final String index, final int k, final Object[] vector) throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); builder.startObject(TYPE_NESTED); @@ -172,31 +151,4 @@ private Response queryNestedField(final String index, final int k, final Object[ return response; } - - private static class NestedKnnDocBuilder { - private XContentBuilder builder; - - public NestedKnnDocBuilder(final String fieldName) throws IOException { - builder = XContentFactory.jsonBuilder().startObject().startArray(fieldName); - } - - public static NestedKnnDocBuilder create(final String fieldName) throws IOException { - return new NestedKnnDocBuilder(fieldName); - } - - public NestedKnnDocBuilder add(final String fieldName, final Object[]... vectors) throws IOException { - for (Object[] vector : vectors) { - builder.startObject(); - builder.field(fieldName, vector); - builder.endObject(); - } - return this; - } - - public String build() throws IOException { - builder.endArray().endObject(); - return builder.toString(); - } - - } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index a6b915a85..62d2db544 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -12,11 +12,15 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; +import org.mockito.MockedConstruction; import org.mockito.Mockito; import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.search.NestedHelper; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; @@ -131,6 +135,62 @@ public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery() validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, DiversifyingChildrenFloatKnnVectorQuery.class); } + public void testCreate_whenNestedVectorFiledAndNonNestedFilterField_thenReturnToChildBlockJoinQueryForFilters() { + MapperService mockMapperService = mock(MapperService.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.getMapperService()).thenReturn(mockMapperService); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY); + BitSetProducer parentFilter = mock(BitSetProducer.class); + when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter); + MockedConstruction mockedNestedHelper = Mockito.mockConstruction( + NestedHelper.class, + (mock, context) -> when(mock.mightMatchNestedDocs(FILTER_QUERY)).thenReturn(false) + ); + + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.FAISS) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .k(testK) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .build(); + KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest); + mockedNestedHelper.close(); + assertEquals(ToChildBlockJoinQuery.class, query.getFilterQuery().getClass()); + } + + public void testCreate_whenNestedVectorAndFilterField_thenReturnSameFilterQuery() { + MapperService mockMapperService = mock(MapperService.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.getMapperService()).thenReturn(mockMapperService); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY); + BitSetProducer parentFilter = mock(BitSetProducer.class); + when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter); + MockedConstruction mockedNestedHelper = Mockito.mockConstruction( + NestedHelper.class, + (mock, context) -> when(mock.mightMatchNestedDocs(FILTER_QUERY)).thenReturn(true) + ); + + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.FAISS) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .k(testK) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .build(); + KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest); + mockedNestedHelper.close(); + assertEquals(FILTER_QUERY.getClass(), query.getFilterQuery().getClass()); + } + private void validateDiversifyingQueryWithParentFilter(final VectorDataType type, final Class expectedQueryClass) { List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 78cd84f05..d0ece655c 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -222,6 +222,15 @@ protected Response searchExists(String index, ExistsQueryBuilder existsQueryBuil return response; } + protected Response performSearch(final String indexName, final String query) throws IOException { + Request request = new Request("POST", "/" + indexName + "/_search"); + request.setJsonEntity(query); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return response; + } + /** * Parse the response of KNN search into a List of KNNResults */ @@ -501,6 +510,15 @@ protected void addKnnDoc(String index, String docId, List fieldNames, Li assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + /** + * Adds a doc where document is represented as a string. + */ + protected void addKnnDoc(final String index, final String docId, final String document) throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId); + request.setJsonEntity(document); + client().performRequest(request); + } + /** * Add a single numeric field Doc to an index */ @@ -695,6 +713,14 @@ protected int parseTotalSearchHits(String searchResponseBody) throws IOException return (int) ((Map) responseMap.get("total")).get("value"); } + protected int parseHits(String searchResponseBody) throws IOException { + Map responseMap = (Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + searchResponseBody + ).map().get("hits"); + return ((List) responseMap.get("hits")).size(); + } + /** * Get the total number of graphs in the cache across all nodes */ diff --git a/src/testFixtures/java/org/opensearch/knn/NestedKnnDocBuilder.java b/src/testFixtures/java/org/opensearch/knn/NestedKnnDocBuilder.java new file mode 100644 index 000000000..dca58838a --- /dev/null +++ b/src/testFixtures/java/org/opensearch/knn/NestedKnnDocBuilder.java @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn; + +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +public class NestedKnnDocBuilder { + private XContentBuilder builder; + private boolean isNestedFieldBuildCompleted; + + public NestedKnnDocBuilder(final String fieldName) throws IOException { + isNestedFieldBuildCompleted = false; + builder = XContentFactory.jsonBuilder().startObject().startArray(fieldName); + } + + public static NestedKnnDocBuilder create(final String fieldName) throws IOException { + return new NestedKnnDocBuilder(fieldName); + } + + public NestedKnnDocBuilder addVectors(final String fieldName, final Object[]... vectors) throws IOException { + for (Object[] vector : vectors) { + builder.startObject(); + builder.field(fieldName, vector); + builder.endObject(); + } + return this; + } + + public NestedKnnDocBuilder addVectorWithMetadata( + final String fieldName, + final Object[] vectorValue, + final String metadataFieldName, + final Object metadataValue + ) throws IOException { + builder.startObject(); + builder.field(fieldName, vectorValue); + builder.field(metadataFieldName, metadataValue); + builder.endObject(); + return this; + } + + public NestedKnnDocBuilder addMetadata(final Map metadata) throws IOException { + builder.startObject(); + metadata.forEach((k, v) -> { + try { + builder.field(k, v); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + builder.endObject(); + return this; + } + + /** + * Use this function when you want to add top level fields in the document that contains nested fields. Once you + * run this function you cannot add anything in the nested field. + */ + public NestedKnnDocBuilder addTopLevelField(final String fieldName, final Object value) throws IOException { + if (isNestedFieldBuildCompleted == false) { + // Making sure that we close the building of nested field. + isNestedFieldBuildCompleted = true; + builder.endArray(); + } + builder.field(fieldName, value); + return this; + } + + public String build() throws IOException { + if (isNestedFieldBuildCompleted) { + builder.endObject(); + } else { + builder.endArray().endObject(); + } + return builder.toString(); + } +}