Skip to content

Commit

Permalink
Fixing the dimension for the vector when using Lucene field in ModelF…
Browse files Browse the repository at this point in the history
…ieldMapper

Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Aug 16, 2024
1 parent f42e86e commit d85baa4
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ protected void parseCreateField(ParseContext context) throws IOException {
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
if (useLuceneBasedVectorField) {
int adjustedDimension = modelMetadata.getVectorDataType() == VectorDataType.BINARY
? modelMetadata.getDimension()
: modelMetadata.getDimension() / 8;
? modelMetadata.getDimension() / 8
: modelMetadata.getDimension();
final VectorEncoding encoding = modelMetadata.getVectorDataType() == VectorDataType.FLOAT
? VectorEncoding.FLOAT32
: VectorEncoding.BYTE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.indices.ModelUtil;

import java.io.IOException;
import java.time.ZoneOffset;
Expand All @@ -66,6 +67,7 @@
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
Expand Down Expand Up @@ -827,6 +829,7 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT
}

assertEquals(field1.fieldType().vectorDimension(), TEST_DIMENSION);
assertEquals(Integer.parseInt(field1.fieldType().getAttributes().get(DIMENSION_FIELD_NAME)), dimension);
assertEquals(
field1.fieldType().vectorSimilarityFunction(),
SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction()
Expand Down Expand Up @@ -866,12 +869,127 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT
assertEquals(1, fields.size());
field1 = fields.get(0);
assertTrue(field1 instanceof VectorField);
assertEquals(Integer.parseInt(field1.fieldType().getAttributes().get(DIMENSION_FIELD_NAME)), dimension);
}
// making sure to close the static mock to ensure that for tests running on this thread are not impacted by
// this mocking
utilMockedStatic.close();
}

@SneakyThrows
public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTypes() {
MockedStatic<KNNVectorFieldMapperUtil> utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class);
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
MockedStatic<ModelUtil> modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class);
final MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_IVF, Collections.emptyMap());

for (VectorDataType dataType : VectorDataType.values()) {
log.info("Vector Data Type is : {}", dataType);
SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT;
int dimension = dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION;
when(modelDao.getMetadata(MODEL_ID)).thenReturn(modelMetadata);
modelUtilMockedStatic.when(() -> ModelUtil.isModelCreated(modelMetadata)).thenReturn(true);
when(modelMetadata.getDimension()).thenReturn(dimension);
when(modelMetadata.getVectorDataType()).thenReturn(dataType);
when(modelMetadata.getSpaceType()).thenReturn(spaceType);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getMethodComponentContext()).thenReturn(methodComponentContext);

ParseContext.Document document = new ParseContext.Document();
ContentPath contentPath = new ContentPath();
ParseContext parseContext = mock(ParseContext.class);
when(parseContext.doc()).thenReturn(document);
when(parseContext.path()).thenReturn(contentPath);

utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true);
ModelFieldMapper modelFieldMapper = Mockito.spy(
ModelFieldMapper.createFieldMapper(
TEST_FIELD_NAME,
TEST_FIELD_NAME,
Collections.emptyMap(),
dataType,
MODEL_ID,
FieldMapper.MultiFields.empty(),
FieldMapper.CopyTo.empty(),
new Explicit<>(true, true),
false,
false,
modelDao,
CURRENT
)
);

if (dataType == VectorDataType.BINARY) {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(modelFieldMapper)
.getBytesFromContext(parseContext, TEST_DIMENSION * 8, dataType);
} else if (dataType == VectorDataType.BYTE) {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(modelFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION, dataType);
} else {
doReturn(Optional.of(TEST_VECTOR)).when(modelFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION);
}

modelFieldMapper.parseCreateField(parseContext);

List<IndexableField> fields = document.getFields();
assertEquals(1, fields.size());
IndexableField field1 = fields.get(0);
if (dataType == VectorDataType.FLOAT) {
assertTrue(field1 instanceof KnnFloatVectorField);
assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.FLOAT32);
} else {
assertTrue(field1 instanceof KnnByteVectorField);
assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.BYTE);
}

assertEquals(field1.fieldType().vectorDimension(), TEST_DIMENSION);
assertEquals(
field1.fieldType().vectorSimilarityFunction(),
SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction()
);

utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(false);

document = new ParseContext.Document();
contentPath = new ContentPath();
when(parseContext.doc()).thenReturn(document);
when(parseContext.path()).thenReturn(contentPath);
modelFieldMapper = Mockito.spy(
ModelFieldMapper.createFieldMapper(
TEST_FIELD_NAME,
TEST_FIELD_NAME,
Collections.emptyMap(),
dataType,
MODEL_ID,
FieldMapper.MultiFields.empty(),
FieldMapper.CopyTo.empty(),
new Explicit<>(true, true),
false,
false,
modelDao,
CURRENT
)
);

if (dataType == VectorDataType.FLOAT) {
doReturn(Optional.of(TEST_VECTOR)).when(modelFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION);
} else {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(modelFieldMapper)
.getBytesFromContext(parseContext, dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION, dataType);
}

modelFieldMapper.parseCreateField(parseContext);
fields = document.getFields();
assertEquals(1, fields.size());
field1 = fields.get(0);
assertTrue(field1 instanceof VectorField);
}
// making sure to close the static mock to ensure that for tests running on this thread are not impacted by
// this mocking
utilMockedStatic.close();
modelUtilMockedStatic.close();
}

@SneakyThrows
public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() {
// Create a lucene field mapper that creates a binary doc values field as well as KnnVectorField
Expand Down

0 comments on commit d85baa4

Please sign in to comment.