diff --git a/CHANGELOG.md b/CHANGELOG.md index f2ac91ef5..9652942ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD) ### Features +* [Remote Vector Index Build] Introduce Client Skeleton + Build Request implementation [#2548](https://github.com/opensearch-project/k-NN/pull/2548/files) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/build.gradle b/build.gradle index b5f715847..4ee818c27 100644 --- a/build.gradle +++ b/build.gradle @@ -321,7 +321,11 @@ dependencies { api "net.java.dev.jna:jna-platform:5.13.0" // OpenSearch core is using slf4j 1.7.36. Therefore, we cannot change the version here. implementation 'org.slf4j:slf4j-api:1.7.36' - + api "org.apache.httpcomponents.client5:httpclient5:${versions.httpclient5}" + api "org.apache.httpcomponents.core5:httpcore5:${versions.httpcore5}" + api "org.apache.httpcomponents.core5:httpcore5-h2:${versions.httpcore5}" + api "com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}" + api "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}" } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index c33f3ea63..04d28ea0e 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -15,10 +15,12 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Booleans; +import org.opensearch.common.settings.SecureSetting; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.settings.SecureString; import org.opensearch.core.common.unit.ByteSizeUnit; import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.index.IndexModule; @@ -96,6 +98,11 @@ public class KNNSettings { public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled"; public static final String KNN_INDEX_REMOTE_VECTOR_BUILD = "index.knn.remote_index_build.enabled"; public static final String KNN_REMOTE_VECTOR_REPO = "knn.remote_index_build.vector_repo"; + public static final String KNN_REMOTE_BUILD_SERVICE_ENDPOINT = "knn.remote_build_service.endpoint"; + public static final String KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL = "knn.remote_build_service.poll_interval"; + public static final String KNN_REMOTE_BUILD_SERVICE_TIMEOUT = "knn.remote_build_service.timeout"; + public static final String KNN_REMOTE_BUILD_SERVICE_USERNAME = "knn.remote_build_service.username"; + public static final String KNN_REMOTE_BUILD_SERVICE_PASSWORD = "knn.remote_build_service.password"; /** * Default setting values @@ -127,6 +134,9 @@ public class KNNSettings { public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60; public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = false; + public static final Integer KNN_DEFAULT_REMOTE_BUILD_SERVICE_TIMEOUT_MINUTES = 60; + public static final Integer KNN_DEFAULT_REMOTE_BUILD_SERVICE_POLL_INTERVAL_SECONDS = 30; + /** * Settings Definition */ @@ -388,6 +398,47 @@ public class KNNSettings { */ public static final Setting KNN_REMOTE_VECTOR_REPO_SETTING = Setting.simpleString(KNN_REMOTE_VECTOR_REPO, Dynamic, NodeScope); + /** + * Remote build service endpoint to be used for remote index build. //TODO we can add String validators on these endpoint settings + */ + public static final Setting KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING = Setting.simpleString( + KNN_REMOTE_BUILD_SERVICE_ENDPOINT, + NodeScope, + Dynamic + ); + + /** + * Time the remote build service client will wait before falling back to CPU index build. + */ + public static final Setting KNN_REMOTE_BUILD_SERVICE_TIMEOUT_SETTING = Setting.timeSetting( + KNN_REMOTE_BUILD_SERVICE_TIMEOUT, + TimeValue.timeValueMinutes(KNN_DEFAULT_REMOTE_BUILD_SERVICE_TIMEOUT_MINUTES), + NodeScope, + Dynamic + ); + + /** + * Setting to control how often the remote build service client polls the build service for the status of the job. + */ + public static final Setting KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL_SETTING = Setting.timeSetting( + KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL, + TimeValue.timeValueSeconds(KNN_DEFAULT_REMOTE_BUILD_SERVICE_POLL_INTERVAL_SECONDS), + NodeScope, + Dynamic + ); + + /** + * Keystore settings for build service HTTP authorization + */ + public static final Setting KNN_REMOTE_BUILD_SERVICE_USERNAME_SETTING = SecureSetting.secureString( + KNN_REMOTE_BUILD_SERVICE_USERNAME, + null + ); + public static final Setting KNN_REMOTE_BUILD_SERVICE_PASSWORD_SETTING = SecureSetting.secureString( + KNN_REMOTE_BUILD_SERVICE_PASSWORD, + null + ); + /** * Dynamic settings */ @@ -550,6 +601,26 @@ private Setting getSetting(String key) { return KNN_REMOTE_VECTOR_REPO_SETTING; } + if (KNN_REMOTE_BUILD_SERVICE_ENDPOINT.equals(key)) { + return KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING; + } + + if (KNN_REMOTE_BUILD_SERVICE_TIMEOUT.equals(key)) { + return KNN_REMOTE_BUILD_SERVICE_TIMEOUT_SETTING; + } + + if (KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL.equals(key)) { + return KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL_SETTING; + } + + if (KNN_REMOTE_BUILD_SERVICE_USERNAME.equals(key)) { + return KNN_REMOTE_BUILD_SERVICE_USERNAME_SETTING; + } + + if (KNN_REMOTE_BUILD_SERVICE_PASSWORD.equals(key)) { + return KNN_REMOTE_BUILD_SERVICE_PASSWORD_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -577,7 +648,12 @@ public List> getSettings() { KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, KNN_DERIVED_SOURCE_ENABLED_SETTING, KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING, - KNN_REMOTE_VECTOR_REPO_SETTING + KNN_REMOTE_VECTOR_REPO_SETTING, + KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING, + KNN_REMOTE_BUILD_SERVICE_TIMEOUT_SETTING, + KNN_REMOTE_BUILD_SERVICE_POLL_INTERVAL_SETTING, + KNN_REMOTE_BUILD_SERVICE_USERNAME_SETTING, + KNN_REMOTE_BUILD_SERVICE_PASSWORD_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java index 8555e2ad6..de40440c4 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java @@ -11,10 +11,13 @@ import org.opensearch.common.StopWatch; import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.index.IndexSettings; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.remote.RemoteBuildRequest; +import org.opensearch.knn.index.remote.RemoteIndexClient; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.Repository; @@ -22,10 +25,13 @@ import org.opensearch.repositories.blobstore.BlobStoreRepository; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.function.Supplier; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING; +import static org.opensearch.knn.index.KNNSettings.state; /** * This class orchestrates building vector indices. It handles uploading data to a repository, submitting a remote @@ -54,7 +60,7 @@ public RemoteIndexBuildStrategy(Supplier repositoriesServic * @return whether to use the remote build feature */ public static boolean shouldBuildIndexRemotely(IndexSettings indexSettings) { - String vectorRepo = KNNSettings.state().getSettingValue(KNN_REMOTE_VECTOR_REPO_SETTING.getKey()); + String vectorRepo = state().getSettingValue(KNN_REMOTE_VECTOR_REPO_SETTING.getKey()); return KNNFeatureFlags.isKNNRemoteVectorBuildEnabled() && indexSettings.getValue(KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING) && vectorRepo != null @@ -88,17 +94,18 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException { log.debug("Repository write took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); stopWatch = new StopWatch().start(); - submitVectorBuild(); + RemoteBuildRequest buildRequest = constructBuildRequest(indexInfo); + String jobId = RemoteIndexClient.getInstance().submitVectorBuild(buildRequest); time_in_millis = stopWatch.stop().totalTime().millis(); log.debug("Submit vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); stopWatch = new StopWatch().start(); - awaitVectorBuild(); + String indexPath = awaitVectorBuild(jobId); time_in_millis = stopWatch.stop().totalTime().millis(); log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); stopWatch = new StopWatch().start(); - readFromRepository(); + readFromRepository(indexPath); time_in_millis = stopWatch.stop().totalTime().millis(); log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); } catch (Exception e) { @@ -156,14 +163,57 @@ private void submitVectorBuild() { /** * Wait on remote vector build to complete */ - private void awaitVectorBuild() { + private String awaitVectorBuild(String jobId) { throw new NotImplementedException(); } /** * Read constructed vector file from remote repository and write to IndexOutput */ - private void readFromRepository() { + private void readFromRepository(String indexPath) { throw new NotImplementedException(); } + + /** + * Construct the JSON request body and HTTP request for the index build request + * @return HttpExecuteRequest for the index build request with parameters set + */ + public RemoteBuildRequest constructBuildRequest(BuildIndexParams indexInfo) throws IOException { + String repositoryType = getRepository().getMetadata().type(); + String containerName = switch (repositoryType) { + case "s3" -> getRepository().getMetadata().settings().get("bucket"); + case "fs" -> getRepository().getMetadata().settings().get("location"); + default -> throw new IllegalStateException("Unexpected value: " + repositoryType); + }; + String vectorPath = null; // blobName + VECTOR_BLOB_FILE_EXTENSION + String docIdPath = null; // blobName + DOC_ID_FILE_EXTENSION + String tenantId = null; // indexSettings.getSettings().get(ClusterName.CLUSTER_NAME_SETTING.getKey()); + int dimension = 0; // TODO + int docCount = indexInfo.getTotalLiveDocs(); + String dataType = indexInfo.getVectorDataType().getValue(); // TODO need to fetch encoder param to get fp16 vs fp32 + String engine = indexInfo.getKnnEngine().getName(); + + String spaceType = indexInfo.getParameters().get(KNNConstants.SPACE_TYPE).toString(); // OR + + Map algorithmParams = new HashMap<>(); + algorithmParams.put("ef_construction", 100); + algorithmParams.put("m", 16); + + Map indexParameters = new HashMap<>(); + indexParameters.put("algorithm", "hnsw"); + indexParameters.put("algorithm_parameters", algorithmParams); + + return RemoteBuildRequest.builder() + .repositoryType(repositoryType) + .containerName(containerName) + .vectorPath(vectorPath) + .docIdPath(docIdPath) + .tenantId(tenantId) + .dimension(dimension) + .docCount(docCount) + .dataType(dataType) + .engine(engine) + .indexParameters(indexParameters) + .build(); + } } diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteBuildRequest.java b/src/main/java/org/opensearch/knn/index/remote/RemoteBuildRequest.java new file mode 100644 index 000000000..9cd19173f --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteBuildRequest.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import org.opensearch.common.xcontent.json.JsonXContent; +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +@Builder +@Getter +public class RemoteBuildRequest { + private final String repositoryType; + private final String containerName; + private final String vectorPath; + private final String docIdPath; + private final String tenantId; + private final int dimension; + private final int docCount; + private final String dataType; + private final String engine; + @Builder.Default + private final Map indexParameters = new HashMap<>(); + + public String toJson() throws IOException { + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + builder.startObject(); + builder.field("repository_type", repositoryType); + builder.field("container_name", containerName); + builder.field("vector_path", vectorPath); + builder.field("doc_id_path", docIdPath); + builder.field("tenant_id", tenantId); + builder.field("dimension", dimension); + builder.field("doc_count", docCount); + builder.field("data_type", dataType); + builder.field("engine", engine); + builder.field("index_parameters", indexParameters); + builder.endObject(); + return builder.toString(); + } + } + +} diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClient.java b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClient.java new file mode 100644 index 000000000..3f2e0bcd8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClient.java @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.NotImplementedException; +import org.apache.hc.client5.http.classic.methods.HttpGet; +import org.apache.hc.client5.http.classic.methods.HttpPost; +import org.apache.hc.client5.http.classic.methods.HttpUriRequestBase; +import org.apache.hc.client5.http.impl.classic.BasicHttpClientResponseHandler; +import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; +import org.apache.hc.client5.http.impl.classic.HttpClients; +import org.apache.hc.client5.http.utils.Base64; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.opensearch.core.common.settings.SecureString; +import org.opensearch.knn.index.KNNSettings; + +import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; + +/** + * Class to handle all interactions with the remote vector build service. + * InterruptedExceptions will cause a fallback to local CPU build. + */ +@Log4j2 +public class RemoteIndexClient { + private static RemoteIndexClient INSTANCE; + private volatile CloseableHttpClient httpClient; + public static final int MAX_RETRIES = 1; // 2 total attempts + public static final long BASE_DELAY_MS = 1000; + + private static final ObjectMapper objectMapper = new ObjectMapper(); + + RemoteIndexClient() { + this.httpClient = createHttpClient(); + } + + /** + * Return the Singleton instance of the node's RemoteIndexClient + * @return RemoteIndexClient instance + */ + public static synchronized RemoteIndexClient getInstance() { + if (INSTANCE == null) { + INSTANCE = new RemoteIndexClient(); + } + return INSTANCE; + } + + /** + * Initialize the httpClient to be used + * @return The HTTP Client + */ + private CloseableHttpClient createHttpClient() { + // TODO The client will need to be rebuilt iff we decide to allow for retry configuration in the future + return HttpClients.custom().setRetryStrategy(new RemoteIndexClientRetryStrategy()).build(); + } + + /** + * Submit a build to the Remote Vector Build Service endpoint using round robin task assignment. + * @return job_id from the server response used to track the job + */ + public String submitVectorBuild(RemoteBuildRequest request) throws IOException { + URI endpoint = URI.create(KNNSettings.state().getSettingValue(KNNSettings.KNN_REMOTE_BUILD_SERVICE_ENDPOINT)); + HttpPost buildRequest = new HttpPost(endpoint + "/_build"); + buildRequest.setHeader("Content-Type", "application/json"); + buildRequest.setEntity(new StringEntity(request.toJson())); + authenticateRequest(buildRequest); + + String response = httpClient.execute(buildRequest, body -> { + if (body.getCode() != 200) { + throw new IOException("Failed to submit build request after retries with code: " + body.getCode()); + } + return EntityUtils.toString(body.getEntity()); + }); + + if (response == null) { + throw new IOException("Received 200 status code but response is null."); + } + + return getValueFromResponse(response, "job_id"); + } + + /** + * Await the completion of the index build by polling periodically and handling the returned statuses. + * @param jobId identifier from the server to track the job + * @return the path to the completed index + */ + private void awaitVectorBuild(String jobId) { + throw new NotImplementedException(); + } + + /** + * Helper method to directly get the status response for a given job ID + * @param jobId to check + * @return The entire HttpExecuteResponse for the status request + */ + public String getBuildStatus(String jobId) throws IOException { + URI endpoint = URI.create(KNNSettings.state().getSettingValue(KNNSettings.KNN_REMOTE_BUILD_SERVICE_ENDPOINT)); + HttpGet request = new HttpGet(endpoint + "/_status/" + jobId); + authenticateRequest(request); + return httpClient.execute(request, new BasicHttpClientResponseHandler()); + } + + /** + * Given a JSON response string, get a value for a specific key. Converts json {@literal } to Java null. + * @param responseBody The response to read + * @param key The key to lookup + * @return The value for the key, or null if not found + */ + public static String getValueFromResponse(String responseBody, String key) throws JsonProcessingException { + // TODO See if I can use OpenSearch XContent tools here to avoid Jackson dependency + ObjectNode jsonResponse = (ObjectNode) objectMapper.readTree(responseBody); + if (jsonResponse.has(key)) { + if (jsonResponse.get(key).isNull()) { + return null; + } + return jsonResponse.get(key).asText(); + } + throw new IllegalArgumentException("Key " + key + " not found in response"); + } + + /** + * Authenticate the HTTP request by manually setting the auth header. + * This is done to allow for dynamic credential updates. + * @param request to be authenticated + */ + public void authenticateRequest(HttpUriRequestBase request) { + SecureString username = KNNSettings.state().getSettingValue(KNNSettings.KNN_REMOTE_BUILD_SERVICE_USERNAME); + SecureString password = KNNSettings.state().getSettingValue(KNNSettings.KNN_REMOTE_BUILD_SERVICE_PASSWORD); + + if (password != null) { + final String auth = username + ":" + password.clone(); + final byte[] encodedAuth = Base64.encodeBase64(auth.getBytes(StandardCharsets.ISO_8859_1)); + final String authHeader = "Basic " + new String(encodedAuth); + request.setHeader(HttpHeaders.AUTHORIZATION, authHeader); + } + } + + /** + * Close the httpClient + */ + public void close() throws IOException { + if (httpClient != null) { + httpClient.close(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClientRetryStrategy.java b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClientRetryStrategy.java new file mode 100644 index 000000000..7aa411f77 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClientRetryStrategy.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import org.apache.hc.client5.http.impl.DefaultHttpRequestRetryStrategy; +import org.apache.hc.core5.http.ConnectionClosedException; +import org.apache.hc.core5.http.HttpResponse; +import org.apache.hc.core5.http.protocol.HttpContext; +import org.apache.hc.core5.util.TimeValue; + +import javax.net.ssl.SSLException; +import java.io.InterruptedIOException; +import java.net.ConnectException; +import java.net.NoRouteToHostException; +import java.net.UnknownHostException; +import java.util.Arrays; + +/** + * The public constructors for the Apache HTTP client default retry strategies allow customization of max retries + * and retry interval, but not retriable status codes. + * In order to add the other retriable status codes from our Remote Build API Contract, we must extend this class. + * @see org.apache.hc.client5.http.impl.DefaultHttpRequestRetryStrategy + */ +public class RemoteIndexClientRetryStrategy extends DefaultHttpRequestRetryStrategy { + public RemoteIndexClientRetryStrategy() { + super( + RemoteIndexClient.MAX_RETRIES, + TimeValue.ofMilliseconds(RemoteIndexClient.BASE_DELAY_MS), + Arrays.asList( + InterruptedIOException.class, + UnknownHostException.class, + ConnectException.class, + ConnectionClosedException.class, + NoRouteToHostException.class, + SSLException.class + ), + Arrays.asList(408, 429, 500, 502, 503, 504, 509) + ); + } + + /** + * Override retry interval setting to implement backoff strategy. This is only relevant for future implementations where we may increase the retry count from 1 max retry. + */ + @Override + public TimeValue getRetryInterval(HttpResponse response, int execCount, HttpContext context) { + if (response.getCode() == 429 || response.getCode() == 503) { + long delay = RemoteIndexClient.BASE_DELAY_MS; + long backoffDelay = delay * (long) Math.pow(2, execCount - 1); + return TimeValue.ofMilliseconds(Math.min(backoffDelay, TimeValue.ofMinutes(1).toMilliseconds())); + } + return super.getRetryInterval(response, execCount, context); + } +} diff --git a/src/test/java/org/opensearch/knn/index/remote/RemoteIndexClientTests.java b/src/test/java/org/opensearch/knn/index/remote/RemoteIndexClientTests.java new file mode 100644 index 000000000..93d725c7e --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/remote/RemoteIndexClientTests.java @@ -0,0 +1,116 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RemoteIndexClientTests extends OpenSearchSingleNodeTestCase { + + @Mock + protected ClusterService clusterService; + @Mock + protected ClusterSettings clusterSettings; + + protected AutoCloseable openMocks; + + private ObjectMapper mapper; + + @Before + public void setup() { + this.mapper = new ObjectMapper(); + openMocks = MockitoAnnotations.openMocks(this); + clusterService = mock(ClusterService.class); + Set> defaultClusterSettings = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + defaultClusterSettings.addAll( + KNNSettings.state().getSettings().stream().filter(s -> s.getProperties().contains(Setting.Property.NodeScope)).toList() + ); + KNNSettings.state().setClusterService(clusterService); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, defaultClusterSettings)); + } + + public void testGetHttpClient_success() throws IOException { + RemoteIndexClient client = RemoteIndexClient.getInstance(); + assertNotNull(client); + client.close(); + } + + public void testConstructBuildRequest() throws IOException { + Map algorithmParams = new HashMap<>(); + algorithmParams.put("ef_construction", 100); + algorithmParams.put("m", 16); + + Map indexParameters = new HashMap<>(); + indexParameters.put("algorithm", "hnsw"); + indexParameters.put("space_type", "l2"); + indexParameters.put("algorithm_parameters", algorithmParams); + + RemoteBuildRequest request = RemoteBuildRequest.builder() + .repositoryType("S3") + .containerName("MyVectorStore") + .vectorPath("MyVectorPath") + .docIdPath("MyDocIdPath") + .tenantId("MyTenant") + .dimension(256) + .docCount(1_000_000) + .dataType("fp32") + .engine("faiss") + .indexParameters(indexParameters) + .build(); + + String expectedJson = "{" + + "\"repository_type\":\"S3\"," + + "\"container_name\":\"MyVectorStore\"," + + "\"vector_path\":\"MyVectorPath\"," + + "\"doc_id_path\":\"MyDocIdPath\"," + + "\"tenant_id\":\"MyTenant\"," + + "\"dimension\":256," + + "\"doc_count\":1000000," + + "\"data_type\":\"fp32\"," + + "\"engine\":\"faiss\"," + + "\"index_parameters\":{" + + "\"space_type\":\"l2\"," + + "\"algorithm\":\"hnsw\"," + + "\"algorithm_parameters\":{" + + "\"ef_construction\":100," + + "\"m\":16" + + "}" + + "}" + + "}"; + assertEquals(mapper.readTree(expectedJson), mapper.readTree(request.toJson())); + } + + public void testGetValueFromResponse() throws JsonProcessingException { + String jobID = "{\"job_id\": \"job-1739930402\"}"; + assertEquals("job-1739930402", RemoteIndexClient.getValueFromResponse(jobID, "job_id")); + String failedIndexBuild = "{" + + "\"task_status\":\"FAILED_INDEX_BUILD\"," + + "\"error\":\"Index build process interrupted.\"," + + "\"index_path\": null" + + "}"; + String error = RemoteIndexClient.getValueFromResponse(failedIndexBuild, "error"); + assertEquals("Index build process interrupted.", error); + assertNull(RemoteIndexClient.getValueFromResponse(failedIndexBuild, "index_path")); + } +}