-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Isolate KNNQueryBuilder creation in NeuralKNNQueryBuilder
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
- Loading branch information
1 parent
7dc84b5
commit 6a1f045
Showing
4 changed files
with
369 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
158 changes: 158 additions & 0 deletions
158
src/main/java/org/opensearch/neuralsearch/query/NeuralKNNQueryBuilder.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.query; | ||
|
||
import org.apache.lucene.index.IndexReader; | ||
import org.apache.lucene.search.IndexSearcher; | ||
import org.apache.lucene.search.Query; | ||
import org.apache.lucene.search.QueryVisitor; | ||
import org.apache.lucene.search.ScoreMode; | ||
import org.apache.lucene.search.Weight; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.index.query.AbstractQueryBuilder; | ||
import org.opensearch.index.query.QueryBuilder; | ||
import org.opensearch.index.query.QueryRewriteContext; | ||
import org.opensearch.index.query.QueryShardContext; | ||
import org.opensearch.knn.index.query.KNNQueryBuilder; | ||
import org.opensearch.knn.index.query.rescore.RescoreContext; | ||
|
||
import java.io.IOException; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
/** | ||
* NeuralKNNQueryBuilder wraps KNNQueryBuilder to: | ||
* 1. Isolate KNN plugin API changes to a single location | ||
* 2. Allow extension with neural-search-specific information (e.g., query text) | ||
*/ | ||
public class NeuralKNNQueryBuilder extends AbstractQueryBuilder<NeuralKNNQueryBuilder> { | ||
private final KNNQueryBuilder knnQueryBuilder; | ||
|
||
/** | ||
* Wraps KNN Lucene query to support neural search extensions. | ||
* Delegates core operations to the underlying KNN query. | ||
*/ | ||
public static class NeuralKNNQuery extends Query { | ||
private final Query knnQuery; | ||
|
||
public NeuralKNNQuery(Query knnQuery) { | ||
this.knnQuery = knnQuery; | ||
} | ||
|
||
@Override | ||
public String toString(String field) { | ||
return knnQuery.toString(field); | ||
} | ||
|
||
@Override | ||
public void visit(QueryVisitor visitor) { | ||
// Delegate the visitor to the underlying KNN query | ||
knnQuery.visit(visitor); | ||
} | ||
|
||
@Override | ||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { | ||
// Delegate weight creation to the underlying KNN query | ||
return knnQuery.createWeight(searcher, scoreMode, boost); | ||
} | ||
|
||
@Override | ||
public Query rewrite(IndexReader reader) throws IOException { | ||
Query rewritten = knnQuery.rewrite(reader); | ||
if (rewritten == knnQuery) { | ||
return this; | ||
} | ||
return new NeuralKNNQuery(rewritten); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object other) { | ||
if (this == other) return true; | ||
if (other == null || getClass() != other.getClass()) return false; | ||
NeuralKNNQuery that = (NeuralKNNQuery) other; | ||
return Objects.equals(knnQuery, that.knnQuery); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(knnQuery); | ||
} | ||
} | ||
|
||
/** | ||
* Factory method to create NeuralKNNQueryBuilder with given parameters. | ||
* Centralizes KNNQueryBuilder creation to handle breaking changes. | ||
*/ | ||
public static NeuralKNNQueryBuilder fromKNNBuilder( | ||
String fieldName, | ||
float[] vector, | ||
Integer k, | ||
QueryBuilder filter, | ||
Float maxDistance, | ||
Float minScore, | ||
Boolean expandNested, | ||
Map<String, ?> methodParameters, | ||
RescoreContext rescoreContext | ||
) { | ||
KNNQueryBuilder knnBuilder = KNNQueryBuilder.builder() | ||
.fieldName(fieldName) | ||
.vector(vector) | ||
.k(k) | ||
.filter(filter) | ||
.maxDistance(maxDistance) | ||
.minScore(minScore) | ||
.expandNested(expandNested) | ||
.methodParameters(methodParameters) | ||
.rescoreContext(rescoreContext) | ||
.build(); | ||
return new NeuralKNNQueryBuilder(knnBuilder); | ||
} | ||
|
||
private NeuralKNNQueryBuilder(KNNQueryBuilder knnQueryBuilder) { | ||
super(); | ||
this.knnQueryBuilder = knnQueryBuilder; | ||
} | ||
|
||
@Override | ||
public void doWriteTo(StreamOutput out) throws IOException { | ||
knnQueryBuilder.writeTo(out); | ||
} | ||
|
||
@Override | ||
protected void doXContent(XContentBuilder builder, Params params) throws IOException { | ||
knnQueryBuilder.toXContent(builder, params); | ||
} | ||
|
||
@Override | ||
protected QueryBuilder doRewrite(QueryRewriteContext context) throws IOException { | ||
QueryBuilder rewritten = knnQueryBuilder.rewrite(context); | ||
if (rewritten == knnQueryBuilder) { | ||
return this; | ||
} | ||
return new NeuralKNNQueryBuilder((KNNQueryBuilder) rewritten); | ||
} | ||
|
||
@Override | ||
protected Query doToQuery(QueryShardContext context) throws IOException { | ||
Query knnQuery = knnQueryBuilder.toQuery(context); | ||
return new NeuralKNNQuery(knnQuery); | ||
} | ||
|
||
@Override | ||
protected boolean doEquals(NeuralKNNQueryBuilder other) { | ||
return Objects.equals(knnQueryBuilder, other.knnQueryBuilder); | ||
} | ||
|
||
@Override | ||
protected int doHashCode() { | ||
return Objects.hash(knnQueryBuilder); | ||
} | ||
|
||
@Override | ||
public String getWriteableName() { | ||
return knnQueryBuilder.getWriteableName(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.