Skip to content

Commit

Permalink
Added spaceType as a top level parameter
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Sep 5, 2024
1 parent 697a51c commit b4a6c9c
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public class KNNConstants {
public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD;
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;
public static final String MINIMAL_MODE_AND_COMPRESSION_FEATURE = "mode_and_compression_feature";
public static final String TOP_LEVEL_SPACE_TYPE_FEATURE = "top_level_space_type_feature";

public static final String RADIAL_SEARCH_KEY = "radial_search";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
CompressionLevel.NAMES_ARRAY
).acceptsNull();

// A top level space Type field.
protected final Parameter<SpaceType> topLevelSpaceType = new Parameter<>(
KNNConstants.SPACE_TYPE,
false,
() -> SpaceType.UNDEFINED, // making sure that if users don't want to set the space type they can avoid setting the
// space type
(n, c, o) -> SpaceType.getSpace((String) o),
m -> toType(m).originalMappingParameters.getTopLevelSpaceType()
);

protected final Parameter<Map<String, String>> meta = Parameter.metaParam();

protected ModelDao modelDao;
Expand All @@ -187,7 +197,8 @@ public Builder(

@Override
protected List<Parameter<?>> getParameters() {
return Arrays.asList(stored, hasDocValues, dimension, vectorDataType, meta, knnMethodContext, modelId, mode, compressionLevel);
return Arrays.asList(stored, hasDocValues, dimension, vectorDataType, meta, knnMethodContext, modelId,
mode, compressionLevel, topLevelSpaceType);
}

protected Explicit<Boolean> ignoreMalformed(BuilderContext context) {
Expand Down Expand Up @@ -346,13 +357,27 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
validateFromModel(builder);
} else {
validateMode(builder);
validateSpaceType(builder);
resolveKNNMethodComponents(builder, parserContext);
validateFromKNNMethod(builder);
}

return builder;
}

private void validateSpaceType(KNNVectorFieldMapper.Builder builder) {
final KNNMethodContext knnMethodContext = builder.originalParameters.getKnnMethodContext();
// if context is defined
if(knnMethodContext != null) {
final SpaceType knnMethodContextSpaceType = knnMethodContext.getSpaceType();
final SpaceType topLevelSpaceType = builder.topLevelSpaceType.get();
if(topLevelSpaceType != SpaceType.UNDEFINED && topLevelSpaceType != knnMethodContextSpaceType) {
throw new MapperParsingException("Space type in \"method\" and top level space type should be " +
"same or one of them should be defined");
}
}
}

private void validateMode(KNNVectorFieldMapper.Builder builder) {
boolean isKNNMethodContextConfigured = builder.originalParameters.getKnnMethodContext() != null;
boolean isModeConfigured = builder.mode.isConfigured() || builder.compressionLevel.isConfigured();
Expand Down Expand Up @@ -386,6 +411,12 @@ private void validateFromModel(KNNVectorFieldMapper.Builder builder) {
if (builder.dimension.getValue() == UNSET_MODEL_DIMENSION_IDENTIFIER && builder.modelId.get() == null) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Dimension value missing for vector: %s", builder.name()));
}
if(builder.modelId.get() != null && builder.topLevelSpaceType.get() != SpaceType.UNDEFINED) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "TopLevel Space type and model can not be both specified in the mapping: %s", name)
);
}

validateCompressionAndModeNotSet(builder, builder.name(), "model");
}

Expand Down Expand Up @@ -442,15 +473,22 @@ private void resolveKNNMethodComponents(KNNVectorFieldMapper.Builder builder, Pa
createKNNMethodContextFromLegacy(parserContext.getSettings(), parserContext.indexVersionCreated())
);
} else if (builder.mode.isConfigured() || builder.compressionLevel.isConfigured()) {
// we need don't need to resolve the space type, whatever default we are using will be passed down to
// while resolving KNNMethodContext for the mode and compression. and then when we resolve the spaceType
// we will set the correct spaceType.
builder.originalParameters.setResolvedKnnMethodContext(
ModeBasedResolver.INSTANCE.resolveKNNMethodContext(
builder.knnMethodConfigContext.getMode(),
builder.knnMethodConfigContext.getCompressionLevel(),
false
false,
builder.originalParameters.getTopLevelSpaceType()
)
);
}
setDefaultSpaceType(builder.originalParameters.getResolvedKnnMethodContext(), builder.originalParameters.getVectorDataType());
// this function should now correct the space type for the above resolved context too, if spaceType was
// not provided.
setDefaultSpaceType(builder.originalParameters.getResolvedKnnMethodContext(),
builder.originalParameters.getVectorDataType());
}

