Skip to content

Commit

Permalink
Fixed the query rewrite logic to happen during createWeight
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Aug 26, 2024
1 parent bbaaaf9 commit f4a5415
Showing 1 changed file with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.opensearch.knn.index.query.KNNQuery;
Expand Down Expand Up @@ -44,8 +45,7 @@ public class NativeEngineKnnVectorQuery extends Query {

private final KNNQuery knnQuery;

@Override
public Query rewrite(final IndexSearcher indexSearcher) throws IOException {
public Weight createWeight(final IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException {
final IndexReader reader = indexSearcher.getIndexReader();
final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, ScoreMode.COMPLETE, 1);
List<LeafReaderContext> leafReaderContexts = reader.leaves();
Expand All @@ -69,9 +69,9 @@ public Query rewrite(final IndexSearcher indexSearcher) throws IOException {

TopDocs topK = TopDocs.merge(knnQuery.getK(), topDocs);
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery();
return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost);
}
return createRewrittenQuery(reader, topK);
return createRewrittenQueryWeight(reader, topK, scoreMode, boost, indexSearcher);
}

private List<Map<Integer, Float>> doSearch(
Expand Down Expand Up @@ -106,7 +106,7 @@ private List<Map<Integer, Float>> doRescore(
return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks);
}

private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
private Weight createRewrittenQueryWeight(IndexReader reader, TopDocs topK, ScoreMode scoreMode, float boost, final IndexSearcher indexSearcher) {
int len = topK.scoreDocs.length;
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
int[] docs = new int[len];
Expand All @@ -116,7 +116,7 @@ private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
scores[i] = topK.scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(knnQuery.getK(), docs, scores, segmentStarts, reader.getContext().id());
return new DocAndScoreQuery(knnQuery.getK(), docs, scores, segmentStarts, reader.getContext().id()).createWeight(indexSearcher, scoreMode, boost);
}

static int[] findSegmentStarts(IndexReader reader, int[] docs) {
Expand Down

0 comments on commit f4a5415

Please sign in to comment.