Skip to content

Commit

Permalink
separate encoding and embedding classes
Browse files Browse the repository at this point in the history
  • Loading branch information
will-hwang committed Feb 6, 2025
1 parent bbd71d2 commit 4a43ca2
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.nio.file.Path;
import java.util.Locale;
import java.util.Optional;

import org.junit.Before;
import org.opensearch.common.settings.Settings;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;
Expand All @@ -16,6 +17,9 @@
import static org.opensearch.neuralsearch.util.TestUtils.RESTART_UPGRADE_OLD_CLUSTER;
import static org.opensearch.neuralsearch.util.TestUtils.BWC_VERSION;
import static org.opensearch.neuralsearch.util.TestUtils.generateModelId;

import org.opensearch.neuralsearch.util.SparseEncodingModel;
import org.opensearch.neuralsearch.util.TextEmbeddingModel;
import org.opensearch.test.rest.OpenSearchRestTestCase;

public abstract class AbstractRestartUpgradeRestTestCase extends BaseNeuralSearchIT {
Expand All @@ -27,34 +31,6 @@ protected String getIndexNameForTest() {
return NEURAL_SEARCH_BWC_PREFIX + getTestName().toLowerCase(Locale.ROOT);
}

protected enum TextEmbeddingModel {
INSTANCE;

private static String modelId;

public static void setModelId(String id) {
modelId = id;
}

public static String getModelId() {
return modelId;
}
}

protected enum SparseEncodingModel {
INSTANCE;

private static String modelId;

public static void setModelId(String id) {
modelId = id;
}

public static String getModelId() {
return modelId;
}
}

@Override
protected final boolean preserveIndicesUponCompletion() {
return true;
Expand Down Expand Up @@ -90,11 +66,12 @@ protected final Optional<String> getBWCVersion() {
}

protected String uploadTextEmbeddingModel() throws Exception {
String modelId = TextEmbeddingModel.getModelId();
TextEmbeddingModel textEmbeddingModel = TextEmbeddingModel.getInstance();
String modelId = textEmbeddingModel.getModelId();
if (modelId == null) {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
String id = registerModelGroupAndGetModelId(requestBody);
TextEmbeddingModel.setModelId(id);
textEmbeddingModel.setModelId(id);
return id;
}
return modelId;
Expand All @@ -114,13 +91,14 @@ protected void createPipelineProcessor(final String modelId, final String pipeli
}

protected String uploadSparseEncodingModel() throws Exception {
String modelId = SparseEncodingModel.getModelId();
SparseEncodingModel sparseEncodingModel = SparseEncodingModel.getInstance();
String modelId = sparseEncodingModel.getModelId();
if (modelId == null) {
String requestBody = Files.readString(
Path.of(classLoader.getResource("processor/UploadSparseEncodingModelRequestBody.json").toURI())
);
String id = registerModelGroupAndGetModelId(requestBody);
SparseEncodingModel.setModelId(id);
sparseEncodingModel.setModelId(id);
return id;
}
return modelId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/
package org.opensearch.neuralsearch.bwc.restart;

import org.opensearch.neuralsearch.util.TestUtils;
import org.opensearch.neuralsearch.util.SparseEncodingModel;

import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -13,7 +13,6 @@

import static org.opensearch.neuralsearch.util.BatchIngestionUtils.prepareDataForBulkIngestion;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.util.TestUtils.SPARSE_ENCODING_PROCESSOR;

public class BatchIngestionIT extends AbstractRestartUpgradeRestTestCase {
private static final String PIPELINE_NAME = "pipeline-BatchIngestionIT";
Expand All @@ -38,7 +37,7 @@ public void testBatchIngestionWithNeuralSparseProcessor_E2EFlow() throws Excepti
validateDocCountAndInfo(indexName, 5, () -> getDocById(indexName, "4"), EMBEDDING_FIELD_NAME, Map.class);
} else {
String modelId = null;
modelId = TestUtils.getModelId(getIngestionPipeline(PIPELINE_NAME), SPARSE_ENCODING_PROCESSOR);
modelId = SparseEncodingModel.getInstance().getModelId();
loadModel(modelId);
try {
List<Map<String, String>> docs = prepareDataForBulkIngestion(5, 5);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.util.TextEmbeddingModel;

public class HybridSearchIT extends AbstractRestartUpgradeRestTestCase {
private static final String PIPELINE_NAME = "nlp-hybrid-pipeline";
Expand Down Expand Up @@ -67,7 +68,7 @@ private void validateNormalizationProcessor(final String fileName, final String
} else {
String modelId = null;
try {
modelId = TextEmbeddingModel.getModelId();
modelId = TextEmbeddingModel.getInstance().getModelId();
loadModel(modelId);
addDocuments(getIndexNameForTest(), false);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.util.TextEmbeddingModel;

public class HybridSearchWithRescoreIT extends AbstractRestartUpgradeRestTestCase {
private static final String PIPELINE_NAME = "nlp-hybrid-with-rescore-pipeline";
Expand Down Expand Up @@ -59,7 +60,7 @@ public void testHybridQueryWithRescore_whenIndexWithMultipleShards_E2EFlow() thr
} else {
String modelId = null;
try {
modelId = TextEmbeddingModel.getModelId();
modelId = TextEmbeddingModel.getInstance().getModelId();
loadModel(modelId);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_UPGRADED, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Map;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.util.TextEmbeddingModel;

public class KnnRadialSearchIT extends AbstractRestartUpgradeRestTestCase {
private static final String PIPELINE_NAME = "radial-search-pipeline";
Expand Down Expand Up @@ -38,7 +39,7 @@ public void testKnnRadialSearch_E2EFlow() throws Exception {
} else {
String modelId = null;
try {
modelId = TextEmbeddingModel.getModelId();
modelId = TextEmbeddingModel.getInstance().getModelId();
loadModel(modelId);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_1, TEST_IMAGE_FIELD, TEST_IMAGE_TEXT_1);
validateIndexQuery(modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Map;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.util.TextEmbeddingModel;

public class MultiModalSearchIT extends AbstractRestartUpgradeRestTestCase {
private static final String PIPELINE_NAME = "nlp-ingest-pipeline";
Expand Down Expand Up @@ -38,7 +39,7 @@ public void testTextImageEmbeddingProcessor_E2EFlow() throws Exception {
} else {
String modelId = null;
try {
modelId = TextEmbeddingModel.getModelId();
modelId = TextEmbeddingModel.getInstance().getModelId();
loadModel(modelId);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_1, TEST_IMAGE_FIELD, TEST_IMAGE_TEXT_1);
validateTestIndex(modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR;

import org.opensearch.common.settings.Settings;
import org.opensearch.neuralsearch.util.SparseEncodingModel;
import org.opensearch.neuralsearch.util.TestUtils;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
Expand Down Expand Up @@ -60,7 +61,7 @@ public void testNeuralQueryEnricherProcessor_NeuralSparseSearch_E2EFlow() throws
} else {
String modelId = null;
try {
modelId = SparseEncodingModel.getModelId();
modelId = SparseEncodingModel.getInstance().getModelId();
loadModel(modelId);
sparseEncodingQueryBuilderWithModelId.modelId(modelId);
assertEquals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Map;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.neuralsearch.util.SparseEncodingModel;
import org.opensearch.neuralsearch.util.TestUtils;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.util.TestUtils.objectToFloat;
Expand Down Expand Up @@ -52,7 +53,7 @@ public void testSparseEncodingProcessor_E2EFlow() throws Exception {
} else {
String modelId = null;
try {
modelId = SparseEncodingModel.getModelId();
modelId = SparseEncodingModel.getInstance().getModelId();
loadModel(modelId);
addSparseEncodingDoc(
getIndexNameForTest(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import org.opensearch.common.settings.Settings;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.util.SparseEncodingModel;

import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -45,7 +46,7 @@ public void testNeuralSparseQueryTwoPhaseProcessor_NeuralSearch_E2EFlow() throws
} else {
String modelId = null;
try {
modelId = SparseEncodingModel.getModelId();
modelId = SparseEncodingModel.getInstance().getModelId();
loadModel(modelId);
neuralSparseQueryBuilder.modelId(modelId);
Object resultWith2PhasePipeline = search(getIndexNameForTest(), neuralSparseQueryBuilder, 1).get("hits");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Map;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.util.TextEmbeddingModel;

public class SemanticSearchIT extends AbstractRestartUpgradeRestTestCase {

Expand All @@ -22,7 +23,6 @@ public class SemanticSearchIT extends AbstractRestartUpgradeRestTestCase {
// Validate process , pipeline and document count in restart-upgrade scenario
public void testTextEmbeddingProcessor_E2EFlow() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);

if (isRunningAgainstOldCluster()) {
String modelId = uploadTextEmbeddingModel();
loadModel(modelId);
Expand All @@ -36,7 +36,7 @@ public void testTextEmbeddingProcessor_E2EFlow() throws Exception {
} else {
String modelId = null;
try {
modelId = TextEmbeddingModel.getModelId();
modelId = TextEmbeddingModel.getInstance().getModelId();
loadModel(modelId);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_1, null, null);
validateTestIndex(modelId);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.util;

/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

public class SparseEncodingModel {
private static SparseEncodingModel instance = null;

public String modelId = null;

private SparseEncodingModel() {

}

public void setModelId(String modelId) {
this.modelId = modelId;
}

public String getModelId() {
return modelId;
}

public static SparseEncodingModel getInstance() {
// To ensure only one instance is created
if (instance == null) {
instance = new SparseEncodingModel();
}
return instance;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.util;

/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

public class TextEmbeddingModel {
private static TextEmbeddingModel instance = null;

public String modelId = null;

private TextEmbeddingModel() {

}

public void setModelId(String modelId) {
this.modelId = modelId;
}

public String getModelId() {
return modelId;
}

public static TextEmbeddingModel getInstance() {
// To ensure only one instance is created
if (instance == null) {
instance = new TextEmbeddingModel();
}
return instance;
}
}

0 comments on commit 4a43ca2

Please sign in to comment.