Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encapsulate KNNQueryBuilder creation within NeuralKNNQueryBuilder #1183

Merged
merged 2 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
### Maintenance
### Refactoring
- Encapsulate KNNQueryBuilder creation within NeuralKNNQueryBuilder ([#1183](https://github.com/opensearch-project/neural-search/pull/1183))

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.19...2.x)
### Features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.bwc.rolling;

import org.opensearch.neuralsearch.util.TestUtils;
import org.opensearch.ml.common.model.MLModelState;

import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -28,29 +29,59 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
case OLD:
sparseModelId = uploadSparseEncodingModel();
loadModel(sparseModelId);
MLModelState oldModelState = getModelState(sparseModelId);
logger.info("Model state in OLD phase: {}", oldModelState);
if (oldModelState != MLModelState.LOADED) {
logger.error("Model {} is not in LOADED state in OLD phase. Current state: {}", sparseModelId, oldModelState);
waitForModelToLoad(sparseModelId);
}
createPipelineForSparseEncodingProcessor(sparseModelId, SPARSE_PIPELINE, 2);
logger.info("Pipeline state in OLD phase: {}", getIngestionPipeline(SPARSE_PIPELINE));
createIndexWithConfiguration(
indexName,
Files.readString(Path.of(classLoader.getResource("processor/SparseIndexMappings.json").toURI())),
SPARSE_PIPELINE
);
List<Map<String, String>> docs = prepareDataForBulkIngestion(0, 5);
bulkAddDocuments(indexName, TEXT_FIELD_NAME, SPARSE_PIPELINE, docs);
logger.info("Document count after OLD phase ingestion: {}", getDocCount(indexName));
validateDocCountAndInfo(indexName, 5, () -> getDocById(indexName, "4"), EMBEDDING_FIELD_NAME, Map.class);
break;
case MIXED:
sparseModelId = TestUtils.getModelId(getIngestionPipeline(SPARSE_PIPELINE), SPARSE_ENCODING_PROCESSOR);
loadModel(sparseModelId);
MLModelState mixedModelState = getModelState(sparseModelId);
logger.info("Model state in MIXED phase: {}", mixedModelState);
if (mixedModelState != MLModelState.LOADED) {
logger.error("Model {} is not in LOADED state in MIXED phase. Current state: {}", sparseModelId, mixedModelState);
waitForModelToLoad(sparseModelId);
}
logger.info("Pipeline state in MIXED phase: {}", getIngestionPipeline(SPARSE_PIPELINE));
List<Map<String, String>> docsForMixed = prepareDataForBulkIngestion(5, 5);
logger.info("Document count before MIXED phase ingestion: {}", getDocCount(indexName));
bulkAddDocuments(indexName, TEXT_FIELD_NAME, SPARSE_PIPELINE, docsForMixed);
logger.info("Document count after MIXED phase ingestion: {}", getDocCount(indexName));
validateDocCountAndInfo(indexName, 10, () -> getDocById(indexName, "9"), EMBEDDING_FIELD_NAME, Map.class);
break;
case UPGRADED:
try {
sparseModelId = TestUtils.getModelId(getIngestionPipeline(SPARSE_PIPELINE), SPARSE_ENCODING_PROCESSOR);
loadModel(sparseModelId);
MLModelState upgradedModelState = getModelState(sparseModelId);
logger.info("Model state in UPGRADED phase: {}", upgradedModelState);
if (upgradedModelState != MLModelState.LOADED) {
logger.error(
"Model {} is not in LOADED state in UPGRADED phase. Current state: {}",
sparseModelId,
upgradedModelState
);
waitForModelToLoad(sparseModelId);
}
logger.info("Pipeline state in UPGRADED phase: {}", getIngestionPipeline(SPARSE_PIPELINE));
List<Map<String, String>> docsForUpgraded = prepareDataForBulkIngestion(10, 5);
logger.info("Document count before UPGRADED phase ingestion: {}", getDocCount(indexName));
bulkAddDocuments(indexName, TEXT_FIELD_NAME, SPARSE_PIPELINE, docsForUpgraded);
logger.info("Document count after UPGRADED phase ingestion: {}", getDocCount(indexName));
validateDocCountAndInfo(indexName, 15, () -> getDocById(indexName, "14"), EMBEDDING_FIELD_NAME, Map.class);
} finally {
wipeOfTestResources(indexName, SPARSE_PIPELINE, sparseModelId, null);
Expand All @@ -60,4 +91,20 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
throw new IllegalStateException("Unexpected value: " + getClusterType());
}
}

