Skip to content

Commit

Permalink
Initial commit for enabling the float vector values for vector search.
Browse files Browse the repository at this point in the history
Things not working:
1. Filter query not working
2. Training index creation not tested.
  • Loading branch information
navneet1v committed Jun 4, 2024
1 parent 623b610 commit bf240e3
Show file tree
Hide file tree
Showing 16 changed files with 693 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index;

import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
Expand Down Expand Up @@ -57,8 +58,10 @@ public ScriptDocValues<float[]> getScriptValues() {
default:
throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding());
}
} else {
} else if (fieldInfo.getDocValuesType() == DocValuesType.BINARY) {
values = DocValues.getBinary(reader, fieldName);
} else {
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
}
return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines99KnnVectorsFormat;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelCache;

import java.util.Map;
import java.util.Optional;
Expand All @@ -25,7 +28,7 @@
@Log4j2
public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {

private final Optional<MapperService> mapperService;
private final Optional<MapperService> optionalMapperService;
private final int defaultMaxConnections;
private final int defaultBeamWidth;
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
Expand All @@ -42,12 +45,22 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
);
return defaultFormatSupplier.get();
}
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
if (optionalMapperService.isEmpty()) {
throw new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponentContext().getParameters();
);
}
final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType = (KNNVectorFieldMapper.KNNVectorFieldType) optionalMapperService
.get()
.fieldType(field);

final KNNEngine knnEngine = getKNNEngine(mappedFieldType);
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) {
log.debug("Native Engine present hence using NativeEnginesKNNVectorsFormat. Engine found: {}", knnEngine);
return new NativeEngines99KnnVectorsFormat();
}

final Map<String, Object> params = mappedFieldType.getKnnMethodContext().getMethodComponentContext().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
log.debug(
Expand All @@ -65,7 +78,8 @@ public int getMaxDimensions(String fieldName) {
}

private boolean isKnnVectorFieldType(final String field) {
return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
return optionalMapperService.isPresent()
&& optionalMapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
}

private int getMaxConnections(final Map<String, Object> params) {
Expand All @@ -81,4 +95,18 @@ private int getBeamWidth(final Map<String, Object> params) {
}
return defaultBeamWidth;
}

private KNNEngine getKNNEngine(final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType) {
final String modelId = mappedFieldType.getModelId();
if (modelId != null) {
var model = ModelCache.getInstance().get(modelId);
return model.getModelMetadata().getKnnEngine();
}

if (mappedFieldType.getKnnMethodContext() == null) {
return KNNEngine.DEFAULT;
} else {
return mappedFieldType.getKnnMethodContext().getKnnEngine();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.knn.index.codec.KNN80Codec;

import com.google.common.collect.ImmutableMap;
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.store.ChecksumIndexInput;
import org.opensearch.common.StopWatch;
Expand Down Expand Up @@ -61,7 +60,7 @@
* This class writes the KNN docvalues to the segments
*/
@Log4j2
class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {
public class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {

private final Logger logger = LogManager.getLogger(KNN80DocValuesConsumer.class);

Expand Down Expand Up @@ -90,22 +89,14 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th
}

private boolean isKNNBinaryFieldRequired(FieldInfo field) {
final KNNEngine knnEngine = getKNNEngine(field);
final KNNEngine knnEngine = KNNCodecUtil.getKNNEngine(field);
log.debug(String.format("Read engine [%s] for field [%s]", knnEngine.getName(), field.getName()));
return field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)
// This value will not be set: field.getVectorDimension()
return field.getVectorDimension() <= 0
&& field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)
&& KNNEngine.getEnginesThatCreateCustomSegmentFiles().stream().anyMatch(engine -> engine == knnEngine);
}

private KNNEngine getKNNEngine(@NonNull FieldInfo field) {
final String modelId = field.attributes().get(MODEL_ID);
if (modelId != null) {
var model = ModelCache.getInstance().get(modelId);
return model.getModelMetadata().getKnnEngine();
}
final String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
return KNNEngine.getEngine(engineName);
}

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh)
throws IOException {
// Get values to be indexed
Expand All @@ -123,7 +114,18 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
}
// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
final KNNEngine knnEngine = getKNNEngine(field);
if (isMerge) {
recordMergeStats(pair.docs.length, arraySize);
}

if (isRefresh) {
recordRefreshStats();
}
createNativeIndex(state, field, pair);
}

public static void createNativeIndex(SegmentWriteState state, FieldInfo field, KNNCodecUtil.Pair pair) throws IOException {
final KNNEngine knnEngine = KNNCodecUtil.getKNNEngine(field);
final String engineFileName = buildEngineFileName(
state.segmentInfo.name,
knnEngine.getVersion(),
Expand All @@ -147,20 +149,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
}

if (isMerge) {
recordMergeStats(pair.docs.length, arraySize);
}

if (isRefresh) {
recordRefreshStats();
}

// This is a bit of a hack. We have to create an output here and then immediately close it to ensure that
// engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will
// not be marked as added to the directory.
state.directory.createOutput(engineFileName, state.context).close();
indexCreator.createIndex();
writeFooter(indexPath, engineFileName);
writeFooter(state, indexPath, engineFileName);
}

private void recordMergeStats(int length, long arraySize) {
Expand All @@ -176,7 +170,7 @@ private void recordRefreshStats() {
KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment();
}

private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
private static void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
Map<String, Object> parameters = ImmutableMap.of(
KNNConstants.INDEX_THREAD_QTY,
KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)
Expand All @@ -195,7 +189,7 @@ private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KN
});
}

private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath)
private static void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath)
throws IOException {
Map<String, Object> parameters = new HashMap<>();
Map<String, String> fieldAttributes = fieldInfo.attributes();
Expand Down Expand Up @@ -295,7 +289,7 @@ private interface NativeIndexCreator {
void createIndex() throws IOException;
}

private void writeFooter(String indexPath, String engineFileName) throws IOException {
private static void writeFooter(SegmentWriteState state, String indexPath, String engineFileName) throws IOException {
// Opens the engine file that was created and appends a footer to it. The footer consists of
// 1. A Footer magic number (int - 4 bytes)
// 2. A checksum algorithm id (int - 4 bytes)
Expand Down Expand Up @@ -325,7 +319,7 @@ private void writeFooter(String indexPath, String engineFileName) throws IOExcep
os.close();
}

private boolean isChecksumValid(long value) {
private static boolean isChecksumValid(long value) {
// Check pulled from
// https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647
return (value & CRC32_CHECKSUM_SANITY) != 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;

import java.io.IOException;

public class NativeEngines99KnnVectorsFormat extends KnnVectorsFormat {

/** The format for storing, reading, merging vectors on disk */
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());

/**
* Sole constructor
*
*/
public NativeEngines99KnnVectorsFormat() {
super("NativeEngines99KnnVectorsFormat");
}

/**
* Returns a {@link KnnVectorsWriter} to write the vectors to the index.
*
* @param state
*/
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new NativeEnginesKNNVectorsWriter(state, flatVectorsFormat.fieldsWriter(state));
}

/**
* Returns a {@link KnnVectorsReader} to read the vectors from the index.
*
* @param state
*/
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new NativeEnginesKNNVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}

@Override
public String toString() {
return "NativeEngines99KnnVectorsFormat(name=NativeEngines99KnnVectorsFormat, flatVectorsFormat=" + flatVectorsFormat + ")";
}

}
Loading

0 comments on commit bf240e3

Please sign in to comment.