Skip to content

Commit

Permalink
Isolate KNNQueryBuilder creation in NeuralKNNQueryBuilder
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
  • Loading branch information
junqiu-lei committed Feb 11, 2025
1 parent 7dc84b5 commit 6a1f045
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
### Maintenance
### Refactoring
- Isolate KNNQueryBuilder creation in NeuralKNNQueryBuilder ([#1183](https://github.com/opensearch-project/neural-search/pull/1183))

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.19...2.x)
### Features
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.query;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.rescore.RescoreContext;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

/**
* NeuralKNNQueryBuilder wraps KNNQueryBuilder to:
* 1. Isolate KNN plugin API changes to a single location
* 2. Allow extension with neural-search-specific information (e.g., query text)
*/
public class NeuralKNNQueryBuilder extends AbstractQueryBuilder<NeuralKNNQueryBuilder> {
private final KNNQueryBuilder knnQueryBuilder;

/**
* Wraps KNN Lucene query to support neural search extensions.
* Delegates core operations to the underlying KNN query.
*/
public static class NeuralKNNQuery extends Query {
private final Query knnQuery;

public NeuralKNNQuery(Query knnQuery) {
this.knnQuery = knnQuery;
}

@Override
public String toString(String field) {
return knnQuery.toString(field);
}

@Override
public void visit(QueryVisitor visitor) {
// Delegate the visitor to the underlying KNN query
knnQuery.visit(visitor);
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
// Delegate weight creation to the underlying KNN query
return knnQuery.createWeight(searcher, scoreMode, boost);
}

@Override
public Query rewrite(IndexReader reader) throws IOException {
Query rewritten = knnQuery.rewrite(reader);
if (rewritten == knnQuery) {
return this;
}
return new NeuralKNNQuery(rewritten);
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || getClass() != other.getClass()) return false;
NeuralKNNQuery that = (NeuralKNNQuery) other;
return Objects.equals(knnQuery, that.knnQuery);
}

@Override
public int hashCode() {
return Objects.hash(knnQuery);
}
}

/**
* Factory method to create NeuralKNNQueryBuilder with given parameters.
* Centralizes KNNQueryBuilder creation to handle breaking changes.
*/
public static NeuralKNNQueryBuilder fromKNNBuilder(
String fieldName,
float[] vector,
Integer k,
QueryBuilder filter,
Float maxDistance,
Float minScore,
Boolean expandNested,
Map<String, ?> methodParameters,
RescoreContext rescoreContext
) {
KNNQueryBuilder knnBuilder = KNNQueryBuilder.builder()
.fieldName(fieldName)
.vector(vector)
.k(k)
.filter(filter)
.maxDistance(maxDistance)
.minScore(minScore)
.expandNested(expandNested)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
.build();
return new NeuralKNNQueryBuilder(knnBuilder);
}

private NeuralKNNQueryBuilder(KNNQueryBuilder knnQueryBuilder) {
super();
this.knnQueryBuilder = knnQueryBuilder;
}

@Override
public void doWriteTo(StreamOutput out) throws IOException {
knnQueryBuilder.writeTo(out);
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
knnQueryBuilder.toXContent(builder, params);
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext context) throws IOException {
QueryBuilder rewritten = knnQueryBuilder.rewrite(context);
if (rewritten == knnQueryBuilder) {
return this;
}
return new NeuralKNNQueryBuilder((KNNQueryBuilder) rewritten);
}

@Override
protected Query doToQuery(QueryShardContext context) throws IOException {
Query knnQuery = knnQueryBuilder.toQuery(context);
return new NeuralKNNQuery(knnQuery);
}

@Override
protected boolean doEquals(NeuralKNNQueryBuilder other) {
return Objects.equals(knnQueryBuilder, other.knnQueryBuilder);
}

@Override
protected int doHashCode() {
return Objects.hash(knnQueryBuilder);
}

@Override
public String getWriteableName() {
return knnQueryBuilder.getWriteableName();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
Expand Down Expand Up @@ -462,23 +461,24 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
// https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/Rewriteable.java#L117.
// With the asynchronous call, on first rewrite, we create a new
// vector supplier that will get populated once the asynchronous call finishes and pass this supplier in to
// create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just
// return the current unmodified query builder.
// create a new builder. Once the supplier's value gets set, we return a NeuralKNNQueryBuilder
// which wrapped KNNQueryBuilder. Otherwise, we just return the current unmodified query builder.
if (vectorSupplier() != null) {
if (vectorSupplier().get() == null) {
return this;
}
return KNNQueryBuilder.builder()
.fieldName(fieldName())
.vector(vectorSupplier.get())
.filter(filter())
.maxDistance(maxDistance)
.minScore(minScore)
.expandNested(expandNested)
.k(k)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
.build();

return NeuralKNNQueryBuilder.fromKNNBuilder(
fieldName(),
vectorSupplier.get(),
k(),
filter(),
maxDistance(),
minScore(),
expandNested(),
methodParameters(),
rescoreContext()
);
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
Expand Down
Loading

0 comments on commit 6a1f045

Please sign in to comment.