From b5cbc60194880bfe04a2c0a98ece6c6b9e248e36 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Thu, 15 Aug 2024 22:45:03 -0700 Subject: [PATCH] Fixing the tests for when lucene fields are used. Signed-off-by: Navneet Verma --- .../codec/BasePerFieldKnnVectorsFormat.java | 17 ++- .../knn/index/mapper/ModelFieldMapper.java | 6 +- .../mapper/KNNVectorFieldMapperTests.java | 119 ++++++++++++++++++ 3 files changed, 136 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 8beced605..97ce20666 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -19,6 +19,9 @@ import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import org.opensearch.knn.index.mapper.ModelFieldMapper; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; import java.util.Map; import java.util.Optional; @@ -78,9 +81,17 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { ) ).fieldType(field); - KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); - KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() - .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); + final KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); + final KNNMethodContext knnMethodContext; + if (knnMappingConfig.getModelId().isPresent()) { + ModelMetadata modelMetadata = ModelUtil.getModelMetadata(knnMappingConfig.getModelId().get()); + assert modelMetadata != null : String.format("Model ID '%s' is not " + "created.", knnMappingConfig.getModelId().get()); + knnMethodContext = ModelFieldMapper.getKNNMethodContextFromModelMetadata(modelMetadata); + } else if (knnMappingConfig.getKnnMethodContext().isPresent()) { + knnMethodContext = knnMappingConfig.getKnnMethodContext().get(); + } else { + throw new IllegalArgumentException("KNN method context cannot is empty and also model Id not present"); + } final KNNEngine engine = knnMethodContext.getKnnEngine(); final Map params = knnMethodContext.getMethodComponentContext().getParameters(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 954d6addf..1f66feb37 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -196,8 +196,8 @@ protected void parseCreateField(ParseContext context) throws IOException { ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); if (useLuceneBasedVectorField) { int adjustedDimension = modelMetadata.getVectorDataType() == VectorDataType.BINARY - ? modelMetadata.getDimension() - : modelMetadata.getDimension() / 8; + ? modelMetadata.getDimension() / 8 + : modelMetadata.getDimension(); final VectorEncoding encoding = modelMetadata.getVectorDataType() == VectorDataType.FLOAT ? VectorEncoding.FLOAT32 : VectorEncoding.BYTE; @@ -212,7 +212,7 @@ protected void parseCreateField(ParseContext context) throws IOException { parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType()); } - private static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) { + public static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) { MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); if (methodComponentContext == MethodComponentContext.EMPTY) { return null; diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 4034e4cb6..6aa4fa893 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -39,9 +39,11 @@ import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; +import org.opensearch.knn.indices.ModelUtil; import java.io.IOException; import java.time.ZoneOffset; @@ -66,6 +68,7 @@ import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; @@ -872,6 +875,122 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT utilMockedStatic.close(); } + @SneakyThrows + public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTypes() { + MockedStatic utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class); + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + MockedStatic modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class); + final MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_IVF, Collections.emptyMap()); + + for (VectorDataType dataType : VectorDataType.values()) { + log.info("Vector Data Type is : {}", dataType); + SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT; + int dimension = dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION; + when(modelDao.getMetadata(MODEL_ID)).thenReturn(modelMetadata); + modelUtilMockedStatic.when(() -> ModelUtil.isModelCreated(modelMetadata)).thenReturn(true); + when(modelMetadata.getDimension()).thenReturn(dimension); + when(modelMetadata.getVectorDataType()).thenReturn(dataType); + when(modelMetadata.getSpaceType()).thenReturn(spaceType); + when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(modelMetadata.getMethodComponentContext()).thenReturn(methodComponentContext); + + ParseContext.Document document = new ParseContext.Document(); + ContentPath contentPath = new ContentPath(); + ParseContext parseContext = mock(ParseContext.class); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + + + + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true); + ModelFieldMapper modelFieldMapper = Mockito.spy( + ModelFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + TEST_FIELD_NAME, + Collections.emptyMap(), + dataType, + MODEL_ID, + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + false, + modelDao, + CURRENT + ) + ); + + if (dataType == VectorDataType.BINARY) { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(modelFieldMapper) + .getBytesFromContext(parseContext, TEST_DIMENSION * 8, dataType); + } else if (dataType == VectorDataType.BYTE) { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(modelFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION, dataType); + } else { + doReturn(Optional.of(TEST_VECTOR)).when(modelFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + } + + modelFieldMapper.parseCreateField(parseContext); + + List fields = document.getFields(); + assertEquals(1, fields.size()); + IndexableField field1 = fields.get(0); + if (dataType == VectorDataType.FLOAT) { + assertTrue(field1 instanceof KnnFloatVectorField); + assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.FLOAT32); + } else { + assertTrue(field1 instanceof KnnByteVectorField); + assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.BYTE); + } + + assertEquals(field1.fieldType().vectorDimension(), TEST_DIMENSION); + assertEquals( + field1.fieldType().vectorSimilarityFunction(), + SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(false); + + document = new ParseContext.Document(); + contentPath = new ContentPath(); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + modelFieldMapper = Mockito.spy( + ModelFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + TEST_FIELD_NAME, + Collections.emptyMap(), + dataType, + MODEL_ID, + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + false, + modelDao, + CURRENT + ) + ); + + if (dataType == VectorDataType.FLOAT) { + doReturn(Optional.of(TEST_VECTOR)).when(modelFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + } else { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(modelFieldMapper) + .getBytesFromContext(parseContext, dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION, dataType); + } + + modelFieldMapper.parseCreateField(parseContext); + fields = document.getFields(); + assertEquals(1, fields.size()); + field1 = fields.get(0); + assertTrue(field1 instanceof VectorField); + } + // making sure to close the static mock to ensure that for tests running on this thread are not impacted by + // this mocking + utilMockedStatic.close(); + modelUtilMockedStatic.close(); + } + @SneakyThrows public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { // Create a lucene field mapper that creates a binary doc values field as well as KnnVectorField