diff --git a/CHANGELOG.md b/CHANGELOG.md
index b2ab6b446..aafc2e585 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -12,12 +12,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Maintenance
### Refactoring
-## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.11...2.x)
+## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.12...2.x)
### Features
* Add parent join support for lucene knn [#1182](https://github.com/opensearch-project/k-NN/pull/1182)
### Enhancements
* Increase Lucene max dimension limit to 16,000 [#1346](https://github.com/opensearch-project/k-NN/pull/1346)
* Tuned default values for ef_search and ef_construction for better indexing and search performance for vector search [#1353](https://github.com/opensearch-project/k-NN/pull/1353)
+* Enabled Filtering on Nested Vector fields with top level filters [#1372](https://github.com/opensearch-project/k-NN/pull/1372)
### Bug Fixes
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
index c073450af..7772ab582 100644
--- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
+++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
@@ -17,8 +17,10 @@
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
+import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
+import org.opensearch.index.search.NestedHelper;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;
@@ -155,11 +157,27 @@ private static Query getFilterQuery(CreateQueryRequest createQueryRequest) {
createQueryRequest.k
)
);
+ final Query filterQuery;
try {
- return createQueryRequest.getFilter().get().toQuery(queryShardContext);
+ filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext);
} catch (IOException e) {
throw new RuntimeException("Cannot create knn query with filter", e);
}
+ // If k-NN Field is nested field then parentFilter will not be null. This parentFilter is set by the
+ // Opensearch core. Ref PR: https://github.com/opensearch-project/OpenSearch/pull/10246
+ if (queryShardContext.getParentFilter() != null) {
+ // if the filter is also a nested query clause then we should just return the same query without
+ // considering it to join with the parent documents.
+ if (new NestedHelper(queryShardContext.getMapperService()).mightMatchNestedDocs(filterQuery)) {
+ return filterQuery;
+ }
+ // This condition will be hit when filters are getting applied on the top level fields and k-nn
+ // query field is a nested field. In this case we need to wrap the filter query with
+ // ToChildBlockJoinQuery to ensure parent documents which will be retrieved from filters can be
+ // joined with the child documents containing vector field.
+ return new ToChildBlockJoinQuery(filterQuery, queryShardContext.getParentFilter());
+ }
+ return filterQuery;
}
return null;
}
diff --git a/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java b/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java
new file mode 100644
index 000000000..17625fcda
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java
@@ -0,0 +1,550 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ *
+ * Modifications Copyright OpenSearch Contributors. See
+ * GitHub history for details.
+ */
+
+package org.opensearch.knn.index;
+
+import com.google.common.collect.ImmutableMap;
+import lombok.SneakyThrows;
+import org.apache.hc.core5.http.ParseException;
+import org.apache.hc.core5.http.io.entity.EntityUtils;
+import org.junit.Assert;
+import org.opensearch.common.settings.Settings;
+import org.opensearch.common.xcontent.XContentFactory;
+import org.opensearch.core.xcontent.XContentBuilder;
+import org.opensearch.knn.KNNRestTestCase;
+import org.opensearch.knn.NestedKnnDocBuilder;
+import org.opensearch.knn.index.util.KNNEngine;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.opensearch.knn.common.KNNConstants.DIMENSION;
+import static org.opensearch.knn.common.KNNConstants.K;
+import static org.opensearch.knn.common.KNNConstants.KNN;
+import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
+import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
+import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
+import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
+import static org.opensearch.knn.common.KNNConstants.NAME;
+import static org.opensearch.knn.common.KNNConstants.PATH;
+import static org.opensearch.knn.common.KNNConstants.QUERY;
+import static org.opensearch.knn.common.KNNConstants.TYPE;
+import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR;
+import static org.opensearch.knn.common.KNNConstants.TYPE_NESTED;
+import static org.opensearch.knn.common.KNNConstants.VECTOR;
+
+/**
+ * This class contains the IT for some advanced and tricky use-case of filters.
+ * Github issue
+ */
+public class AdvancedFilteringUseCasesIT extends KNNRestTestCase {
+
+ private static final String INDEX_NAME = "advanced_filtering_test_index";
+
+ private static final String FIELD_NAME_NESTED = "test_nested";
+
+ private static final String FIELD_NAME_VECTOR = "test_vector";
+
+ private static final String PROPERTIES_FIELD = "properties";
+
+ private static final String FILTER_FIELD = "filter";
+
+ private static final String TERM_FIELD = "term";
+
+ private static final int k = 20;
+
+ private static final String FIELD_NAME_METADATA = "parking";
+
+ private static final int NUM_DOCS = 50;
+
+ private static final int DOCUMENT_IN_RESPONSE = 10;
+
+ private static final Float[] QUERY_VECTOR = { 5f };
+
+ private static final List enginesToTest = KNNEngine.getEnginesThatSupportsFilters()
+ .stream()
+ .map(KNNEngine::getName)
+ .collect(Collectors.toList());
+
+ /**
+ * {
+ * "query": {
+ * "nested": {
+ * "path": "test_nested",
+ * "query": {
+ * "knn": {
+ * "test_nested.test_vector": {
+ * "vector": [
+ * 3
+ * ],
+ * "k": 20,
+ * "filter": {
+ * "nested": {
+ * "path": "test_nested",
+ * "query": {
+ * "term": {
+ * "test_nested.parking": "false"
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ */
+ @SneakyThrows
+ public void testFiltering_whenNestedKNNAndFilterFieldWithNestedQueries_thenSuccess() {
+ for (final String engine : enginesToTest) {
+ // Set up the index with nested k-nn and metadata fields
+ createKnnIndex(INDEX_NAME, createNestedMappings(1, engine));
+ for (int i = 1; i <= NUM_DOCS; i++) {
+ // making sure that only 2 documents have valid filters
+ final String metadataFieldValue = i % 2 == 0 ? "false" : "true";
+ String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
+ .addVectors(FIELD_NAME_VECTOR, new Float[] { (float) i + 1 })
+ .addVectorWithMetadata(FIELD_NAME_VECTOR, new Float[] { (float) i }, FIELD_NAME_METADATA, metadataFieldValue)
+ .build();
+ addKnnDoc(INDEX_NAME, String.valueOf(i), doc);
+ }
+ refreshIndex(INDEX_NAME);
+ forceMergeKnnIndex(INDEX_NAME);
+
+ // Build the query with both k-nn and filters as nested fields. The filter should also have a nested context
+ final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY);
+ builder.startObject(TYPE_NESTED);
+ builder.field(PATH, FIELD_NAME_NESTED);
+ builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR);
+ builder.field(VECTOR, QUERY_VECTOR);
+ builder.field(K, k);
+ builder.startObject(FILTER_FIELD);
+ builder.startObject(TYPE_NESTED);
+ builder.field(PATH, FIELD_NAME_NESTED);
+ builder.startObject(QUERY);
+ builder.startObject(TERM_FIELD);
+ builder.field(FIELD_NAME_NESTED + "." + FIELD_NAME_METADATA, "false");
+ builder.endObject();
+ builder.endObject();
+ builder.endObject();
+ builder.endObject();
+ builder.endObject().endObject().endObject().endObject().endObject().endObject();
+
+ validateFilterSearch(builder.toString(), engine);
+ // cleanup
+ deleteKNNIndex(INDEX_NAME);
+ }
+ }
+
+ /**
+ * {
+ * "query": {
+ * "nested": {
+ * "path": "test_nested",
+ * "query": {
+ * "knn": {
+ * "test_nested.test_vector": {
+ * "vector": [
+ * 3
+ * ],
+ * "k": 20,
+ * "filter": {
+ * "term": {
+ * "test_nested.parking": "false"
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ */
+ @SneakyThrows
+ public void testFiltering_whenNestedKNNAndFilterFieldWithNoNestedContextInFilterQuery_thenSuccess() {
+ for (final String engine : enginesToTest) {
+ // Set up the index with nested k-nn and metadata fields
+ createKnnIndex(INDEX_NAME, createNestedMappings(1, engine));
+ for (int i = 1; i <= NUM_DOCS; i++) {
+ // making sure that only 2 documents have valid filters
+ final String metadataFieldValue = i % 2 == 0 ? "false" : "true";
+ String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
+ .addVectors(FIELD_NAME_VECTOR, new Float[] { (float) i + 1 })
+ .addVectorWithMetadata(FIELD_NAME_VECTOR, new Float[] { (float) i }, FIELD_NAME_METADATA, metadataFieldValue)
+ .build();
+ addKnnDoc(INDEX_NAME, String.valueOf(i), doc);
+ }
+ refreshIndex(INDEX_NAME);
+ forceMergeKnnIndex(INDEX_NAME);
+
+ // Build the query with both k-nn and filters as nested fields but a single nested context
+ final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY);
+ builder.startObject(TYPE_NESTED);
+ builder.field(PATH, FIELD_NAME_NESTED);
+ builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR);
+ builder.field(VECTOR, QUERY_VECTOR);
+ builder.field(K, k);
+ builder.startObject(FILTER_FIELD);
+ builder.startObject(TERM_FIELD);
+ builder.field(FIELD_NAME_NESTED + "." + FIELD_NAME_METADATA, "false");
+ builder.endObject();
+ builder.endObject();
+ builder.endObject().endObject().endObject().endObject().endObject().endObject();
+
+ validateFilterSearch(builder.toString(), engine);
+
+ // cleanup
+ deleteKNNIndex(INDEX_NAME);
+ }
+ }
+
+ /**
+ * {
+ * "query": {
+ * "nested": {
+ * "path": "test_nested",
+ * "query": {
+ * "knn": {
+ * "test_nested.test_vector": {
+ * "vector": [
+ * 3
+ * ],
+ * "k": 20,
+ * "filter": {
+ * "term": {
+ * "parking": "false"
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ *
+ */
+ @SneakyThrows
+ public void testFiltering_whenNestedKNNAndNonNestedFilterFieldWithNonNestedFilterQuery_thenSuccess() {
+ for (final String engine : enginesToTest) {
+ // Set up the index with nested k-nn and metadata fields
+ createKnnIndex(INDEX_NAME, createNestedMappings(1, engine));
+ for (int i = 1; i <= NUM_DOCS; i++) {
+ final String metadataFieldValue = i % 2 == 0 ? "false" : "true";
+ String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
+ .addVectors(FIELD_NAME_VECTOR, new Float[] { (float) i + 1 }, new Float[] { (float) i })
+ .addTopLevelField(FIELD_NAME_METADATA, metadataFieldValue)
+ .build();
+ addKnnDoc(INDEX_NAME, String.valueOf(i), doc);
+ }
+ refreshIndex(INDEX_NAME);
+ forceMergeKnnIndex(INDEX_NAME);
+
+ // Build the query with k-nn field as nested query and filter on the top level fields
+ final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY);
+ builder.startObject(TYPE_NESTED);
+ builder.field(PATH, FIELD_NAME_NESTED);
+ builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR);
+ builder.field(VECTOR, QUERY_VECTOR);
+ builder.field(K, k);
+ builder.startObject(FILTER_FIELD);
+ builder.startObject(TERM_FIELD);
+ builder.field(FIELD_NAME_METADATA, "false");
+ builder.endObject();
+ builder.endObject();
+ builder.endObject().endObject().endObject().endObject().endObject().endObject();
+
+ validateFilterSearch(builder.toString(), engine);
+
+ // cleanup
+ deleteKNNIndex(INDEX_NAME);
+ }
+ }
+
+ /**
+ * {
+ * "query": {
+ * "knn": {
+ * "test_vector": {
+ * "vector": [
+ * 3
+ * ],
+ * "k": 20,
+ * "filter": {
+ * "bool": {
+ * "should": [
+ * {
+ * "nested": {
+ * "path": "test_nested",
+ * "query": {
+ * "term": {
+ * "test_nested.parking": "false"
+ * }
+ * }
+ * }
+ * }
+ * ]
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ */
+ @SneakyThrows
+ public void testFiltering_whenNonNestedKNNAndNestedFilterFieldWithNestedFilterQuery_thenSuccess() {
+ for (final String engine : enginesToTest) {
+ // Set up the index with nested k-nn and metadata fields
+ createKnnIndex(INDEX_NAME, createVectorNonNestedMappings(1, engine));
+ for (int i = 1; i <= NUM_DOCS; i++) {
+ final String metadataFieldValue = i % 2 == 0 ? "false" : "true";
+ String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
+ .addMetadata(ImmutableMap.of(FIELD_NAME_METADATA, metadataFieldValue))
+ .addTopLevelField(FIELD_NAME_VECTOR, new Float[] { (float) i })
+ .build();
+ addKnnDoc(INDEX_NAME, String.valueOf(i), doc);
+ }
+ refreshIndex(INDEX_NAME);
+ forceMergeKnnIndex(INDEX_NAME);
+
+ // Build the query when filters are nested with nested path and k-NN field is non nested.
+ final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY);
+ builder.startObject(KNN).startObject(FIELD_NAME_VECTOR);
+ builder.field(VECTOR, QUERY_VECTOR);
+ builder.field(K, k);
+ builder.startObject(FILTER_FIELD);
+ builder.startObject("bool");
+ builder.startArray("should");
+ builder.startObject();
+
+ builder.startObject(TYPE_NESTED);
+ builder.field(PATH, FIELD_NAME_NESTED);
+
+ builder.startObject(QUERY);
+ builder.startObject(TERM_FIELD);
+ builder.field(FIELD_NAME_NESTED + "." + FIELD_NAME_METADATA, "false");
+ builder.endObject();
+ builder.endObject();
+
+ builder.endObject();
+
+ builder.endObject();
+ builder.endArray();
+ builder.endObject();
+ builder.endObject();
+ builder.endObject();
+ builder.endObject();
+ builder.endObject();
+ builder.endObject();
+
+ validateFilterSearch(builder.toString(), engine);
+ // cleanup
+ deleteKNNIndex(INDEX_NAME);
+ }
+ }
+
+ /**
+ * {
+ * "query": {
+ * "knn": {
+ * "test_vector": {
+ * "vector": [
+ * 5
+ * ],
+ * "k": 20,
+ * "filter": {
+ * "bool": {
+ * "must": [
+ * {
+ * "nested": {
+ * "path": "test_nested",
+ * "query": {
+ * "term": {
+ * "test_nested.parking": "false"
+ * }
+ * }
+ * }
+ * },
+ * {
+ * "term": {
+ * "parking": "false"
+ * }
+ * }
+ * ]
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ */
+ @SneakyThrows
+ public void testFiltering_whenNonNestedKNNAndNestedFilterAndNonNestedFieldWithNestedAndNonNestedFilterQuery_thenSuccess() {
+ for (final String engine : enginesToTest) {
+ // Set up the index with nested k-nn and metadata fields
+ createKnnIndex(INDEX_NAME, createVectorNonNestedMappings(1, engine));
+ for (int i = 1; i <= NUM_DOCS; i++) {
+ final String metadataFieldValue = i % 2 == 0 ? "false" : "true";
+ String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
+ .addMetadata(ImmutableMap.of(FIELD_NAME_METADATA, metadataFieldValue))
+ .addTopLevelField(FIELD_NAME_VECTOR, new Float[] { (float) i })
+ .addTopLevelField(FIELD_NAME_METADATA, metadataFieldValue)
+ .build();
+ addKnnDoc(INDEX_NAME, String.valueOf(i), doc);
+ }
+ refreshIndex(INDEX_NAME);
+ forceMergeKnnIndex(INDEX_NAME);
+
+ // Build the query when filters are nested with nested path and k-NN field is non nested.
+ final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY);
+ builder.startObject(KNN).startObject(FIELD_NAME_VECTOR);
+ builder.field(VECTOR, QUERY_VECTOR);
+ builder.field(K, k);
+ builder.startObject(FILTER_FIELD);
+ builder.startObject("bool");
+ builder.startArray("must");
+ builder.startObject();
+ builder.startObject(TERM_FIELD);
+ builder.field(FIELD_NAME_METADATA, "false");
+ builder.endObject();
+ builder.endObject();
+
+ builder.startObject();
+
+ builder.startObject(TYPE_NESTED);
+ builder.field(PATH, FIELD_NAME_NESTED);
+
+ builder.startObject(QUERY);
+ builder.startObject(TERM_FIELD);
+ builder.field(FIELD_NAME_NESTED + "." + FIELD_NAME_METADATA, "false");
+ builder.endObject();
+ builder.endObject();
+
+ builder.endObject();
+
+ builder.endObject();
+ builder.endArray();
+ builder.endObject();
+ builder.endObject();
+ builder.endObject();
+ builder.endObject();
+ builder.endObject();
+ builder.endObject();
+
+ validateFilterSearch(builder.toString(), engine);
+ // cleanup
+ deleteKNNIndex(INDEX_NAME);
+ }
+ }
+
+ private void validateFilterSearch(final String query, final String engine) throws IOException, ParseException {
+ String response = EntityUtils.toString(performSearch(INDEX_NAME, query).getEntity());
+ // Validate number of documents returned as the expected number of documents
+ Assert.assertEquals("For engine " + engine + " : ", DOCUMENT_IN_RESPONSE, parseHits(response));
+ if (KNNEngine.getEngine(engine) == KNNEngine.FAISS) {
+ // Update the filter threshold to 0 to ensure that we are hitting ANN Search use case for FAISS
+ updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 0));
+ response = EntityUtils.toString(performSearch(INDEX_NAME, query).getEntity());
+
+ // Validate number of documents returned as the expected number of documents
+ Assert.assertEquals("For engine " + engine + " with ANN search :", DOCUMENT_IN_RESPONSE, parseHits(response));
+ }
+ }
+
+ /**
+ * Sample return
+ * {
+ * "properties": {
+ * "test_nested": {
+ * "type": "nested",
+ * "properties": {
+ * "test_vector": {
+ * "type": "knn_vector",
+ * "dimension": 1,
+ * "method": {
+ * "name": "hnsw",
+ * "space_type": "l2",
+ * "engine": "lucene"
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }
+ */
+ @SneakyThrows
+ private String createNestedMappings(final int dimension, final String engine) {
+ XContentBuilder builder = XContentFactory.jsonBuilder()
+ .startObject()
+ .startObject(PROPERTIES_FIELD)
+ .startObject(FIELD_NAME_NESTED)
+ .field(TYPE, TYPE_NESTED)
+ .startObject(PROPERTIES_FIELD)
+ .startObject(FIELD_NAME_VECTOR)
+ .field(TYPE, TYPE_KNN_VECTOR)
+ .field(DIMENSION, dimension)
+ .startObject(KNN_METHOD)
+ .field(NAME, METHOD_HNSW)
+ .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
+ .field(KNN_ENGINE, engine)
+ .endObject()
+ .endObject()
+ .endObject()
+ .endObject()
+ .endObject()
+ .endObject();
+
+ return builder.toString();
+ }
+
+ /**
+ * Sample return
+ * {
+ * "properties": {
+ * "test_vector": {
+ * "type": "knn_vector",
+ * "dimension": 1,
+ * "method": {
+ * "name": "hnsw",
+ * "space_type": "l2",
+ * "engine": "lucene"
+ * }
+ * },
+ * "test_nested": {
+ * "type": "nested"
+ * }
+ * }
+ * }
+ */
+ @SneakyThrows
+ private String createVectorNonNestedMappings(final int dimension, final String engine) {
+ XContentBuilder builder = XContentFactory.jsonBuilder()
+ .startObject()
+ .startObject(PROPERTIES_FIELD)
+ .startObject(FIELD_NAME_NESTED)
+ .field(TYPE, TYPE_NESTED)
+ .endObject()
+ .startObject(FIELD_NAME_VECTOR)
+ .field(TYPE, TYPE_KNN_VECTOR)
+ .field(DIMENSION, dimension)
+ .startObject(KNN_METHOD)
+ .field(NAME, METHOD_HNSW)
+ .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
+ .field(KNN_ENGINE, engine)
+ .endObject()
+ .endObject()
+ .endObject()
+ .endObject();
+
+ return builder.toString();
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java
index 4f39c8bba..9cd446195 100644
--- a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java
+++ b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java
@@ -15,6 +15,7 @@
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNRestTestCase;
+import org.opensearch.knn.NestedKnnDocBuilder;
import org.opensearch.knn.index.util.KNNEngine;
import java.io.IOException;
@@ -59,14 +60,16 @@ public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() {
createKnnIndex(2, KNNEngine.LUCENE.getName());
String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
- .add(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f })
+ .addVectors(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f })
.build();
- addNestedKnnDoc(INDEX_NAME, "1", doc1);
+ addKnnDoc(INDEX_NAME, "1", doc1);
String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
- .add(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f })
+ .addVectors(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f })
.build();
- addNestedKnnDoc(INDEX_NAME, "2", doc2);
+ addKnnDoc(INDEX_NAME, "2", doc2);
+
+ refreshIndex(INDEX_NAME);
Float[] queryVector = { 1f, 1f };
Response response = queryNestedField(INDEX_NAME, 2, queryVector);
@@ -131,30 +134,6 @@ private void createKnnIndex(final int dimension, final String engine) throws Exc
createKnnIndex(INDEX_NAME, mapping);
}
- @SneakyThrows
- private void ingestTestData() {
- String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
- .add(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f })
- .build();
- addNestedKnnDoc(INDEX_NAME, "1", doc1);
-
- String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
- .add(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f })
- .build();
- addNestedKnnDoc(INDEX_NAME, "2", doc2);
- }
-
- private void addNestedKnnDoc(final String index, final String docId, final String document) throws IOException {
- Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");
-
- request.setJsonEntity(document);
- client().performRequest(request);
-
- request = new Request("POST", "/" + index + "/_refresh");
- Response response = client().performRequest(request);
- assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
- }
-
private Response queryNestedField(final String index, final int k, final Object[] vector) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY);
builder.startObject(TYPE_NESTED);
@@ -172,31 +151,4 @@ private Response queryNestedField(final String index, final int k, final Object[
return response;
}
-
- private static class NestedKnnDocBuilder {
- private XContentBuilder builder;
-
- public NestedKnnDocBuilder(final String fieldName) throws IOException {
- builder = XContentFactory.jsonBuilder().startObject().startArray(fieldName);
- }
-
- public static NestedKnnDocBuilder create(final String fieldName) throws IOException {
- return new NestedKnnDocBuilder(fieldName);
- }
-
- public NestedKnnDocBuilder add(final String fieldName, final Object[]... vectors) throws IOException {
- for (Object[] vector : vectors) {
- builder.startObject();
- builder.field(fieldName, vector);
- builder.endObject();
- }
- return this;
- }
-
- public String build() throws IOException {
- builder.endArray().endObject();
- return builder.toString();
- }
-
- }
}
diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java
index a6b915a85..62d2db544 100644
--- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java
+++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java
@@ -12,11 +12,15 @@
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
+import org.apache.lucene.search.join.ToChildBlockJoinQuery;
+import org.mockito.MockedConstruction;
import org.mockito.Mockito;
import org.opensearch.index.mapper.MappedFieldType;
+import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
+import org.opensearch.index.search.NestedHelper;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;
@@ -131,6 +135,62 @@ public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery()
validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, DiversifyingChildrenFloatKnnVectorQuery.class);
}
+ public void testCreate_whenNestedVectorFiledAndNonNestedFilterField_thenReturnToChildBlockJoinQueryForFilters() {
+ MapperService mockMapperService = mock(MapperService.class);
+ QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
+ when(mockQueryShardContext.getMapperService()).thenReturn(mockMapperService);
+ MappedFieldType testMapper = mock(MappedFieldType.class);
+ when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper);
+ when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY);
+ BitSetProducer parentFilter = mock(BitSetProducer.class);
+ when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter);
+ MockedConstruction mockedNestedHelper = Mockito.mockConstruction(
+ NestedHelper.class,
+ (mock, context) -> when(mock.mightMatchNestedDocs(FILTER_QUERY)).thenReturn(false)
+ );
+
+ final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
+ .knnEngine(KNNEngine.FAISS)
+ .indexName(testIndexName)
+ .fieldName(testFieldName)
+ .vector(testQueryVector)
+ .k(testK)
+ .context(mockQueryShardContext)
+ .filter(FILTER_QUERY_BUILDER)
+ .build();
+ KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest);
+ mockedNestedHelper.close();
+ assertEquals(ToChildBlockJoinQuery.class, query.getFilterQuery().getClass());
+ }
+
+ public void testCreate_whenNestedVectorAndFilterField_thenReturnSameFilterQuery() {
+ MapperService mockMapperService = mock(MapperService.class);
+ QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
+ when(mockQueryShardContext.getMapperService()).thenReturn(mockMapperService);
+ MappedFieldType testMapper = mock(MappedFieldType.class);
+ when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper);
+ when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY);
+ BitSetProducer parentFilter = mock(BitSetProducer.class);
+ when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter);
+ MockedConstruction mockedNestedHelper = Mockito.mockConstruction(
+ NestedHelper.class,
+ (mock, context) -> when(mock.mightMatchNestedDocs(FILTER_QUERY)).thenReturn(true)
+ );
+
+ final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
+ .knnEngine(KNNEngine.FAISS)
+ .indexName(testIndexName)
+ .fieldName(testFieldName)
+ .vector(testQueryVector)
+ .k(testK)
+ .context(mockQueryShardContext)
+ .filter(FILTER_QUERY_BUILDER)
+ .build();
+ KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest);
+ mockedNestedHelper.close();
+ assertEquals(FILTER_QUERY.getClass(), query.getFilterQuery().getClass());
+ }
+
private void validateDiversifyingQueryWithParentFilter(final VectorDataType type, final Class expectedQueryClass) {
List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values())
.filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine))
diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
index 78cd84f05..d0ece655c 100644
--- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
+++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
@@ -222,6 +222,15 @@ protected Response searchExists(String index, ExistsQueryBuilder existsQueryBuil
return response;
}
+ protected Response performSearch(final String indexName, final String query) throws IOException {
+ Request request = new Request("POST", "/" + indexName + "/_search");
+ request.setJsonEntity(query);
+
+ Response response = client().performRequest(request);
+ assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
+ return response;
+ }
+
/**
* Parse the response of KNN search into a List of KNNResults
*/
@@ -501,6 +510,15 @@ protected void addKnnDoc(String index, String docId, List fieldNames, Li
assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}
+ /**
+ * Adds a doc where document is represented as a string.
+ */
+ protected void addKnnDoc(final String index, final String docId, final String document) throws IOException {
+ Request request = new Request("POST", "/" + index + "/_doc/" + docId);
+ request.setJsonEntity(document);
+ client().performRequest(request);
+ }
+
/**
* Add a single numeric field Doc to an index
*/
@@ -695,6 +713,14 @@ protected int parseTotalSearchHits(String searchResponseBody) throws IOException
return (int) ((Map) responseMap.get("total")).get("value");
}
+ protected int parseHits(String searchResponseBody) throws IOException {
+ Map responseMap = (Map) createParser(
+ MediaTypeRegistry.getDefaultMediaType().xContent(),
+ searchResponseBody
+ ).map().get("hits");
+ return ((List) responseMap.get("hits")).size();
+ }
+
/**
* Get the total number of graphs in the cache across all nodes
*/
diff --git a/src/testFixtures/java/org/opensearch/knn/NestedKnnDocBuilder.java b/src/testFixtures/java/org/opensearch/knn/NestedKnnDocBuilder.java
new file mode 100644
index 000000000..dca58838a
--- /dev/null
+++ b/src/testFixtures/java/org/opensearch/knn/NestedKnnDocBuilder.java
@@ -0,0 +1,90 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ *
+ * Modifications Copyright OpenSearch Contributors. See
+ * GitHub history for details.
+ */
+
+package org.opensearch.knn;
+
+import org.opensearch.common.xcontent.XContentFactory;
+import org.opensearch.core.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Map;
+
+public class NestedKnnDocBuilder {
+ private XContentBuilder builder;
+ private boolean isNestedFieldBuildCompleted;
+
+ public NestedKnnDocBuilder(final String fieldName) throws IOException {
+ isNestedFieldBuildCompleted = false;
+ builder = XContentFactory.jsonBuilder().startObject().startArray(fieldName);
+ }
+
+ public static NestedKnnDocBuilder create(final String fieldName) throws IOException {
+ return new NestedKnnDocBuilder(fieldName);
+ }
+
+ public NestedKnnDocBuilder addVectors(final String fieldName, final Object[]... vectors) throws IOException {
+ for (Object[] vector : vectors) {
+ builder.startObject();
+ builder.field(fieldName, vector);
+ builder.endObject();
+ }
+ return this;
+ }
+
+ public NestedKnnDocBuilder addVectorWithMetadata(
+ final String fieldName,
+ final Object[] vectorValue,
+ final String metadataFieldName,
+ final Object metadataValue
+ ) throws IOException {
+ builder.startObject();
+ builder.field(fieldName, vectorValue);
+ builder.field(metadataFieldName, metadataValue);
+ builder.endObject();
+ return this;
+ }
+
+ public NestedKnnDocBuilder addMetadata(final Map metadata) throws IOException {
+ builder.startObject();
+ metadata.forEach((k, v) -> {
+ try {
+ builder.field(k, v);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ builder.endObject();
+ return this;
+ }
+
+ /**
+ * Use this function when you want to add top level fields in the document that contains nested fields. Once you
+ * run this function you cannot add anything in the nested field.
+ */
+ public NestedKnnDocBuilder addTopLevelField(final String fieldName, final Object value) throws IOException {
+ if (isNestedFieldBuildCompleted == false) {
+ // Making sure that we close the building of nested field.
+ isNestedFieldBuildCompleted = true;
+ builder.endArray();
+ }
+ builder.field(fieldName, value);
+ return this;
+ }
+
+ public String build() throws IOException {
+ if (isNestedFieldBuildCompleted) {
+ builder.endObject();
+ } else {
+ builder.endArray().endObject();
+ }
+ return builder.toString();
+ }
+}