Skip to content

Commit

Permalink
Fixing the tests for when lucene fields are used.
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Aug 16, 2024
1 parent bf94d2c commit b5cbc60
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.mapper.ModelFieldMapper;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -78,9 +81,17 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
)
).fieldType(field);

KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig();
KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext()
.orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty"));
final KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig();
final KNNMethodContext knnMethodContext;
if (knnMappingConfig.getModelId().isPresent()) {
ModelMetadata modelMetadata = ModelUtil.getModelMetadata(knnMappingConfig.getModelId().get());
assert modelMetadata != null : String.format("Model ID '%s' is not " + "created.", knnMappingConfig.getModelId().get());
knnMethodContext = ModelFieldMapper.getKNNMethodContextFromModelMetadata(modelMetadata);
} else if (knnMappingConfig.getKnnMethodContext().isPresent()) {
knnMethodContext = knnMappingConfig.getKnnMethodContext().get();
} else {
throw new IllegalArgumentException("KNN method context cannot is empty and also model Id not present");
}

final KNNEngine engine = knnMethodContext.getKnnEngine();
final Map<String, Object> params = knnMethodContext.getMethodComponentContext().getParameters();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ protected void parseCreateField(ParseContext context) throws IOException {
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
if (useLuceneBasedVectorField) {
int adjustedDimension = modelMetadata.getVectorDataType() == VectorDataType.BINARY
? modelMetadata.getDimension()
: modelMetadata.getDimension() / 8;
? modelMetadata.getDimension() / 8
: modelMetadata.getDimension();
final VectorEncoding encoding = modelMetadata.getVectorDataType() == VectorDataType.FLOAT
? VectorEncoding.FLOAT32
: VectorEncoding.BYTE;
Expand All @@ -212,7 +212,7 @@ protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType());
}

private static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) {
public static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) {
MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
if (methodComponentContext == MethodComponentContext.EMPTY) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.indices.ModelUtil;

import java.io.IOException;
import java.time.ZoneOffset;
Expand All @@ -66,6 +68,7 @@
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
Expand Down Expand Up @@ -872,6 +875,122 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT
utilMockedStatic.close();
}

@SneakyThrows
public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTypes() {
MockedStatic<KNNVectorFieldMapperUtil> utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class);
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
MockedStatic<ModelUtil> modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class);
final MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_IVF, Collections.emptyMap());

for (VectorDataType dataType : VectorDataType.values()) {
log.info("Vector Data Type is : {}", dataType);
SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT;
int dimension = dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION;
when(modelDao.getMetadata(MODEL_ID)).thenReturn(modelMetadata);
modelUtilMockedStatic.when(() -> ModelUtil.isModelCreated(modelMetadata)).thenReturn(true);
when(modelMetadata.getDimension()).thenReturn(dimension);
when(modelMetadata.getVectorDataType()).thenReturn(dataType);
when(modelMetadata.getSpaceType()).thenReturn(spaceType);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getMethodComponentContext()).thenReturn(methodComponentContext);

ParseContext.Document document = new ParseContext.Document();
ContentPath contentPath = new ContentPath();
ParseContext parseContext = mock(ParseContext.class);
when(parseContext.doc()).thenReturn(document);
when(parseContext.path()).thenReturn(contentPath);



utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true);
ModelFieldMapper modelFieldMapper = Mockito.spy(
ModelFieldMapper.createFieldMapper(
TEST_FIELD_NAME,
TEST_FIELD_NAME,
Collections.emptyMap(),
dataType,
MODEL_ID,
FieldMapper.MultiFields.empty(),
FieldMapper.CopyTo.empty(),
new Explicit<>(true, true),
false,
false,
modelDao,
CURRENT
)
);

if (dataType == VectorDataType.BINARY) {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(modelFieldMapper)
.getBytesFromContext(parseContext, TEST_DIMENSION * 8, dataType);
} else if (dataType == VectorDataType.BYTE) {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(modelFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION, dataType);
} else {
doReturn(Optional.of(TEST_VECTOR)).when(modelFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION);
}

modelFieldMapper.parseCreateField(parseContext);

List<IndexableField> fields = document.getFields();
assertEquals(1, fields.size());
IndexableField field1 = fields.get(0);
if (dataType == VectorDataType.FLOAT) {
assertTrue(field1 instanceof KnnFloatVectorField);
assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.FLOAT32);
} else {
assertTrue(field1 instanceof KnnByteVectorField);
assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.BYTE);
}

assertEquals(field1.fieldType().vectorDimension(), TEST_DIMENSION);
assertEquals(
field1.fieldType().vectorSimilarityFunction(),
SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction()
);

utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(false);

document = new ParseContext.Document();
contentPath = new ContentPath();
when(parseContext.doc()).thenReturn(document);
when(parseContext.path()).thenReturn(contentPath);
modelFieldMapper = Mockito.spy(
ModelFieldMapper.createFieldMapper(
TEST_FIELD_NAME,
TEST_FIELD_NAME,
Collections.emptyMap(),
dataType,
MODEL_ID,
FieldMapper.MultiFields.empty(),
FieldMapper.CopyTo.empty(),
new Explicit<>(true, true),
false,
false,
modelDao,
CURRENT
)
);

if (dataType == VectorDataType.FLOAT) {
doReturn(Optional.of(TEST_VECTOR)).when(modelFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION);
} else {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(modelFieldMapper)
.getBytesFromContext(parseContext, dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION, dataType);
}

modelFieldMapper.parseCreateField(parseContext);
fields = document.getFields();
assertEquals(1, fields.size());
field1 = fields.get(0);
assertTrue(field1 instanceof VectorField);
}
// making sure to close the static mock to ensure that for tests running on this thread are not impacted by
// this mocking
utilMockedStatic.close();
modelUtilMockedStatic.close();
}

@SneakyThrows
public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() {
// Create a lucene field mapper that creates a binary doc values field as well as KnnVectorField
Expand Down

0 comments on commit b5cbc60

Please sign in to comment.