Skip to content

Commit

Permalink
fix uTs
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 committed Sep 4, 2024
1 parent 379a10e commit 697a51c
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 76 deletions.
3 changes: 3 additions & 0 deletions src/test/java/org/opensearch/knn/index/SpaceTypeTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.knn.index.engine.KNNEngine;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -66,6 +67,8 @@ public void testGetVectorSimilarityFunction_whenInnerproduct_thenConsistentWithS

public void testValidateVectorDataType_whenCalled_thenReturn() {
Map<SpaceType, Set<VectorDataType>> expected = Map.of(
SpaceType.UNDEFINED,
Collections.emptySet(),
SpaceType.L2,
Set.of(VectorDataType.FLOAT, VectorDataType.BYTE),
SpaceType.COSINESIMIL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;

import java.io.IOException;
Expand Down Expand Up @@ -137,8 +136,6 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc
indexWriter.commit();
indexWriter.close();

assertNotEquals(0L, (long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue());

// Validate to see if correct values are returned, assumption here is only 1 segment is getting created
IndexSearcher searcher = new IndexSearcher(indexReader);
final LeafReader leafReader = searcher.getLeafContexts().get(0).reader();
Expand Down Expand Up @@ -208,7 +205,6 @@ public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSucce
indexWriter.flush();
indexWriter.commit();
indexWriter.close();
assertNotEquals(0L, (long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue());

IndexSearcher searcher = new IndexSearcher(indexReader);
final LeafReader leafReader = searcher.getLeafContexts().get(0).reader();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ public void testParse_valid() throws IOException {
KNNMethodContext knnMethodContext = KNNMethodContext.parse(in);

assertEquals(KNNEngine.DEFAULT, knnMethodContext.getKnnEngine());
assertEquals(SpaceType.DEFAULT, knnMethodContext.getSpaceType());
assertEquals(SpaceType.UNDEFINED, knnMethodContext.getSpaceType());
assertEquals(methodName, knnMethodContext.getMethodComponentContext().getName());
assertTrue(knnMethodContext.getMethodComponentContext().getParameters().isEmpty());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ public void testTypeParser_whenBinaryWithInvalidDimension_thenException() throws

public void testTypeParser_whenBinaryFaissHNSWWithInvalidSpaceType_thenException() throws IOException {
for (SpaceType spaceType : SpaceType.values()) {
if (SpaceType.HAMMING == spaceType) {
if (SpaceType.UNDEFINED == spaceType || SpaceType.HAMMING == spaceType) {
continue;
}
testTypeParserWithBinaryDataType(KNNEngine.FAISS, spaceType, METHOD_HNSW, 8, "is not supported with");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,13 @@ public void testValidation_invalid_invalidMethodContext() {
String modelId = "test-model-id";

// Mock throwing an exception on validation
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
String validationExceptionMessage = "knn method invalid";
ValidationException validationException = new ValidationException();
validationException.addValidationError(validationExceptionMessage);
KNNEngine knnEngine = mock(KNNEngine.class);
when(knnEngine.validateMethod(any(), any())).thenReturn(validationException);
when(knnEngine.isTrainingRequired(any())).thenReturn(false);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine);
when(knnMethodContext.validate(any())).thenReturn(validationException);

when(knnMethodContext.isTrainingRequired()).thenReturn(false);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
Expand Down Expand Up @@ -346,26 +345,20 @@ public void testValidation_invalid_trainingIndexDoesNotExist() {

// Setup the training request
String modelId = "test-model-id";

KNNEngine knnEngine = mock(KNNEngine.class);
when(knnEngine.validateMethod(any(), any())).thenReturn(null);
when(knnEngine.isTrainingRequired(any())).thenReturn(true);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";

TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
null,
dimension,
trainingIndex,
trainingField,
null,
null,
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
Mode.ON_DISK,
CompressionLevel.NOT_CONFIGURED
);

Expand Down Expand Up @@ -396,26 +389,20 @@ public void testValidation_invalid_trainingFieldDoesNotExist() {

// Setup the training request
String modelId = "test-model-id";

KNNEngine knnEngine = mock(KNNEngine.class);
when(knnEngine.validateMethod(any(), any())).thenReturn(null);
when(knnEngine.isTrainingRequired(any())).thenReturn(true);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";

TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
null,
dimension,
trainingIndex,
trainingField,
null,
null,
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
Mode.ON_DISK,
CompressionLevel.NOT_CONFIGURED
);

Expand All @@ -442,6 +429,7 @@ public void testValidation_invalid_trainingFieldDoesNotExist() {
ActionRequestValidationException exception = trainingModelRequest.validate();
assertNotNull(exception);
List<String> validationErrors = exception.validationErrors();
logger.error("Validation errors: " + validationErrors);
assertEquals(1, validationErrors.size());
assertTrue(validationErrors.get(0).contains("does not exist"));
}
Expand All @@ -451,26 +439,20 @@ public void testValidation_invalid_trainingFieldNotKnnVector() {

// Setup the training request
String modelId = "test-model-id";

KNNEngine knnEngine = mock(KNNEngine.class);
when(knnEngine.validateMethod(any(), any())).thenReturn(null);
when(knnEngine.isTrainingRequired(any())).thenReturn(true);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";

TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
null,
dimension,
trainingIndex,
trainingField,
null,
null,
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
Mode.ON_DISK,
CompressionLevel.NOT_CONFIGURED
);

Expand Down Expand Up @@ -510,27 +492,20 @@ public void testValidation_invalid_dimensionDoesNotMatch() {

// Setup the training request
String modelId = "test-model-id";

KNNEngine knnEngine = mock(KNNEngine.class);
when(knnEngine.validateMethod(any(), any())).thenReturn(null);
when(knnEngine.isTrainingRequired(any())).thenReturn(true);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";

TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
null,
dimension,
trainingIndex,
trainingField,
null,
null,
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
Mode.ON_DISK,
CompressionLevel.NOT_CONFIGURED
);

Expand Down Expand Up @@ -573,27 +548,21 @@ public void testValidation_invalid_preferredNodeDoesNotExist() {

// Setup the training request
String modelId = "test-model-id";
KNNEngine knnEngine = mock(KNNEngine.class);
when(knnEngine.validateMethod(any(), any())).thenReturn(null);
when(knnEngine.isTrainingRequired(any())).thenReturn(true);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
String preferredNode = "preferred-node";

TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
null,
dimension,
trainingIndex,
trainingField,
preferredNode,
null,
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
Mode.ON_DISK,
CompressionLevel.NOT_CONFIGURED
);

Expand Down Expand Up @@ -640,12 +609,6 @@ public void testValidation_invalid_descriptionToLong() {

// Setup the training request
String modelId = "test-model-id";
KNNEngine knnEngine = mock(KNNEngine.class);
when(knnEngine.validateMethod(any(), any())).thenReturn(null);
when(knnEngine.isTrainingRequired(any())).thenReturn(true);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand All @@ -657,14 +620,14 @@ public void testValidation_invalid_descriptionToLong() {

TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
null,
dimension,
trainingIndex,
trainingField,
null,
description,
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
Mode.ON_DISK,
CompressionLevel.NOT_CONFIGURED
);

Expand All @@ -686,6 +649,7 @@ public void testValidation_invalid_descriptionToLong() {
ActionRequestValidationException exception = trainingModelRequest.validate();
assertNotNull(exception);
List<String> validationErrors = exception.validationErrors();
logger.error("Validation errorsa " + validationErrors);
assertEquals(1, validationErrors.size());
assertTrue(validationErrors.get(0).contains("Description exceeds limit"));
}
Expand All @@ -695,26 +659,20 @@ public void testValidation_valid_trainingIndexBuiltFromMethod() {

// Setup the training request
String modelId = "test-model-id";
KNNEngine knnEngine = mock(KNNEngine.class);
when(knnEngine.validateMethod(any(), any())).thenReturn(null);
when(knnEngine.isTrainingRequired(any())).thenReturn(true);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";

TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
null,
dimension,
trainingIndex,
trainingField,
null,
null,
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
Mode.ON_DISK,
CompressionLevel.NOT_CONFIGURED
);

Expand All @@ -737,27 +695,21 @@ public void testValidation_valid_trainingIndexBuiltFromModel() {

// Setup the training request
String modelId = "test-model-id";
KNNEngine knnEngine = mock(KNNEngine.class);
when(knnEngine.validateMethod(any(), any())).thenReturn(null);
when(knnEngine.isTrainingRequired(any())).thenReturn(true);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
String trainingFieldModeId = "training-field-model-id";

TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
null,
dimension,
trainingIndex,
trainingField,
null,
null,
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
Mode.ON_DISK,
CompressionLevel.NOT_CONFIGURED
);

Expand Down

0 comments on commit 697a51c

Please sign in to comment.