Skip to content

Commit

Permalink
Separate out NeuralKNNQuery and use builder pattern for NeuralKNNQuer…
Browse files Browse the repository at this point in the history
…yBuilder

Signed-off-by: Junqiu Lei <junqiu@amazon.com>
  • Loading branch information
junqiu-lei committed Feb 11, 2025
1 parent 6a1f045 commit defb108
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 185 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.query;

import lombok.Getter;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.Objects;

/**
* Wraps KNN Lucene query to support neural search extensions.
* Delegates core operations to the underlying KNN query.
*/
@Getter
public class NeuralKNNQuery extends Query {
private final Query knnQuery;

public NeuralKNNQuery(Query knnQuery) {
this.knnQuery = knnQuery;
}

@Override
public String toString(String field) {
return knnQuery.toString(field);
}

@Override
public void visit(QueryVisitor visitor) {
// Delegate the visitor to the underlying KNN query
knnQuery.visit(visitor);
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
// Delegate weight creation to the underlying KNN query
return knnQuery.createWeight(searcher, scoreMode, boost);
}

@Override
public Query rewrite(IndexReader reader) throws IOException {
Query rewritten = knnQuery.rewrite(reader);
if (rewritten == knnQuery) {
return this;
}
return new NeuralKNNQuery(rewritten);
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || getClass() != other.getClass()) return false;
NeuralKNNQuery that = (NeuralKNNQuery) other;
return Objects.equals(knnQuery, that.knnQuery);
}

