Skip to content

Commit

Permalink
Fixing the dimension for the vector when using Lucene field in ModelF…
Browse files Browse the repository at this point in the history
…ieldMapper

Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Aug 17, 2024
1 parent f42e86e commit d55034b
Show file tree
Hide file tree
Showing 5 changed files with 357 additions and 238 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ protected void parseCreateField(ParseContext context) throws IOException {
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
if (useLuceneBasedVectorField) {
int adjustedDimension = modelMetadata.getVectorDataType() == VectorDataType.BINARY
? modelMetadata.getDimension()
: modelMetadata.getDimension() / 8;
? modelMetadata.getDimension() / Byte.SIZE
: modelMetadata.getDimension();
final VectorEncoding encoding = modelMetadata.getVectorDataType() == VectorDataType.FLOAT
? VectorEncoding.FLOAT32
: VectorEncoding.BYTE;
Expand Down
22 changes: 22 additions & 0 deletions src/test/java/org/opensearch/knn/KNNTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNLibrarySearchContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
Expand Down Expand Up @@ -146,4 +147,25 @@ public int getDimension() {
}
};
}

/**
* Adjust the provided dimension based on {@link VectorDataType} during ingestion.
* @param dimension int
* @param vectorDataType {@link VectorDataType}
* @return int
*/
protected int adjustDimensionForIndexing(final int dimension, final VectorDataType vectorDataType) {
return VectorDataType.BINARY == vectorDataType ? dimension * Byte.SIZE : dimension;
}

/**
* Adjust the provided dimension based on {@link VectorDataType} for search.
*
* @param dimension int
* @param vectorDataType {@link VectorDataType}
* @return int
*/
protected int adjustDimensionForSearch(final int dimension, final VectorDataType vectorDataType) {
return VectorDataType.BINARY == vectorDataType ? dimension / Byte.SIZE : dimension;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,23 @@ public class FieldInfoExtractorTests extends KNNTestCase {

public void testExtractVectorDataType_whenDifferentConditions_thenSuccess() {
FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
MockedStatic<ModelUtil> modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class);

// default case
Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(null);
Mockito.when(fieldInfo.getAttribute(KNNConstants.MODEL_ID)).thenReturn(MODEL_ID);
modelUtilMockedStatic.when(() -> ModelUtil.getModelMetadata(MODEL_ID)).thenReturn(null);
Assert.assertEquals(VectorDataType.DEFAULT, FieldInfoExtractor.extractVectorDataType(fieldInfo));

// VectorDataType present in fieldInfo
Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(VectorDataType.BINARY.getValue());
Assert.assertEquals(VectorDataType.BINARY, FieldInfoExtractor.extractVectorDataType(fieldInfo));

// VectorDataType present in ModelMetadata
ModelMetadata modelMetadata = Mockito.mock(ModelMetadata.class);
Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(null);
modelUtilMockedStatic.when(() -> ModelUtil.getModelMetadata(MODEL_ID)).thenReturn(modelMetadata);
Mockito.when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.BYTE);
Assert.assertEquals(VectorDataType.BYTE, FieldInfoExtractor.extractVectorDataType(fieldInfo));

modelUtilMockedStatic.close();
try (MockedStatic<ModelUtil> modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class)) {
// default case
Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(null);
Mockito.when(fieldInfo.getAttribute(KNNConstants.MODEL_ID)).thenReturn(MODEL_ID);
modelUtilMockedStatic.when(() -> ModelUtil.getModelMetadata(MODEL_ID)).thenReturn(null);
Assert.assertEquals(VectorDataType.DEFAULT, FieldInfoExtractor.extractVectorDataType(fieldInfo));

// VectorDataType present in fieldInfo
Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(VectorDataType.BINARY.getValue());
Assert.assertEquals(VectorDataType.BINARY, FieldInfoExtractor.extractVectorDataType(fieldInfo));

// VectorDataType present in ModelMetadata
ModelMetadata modelMetadata = Mockito.mock(ModelMetadata.class);
Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(null);
modelUtilMockedStatic.when(() -> ModelUtil.getModelMetadata(MODEL_ID)).thenReturn(modelMetadata);
Mockito.when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.BYTE);
Assert.assertEquals(VectorDataType.BYTE, FieldInfoExtractor.extractVectorDataType(fieldInfo));
}
}
}
Loading

0 comments on commit d55034b

Please sign in to comment.