From a7b424da91ab8fe57ccf4dde9e6e3930b471414e Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Tue, 9 Jul 2024 17:54:30 -0700 Subject: [PATCH] Fixed LeafReaders casting errors to SegmentReaders when segment replication is enabled during search Signed-off-by: Navneet Verma --- .github/workflows/CI.yml | 6 +- .github/workflows/test_security.yml | 2 +- CHANGELOG.md | 1 + .../opensearch/knn/index/KNNIndexShard.java | 4 +- .../opensearch/knn/index/query/KNNWeight.java | 6 +- .../knn/index/SegmentReplicationIT.java | 94 +++++++++++++++++++ .../plugin/action/RestKNNStatsHandlerIT.java | 72 ++++++++++---- .../org/opensearch/knn/KNNRestTestCase.java | 35 ++++++- 8 files changed, 190 insertions(+), 30 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/SegmentReplicationIT.java diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0b9b24d983..295bf98108 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -61,7 +61,7 @@ jobs: su `id -un 1000` -c "whoami && java -version && ./gradlew build" else echo "avx2 not available on system" - su `id -un 1000` -c "whoami && java -version && ./gradlew build -Dsimd.enabled=false" + su `id -un 1000` -c "whoami && java -version && ./gradlew build -Dsimd.enabled=false -PnumNodes=2" fi @@ -107,7 +107,7 @@ jobs: ./gradlew build else echo "avx2 not available on system" - ./gradlew build -Dsimd.enabled=false + ./gradlew build -Dsimd.enabled=false -PnumNodes=2 fi Build-k-NN-Windows: @@ -167,4 +167,4 @@ jobs: - name: Run build run: | - ./gradlew.bat build -D'simd.enabled=false' + ./gradlew.bat build -D'simd.enabled=false -PnumNodes=2' diff --git a/.github/workflows/test_security.yml b/.github/workflows/test_security.yml index e0f2dbf451..2b93e044a2 100644 --- a/.github/workflows/test_security.yml +++ b/.github/workflows/test_security.yml @@ -54,4 +54,4 @@ jobs: # switching the user, as OpenSearch cluster can only be started as root/Administrator on linux-deb/linux-rpm/windows-zip. run: | chown -R 1000:1000 `pwd` - su `id -un 1000` -c "whoami && java -version && ./gradlew integTest -Dsecurity.enabled=true -Dsimd.enabled=true" + su `id -un 1000` -c "whoami && java -version && ./gradlew integTest -PnumNodes=2 -Dsecurity.enabled=true -Dsimd.enabled=true" diff --git a/CHANGELOG.md b/CHANGELOG.md index 09f4c0d8df..024a05a1c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements ### Bug Fixes * Fixing the arithmetic to find the number of vectors to stream from java to jni layer.[#1804](https://github.com/opensearch-project/k-NN/pull/1804) +* Fixed LeafReaders casting errors to SegmentReaders when segment replication is enabled during search.[#1808](https://github.com/opensearch-project/k-NN/pull/1808) * Release memory properly for an array type [#1820](https://github.com/opensearch-project/k-NN/pull/1820) ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java index baddd674d2..32595e6025 100644 --- a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java +++ b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java @@ -10,12 +10,12 @@ import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FilterDirectory; +import org.opensearch.common.lucene.Lucene; import org.opensearch.index.engine.Engine; import org.opensearch.index.shard.IndexShard; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; @@ -160,7 +160,7 @@ List getEngineFileContexts(IndexReader indexReader, KNNEngine List engineFiles = new ArrayList<>(); for (LeafReaderContext leafReaderContext : indexReader.leaves()) { - SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); + SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); Path shardPath = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory(); String fileExtension = reader.getSegmentInfo().info.getUseCompoundFile() ? knnEngine.getCompoundExtension() diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index fce8e8e04b..3cd0118b63 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -10,7 +10,6 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; import org.apache.lucene.search.DocIdSetIterator; @@ -28,6 +27,7 @@ import org.apache.lucene.util.DocIdSetBuilder; import org.apache.lucene.util.FixedBitSet; import org.opensearch.common.io.PathUtils; +import org.opensearch.common.lucene.Lucene; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; @@ -197,7 +197,7 @@ private int[] bitSetToIntArray(final BitSet bitSet) { private Map doANNSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality) throws IOException { - SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); + final SegmentReader reader = Lucene.segmentReader(context.reader()); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); @@ -369,7 +369,7 @@ private Map doExactSearch(final LeafReaderContext leafReaderCont private FilteredIdsKNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) throws IOException { - final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); + final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName()); final SpaceType spaceType = getSpaceType(fieldInfo); diff --git a/src/test/java/org/opensearch/knn/index/SegmentReplicationIT.java b/src/test/java/org/opensearch/knn/index/SegmentReplicationIT.java new file mode 100644 index 0000000000..af83adac5d --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/SegmentReplicationIT.java @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.Assert; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; + +import java.util.List; + +/** + * This IT class contains will contain special cases of IT for segment replication behavior. + * All the index created in this test will have replication type SEGMENT, number of replicas: 1 and should be run on + * at-least 2 node configuration. + */ +@Log4j2 +public class SegmentReplicationIT extends KNNRestTestCase { + private static final String INDEX_NAME = "segment-replicated-knn-index"; + + @SneakyThrows + public void testSearchOnReplicas_whenIndexHasDeletedDocs_thenSuccess() { + if (ensureMinDataNodesCountForSegmentReplication() == false) { + return; + } + createKnnIndex(INDEX_NAME, getKNNSegmentReplicatedIndexSettings(), createKNNIndexMethodFieldMapping(FIELD_NAME, 2)); + + Float[] vector = { 1.3f, 2.2f }; + int docsInIndex = 10; + + for (int i = 0; i < docsInIndex; i++) { + addKnnDoc(INDEX_NAME, Integer.toString(i), FIELD_NAME, vector); + } + refreshIndex(INDEX_NAME); + int deleteDocs = 5; + for (int i = 0; i < deleteDocs; i++) { + deleteKnnDoc(INDEX_NAME, Integer.toString(i)); + } + refreshIndex(INDEX_NAME); + // sleep for 5sec to ensure data is replicated. I don't have a better way here to know if segments has been + // replicated. + Thread.sleep(5000); + // validate warmup is successful or not. + doKnnWarmup(List.of(INDEX_NAME)); + + XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilder.startObject("knn"); + queryBuilder.startObject(FIELD_NAME); + queryBuilder.field("vector", vector); + queryBuilder.field("k", docsInIndex); + queryBuilder.endObject().endObject().endObject().endObject(); + // validate replicas are working + Response searchResponse = performSearch(INDEX_NAME, queryBuilder.toString(), "preference=_replica"); + String responseBody = EntityUtils.toString(searchResponse.getEntity()); + List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + assertEquals(docsInIndex - deleteDocs, knnResults.size()); + + // validate primaries are working + searchResponse = performSearch(INDEX_NAME, queryBuilder.toString(), "preference=_primary"); + responseBody = EntityUtils.toString(searchResponse.getEntity()); + knnResults = parseSearchResponse(responseBody, FIELD_NAME); + assertEquals(docsInIndex - deleteDocs, knnResults.size()); + } + + private boolean ensureMinDataNodesCountForSegmentReplication() { + int dataNodeCount = getDataNodeCount(); + if (dataNodeCount <= 1) { + log.warn( + "Not running segment replication tests named: " + + "testSearchOnReplicas_whenIndexHasDeletedDocs_thenSuccess, as data nodes count is not atleast 2. " + + "Actual datanode count : {}", + dataNodeCount + ); + Assert.assertTrue(true); + // making the test successful because we don't want to break already running tests. + return false; + } + return true; + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index d9949aaf24..1a6f465c8a 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -214,9 +214,13 @@ public void testScriptStats_singleShard() throws Exception { ) ); List> nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); - int initialScriptCompilations = (int) (nodeStats.get(0).get(StatNames.SCRIPT_COMPILATIONS.getName())); - int initialScriptQueryRequests = (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName())); - int initialScriptQueryErrors = (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName())); + int initialScriptCompilations = nodeStats.stream() + .mapToInt(value -> (int) value.get(StatNames.SCRIPT_COMPILATIONS.getName())) + .sum(); + int initialScriptQueryRequests = nodeStats.stream() + .mapToInt(value -> (int) value.get(StatNames.SCRIPT_QUERY_REQUESTS.getName())) + .sum(); + int initialScriptQueryErrors = nodeStats.stream().mapToInt(value -> (int) value.get(StatNames.SCRIPT_QUERY_ERRORS.getName())).sum(); // Create an index with a single vector createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); @@ -239,8 +243,14 @@ public void testScriptStats_singleShard() throws Exception { Arrays.asList(StatNames.SCRIPT_COMPILATIONS.getName(), StatNames.SCRIPT_QUERY_REQUESTS.getName()) ); nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); - assertEquals((int) (nodeStats.get(0).get(StatNames.SCRIPT_COMPILATIONS.getName())), initialScriptCompilations + 1); - assertEquals(initialScriptQueryRequests + 1, (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName()))); + assertEquals( + nodeStats.stream().mapToInt(value -> (int) value.get(StatNames.SCRIPT_COMPILATIONS.getName())).sum(), + initialScriptCompilations + 1 + ); + assertEquals( + initialScriptQueryRequests + 1, + nodeStats.stream().mapToInt(value -> (int) value.get(StatNames.SCRIPT_QUERY_REQUESTS.getName())).sum() + ); // Check query error stats params = new HashMap<>(); @@ -253,7 +263,10 @@ public void testScriptStats_singleShard() throws Exception { response = getKnnStats(Collections.emptyList(), Collections.singletonList(StatNames.SCRIPT_QUERY_ERRORS.getName())); nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); - assertEquals(initialScriptQueryErrors + 1, (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName()))); + assertEquals( + initialScriptQueryErrors + 1, + nodeStats.stream().mapToInt(value -> (int) value.get(StatNames.SCRIPT_QUERY_ERRORS.getName())).sum() + ); } /** @@ -272,9 +285,13 @@ public void testScriptStats_multipleShards() throws Exception { ) ); List> nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); - int initialScriptCompilations = (int) (nodeStats.get(0).get(StatNames.SCRIPT_COMPILATIONS.getName())); - int initialScriptQueryRequests = (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName())); - int initialScriptQueryErrors = (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName())); + int initialScriptCompilations = nodeStats.stream() + .mapToInt(value -> (int) value.get(StatNames.SCRIPT_COMPILATIONS.getName())) + .sum(); + int initialScriptQueryRequests = nodeStats.stream() + .mapToInt(value -> (int) value.get(StatNames.SCRIPT_QUERY_REQUESTS.getName())) + .sum(); + int initialScriptQueryErrors = nodeStats.stream().mapToInt(value -> (int) value.get(StatNames.SCRIPT_QUERY_ERRORS.getName())).sum(); // Create an index with a single vector createKnnIndex( @@ -305,10 +322,16 @@ public void testScriptStats_multipleShards() throws Exception { Arrays.asList(StatNames.SCRIPT_COMPILATIONS.getName(), StatNames.SCRIPT_QUERY_REQUESTS.getName()) ); nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); - assertEquals((int) (nodeStats.get(0).get(StatNames.SCRIPT_COMPILATIONS.getName())), initialScriptCompilations + 1); + assertEquals( + nodeStats.stream().mapToInt(value -> (int) value.get(StatNames.SCRIPT_COMPILATIONS.getName())).sum(), + initialScriptCompilations + nodeStats.size() + ); // TODO fix the test case. For some reason request count is treated as 4. // https://github.com/opendistro-for-elasticsearch/k-NN/issues/272 - assertEquals(initialScriptQueryRequests + 4, (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName()))); + assertEquals( + initialScriptQueryRequests + 4, + nodeStats.stream().mapToInt(value -> (int) value.get(StatNames.SCRIPT_QUERY_REQUESTS.getName())).sum() + ); // Check query error stats params = new HashMap<>(); @@ -321,7 +344,10 @@ public void testScriptStats_multipleShards() throws Exception { response = getKnnStats(Collections.emptyList(), Collections.singletonList(StatNames.SCRIPT_QUERY_ERRORS.getName())); nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); - assertEquals(initialScriptQueryErrors + 2, (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName()))); + assertEquals( + initialScriptQueryErrors + 2, + nodeStats.stream().mapToInt(value -> (int) value.get(StatNames.SCRIPT_QUERY_ERRORS.getName())).sum() + ); } public void testModelIndexHealthMetricsStats() throws Exception { @@ -378,19 +404,25 @@ public void testModelIndexingDegradedMetricsStats() throws Exception { * * @throws IOException throws IOException */ - public void testFieldByEngineStats() throws Exception { + @SneakyThrows + public void testFieldByEngineStats() { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2, METHOD_HNSW, NMSLIB_NAME)); - putMappingRequest(INDEX_NAME, createKnnIndexMapping(FIELD_NAME_2, 3, METHOD_HNSW, LUCENE_NAME)); - putMappingRequest(INDEX_NAME, createKnnIndexMapping(FIELD_NAME_3, 3, METHOD_HNSW, FAISS_NAME)); + putMappingRequest(INDEX_NAME, createKnnIndexMapping(FIELD_NAME_2, 2, METHOD_HNSW, LUCENE_NAME)); + putMappingRequest(INDEX_NAME, createKnnIndexMapping(FIELD_NAME_3, 2, METHOD_HNSW, FAISS_NAME)); + Float[] vector = { 6.0f, 6.0f }; + // adding the doc to ensure that stats are initialized when test is run independently + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); + addKnnDoc(INDEX_NAME, "2", FIELD_NAME_2, vector); + addKnnDoc(INDEX_NAME, "3", FIELD_NAME_3, vector); Response response = getKnnStats(Collections.emptyList(), Collections.emptyList()); String responseBody = EntityUtils.toString(response.getEntity()); - - Map nodeStats0 = parseNodeStatsResponse(responseBody).get(0); - boolean faissField = (Boolean) nodeStats0.get(StatNames.FAISS_LOADED.getName()); - boolean luceneField = (Boolean) nodeStats0.get(StatNames.LUCENE_LOADED.getName()); - boolean nmslibField = (Boolean) nodeStats0.get(StatNames.NMSLIB_LOADED.getName()); + boolean faissField = parseNodeStatsResponse(responseBody).stream().anyMatch(v -> (Boolean) v.get(StatNames.FAISS_LOADED.getName())); + boolean luceneField = parseNodeStatsResponse(responseBody).stream() + .anyMatch(v -> (Boolean) v.get(StatNames.LUCENE_LOADED.getName())); + boolean nmslibField = parseNodeStatsResponse(responseBody).stream() + .anyMatch(v -> (Boolean) v.get(StatNames.NMSLIB_LOADED.getName())); assertTrue(faissField); assertTrue(luceneField); diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 860cd2efaa..ebb171c193 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -235,7 +235,11 @@ protected Response searchExists(String index, ExistsQueryBuilder existsQueryBuil } protected Response performSearch(final String indexName, final String query) throws IOException { - Request request = new Request("POST", "/" + indexName + "/_search"); + return performSearch(indexName, query, ""); + } + + protected Response performSearch(final String indexName, final String query, final String urlParameters) throws IOException { + Request request = new Request("POST", "/" + indexName + "/_search?" + urlParameters); request.setJsonEntity(query); Response response = client().performRequest(request); @@ -667,6 +671,35 @@ protected Settings getKNNDefaultIndexSettings() { return Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", true).build(); } + protected Settings getKNNSegmentReplicatedIndexSettings() { + return Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 1) + .put("index.knn", true) + .put("index.replication.type", "SEGMENT") + .build(); + } + + @SneakyThrows + protected int getDataNodeCount() { + Request request = new Request("GET", "_nodes/stats?filter_path=nodes.*.roles"); + + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + String responseBody = EntityUtils.toString(response.getEntity()); + + Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); + Map nodesInfo = (Map) responseMap.get("nodes"); + int dataNodeCount = 0; + for (String key : nodesInfo.keySet()) { + Map> nodeRoles = (Map>) nodesInfo.get(key); + if (nodeRoles.get("roles").contains("data")) { + dataNodeCount++; + } + } + return dataNodeCount; + } + /** * Get Stats from KNN Plugin */