Skip to content

Commit

Permalink
Encapsulate KNNQueryBuilder creation within NeuralKNNQueryBuilder
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
  • Loading branch information
junqiu-lei committed Feb 12, 2025
1 parent 7dc84b5 commit ebfc232
Show file tree
Hide file tree
Showing 8 changed files with 430 additions and 32 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
### Maintenance
### Refactoring
- Encapsulate KNNQueryBuilder creation within NeuralKNNQueryBuilder ([#1183](https://github.com/opensearch-project/neural-search/pull/1183))

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.19...2.x)
### Features
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.query;

import lombok.Getter;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.Objects;

/**
* Wraps KNN Lucene query to support neural search extensions.
* Delegates core operations to the underlying KNN query.
*/
@Getter
public class NeuralKNNQuery extends Query {
private final Query knnQuery;

public NeuralKNNQuery(Query knnQuery) {
this.knnQuery = knnQuery;
}

@Override
public String toString(String field) {
return knnQuery.toString(field);
}

@Override
public void visit(QueryVisitor visitor) {
// Delegate the visitor to the underlying KNN query
knnQuery.visit(visitor);
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
// Delegate weight creation to the underlying KNN query
return knnQuery.createWeight(searcher, scoreMode, boost);
}

@Override
public Query rewrite(IndexReader reader) throws IOException {
Query rewritten = knnQuery.rewrite(reader);
if (rewritten == knnQuery) {
return this;
}
return new NeuralKNNQuery(rewritten);
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || getClass() != other.getClass()) return false;
NeuralKNNQuery that = (NeuralKNNQuery) other;
return Objects.equals(knnQuery, that.knnQuery);
}

@Override
public int hashCode() {
return Objects.hash(knnQuery);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.query;

import lombok.Getter;
import org.apache.lucene.search.Query;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.rescore.RescoreContext;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

/**
* NeuralKNNQueryBuilder wraps KNNQueryBuilder to:
* 1. Isolate KNN plugin API changes to a single location
* 2. Allow extension with neural-search-specific information (e.g., query text)
*/

@Getter
public class NeuralKNNQueryBuilder extends AbstractQueryBuilder<NeuralKNNQueryBuilder> {
private final KNNQueryBuilder knnQueryBuilder;

/**
* Creates a new builder instance.
*/
public static Builder builder() {
return new Builder();
}

public String fieldName() {
return knnQueryBuilder.fieldName();
}

public int k() {
return knnQueryBuilder.getK();
}

/**
* Builder for NeuralKNNQueryBuilder.
*/
public static class Builder {
private String fieldName;
private float[] vector;
private Integer k;
private QueryBuilder filter;
private Float maxDistance;
private Float minScore;
private Boolean expandNested;
private Map<String, ?> methodParameters;
private RescoreContext rescoreContext;

private Builder() {}

public Builder fieldName(String fieldName) {
this.fieldName = fieldName;
return this;
}

public Builder vector(float[] vector) {
this.vector = vector;
return this;
}

public Builder k(Integer k) {
this.k = k;
return this;
}

public Builder filter(QueryBuilder filter) {
this.filter = filter;
return this;
}

public Builder maxDistance(Float maxDistance) {
this.maxDistance = maxDistance;
return this;
}

public Builder minScore(Float minScore) {
this.minScore = minScore;
return this;
}

public Builder expandNested(Boolean expandNested) {
this.expandNested = expandNested;
return this;
}

public Builder methodParameters(Map<String, ?> methodParameters) {
this.methodParameters = methodParameters;
return this;
}

public Builder rescoreContext(RescoreContext rescoreContext) {
this.rescoreContext = rescoreContext;
return this;
}

public NeuralKNNQueryBuilder build() {
KNNQueryBuilder knnBuilder = KNNQueryBuilder.builder()
.fieldName(fieldName)
.vector(vector)
.k(k)
.filter(filter)
.maxDistance(maxDistance)
.minScore(minScore)
.expandNested(expandNested)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
.build();
return new NeuralKNNQueryBuilder(knnBuilder);
}
}

private NeuralKNNQueryBuilder(KNNQueryBuilder knnQueryBuilder) {
this.knnQueryBuilder = knnQueryBuilder;
}

@Override
public void doWriteTo(StreamOutput out) throws IOException {
knnQueryBuilder.writeTo(out);
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
knnQueryBuilder.toXContent(builder, params);
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext context) throws IOException {
QueryBuilder rewritten = knnQueryBuilder.rewrite(context);
if (rewritten == knnQueryBuilder) {
return this;
}
return new NeuralKNNQueryBuilder((KNNQueryBuilder) rewritten);
}

@Override
protected Query doToQuery(QueryShardContext context) throws IOException {
Query knnQuery = knnQueryBuilder.toQuery(context);
return new NeuralKNNQuery(knnQuery);
}

@Override
protected boolean doEquals(NeuralKNNQueryBuilder other) {
return Objects.equals(knnQueryBuilder, other.knnQueryBuilder);
}

@Override
protected int doHashCode() {
return Objects.hash(knnQueryBuilder);
}

@Override
public String getWriteableName() {
return knnQueryBuilder.getWriteableName();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
Expand Down Expand Up @@ -462,22 +461,23 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
// https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/Rewriteable.java#L117.
// With the asynchronous call, on first rewrite, we create a new
// vector supplier that will get populated once the asynchronous call finishes and pass this supplier in to
// create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just
// return the current unmodified query builder.
// create a new builder. Once the supplier's value gets set, we return a NeuralKNNQueryBuilder
// which wrapped KNNQueryBuilder. Otherwise, we just return the current unmodified query builder.
if (vectorSupplier() != null) {
if (vectorSupplier().get() == null) {
return this;
}
return KNNQueryBuilder.builder()

return NeuralKNNQueryBuilder.builder()
.fieldName(fieldName())
.vector(vectorSupplier.get())
.k(k())
.filter(filter())
.maxDistance(maxDistance)
.minScore(minScore)
.expandNested(expandNested)
.k(k)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
.maxDistance(maxDistance())
.minScore(minScore())
.expandNested(expandNested())
.methodParameters(methodParameters())
.rescoreContext(rescoreContext())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

Expand Down Expand Up @@ -154,11 +151,10 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() {
assertNotNull(queryOnlyNeural);
assertTrue(queryOnlyNeural instanceof HybridQuery);
assertEquals(1, ((HybridQuery) queryOnlyNeural).getSubQueries().size());
assertTrue(((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next() instanceof NativeEngineKnnVectorQuery);
KNNQuery knnQuery = ((NativeEngineKnnVectorQuery) ((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next()).getKnnQuery();
assertEquals(VECTOR_FIELD_NAME, knnQuery.getField());
assertEquals(K, knnQuery.getK());
assertNotNull(knnQuery.getQueryVector());
assertTrue(((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next() instanceof NeuralKNNQuery);
Query knnQuery = ((NeuralKNNQuery) ((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next()).getKnnQuery();
assertNotNull(knnQuery);
assertTrue(knnQuery.toString(VECTOR_FIELD_NAME).contains(VECTOR_FIELD_NAME));
}

@SneakyThrows
Expand Down Expand Up @@ -203,11 +199,10 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() {
// verify knn vector query
Iterator<Query> queryIterator = ((HybridQuery) queryTwoSubQueries).getSubQueries().iterator();
Query firstQuery = queryIterator.next();
assertTrue(firstQuery instanceof NativeEngineKnnVectorQuery);
KNNQuery knnQuery = ((NativeEngineKnnVectorQuery) firstQuery).getKnnQuery();
assertEquals(VECTOR_FIELD_NAME, knnQuery.getField());
assertEquals(K, knnQuery.getK());
assertNotNull(knnQuery.getQueryVector());
assertTrue(firstQuery instanceof NeuralKNNQuery);
Query knnQuery = ((NeuralKNNQuery) firstQuery).getKnnQuery();
assertNotNull(knnQuery);
assertTrue(knnQuery.toString(VECTOR_FIELD_NAME).contains(VECTOR_FIELD_NAME));
// verify term query
Query secondQuery = queryIterator.next();
assertTrue(secondQuery instanceof TermQuery);
Expand Down Expand Up @@ -765,10 +760,10 @@ public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery(
assertEquals(2, hybridQueryBuilder.queries().size());
List<QueryBuilder> queryBuilders = hybridQueryBuilder.queries();
// verify each sub-query builder
assertTrue(queryBuilders.get(0) instanceof KNNQueryBuilder);
KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilders.get(0);
assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName());
assertEquals((int) neuralQueryBuilder.k(), knnQueryBuilder.getK());
assertTrue(queryBuilders.get(0) instanceof NeuralKNNQueryBuilder);
NeuralKNNQueryBuilder neuralKNNQueryBuilder = (NeuralKNNQueryBuilder) queryBuilders.get(0);
assertEquals(neuralQueryBuilder.fieldName(), neuralKNNQueryBuilder.fieldName());
assertEquals((int) neuralQueryBuilder.k(), neuralKNNQueryBuilder.k());
assertTrue(queryBuilders.get(1) instanceof TermQueryBuilder);
TermQueryBuilder termQueryBuilder = (TermQueryBuilder) queryBuilders.get(1);
assertEquals(termSubQuery.fieldName(), termQueryBuilder.fieldName());
Expand Down
Loading

0 comments on commit ebfc232

Please sign in to comment.