@Override
public int hashCode() {
return Objects.hash(knnQuery);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@
*/
package org.opensearch.neuralsearch.query;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import lombok.Getter;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.AbstractQueryBuilder;
Expand All @@ -28,91 +24,104 @@
* 1. Isolate KNN plugin API changes to a single location
* 2. Allow extension with neural-search-specific information (e.g., query text)
*/

@Getter
public class NeuralKNNQueryBuilder extends AbstractQueryBuilder<NeuralKNNQueryBuilder> {
private final KNNQueryBuilder knnQueryBuilder;

/**
* Wraps KNN Lucene query to support neural search extensions.
* Delegates core operations to the underlying KNN query.
* Creates a new builder instance.
*/
public static class NeuralKNNQuery extends Query {
private final Query knnQuery;
public static Builder builder() {
return new Builder();
}

public String fieldName() {
return knnQueryBuilder.fieldName();
}

public NeuralKNNQuery(Query knnQuery) {
this.knnQuery = knnQuery;
public int k() {
return knnQueryBuilder.getK();
}

/**
* Builder for NeuralKNNQueryBuilder.
*/
public static class Builder {
private String fieldName;
private float[] vector;
private Integer k;
private QueryBuilder filter;
private Float maxDistance;
private Float minScore;
private Boolean expandNested;
private Map<String, ?> methodParameters;
private RescoreContext rescoreContext;

private Builder() {}

public Builder fieldName(String fieldName) {
this.fieldName = fieldName;
return this;
}

@Override
public String toString(String field) {
return knnQuery.toString(field);
public Builder vector(float[] vector) {
this.vector = vector;
return this;
}

@Override
public void visit(QueryVisitor visitor) {
// Delegate the visitor to the underlying KNN query
knnQuery.visit(visitor);
public Builder k(Integer k) {
this.k = k;
return this;
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
// Delegate weight creation to the underlying KNN query
return knnQuery.createWeight(searcher, scoreMode, boost);
public Builder filter(QueryBuilder filter) {
this.filter = filter;
return this;
}

@Override
public Query rewrite(IndexReader reader) throws IOException {
Query rewritten = knnQuery.rewrite(reader);
if (rewritten == knnQuery) {
return this;
}
return new NeuralKNNQuery(rewritten);
public Builder maxDistance(Float maxDistance) {
this.maxDistance = maxDistance;
return this;
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || getClass() != other.getClass()) return false;
NeuralKNNQuery that = (NeuralKNNQuery) other;
return Objects.equals(knnQuery, that.knnQuery);
public Builder minScore(Float minScore) {
this.minScore = minScore;
return this;
}

@Override
public int hashCode() {
return Objects.hash(knnQuery);
public Builder expandNested(Boolean expandNested) {
this.expandNested = expandNested;
return this;
}
}

/**
* Factory method to create NeuralKNNQueryBuilder with given parameters.
* Centralizes KNNQueryBuilder creation to handle breaking changes.
*/
public static NeuralKNNQueryBuilder fromKNNBuilder(
String fieldName,
float[] vector,
Integer k,
QueryBuilder filter,
Float maxDistance,
Float minScore,
Boolean expandNested,
Map<String, ?> methodParameters,
RescoreContext rescoreContext
) {
KNNQueryBuilder knnBuilder = KNNQueryBuilder.builder()
.fieldName(fieldName)
.vector(vector)
.k(k)
.filter(filter)
.maxDistance(maxDistance)
.minScore(minScore)
.expandNested(expandNested)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
.build();
return new NeuralKNNQueryBuilder(knnBuilder);
public Builder methodParameters(Map<String, ?> methodParameters) {
this.methodParameters = methodParameters;
return this;
}

public Builder rescoreContext(RescoreContext rescoreContext) {
this.rescoreContext = rescoreContext;
return this;
}

public NeuralKNNQueryBuilder build() {
KNNQueryBuilder knnBuilder = KNNQueryBuilder.builder()
.fieldName(fieldName)
.vector(vector)
.k(k)
.filter(filter)
.maxDistance(maxDistance)
.minScore(minScore)
.expandNested(expandNested)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
.build();
return new NeuralKNNQueryBuilder(knnBuilder);
}
}

private NeuralKNNQueryBuilder(KNNQueryBuilder knnQueryBuilder) {
super();
this.knnQueryBuilder = knnQueryBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,17 +468,17 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
return this;
}

return NeuralKNNQueryBuilder.fromKNNBuilder(
fieldName(),
vectorSupplier.get(),
k(),
filter(),
maxDistance(),
minScore(),
expandNested(),
methodParameters(),
rescoreContext()
);
return NeuralKNNQueryBuilder.builder()
.fieldName(fieldName())
.vector(vectorSupplier.get())
.k(k())
.filter(filter())
.maxDistance(maxDistance())
.minScore(minScore())
.expandNested(expandNested())
.methodParameters(methodParameters())
.rescoreContext(rescoreContext())
.build();
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

Expand Down Expand Up @@ -154,11 +151,10 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() {
assertNotNull(queryOnlyNeural);
assertTrue(queryOnlyNeural instanceof HybridQuery);
assertEquals(1, ((HybridQuery) queryOnlyNeural).getSubQueries().size());
assertTrue(((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next() instanceof NativeEngineKnnVectorQuery);
KNNQuery knnQuery = ((NativeEngineKnnVectorQuery) ((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next()).getKnnQuery();
assertEquals(VECTOR_FIELD_NAME, knnQuery.getField());
assertEquals(K, knnQuery.getK());
assertNotNull(knnQuery.getQueryVector());
assertTrue(((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next() instanceof NeuralKNNQuery);
Query knnQuery = ((NeuralKNNQuery) ((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next()).getKnnQuery();
assertNotNull(knnQuery);
assertTrue(knnQuery.toString(VECTOR_FIELD_NAME).contains(VECTOR_FIELD_NAME));
}

@SneakyThrows
Expand Down Expand Up @@ -203,11 +199,10 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() {
// verify knn vector query
Iterator<Query> queryIterator = ((HybridQuery) queryTwoSubQueries).getSubQueries().iterator();
Query firstQuery = queryIterator.next();
assertTrue(firstQuery instanceof NativeEngineKnnVectorQuery);
KNNQuery knnQuery = ((NativeEngineKnnVectorQuery) firstQuery).getKnnQuery();
assertEquals(VECTOR_FIELD_NAME, knnQuery.getField());
assertEquals(K, knnQuery.getK());
assertNotNull(knnQuery.getQueryVector());
assertTrue(firstQuery instanceof NeuralKNNQuery);
Query knnQuery = ((NeuralKNNQuery) firstQuery).getKnnQuery();
assertNotNull(knnQuery);
assertTrue(knnQuery.toString(VECTOR_FIELD_NAME).contains(VECTOR_FIELD_NAME));
// verify term query
Query secondQuery = queryIterator.next();
assertTrue(secondQuery instanceof TermQuery);
Expand Down Expand Up @@ -765,10 +760,10 @@ public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery(
assertEquals(2, hybridQueryBuilder.queries().size());
List<QueryBuilder> queryBuilders = hybridQueryBuilder.queries();
// verify each sub-query builder
assertTrue(queryBuilders.get(0) instanceof KNNQueryBuilder);
KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilders.get(0);
assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName());
assertEquals((int) neuralQueryBuilder.k(), knnQueryBuilder.getK());
assertTrue(queryBuilders.get(0) instanceof NeuralKNNQueryBuilder);
NeuralKNNQueryBuilder neuralKNNQueryBuilder = (NeuralKNNQueryBuilder) queryBuilders.get(0);
assertEquals(neuralQueryBuilder.fieldName(), neuralKNNQueryBuilder.fieldName());
assertEquals((int) neuralQueryBuilder.k(), neuralKNNQueryBuilder.k());
assertTrue(queryBuilders.get(1) instanceof TermQueryBuilder);
TermQueryBuilder termQueryBuilder = (TermQueryBuilder) queryBuilders.get(1);
assertEquals(termSubQuery.fieldName(), termQueryBuilder.fieldName());
Expand Down
Loading

0 comments on commit defb108

Please sign in to comment.