private boolean isKNNDisabled(Settings settings) {
Expand All @@ -459,8 +497,10 @@ private boolean isKNNDisabled(Settings settings) {
}

private void setDefaultSpaceType(final KNNMethodContext knnMethodContext, final VectorDataType vectorDataType) {
// Now KNNMethodContext should never be null. Because only case it could be null is flatMapper which is
// already handled
if (knnMethodContext == null) {
return;
throw new IllegalArgumentException("KNNMethodContext cannot be null");
}

if (SpaceType.UNDEFINED == knnMethodContext.getSpaceType()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,14 @@ private ModeBasedResolver() {}
* @param requiresTraining whether config requires trianing
* @return {@link KNNMethodContext}
*/
public KNNMethodContext resolveKNNMethodContext(Mode mode, CompressionLevel compressionLevel, boolean requiresTraining) {
public KNNMethodContext resolveKNNMethodContext(Mode mode, CompressionLevel compressionLevel, boolean requiresTraining, SpaceType spaceType) {
if (requiresTraining) {
return resolveWithTraining(mode, compressionLevel);
return resolveWithTraining(mode, compressionLevel, spaceType);
}

return resolveWithoutTraining(mode, compressionLevel);
return resolveWithoutTraining(mode, compressionLevel, spaceType);
}

private KNNMethodContext resolveWithoutTraining(Mode mode, CompressionLevel compressionLevel) {
private KNNMethodContext resolveWithoutTraining(Mode mode, CompressionLevel compressionLevel, final SpaceType spaceType) {
CompressionLevel resolvedCompressionLevel = resolveCompressionLevel(mode, compressionLevel);
MethodComponentContext encoderContext = resolveEncoder(resolvedCompressionLevel);

Expand All @@ -72,7 +71,7 @@ private KNNMethodContext resolveWithoutTraining(Mode mode, CompressionLevel comp
if (encoderContext != null) {
return new KNNMethodContext(
knnEngine,
SpaceType.DEFAULT,
spaceType,
new MethodComponentContext(
METHOD_HNSW,
Map.of(
Expand All @@ -92,7 +91,7 @@ private KNNMethodContext resolveWithoutTraining(Mode mode, CompressionLevel comp
if (knnEngine == KNNEngine.FAISS) {
return new KNNMethodContext(
knnEngine,
SpaceType.DEFAULT,
spaceType,
new MethodComponentContext(
METHOD_HNSW,
Map.of(
Expand All @@ -109,7 +108,7 @@ private KNNMethodContext resolveWithoutTraining(Mode mode, CompressionLevel comp

return new KNNMethodContext(
knnEngine,
SpaceType.DEFAULT,
spaceType,
new MethodComponentContext(
METHOD_HNSW,
Map.of(
Expand All @@ -122,13 +121,13 @@ private KNNMethodContext resolveWithoutTraining(Mode mode, CompressionLevel comp
);
}

private KNNMethodContext resolveWithTraining(Mode mode, CompressionLevel compressionLevel) {
private KNNMethodContext resolveWithTraining(Mode mode, CompressionLevel compressionLevel, SpaceType spaceType) {
CompressionLevel resolvedCompressionLevel = resolveCompressionLevel(mode, compressionLevel);
MethodComponentContext encoderContext = resolveEncoder(resolvedCompressionLevel);
if (encoderContext != null) {
return new KNNMethodContext(
KNNEngine.FAISS,
SpaceType.DEFAULT,
spaceType,
new MethodComponentContext(
METHOD_IVF,
Map.of(
Expand All @@ -145,7 +144,7 @@ private KNNMethodContext resolveWithTraining(Mode mode, CompressionLevel compres

return new KNNMethodContext(
KNNEngine.FAISS,
SpaceType.DEFAULT,
spaceType,
new MethodComponentContext(
METHOD_IVF,
Map.of(METHOD_PARAMETER_NLIST, METHOD_PARAMETER_NLIST_DEFAULT, METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NPROBES_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.opensearch.core.common.Strings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNMethodContext;

Expand Down Expand Up @@ -42,6 +43,7 @@ public final class OriginalMappingParameters {
private final String mode;
private final String compressionLevel;
private final String modelId;
private final SpaceType topLevelSpaceType;

/**
* Initialize the parameters from the builder
Expand All @@ -56,6 +58,7 @@ public OriginalMappingParameters(KNNVectorFieldMapper.Builder builder) {
this.mode = builder.mode.get();
this.compressionLevel = builder.compressionLevel.get();
this.modelId = builder.modelId.get();
this.topLevelSpaceType = builder.topLevelSpaceType.get();
}

/**
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class IndexUtil {
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE = Version.V_2_16_0;
private static final Version MINIMAL_RESCORE_FEATURE = Version.V_2_17_0;
private static final Version MINIMAL_MODE_AND_COMPRESSION_FEATURE = Version.V_2_17_0;
private static final Version MINIMAL_TOP_LEVEL_SPACE_TYPE_FEATURE = Version.V_2_17_0;
// public so neural search can access it
public static final Map<String, Version> minimalRequiredVersionMap = initializeMinimalRequiredVersionMap();
public static final Set<VectorDataType> VECTOR_DATA_TYPES_NOT_SUPPORTING_ENCODERS = Set.of(VectorDataType.BINARY, VectorDataType.BYTE);
Expand Down Expand Up @@ -390,6 +391,7 @@ private static Map<String, Version> initializeMinimalRequiredVersionMap() {
put(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE);
put(RESCORE_PARAMETER, MINIMAL_RESCORE_FEATURE);
put(KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE, MINIMAL_MODE_AND_COMPRESSION_FEATURE);
put(KNNConstants.TOP_LEVEL_SPACE_TYPE_FEATURE, MINIMAL_TOP_LEVEL_SPACE_TYPE_FEATURE);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
int dimension = DEFAULT_NOT_SET_INT_VALUE;
int maximumVectorCount = DEFAULT_NOT_SET_INT_VALUE;
int searchSize = DEFAULT_NOT_SET_INT_VALUE;
SpaceType topLevelSpaceType = SpaceType.UNDEFINED;

String compressionLevel = null;
String mode = null;
Expand All @@ -109,9 +110,6 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
trainingField = parser.textOrNull();
} else if (KNN_METHOD.equals(fieldName) && ensureNotSet(fieldName, knnMethodContext)) {
knnMethodContext = KNNMethodContext.parse(parser.map());
if (SpaceType.UNDEFINED == knnMethodContext.getSpaceType()) {
knnMethodContext.setSpaceType(SpaceType.L2);
}
} else if (DIMENSION.equals(fieldName) && ensureNotSet(fieldName, dimension)) {
dimension = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
} else if (MAX_VECTOR_COUNT_PARAMETER.equals(fieldName) && ensureNotSet(fieldName, maximumVectorCount)) {
Expand All @@ -127,6 +125,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
mode = parser.text();
} else if (KNNConstants.COMPRESSION_LEVEL_PARAMETER.equals(fieldName) && ensureNotSet(fieldName, compressionLevel)) {
compressionLevel = parser.text();
} else if(KNNConstants.SPACE_TYPE.equals(fieldName) && ensureSpaceTypeNotSet(topLevelSpaceType)) {
topLevelSpaceType = SpaceType.getSpace(parser.text());
} else {
throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter.");
}
Expand Down Expand Up @@ -159,7 +159,11 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
if (vectorDataType == DEFAULT_NOT_SET_OBJECT_VALUE) {
vectorDataType = VectorDataType.DEFAULT;
}

resolveSpaceTypeAndSetInKNNMethodContext(topLevelSpaceType, knnMethodContext);
// if KNNMethodContext was not null then spaceTypes we should fix the space type if it is not set.
if(knnMethodContext == null && topLevelSpaceType == SpaceType.UNDEFINED) {
topLevelSpaceType = SpaceType.DEFAULT;
}
TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
Expand All @@ -170,7 +174,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
description,
vectorDataType,
Mode.fromName(mode),
CompressionLevel.fromName(compressionLevel)
CompressionLevel.fromName(compressionLevel),
topLevelSpaceType
);

if (maximumVectorCount != DEFAULT_NOT_SET_INT_VALUE) {
Expand Down Expand Up @@ -204,6 +209,33 @@ private void ensureMutualExclusion(String fieldNameA, Object valueA, String fiel
}
}

private boolean ensureSpaceTypeNotSet(SpaceType spaceType) {
if(spaceType != SpaceType.UNDEFINED) {
throw new IllegalArgumentException("Unable to parse SpaceType as it is duplicated.");
}
return true;
}

private void resolveSpaceTypeAndSetInKNNMethodContext(SpaceType topLevelSpaceType, KNNMethodContext knnMethodContext) {
// First check if KNNMethodContext is not null as it can be null
if(knnMethodContext != null) {
// if space type is not provided by user then it will undefined
if(knnMethodContext.getSpaceType() == SpaceType.UNDEFINED) {
// fix the top level spaceType if it is undefined
if(topLevelSpaceType == SpaceType.UNDEFINED) {
topLevelSpaceType = SpaceType.DEFAULT;
}
// set the space type now in KNNMethodContext
knnMethodContext.setSpaceType(topLevelSpaceType);
} else {
// if spaceType is set at 2 places lets ensure that we validate those cases and throw error
if(topLevelSpaceType != SpaceType.UNDEFINED) {
throw new IllegalArgumentException("Top Level spaceType and space type in method both are set. Set space type at 1 place.");
}
}
}
}

private void ensureIfSetThenEquals(
String fieldNameA,
Object valueA,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;
Expand Down Expand Up @@ -56,6 +57,21 @@ public class TrainingModelRequest extends ActionRequest {
private final Mode mode;
private final CompressionLevel compressionLevel;

public TrainingModelRequest(
String modelId,
KNNMethodContext knnMethodContext,
int dimension,
String trainingIndex,
String trainingField,
String preferredNodeId,
String description,
VectorDataType vectorDataType,
Mode mode,
CompressionLevel compressionLevel
) {
this(modelId, knnMethodContext, dimension, trainingIndex, trainingField, preferredNodeId, description, vectorDataType, mode, compressionLevel, null);
}

/**
* Constructor.
*
Expand All @@ -77,7 +93,8 @@ public TrainingModelRequest(
String description,
VectorDataType vectorDataType,
Mode mode,
CompressionLevel compressionLevel
CompressionLevel compressionLevel,
SpaceType spaceType
) {
super();
this.modelId = modelId;
Expand Down Expand Up @@ -107,7 +124,7 @@ public TrainingModelRequest(
.build();

if (knnMethodContext == null && (Mode.isConfigured(mode) || CompressionLevel.isConfigured(compressionLevel))) {
this.knnMethodContext = ModeBasedResolver.INSTANCE.resolveKNNMethodContext(mode, compressionLevel, true);
this.knnMethodContext = ModeBasedResolver.INSTANCE.resolveKNNMethodContext(mode, compressionLevel, true, spaceType);
} else {
this.knnMethodContext = knnMethodContext;
}
Expand Down Expand Up @@ -144,6 +161,13 @@ public TrainingModelRequest(StreamInput in) throws IOException {
this.compressionLevel = CompressionLevel.NOT_CONFIGURED;
}

// SpaceType topLevelSpaceType = SpaceType.DEFAULT;
//
// if(IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.TOP_LEVEL_SPACE_TYPE_FEATURE)) {
// topLevelSpaceType = SpaceType.getSpace(in.readOptionalString());
// }


this.knnMethodConfigContext = KNNMethodConfigContext.builder()
.vectorDataType(vectorDataType)
.dimension(dimension)
Expand Down

0 comments on commit b4a6c9c

Please sign in to comment.