Skip to content

Commit

Permalink
Fixing some of the things from merge side for KNNVectorValues
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Aug 29, 2024
1 parent 3e19b9d commit ab1ee62
Showing 1 changed file with 14 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.lucene.index.Sorter;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
Expand Down Expand Up @@ -72,6 +73,8 @@ public KnnFieldVectorsWriter<?> addField(final FieldInfo fieldInfo) throws IOExc
*/
@Override
public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
flatVectorsWriter.flush(maxDoc, sortMap);
for (final NativeEngineFieldVectorsWriter<?> field : fields) {
trainAndIndex(
Expand All @@ -81,15 +84,23 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
field
);
}
stopWatch.stop();
long time_in_millis = stopWatch.totalTime().millis();
log.warn("Refresh operation complete in {} ms", time_in_millis);
}

@Override
public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
// This will ensure that we are merging the FlatIndex during force merge.
flatVectorsWriter.mergeOneField(fieldInfo, mergeState);
flatVectorsWriter.finish();
// For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs
trainAndIndex(fieldInfo, this::getKNNVectorValuesForMerge, NativeIndexWriter::mergeIndex, mergeState);

stopWatch.stop();
long time_in_millis = stopWatch.totalTime().millis();
log.warn("Merge operation complete in {} ms", time_in_millis);
}

/**
Expand All @@ -101,7 +112,6 @@ public void finish() throws IOException {
throw new IllegalStateException("NativeEnginesKNNVectorsWriter is already finished");
}
finished = true;
flatVectorsWriter.finish();
}

/**
Expand Down Expand Up @@ -217,17 +227,15 @@ private <T, C> void trainAndIndex(
final C VectorProcessingContext
) throws IOException {
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
KNNVectorValues<T> knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
QuantizationState quantizationState = null;
if (quantizationParams != null) {
KNNVectorValues<T> knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
quantizationState = quantizationService.train(quantizationParams, knnVectorValues);
}
NativeIndexWriter writer = (quantizationParams != null)
? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)
: NativeIndexWriter.getWriter(fieldInfo, segmentWriteState);

knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
indexOperation.buildAndWrite(writer, knnVectorValues);
indexOperation.buildAndWrite(writer, vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext));
}
}

0 comments on commit ab1ee62

Please sign in to comment.