private void waitForModelToLoad(String modelId) throws Exception {
int maxAttempts = 30; // Maximum number of attempts
int waitTimeInSeconds = 2; // Time to wait between attempts

for (int attempt = 0; attempt < maxAttempts; attempt++) {
MLModelState state = getModelState(modelId);
if (state == MLModelState.LOADED) {
logger.info("Model {} is now loaded after {} attempts", modelId, attempt + 1);
return;
}
logger.info("Waiting for model {} to load. Current state: {}. Attempt {}/{}", modelId, state, attempt + 1, maxAttempts);
Thread.sleep(waitTimeInSeconds * 1000);
}
throw new RuntimeException("Model " + modelId + " failed to load after " + maxAttempts + " attempts");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.query;

import lombok.Getter;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.Objects;

/**
* Wraps KNN Lucene query to support neural search extensions.
* Delegates core operations to the underlying KNN query.
*/
@Getter
public class NeuralKNNQuery extends Query {
private final Query knnQuery;

public NeuralKNNQuery(Query knnQuery) {
this.knnQuery = knnQuery;
}

@Override
public String toString(String field) {
return knnQuery.toString(field);
}

@Override
public void visit(QueryVisitor visitor) {
// Delegate the visitor to the underlying KNN query
knnQuery.visit(visitor);
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
// Delegate weight creation to the underlying KNN query
return knnQuery.createWeight(searcher, scoreMode, boost);
}

@Override
public Query rewrite(IndexReader reader) throws IOException {
Query rewritten = knnQuery.rewrite(reader);
if (rewritten == knnQuery) {
return this;
}
return new NeuralKNNQuery(rewritten);
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || getClass() != other.getClass()) return false;
NeuralKNNQuery that = (NeuralKNNQuery) other;
return Objects.equals(knnQuery, that.knnQuery);
}

@Override
public int hashCode() {
return Objects.hash(knnQuery);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.query;

import lombok.Getter;
import org.apache.lucene.search.Query;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.IndexUtil;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

/**
* NeuralKNNQueryBuilder wraps KNNQueryBuilder to:
* 1. Isolate KNN plugin API changes to a single location
* 2. Allow extension with neural-search-specific information (e.g., query text)
*/

@Getter
public class NeuralKNNQueryBuilder extends AbstractQueryBuilder<NeuralKNNQueryBuilder> {
private final KNNQueryBuilder knnQueryBuilder;

/**
* Creates a new builder instance.
*/
public static Builder builder() {
return new Builder();
}

public String fieldName() {
return knnQueryBuilder.fieldName();
}

public int k() {
return knnQueryBuilder.getK();
}

/**
* Builder for NeuralKNNQueryBuilder.
*/
public static class Builder {
private String fieldName;
private float[] vector;
private Integer k;
private QueryBuilder filter;
private Float maxDistance;
private Float minScore;
private Boolean expandNested;
private Map<String, ?> methodParameters;
private RescoreContext rescoreContext;

private Builder() {}

public Builder fieldName(String fieldName) {
this.fieldName = fieldName;
return this;
}

public Builder vector(float[] vector) {
this.vector = vector;
return this;
}

public Builder k(Integer k) {
this.k = k;
return this;
}

public Builder filter(QueryBuilder filter) {
this.filter = filter;
return this;
}

public Builder maxDistance(Float maxDistance) {
this.maxDistance = maxDistance;
return this;
}

public Builder minScore(Float minScore) {
this.minScore = minScore;
return this;
}

public Builder expandNested(Boolean expandNested) {
this.expandNested = expandNested;
return this;
}

public Builder methodParameters(Map<String, ?> methodParameters) {
this.methodParameters = methodParameters;
return this;
}

public Builder rescoreContext(RescoreContext rescoreContext) {
this.rescoreContext = rescoreContext;
return this;
}

public NeuralKNNQueryBuilder build() {
KNNQueryBuilder knnBuilder = KNNQueryBuilder.builder()
.fieldName(fieldName)
.vector(vector)
.k(k)
.filter(filter)
.maxDistance(maxDistance)
.minScore(minScore)
.expandNested(expandNested)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
.build();
return new NeuralKNNQueryBuilder(knnBuilder);
}
}

private NeuralKNNQueryBuilder(KNNQueryBuilder knnQueryBuilder) {
this.knnQueryBuilder = knnQueryBuilder;
}

@Override
public void doWriteTo(StreamOutput out) throws IOException {
KNNQueryBuilderParser.streamOutput(out, knnQueryBuilder, IndexUtil::isClusterOnOrAfterMinRequiredVersion);
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
knnQueryBuilder.doXContent(builder, params);
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext context) throws IOException {
QueryBuilder rewritten = knnQueryBuilder.rewrite(context);
if (rewritten == knnQueryBuilder) {
return this;
}
return new NeuralKNNQueryBuilder((KNNQueryBuilder) rewritten);
}

@Override
protected Query doToQuery(QueryShardContext context) throws IOException {
Query knnQuery = knnQueryBuilder.toQuery(context);
return new NeuralKNNQuery(knnQuery);
}

@Override
protected boolean doEquals(NeuralKNNQueryBuilder other) {
return Objects.equals(knnQueryBuilder, other.knnQueryBuilder);
}

@Override
protected int doHashCode() {
return Objects.hash(knnQueryBuilder);
}

@Override
public String getWriteableName() {
return knnQueryBuilder.getWriteableName();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
Expand Down Expand Up @@ -463,22 +462,23 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
// https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/Rewriteable.java#L117.
// With the asynchronous call, on first rewrite, we create a new
// vector supplier that will get populated once the asynchronous call finishes and pass this supplier in to
// create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just
// return the current unmodified query builder.
// create a new builder. Once the supplier's value gets set, we return a NeuralKNNQueryBuilder
// which wrapped KNNQueryBuilder. Otherwise, we just return the current unmodified query builder.
if (vectorSupplier() != null) {
if (vectorSupplier().get() == null) {
return this;
}
return KNNQueryBuilder.builder()

return NeuralKNNQueryBuilder.builder()
.fieldName(fieldName())
.vector(vectorSupplier.get())
.k(k())
.filter(filter())
.maxDistance(maxDistance)
.minScore(minScore)
.expandNested(expandNested)
.k(k)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
.maxDistance(maxDistance())
.minScore(minScore())
.expandNested(expandNested())
.methodParameters(methodParameters())
.rescoreContext(rescoreContext())
.build();
}

Expand Down
Loading
Loading