Skip to content

Commit

Permalink
Fix BWC tests failure
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
  • Loading branch information
junqiu-lei committed Feb 12, 2025
1 parent 1e33d65 commit 45b99f4
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.bwc.rolling;

import org.opensearch.neuralsearch.util.TestUtils;
import org.opensearch.ml.common.model.MLModelState;

import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -28,29 +29,59 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
case OLD:
sparseModelId = uploadSparseEncodingModel();
loadModel(sparseModelId);
MLModelState oldModelState = getModelState(sparseModelId);
logger.info("Model state in OLD phase: {}", oldModelState);
if (oldModelState != MLModelState.LOADED) {
logger.error("Model {} is not in LOADED state in OLD phase. Current state: {}", sparseModelId, oldModelState);
waitForModelToLoad(sparseModelId);
}
createPipelineForSparseEncodingProcessor(sparseModelId, SPARSE_PIPELINE, 2);
logger.info("Pipeline state in OLD phase: {}", getIngestionPipeline(SPARSE_PIPELINE));
createIndexWithConfiguration(
indexName,
Files.readString(Path.of(classLoader.getResource("processor/SparseIndexMappings.json").toURI())),
SPARSE_PIPELINE
);
List<Map<String, String>> docs = prepareDataForBulkIngestion(0, 5);
bulkAddDocuments(indexName, TEXT_FIELD_NAME, SPARSE_PIPELINE, docs);
logger.info("Document count after OLD phase ingestion: {}", getDocCount(indexName));
validateDocCountAndInfo(indexName, 5, () -> getDocById(indexName, "4"), EMBEDDING_FIELD_NAME, Map.class);
break;
case MIXED:
sparseModelId = TestUtils.getModelId(getIngestionPipeline(SPARSE_PIPELINE), SPARSE_ENCODING_PROCESSOR);
loadModel(sparseModelId);
MLModelState mixedModelState = getModelState(sparseModelId);
logger.info("Model state in MIXED phase: {}", mixedModelState);
if (mixedModelState != MLModelState.LOADED) {
logger.error("Model {} is not in LOADED state in MIXED phase. Current state: {}", sparseModelId, mixedModelState);
waitForModelToLoad(sparseModelId);
}
logger.info("Pipeline state in MIXED phase: {}", getIngestionPipeline(SPARSE_PIPELINE));
List<Map<String, String>> docsForMixed = prepareDataForBulkIngestion(5, 5);
logger.info("Document count before MIXED phase ingestion: {}", getDocCount(indexName));
bulkAddDocuments(indexName, TEXT_FIELD_NAME, SPARSE_PIPELINE, docsForMixed);
logger.info("Document count after MIXED phase ingestion: {}", getDocCount(indexName));
validateDocCountAndInfo(indexName, 10, () -> getDocById(indexName, "9"), EMBEDDING_FIELD_NAME, Map.class);
break;
case UPGRADED:
try {
sparseModelId = TestUtils.getModelId(getIngestionPipeline(SPARSE_PIPELINE), SPARSE_ENCODING_PROCESSOR);
loadModel(sparseModelId);
MLModelState upgradedModelState = getModelState(sparseModelId);
logger.info("Model state in UPGRADED phase: {}", upgradedModelState);
if (upgradedModelState != MLModelState.LOADED) {
logger.error(
"Model {} is not in LOADED state in UPGRADED phase. Current state: {}",
sparseModelId,
upgradedModelState
);
waitForModelToLoad(sparseModelId);
}
logger.info("Pipeline state in UPGRADED phase: {}", getIngestionPipeline(SPARSE_PIPELINE));
List<Map<String, String>> docsForUpgraded = prepareDataForBulkIngestion(10, 5);
logger.info("Document count before UPGRADED phase ingestion: {}", getDocCount(indexName));
bulkAddDocuments(indexName, TEXT_FIELD_NAME, SPARSE_PIPELINE, docsForUpgraded);
logger.info("Document count after UPGRADED phase ingestion: {}", getDocCount(indexName));
validateDocCountAndInfo(indexName, 15, () -> getDocById(indexName, "14"), EMBEDDING_FIELD_NAME, Map.class);
} finally {
wipeOfTestResources(indexName, SPARSE_PIPELINE, sparseModelId, null);
Expand All @@ -60,4 +91,20 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
throw new IllegalStateException("Unexpected value: " + getClusterType());
}
}

private void waitForModelToLoad(String modelId) throws Exception {
int maxAttempts = 30; // Maximum number of attempts
int waitTimeInSeconds = 2; // Time to wait between attempts

for (int attempt = 0; attempt < maxAttempts; attempt++) {
MLModelState state = getModelState(modelId);
if (state == MLModelState.LOADED) {
logger.info("Model {} is now loaded after {} attempts", modelId, attempt + 1);
return;
}
logger.info("Waiting for model {} to load. Current state: {}. Attempt {}/{}", modelId, state, attempt + 1, maxAttempts);
Thread.sleep(waitTimeInSeconds * 1000);
}
throw new RuntimeException("Model " + modelId + " failed to load after " + maxAttempts + " attempts");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.IndexUtil;

import java.io.IOException;
import java.util.Map;
Expand Down Expand Up @@ -127,12 +129,12 @@ private NeuralKNNQueryBuilder(KNNQueryBuilder knnQueryBuilder) {

@Override
public void doWriteTo(StreamOutput out) throws IOException {
knnQueryBuilder.writeTo(out);
KNNQueryBuilderParser.streamOutput(out, knnQueryBuilder, IndexUtil::isClusterOnOrAfterMinRequiredVersion);
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
knnQueryBuilder.toXContent(builder, params);
knnQueryBuilder.doXContent(builder, params);
}

@Override
Expand Down

0 comments on commit 45b99f4

Please sign in to comment.