From dbddb29da5c28365cfc6105944ed94e76cacca68 Mon Sep 17 00:00:00 2001 From: owenhalpert Date: Thu, 27 Feb 2025 18:34:19 -0800 Subject: [PATCH] Refactor, add remote build request interface, complete FAISS request parameters, secure setting testing Signed-off-by: owenhalpert --- .../opensearch/knn/common/KNNConstants.java | 21 ++ .../remote/RemoteIndexBuildStrategy.java | 30 +- .../remote/RemoteStatusResponse.java | 15 + .../remote/VectorRepositoryAccessor.java | 3 +- .../knn/index/engine/KNNEngine.java | 6 + .../knn/index/engine/KNNLibrary.java | 10 + .../knn/index/engine/faiss/Faiss.java | 78 +++++ .../knn/index/engine/lucene/Lucene.java | 7 + .../knn/index/engine/nmslib/Nmslib.java | 7 + .../index/remote/HTTPRemoteBuildRequest.java | 65 ++++ .../index/remote/HTTPRemoteBuildResponse.java | 24 ++ .../knn/index/remote/RemoteBuildRequest.java | 94 ++++-- .../knn/index/remote/RemoteBuildResponse.java | 13 + .../knn/index/remote/RemoteIndexClient.java | 40 ++- .../RemoteIndexClientRetryStrategy.java | 25 +- .../index/remote/RemoteIndexHTTPClient.java | 222 ++++--------- .../org/opensearch/knn/plugin/KNNPlugin.java | 4 +- .../remote/RemoteIndexHTTPClientTests.java | 306 ++++++++++-------- 18 files changed, 619 insertions(+), 351 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteStatusResponse.java create mode 100644 src/main/java/org/opensearch/knn/index/remote/HTTPRemoteBuildRequest.java create mode 100644 src/main/java/org/opensearch/knn/index/remote/HTTPRemoteBuildResponse.java create mode 100644 src/main/java/org/opensearch/knn/index/remote/RemoteBuildResponse.java diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 14e95887c7..c2d4289685 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -165,4 +165,25 @@ public class KNNConstants { public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY = "knn-derived-source-enabled"; public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE = "true"; public static final String DERIVED_VECTOR_FIELD_ATTRIBUTE_FALSE_VALUE = "false"; + + // Remote build constants + public static final String BUILD_ENDPOINT = "/_build"; + public static final String STATUS_ENDPOINT = "/_status"; + public static final String S3 = "s3"; + public static final String BUCKET = "bucket"; + // Build request keys + public static final String ALGORITHM = "algorithm"; + public static final String ALGORITHM_PARAMETERS = "algorithm_parameters"; + public static final String INDEX_PARAMETERS = "index_parameters"; + public static final String DOC_COUNT = "doc_count"; + public static final String TENANT_ID = "tenant_id"; + public static final String DOC_ID_PATH = "doc_id_path"; + public static final String VECTOR_PATH = "vector_path"; + public static final String CONTAINER_NAME = "container_name"; + public static final String REPOSITORY_TYPE = "repository_type"; + // Server responses + public static final String JOB_ID = "job_id"; + public static final String TASK_STATUS = "task_status"; + public static final String INDEX_PATH = "index_path"; + public static final String ERROR_MESSAGE = "error_message"; } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java index 9c3706f1b6..575c493c32 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java @@ -6,7 +6,6 @@ package org.opensearch.knn.index.codec.nativeindex.remote; import lombok.extern.log4j.Log4j2; -import org.apache.commons.lang.NotImplementedException; import org.opensearch.common.StopWatch; import org.opensearch.common.UUIDs; import org.opensearch.common.annotation.ExperimentalApi; @@ -14,6 +13,9 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.remote.RemoteBuildRequest; +import org.opensearch.knn.index.remote.RemoteBuildResponse; +import org.opensearch.knn.index.remote.RemoteIndexClient; import org.opensearch.knn.index.remote.RemoteIndexHTTPClient; import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.Repository; @@ -25,6 +27,7 @@ import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING; +import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING; import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING; /** @@ -127,18 +130,24 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException { log.debug("Repository write took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); stopWatch = new StopWatch().start(); - String jobId = RemoteIndexHTTPClient.getInstance() - .submitVectorBuild(indexSettings, indexInfo, getRepository().getMetadata(), blobName); + RemoteIndexClient client = getRemoteIndexClient(); + RemoteBuildRequest remoteBuildRequest = client.constructBuildRequest( + indexSettings, + indexInfo, + getRepository().getMetadata(), + blobName + ); + RemoteBuildResponse remoteBuildResponse = client.submitVectorBuild(remoteBuildRequest); time_in_millis = stopWatch.stop().totalTime().millis(); log.debug("Submit vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); stopWatch = new StopWatch().start(); - awaitVectorBuild(); + RemoteStatusResponse remoteStatusResponse = client.awaitVectorBuild(remoteBuildResponse); time_in_millis = stopWatch.stop().totalTime().millis(); log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); stopWatch = new StopWatch().start(); - vectorRepositoryAccessor.readFromRepository(); + vectorRepositoryAccessor.readFromRepository(remoteStatusResponse.getIndexPath(), indexInfo.getIndexOutputWithBuffer()); time_in_millis = stopWatch.stop().totalTime().millis(); log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); } catch (Exception e) { @@ -167,9 +176,14 @@ private BlobStoreRepository getRepository() throws RepositoryMissingException { } /** - * Wait on remote vector build to complete + * Determine which implementation of RemoteIndexClient to be used by the build strategy + * @return Concrete RemoteIndexClient implementation */ - private void awaitVectorBuild() { - throw new NotImplementedException(); + private RemoteIndexClient getRemoteIndexClient() { + String endpoint = KNNSettings.state().getSettingValue(KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING.getKey()); + if (endpoint == null || endpoint.isEmpty()) { + throw new IllegalArgumentException("No endpoint set for RemoteIndexClient"); + } + return RemoteIndexHTTPClient.getInstance(); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteStatusResponse.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteStatusResponse.java new file mode 100644 index 0000000000..a9a87722a7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteStatusResponse.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.remote; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Getter +@AllArgsConstructor +public class RemoteStatusResponse { + private String indexPath; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorRepositoryAccessor.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorRepositoryAccessor.java index 7d93177d2f..f80a76461d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorRepositoryAccessor.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/VectorRepositoryAccessor.java @@ -7,6 +7,7 @@ import org.apache.commons.lang.NotImplementedException; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import java.io.IOException; @@ -36,7 +37,7 @@ void writeToRepository( /** * Read constructed vector file from remote repository and write to IndexOutput */ - default void readFromRepository() { + default void readFromRepository(String path, IndexOutputWithBuffer indexOutputWithBuffer) { throw new NotImplementedException(); } } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java index 0bd4b0f27a..dd9058a3e7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableSet; import org.opensearch.common.ValidationException; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.engine.faiss.Faiss; import org.opensearch.knn.index.engine.lucene.Lucene; import org.opensearch.knn.index.engine.nmslib.Nmslib; @@ -177,6 +178,11 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext( return knnLibrary.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); } + @Override + public Map getRemoteIndexingParameters(BuildIndexParams params) { + return knnLibrary.getRemoteIndexingParameters(params); + } + @Override public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) { return knnLibrary.getKNNLibrarySearchContext(methodName); diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java index 29e6442f48..338a79b80c 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java @@ -7,9 +7,11 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import java.util.Collections; import java.util.List; +import java.util.Map; /** * KNNLibrary is an interface that helps the plugin communicate with k-NN libraries @@ -147,4 +149,12 @@ default List mmapFileExtensions() { default boolean supportsRemoteIndexBuild() { return false; } + + /** + * Get the remote build supported index parameter mapping to be sent to the remote build service. + * @param params to parse + */ + default Map getRemoteIndexingParameters(BuildIndexParams params) { + throw new UnsupportedOperationException("This method must be implemented by the implementing class"); + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java index d23a475aa7..2956a29e5d 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.engine.KNNMethod; import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -15,11 +16,29 @@ import org.opensearch.knn.index.engine.NativeLibrary; import org.opensearch.knn.index.engine.ResolvedMethodContext; +import java.util.HashMap; import java.util.Map; import java.util.function.Function; +import static org.opensearch.knn.common.KNNConstants.ALGORITHM; +import static org.opensearch.knn.common.KNNConstants.ALGORITHM_PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST_DEFAULT; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_DEFAULT; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE; /** * Implements NativeLibrary for the faiss native library @@ -109,6 +128,65 @@ public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { return spaceType.scoreToDistanceTranslation(score); } + /** + * Get the parameters that need to be passed to the remote build service for training + * @param indexInfo to parse + * @return Map of parameters to be used as "index_parameters" + */ + @Override + public Map getRemoteIndexingParameters(BuildIndexParams indexInfo) { + Map indexParameters = new HashMap<>(); + String methodName = (String) indexInfo.getParameters().get(NAME); + indexParameters.put(ALGORITHM, methodName); + indexParameters.put(METHOD_PARAMETER_SPACE_TYPE, indexInfo.getParameters().getOrDefault(SPACE_TYPE, INDEX_KNN_DEFAULT_SPACE_TYPE)); + + assert (indexInfo.getParameters().containsKey(PARAMETERS)); + Object innerParams = indexInfo.getParameters().get(PARAMETERS); + assert (innerParams instanceof Map); + { + Map algorithmParams = new HashMap<>(); + Map innerMap = (Map) innerParams; + switch (methodName) { + case METHOD_HNSW -> { + algorithmParams.put( + METHOD_PARAMETER_EF_CONSTRUCTION, + innerMap.getOrDefault(METHOD_PARAMETER_EF_CONSTRUCTION, INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION) + ); + algorithmParams.put( + METHOD_PARAMETER_EF_SEARCH, + innerMap.getOrDefault(METHOD_PARAMETER_EF_SEARCH, INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH) + ); + Object indexDescription = indexInfo.getParameters().get(INDEX_DESCRIPTION_PARAMETER); + assert indexDescription instanceof String; + algorithmParams.put(METHOD_PARAMETER_M, getMFromIndexDescription((String) indexDescription)); + } + case METHOD_IVF -> { + algorithmParams.put( + METHOD_PARAMETER_NLIST, + innerMap.getOrDefault(METHOD_PARAMETER_NLIST, METHOD_PARAMETER_NLIST_DEFAULT) + ); + algorithmParams.put( + METHOD_PARAMETER_NPROBES, + innerMap.getOrDefault(METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NPROBES_DEFAULT) + ); + } + } + indexParameters.put(ALGORITHM_PARAMETERS, algorithmParams); + } + return indexParameters; + } + + public static int getMFromIndexDescription(String indexDescription) { + int commaIndex = indexDescription.indexOf(","); + if (commaIndex == -1) { + throw new IllegalArgumentException("Invalid index description: " + indexDescription); + } + String hnswPart = indexDescription.substring(0, commaIndex); + int m = Integer.parseInt(hnswPart.substring(4)); + assert (m > 1 && m < 100); + return m; + } + @Override public ResolvedMethodContext resolveMethod( KNNMethodContext knnMethodContext, diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/Lucene.java b/src/main/java/org/opensearch/knn/index/engine/lucene/Lucene.java index db516d309a..c98185fd19 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/Lucene.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/Lucene.java @@ -8,7 +8,9 @@ import com.google.common.collect.ImmutableMap; import org.apache.lucene.util.Version; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.engine.JVMLibrary; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNMethod; import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -89,6 +91,11 @@ public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { return score; } + @Override + public Map getRemoteIndexingParameters(BuildIndexParams params) { + throw new UnsupportedOperationException(KNNEngine.LUCENE.getName() + " engine not supported for remote index build"); + } + @Override public List mmapFileExtensions() { return List.of("vec", "vex"); diff --git a/src/main/java/org/opensearch/knn/index/engine/nmslib/Nmslib.java b/src/main/java/org/opensearch/knn/index/engine/nmslib/Nmslib.java index 4d7f7f4237..6a13d47274 100644 --- a/src/main/java/org/opensearch/knn/index/engine/nmslib/Nmslib.java +++ b/src/main/java/org/opensearch/knn/index/engine/nmslib/Nmslib.java @@ -7,6 +7,8 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNMethod; import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -60,6 +62,11 @@ public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { return score; } + @Override + public Map getRemoteIndexingParameters(BuildIndexParams params) { + throw new UnsupportedOperationException(KNNEngine.NMSLIB.getName() + " not supported for remote index build"); + } + @Override public ResolvedMethodContext resolveMethod( KNNMethodContext knnMethodContext, diff --git a/src/main/java/org/opensearch/knn/index/remote/HTTPRemoteBuildRequest.java b/src/main/java/org/opensearch/knn/index/remote/HTTPRemoteBuildRequest.java new file mode 100644 index 0000000000..c570257ce8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/remote/HTTPRemoteBuildRequest.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import lombok.Getter; +import org.opensearch.cluster.metadata.RepositoryMetadata; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexSettings; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; + +import java.io.IOException; +import java.net.URI; + +import static org.opensearch.knn.common.KNNConstants.CONTAINER_NAME; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.DOC_COUNT; +import static org.opensearch.knn.common.KNNConstants.VECTOR_PATH; +import static org.opensearch.knn.common.KNNConstants.DOC_ID_PATH; +import static org.opensearch.knn.common.KNNConstants.INDEX_PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.REPOSITORY_TYPE; +import static org.opensearch.knn.common.KNNConstants.TENANT_ID; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING; + +/** + * RemoteBuildRequest implementation for HTTP clients that sets the endpoint and offers a JSON conversion. + */ +@Getter +public class HTTPRemoteBuildRequest extends RemoteBuildRequest { + private final URI endpoint; + + public HTTPRemoteBuildRequest( + IndexSettings indexSettings, + BuildIndexParams indexInfo, + RepositoryMetadata repositoryMetadata, + String blobName + ) throws IOException { + super(indexSettings, indexInfo, repositoryMetadata, blobName); + this.endpoint = URI.create(KNNSettings.state().getSettingValue(KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING.getKey())); + } + + public String toJson() throws IOException { + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + builder.startObject(); + builder.field(REPOSITORY_TYPE, repositoryType); + builder.field(CONTAINER_NAME, containerName); + builder.field(VECTOR_PATH, vectorPath); + builder.field(DOC_ID_PATH, docIdPath); + builder.field(TENANT_ID, tenantId); + builder.field(DIMENSION, dimension); + builder.field(DOC_COUNT, docCount); + builder.field(VECTOR_DATA_TYPE_FIELD, dataType); + builder.field(KNN_ENGINE, engine); + builder.field(INDEX_PARAMETERS, indexParameters); + builder.endObject(); + return builder.toString(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/remote/HTTPRemoteBuildResponse.java b/src/main/java/org/opensearch/knn/index/remote/HTTPRemoteBuildResponse.java new file mode 100644 index 0000000000..037a05a18a --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/remote/HTTPRemoteBuildResponse.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import lombok.Getter; + +import java.net.URI; + +/** + * HTTP-specific implementation of RemoteBuildResponse to pass the endpoint back to awaitVectorBuild + */ +@Getter +public class HTTPRemoteBuildResponse implements RemoteBuildResponse { + private final String jobId; + private final URI endpoint; + + public HTTPRemoteBuildResponse(String requestId, URI endpoint) { + this.jobId = requestId; + this.endpoint = endpoint; + } +} diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteBuildRequest.java b/src/main/java/org/opensearch/knn/index/remote/RemoteBuildRequest.java index 9cd19173ff..225a660db3 100644 --- a/src/main/java/org/opensearch/knn/index/remote/RemoteBuildRequest.java +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteBuildRequest.java @@ -5,46 +5,72 @@ package org.opensearch.knn.index.remote; -import org.opensearch.common.xcontent.json.JsonXContent; -import lombok.Builder; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.metadata.RepositoryMetadata; import lombok.Getter; -import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexSettings; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import java.io.IOException; -import java.util.HashMap; import java.util.Map; -@Builder +import static org.opensearch.knn.common.KNNConstants.BUCKET; +import static org.opensearch.knn.common.KNNConstants.S3; + @Getter -public class RemoteBuildRequest { - private final String repositoryType; - private final String containerName; - private final String vectorPath; - private final String docIdPath; - private final String tenantId; - private final int dimension; - private final int docCount; - private final String dataType; - private final String engine; - @Builder.Default - private final Map indexParameters = new HashMap<>(); - - public String toJson() throws IOException { - try (XContentBuilder builder = JsonXContent.contentBuilder()) { - builder.startObject(); - builder.field("repository_type", repositoryType); - builder.field("container_name", containerName); - builder.field("vector_path", vectorPath); - builder.field("doc_id_path", docIdPath); - builder.field("tenant_id", tenantId); - builder.field("dimension", dimension); - builder.field("doc_count", docCount); - builder.field("data_type", dataType); - builder.field("engine", engine); - builder.field("index_parameters", indexParameters); - builder.endObject(); - return builder.toString(); +public abstract class RemoteBuildRequest { + public String repositoryType; + public String containerName; + public String vectorPath; + public String docIdPath; + public String tenantId; + public int dimension; + public int docCount; + public String dataType; + public String engine; + public Map indexParameters; + + /** + * Construct the RemoteBuildRequest object for the index build request + * @param indexSettings Index settings + * @param indexInfo Index parameters + * @param repositoryMetadata Metadata of the repository containing the index + * @param blobName File name generated by the Build Strategy with a UUID + */ + public RemoteBuildRequest( + IndexSettings indexSettings, + BuildIndexParams indexInfo, + RepositoryMetadata repositoryMetadata, + String blobName + ) throws IOException { + String repositoryType = repositoryMetadata.type(); + String containerName; + switch (repositoryType) { + case S3 -> containerName = repositoryMetadata.settings().get(BUCKET); + default -> throw new IllegalArgumentException( + "Repository type " + repositoryType + " is not supported by the remote build service" + ); } - } + String vectorDataType = indexInfo.getVectorDataType().getValue(); + KNNVectorValues vectorValues = indexInfo.getKnnVectorValuesSupplier().get(); + KNNCodecUtil.initializeVectorValues(vectorValues); + assert (vectorValues.dimension() > 0); + + Map indexParameters = indexInfo.getKnnEngine().getRemoteIndexingParameters(indexInfo); + + this.repositoryType = repositoryType; + this.containerName = containerName; + this.vectorPath = blobName + RemoteIndexBuildStrategy.VECTOR_BLOB_FILE_EXTENSION; + this.docIdPath = blobName + RemoteIndexBuildStrategy.DOC_ID_FILE_EXTENSION; + this.tenantId = indexSettings.getSettings().get(ClusterName.CLUSTER_NAME_SETTING.getKey()); + this.dimension = vectorValues.dimension(); + this.docCount = indexInfo.getTotalLiveDocs(); + this.dataType = vectorDataType; + this.engine = indexInfo.getKnnEngine().getName(); + this.indexParameters = indexParameters; + } } diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteBuildResponse.java b/src/main/java/org/opensearch/knn/index/remote/RemoteBuildResponse.java new file mode 100644 index 0000000000..868d0e8ac0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteBuildResponse.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +/** + * Generic remote build response interface + */ +public interface RemoteBuildResponse { + String getJobId(); +} diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClient.java b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClient.java index 6da0c8e23d..33293e3db5 100644 --- a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClient.java +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClient.java @@ -8,33 +8,41 @@ import org.opensearch.cluster.metadata.RepositoryMetadata; import org.opensearch.index.IndexSettings; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.nativeindex.remote.RemoteStatusResponse; import java.io.IOException; /** * Interface which dictates how we interact with a remote index build service. */ -interface RemoteIndexClient { +public interface RemoteIndexClient { + + /** + * Submit a build to the Remote Vector Build Service. + * @return RemoteBuildResponse from the server + */ + RemoteBuildResponse submitVectorBuild(RemoteBuildRequest remoteBuildRequest) throws IOException; + + /** + * Await the completion of the index build and for the server to return the path to the completed index + * @param remoteBuildResponse the /_build request response from the server + * @return remoteStatusResponse from the server + */ + RemoteStatusResponse awaitVectorBuild(RemoteBuildResponse remoteBuildResponse); + /** - * Submit an index build request to the build service endpoint. - * @param indexSettings IndexSettings for the index being built - * @param indexInfo BuildIndexParams for the index being built - * @param repositoryMetadata RepositoryMetadata representing the registered repo - * @param blobName The name of the blob written to the repo, to be suffixed with ".knnvec" or ".knndid" - * @return job_id from the server response used to track the job - * @throws IOException if there is an issue with the request + * Construct the RemoteBuildRequest from the given parameters + * @param indexSettings IndexSettings to use to get the repository metadata + * @param indexInfo BuildIndexParams to use to get the index info + * @param repositoryMetadata RepositoryMetadata to use to get the repository type + * @param blobName blob name to use to get the blob name + * @return RemoteBuildRequest to use to submit the build + * @throws IOException if there is an error constructing the request */ - String submitVectorBuild( + RemoteBuildRequest constructBuildRequest( IndexSettings indexSettings, BuildIndexParams indexInfo, RepositoryMetadata repositoryMetadata, String blobName ) throws IOException; - - /** - * Await the completion of the index build and for the server to return the path to the completed index - * @param jobId identifier from the server to track the job - * @return the path to the completed index - */ - String awaitVectorBuild(String jobId); } diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClientRetryStrategy.java b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClientRetryStrategy.java index 256ed017d8..8cabaf2b6a 100644 --- a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClientRetryStrategy.java +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexClientRetryStrategy.java @@ -16,6 +16,13 @@ import java.net.UnknownHostException; import java.util.List; +import static org.apache.hc.core5.http.HttpStatus.SC_BAD_GATEWAY; +import static org.apache.hc.core5.http.HttpStatus.SC_CONFLICT; +import static org.apache.hc.core5.http.HttpStatus.SC_INTERNAL_SERVER_ERROR; +import static org.apache.hc.core5.http.HttpStatus.SC_REQUEST_TIMEOUT; +import static org.apache.hc.core5.http.HttpStatus.SC_SERVICE_UNAVAILABLE; +import static org.apache.hc.core5.http.HttpStatus.SC_TOO_MANY_REQUESTS; + /** * The public constructors for the Apache HTTP client default retry strategies allow customization of max retries * and retry interval, but not retryable status codes. @@ -23,12 +30,24 @@ * @see org.apache.hc.client5.http.impl.DefaultHttpRequestRetryStrategy */ public class RemoteIndexClientRetryStrategy extends DefaultHttpRequestRetryStrategy { - private static final List retryableCodes = List.of(408, 429, 500, 502, 503, 504, 509); + private static final int SC_BANDWIDTH_LIMIT_EXCEEDED = 509; + private static final int MAX_RETRIES = 1; // 2 total attempts + private static final long BASE_DELAY_MS = 1000; + + private static final List retryableCodes = List.of( + SC_REQUEST_TIMEOUT, + SC_TOO_MANY_REQUESTS, + SC_INTERNAL_SERVER_ERROR, + SC_BAD_GATEWAY, + SC_SERVICE_UNAVAILABLE, + SC_CONFLICT, + SC_BANDWIDTH_LIMIT_EXCEEDED + ); public RemoteIndexClientRetryStrategy() { super( - RemoteIndexHTTPClient.MAX_RETRIES, - TimeValue.ofMilliseconds(RemoteIndexHTTPClient.BASE_DELAY_MS), + MAX_RETRIES, + TimeValue.ofMilliseconds(BASE_DELAY_MS), List.of( InterruptedIOException.class, UnknownHostException.class, diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexHTTPClient.java b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexHTTPClient.java index 5d32f0ab75..2e18bcd830 100644 --- a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexHTTPClient.java +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexHTTPClient.java @@ -14,51 +14,29 @@ import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; import org.apache.hc.client5.http.impl.classic.HttpClients; import org.apache.hc.client5.http.utils.Base64; +import org.apache.hc.core5.http.ContentType; import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.HttpStatus; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.io.entity.StringEntity; -import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.metadata.RepositoryMetadata; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.settings.SecureString; import org.opensearch.index.IndexSettings; -import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; -import org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.codec.nativeindex.remote.RemoteStatusResponse; import java.io.Closeable; import java.io.IOException; -import java.net.URI; import java.nio.charset.StandardCharsets; import java.security.AccessController; -import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; -import java.util.HashMap; -import java.util.Map; -import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; -import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST_DEFAULT; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_DEFAULT; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.NAME; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; -import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION; -import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH; -import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M; -import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE; +import static org.apache.hc.core5.http.HttpStatus.SC_OK; +import static org.opensearch.knn.common.KNNConstants.JOB_ID; import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_BUILD_CLIENT_PASSWORD_SETTING; import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_BUILD_CLIENT_USERNAME_SETTING; -import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING; /** * Class to handle all interactions with the remote vector build service. @@ -66,11 +44,9 @@ */ @Log4j2 public class RemoteIndexHTTPClient implements RemoteIndexClient, Closeable { + public static final String BASIC_PREFIX = "Basic "; private static RemoteIndexHTTPClient INSTANCE; private volatile CloseableHttpClient httpClient; - protected static final int MAX_RETRIES = 1; // 2 total attempts - protected static final long BASE_DELAY_MS = 1000; - protected static final String BUILD_ENDPOINT = "/_build"; private static final ObjectMapper objectMapper = new ObjectMapper(); private String authHeader = null; @@ -100,28 +76,18 @@ private static CloseableHttpClient createHttpClient() { /** * Submit a build to the Remote Vector Build Service endpoint. - * @return job_id from the server response used to track the job + * @return RemoteBuildResponse containing job_id from the server response used to track the job */ @Override - public String submitVectorBuild( - IndexSettings indexSettings, - BuildIndexParams indexInfo, - RepositoryMetadata repositoryMetadata, - String blobName - ) throws IOException { - RemoteBuildRequest request = constructBuildRequest(indexSettings, indexInfo, repositoryMetadata, blobName); - URI endpoint = URI.create(KNNSettings.state().getSettingValue(KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING.getKey())); - HttpPost buildRequest = new HttpPost(endpoint + BUILD_ENDPOINT); - buildRequest.setHeader("Content-Type", "application/json"); - buildRequest.setEntity(new StringEntity(request.toJson())); - if (authHeader != null) { - buildRequest.setHeader(HttpHeaders.AUTHORIZATION, authHeader); - } + public RemoteBuildResponse submitVectorBuild(RemoteBuildRequest remoteBuildRequest) throws IOException { + assert (remoteBuildRequest instanceof HTTPRemoteBuildRequest); + HTTPRemoteBuildRequest request = (HTTPRemoteBuildRequest) remoteBuildRequest; + HttpPost buildRequest = getHttpPost(request); try { String response = AccessController.doPrivileged( (PrivilegedExceptionAction) () -> httpClient.execute(buildRequest, body -> { - if (body.getCode() != 200) { + if (body.getCode() < SC_OK || body.getCode() > HttpStatus.SC_MULTIPLE_CHOICES) { throw new IOException("Failed to submit build request, got status code: " + body.getCode()); } return EntityUtils.toString(body.getEntity()); @@ -129,25 +95,53 @@ public String submitVectorBuild( ); if (response == null) { - throw new IOException("Received 200 status code but response is null."); + throw new IOException("Received success status code but response is null."); } - return getValueFromResponse(response, "job_id"); - } catch (PrivilegedActionException e) { - throw new IOException("Failed to execute HTTP request", e.getException()); + return new HTTPRemoteBuildResponse(getValueFromResponse(response, JOB_ID), request.getEndpoint()); + } catch (Exception e) { + throw new IOException("Failed to execute HTTP request", e); } } + private HttpPost getHttpPost(HTTPRemoteBuildRequest request) throws IOException { + HttpPost buildRequest = new HttpPost(request.getEndpoint() + KNNConstants.BUILD_ENDPOINT); + buildRequest.setHeader(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.toString()); + buildRequest.setEntity(new StringEntity(request.toJson())); + if (authHeader != null) { + buildRequest.setHeader(HttpHeaders.AUTHORIZATION, authHeader); + } + return buildRequest; + } + /** * Await the completion of the index build by polling periodically and handling the returned statuses until timeout. - * @param jobId identifier from the server to track the job - * @return the path to the completed index + * @param remoteBuildResponse containing job_id from the server response used to track the job + * @return RemoteStatusResponse containing the path to the completed index */ @Override - public String awaitVectorBuild(String jobId) { + public RemoteStatusResponse awaitVectorBuild(RemoteBuildResponse remoteBuildResponse) { throw new NotImplementedException(); } + /** + * Construct the HTTP specific build request (with endpoint and .toJson method) + * @param indexSettings IndexSettings for the index being built + * @param indexInfo BuildIndexParams for the index being built + * @param repositoryMetadata Metadata of the vector repository + * @param blobName The name of the blob written to the repo, to be suffixed with ".knnvec" or ".knndid" + * @return RemoteBuildRequest with parameters set + */ + @Override + public RemoteBuildRequest constructBuildRequest( + IndexSettings indexSettings, + BuildIndexParams indexInfo, + RepositoryMetadata repositoryMetadata, + String blobName + ) throws IOException { + return new HTTPRemoteBuildRequest(indexSettings, indexInfo, repositoryMetadata, blobName); + } + /** * Given a JSON response string, get a value for a specific key. Converts json {@literal } to Java null. * @param responseBody The response to read @@ -167,113 +161,10 @@ static String getValueFromResponse(String responseBody, String key) throws JsonP } /** - * Construct the RemoteBuildRequest object for the index build request - * @param indexSettings Index settings - * @param indexInfo Index parameters - * @param repositoryMetadata Metadata of the repository containing the index - * @param blobName File name generated by the Build Strategy with a UUID - * @return RemoteBuildRequest with parameters set + * Set the global auth header to use the refreshed secure settings + * @param settings Settings to use to get the credentials */ - RemoteBuildRequest constructBuildRequest( - IndexSettings indexSettings, - BuildIndexParams indexInfo, - RepositoryMetadata repositoryMetadata, - String blobName - ) throws IOException { - String repositoryType = repositoryMetadata.type(); - String containerName; - switch (repositoryType) { - case "s3" -> containerName = repositoryMetadata.settings().get("bucket"); - default -> throw new IllegalArgumentException( - "Repository type " + repositoryType + " is not supported by the remote build service" - ); - } - String vectorDataType = indexInfo.getVectorDataType().getValue(); - - KNNVectorValues vectorValues = indexInfo.getKnnVectorValuesSupplier().get(); - KNNCodecUtil.initializeVectorValues(vectorValues); - assert (vectorValues.dimension() > 0); - - Map indexParameters = null; - if (indexInfo.getParameters() != null) { - indexParameters = constructIndexParams(indexInfo); - } - - return RemoteBuildRequest.builder() - .repositoryType(repositoryType) - .containerName(containerName) - .vectorPath(blobName + RemoteIndexBuildStrategy.VECTOR_BLOB_FILE_EXTENSION) - .docIdPath(blobName + RemoteIndexBuildStrategy.DOC_ID_FILE_EXTENSION) - .tenantId(indexSettings.getSettings().get(ClusterName.CLUSTER_NAME_SETTING.getKey())) - .dimension(vectorValues.dimension()) - .docCount(indexInfo.getTotalLiveDocs()) - .dataType(vectorDataType) - .engine(indexInfo.getKnnEngine().getName()) - .indexParameters(indexParameters) - .build(); - } - - /** - * Helper method to construct the index parameter object. Depending on the engine and algorithm, different parameters are needed. - * @param indexInfo Index parameters - * @return Map of necessary index parameters - */ - private Map constructIndexParams(BuildIndexParams indexInfo) { - Map indexParameters = new HashMap<>(); - String methodName = (String) indexInfo.getParameters().get(NAME); - indexParameters.put("algorithm", methodName); - indexParameters.put(METHOD_PARAMETER_SPACE_TYPE, indexInfo.getParameters().getOrDefault(SPACE_TYPE, INDEX_KNN_DEFAULT_SPACE_TYPE)); - - if (indexInfo.getParameters().containsKey(PARAMETERS)) { - Object innerParams = indexInfo.getParameters().get(PARAMETERS); - if (innerParams instanceof Map) { - Map algorithmParams = new HashMap<>(); - Map innerMap = (Map) innerParams; - switch (methodName) { - case METHOD_HNSW -> { - algorithmParams.put( - METHOD_PARAMETER_EF_CONSTRUCTION, - innerMap.getOrDefault(METHOD_PARAMETER_EF_CONSTRUCTION, INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION) - ); - algorithmParams.put(METHOD_PARAMETER_M, innerMap.getOrDefault(METHOD_PARAMETER_M, INDEX_KNN_DEFAULT_ALGO_PARAM_M)); - if (indexInfo.getKnnEngine().getName().equals(FAISS_NAME)) { - algorithmParams.put( - METHOD_PARAMETER_EF_SEARCH, - innerMap.getOrDefault(METHOD_PARAMETER_EF_SEARCH, INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH) - ); - } - } - case METHOD_IVF -> { - algorithmParams.put( - METHOD_PARAMETER_NLIST, - innerMap.getOrDefault(METHOD_PARAMETER_NLIST, METHOD_PARAMETER_NLIST_DEFAULT) - ); - algorithmParams.put( - METHOD_PARAMETER_NPROBES, - innerMap.getOrDefault(METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NPROBES_DEFAULT) - ); - } - } - indexParameters.put("algorithm_parameters", algorithmParams); - } - } - return indexParameters; - } - - /** - * Close the httpClient - */ - public void close() throws IOException { - if (httpClient != null) { - httpClient.close(); - } - } - - /** - * Rebuild the httpClient with the new credentials - * @param settings Settings to use to get the new credentials - */ - public void reloadSecureSettings(Settings settings) { + public void reloadAuthHeader(Settings settings) { SecureString username = KNN_REMOTE_BUILD_CLIENT_USERNAME_SETTING.get(settings); SecureString password = KNN_REMOTE_BUILD_CLIENT_PASSWORD_SETTING.get(settings); @@ -283,9 +174,18 @@ public void reloadSecureSettings(Settings settings) { } final String auth = username + ":" + password.clone(); final byte[] encodedAuth = Base64.encodeBase64(auth.getBytes(StandardCharsets.ISO_8859_1)); - this.authHeader = "Basic " + new String(encodedAuth); + this.authHeader = BASIC_PREFIX + new String(encodedAuth); } else { this.authHeader = null; } } + + /** + * Close the httpClient + */ + public void close() throws IOException { + if (httpClient != null) { + httpClient.close(); + } + } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 07665f2e9e..812a6071c1 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -391,9 +391,9 @@ public void onNodeStarted(DiscoveryNode localNode) { * Update the secure settings by passing the updated settings down upon reload */ @Override - public void reload(Settings settings) throws Exception { + public void reload(Settings settings) { if (KNNFeatureFlags.isKNNRemoteVectorBuildEnabled()) { - RemoteIndexHTTPClient.getInstance().reloadSecureSettings(settings); + RemoteIndexHTTPClient.getInstance().reloadAuthHeader(settings); } } } diff --git a/src/test/java/org/opensearch/knn/index/remote/RemoteIndexHTTPClientTests.java b/src/test/java/org/opensearch/knn/index/remote/RemoteIndexHTTPClientTests.java index fef98e65b4..1fe404d59b 100644 --- a/src/test/java/org/opensearch/knn/index/remote/RemoteIndexHTTPClientTests.java +++ b/src/test/java/org/opensearch/knn/index/remote/RemoteIndexHTTPClientTests.java @@ -2,14 +2,15 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.knn.index.remote; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.codec.binary.Base64; import org.apache.hc.client5.http.classic.methods.HttpPost; import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.ProtocolException; import org.apache.hc.core5.http.io.HttpClientResponseHandler; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -21,23 +22,22 @@ import org.opensearch.cluster.metadata.RepositoryMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.MockSecureSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.index.IndexSettings; -import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; -import org.opensearch.repositories.RepositoriesService; -import org.opensearch.repositories.blobstore.BlobStoreRepository; import org.opensearch.test.OpenSearchSingleNodeTestCase; import java.io.IOException; +import java.net.URI; import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -46,29 +46,32 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.*; +import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_BUILD_CLIENT_PASSWORD_SETTING; +import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_BUILD_CLIENT_USERNAME_SETTING; import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING; import static org.opensearch.knn.index.SpaceType.L2; import static org.opensearch.knn.index.VectorDataType.FLOAT; import static org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy.DOC_ID_FILE_EXTENSION; import static org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy.VECTOR_BLOB_FILE_EXTENSION; +import static org.opensearch.knn.index.engine.faiss.Faiss.getMFromIndexDescription; +import static org.opensearch.knn.index.remote.RemoteIndexHTTPClient.BASIC_PREFIX; public class RemoteIndexHTTPClientTests extends OpenSearchSingleNodeTestCase { public static final String S3 = "s3"; public static final String TEST_BUCKET = "test-bucket"; - public static final String BLOB = "blob"; public static final String TEST_CLUSTER = "test-cluster"; public static final String MOCK_JOB_ID_RESPONSE = "{\"job_id\": \"job-1739930402\"}"; public static final String MOCK_JOB_ID = "job-1739930402"; public static final String MOCK_BLOB_NAME = "blob"; public static final String MOCK_ENDPOINT = "https://mock-build-service.com"; + public static final String USERNAME = "username"; + public static final String PASSWORD = "password"; @Mock protected ClusterService clusterService; @@ -92,120 +95,75 @@ public void testGetHttpClient_success() throws IOException { client.close(); } - public void testConstructBuildRequestJson() throws IOException { - Map algorithmParams = new HashMap<>(); - algorithmParams.put(METHOD_PARAMETER_EF_CONSTRUCTION, 100); - algorithmParams.put(METHOD_PARAMETER_M, 16); - - Map indexParameters = new HashMap<>(); - indexParameters.put("algorithm", METHOD_HNSW); - indexParameters.put(METHOD_PARAMETER_SPACE_TYPE, L2.getValue()); - indexParameters.put("algorithm_parameters", algorithmParams); - - RemoteBuildRequest request = RemoteBuildRequest.builder() - .repositoryType(S3) - .containerName(TEST_BUCKET) - .vectorPath(BLOB + VECTOR_BLOB_FILE_EXTENSION) - .docIdPath(BLOB + DOC_ID_FILE_EXTENSION) - .tenantId(TEST_CLUSTER) - .dimension(256) - .docCount(1_000_000) - .dataType(FLOAT.getValue()) - .engine(FAISS_NAME) - .indexParameters(indexParameters) - .build(); - - String expectedJson = "{" - + "\"repository_type\":\"s3\"," - + "\"container_name\":\"test-bucket\"," - + "\"vector_path\":\"blob.knnvec\"," - + "\"doc_id_path\":\"blob.knndid\"," - + "\"tenant_id\":\"test-cluster\"," - + "\"dimension\":256," - + "\"doc_count\":1000000," - + "\"data_type\":\"float\"," - + "\"engine\":\"faiss\"," - + "\"index_parameters\":{" - + "\"space_type\":\"l2\"," - + "\"algorithm\":\"hnsw\"," - + "\"algorithm_parameters\":{" - + "\"ef_construction\":100," - + "\"m\":16" - + "}" - + "}" - + "}"; - assertEquals(mapper.readTree(expectedJson), mapper.readTree(request.toJson())); - } - public void testGetValueFromResponse() throws JsonProcessingException { String jobID = "{\"job_id\": \"job-1739930402\"}"; - assertEquals("job-1739930402", RemoteIndexHTTPClient.getValueFromResponse(jobID, "job_id")); + assertEquals("job-1739930402", RemoteIndexHTTPClient.getValueFromResponse(jobID, JOB_ID)); String failedIndexBuild = "{" + "\"task_status\":\"FAILED_INDEX_BUILD\"," - + "\"error\":\"Index build process interrupted.\"," + + "\"error_message\":\"Index build process interrupted.\"," + "\"index_path\": null" + "}"; - String error = RemoteIndexHTTPClient.getValueFromResponse(failedIndexBuild, "error"); + String error = RemoteIndexHTTPClient.getValueFromResponse(failedIndexBuild, ERROR_MESSAGE); assertEquals("Index build process interrupted.", error); - assertNull(RemoteIndexHTTPClient.getValueFromResponse(failedIndexBuild, "index_path")); + assertNull(RemoteIndexHTTPClient.getValueFromResponse(failedIndexBuild, INDEX_PATH)); } - public void testBuildRequest() { - RepositoryMetadata metadata = mock(RepositoryMetadata.class); - Settings repoSettings = Settings.builder().put("bucket", TEST_BUCKET).build(); - when(metadata.type()).thenReturn(S3); - when(metadata.settings()).thenReturn(repoSettings); + public void testGetMFromIndexDescription() { + assertEquals(16, getMFromIndexDescription("HNSW16,Flat")); + assertEquals(8, getMFromIndexDescription("HNSW8,SQ")); + assertThrows(IllegalArgumentException.class, () -> getMFromIndexDescription("Invalid description")); + } + public void testBuildRequest() { + RepositoryMetadata metadata = createTestRepositoryMetadata(); KNNSettings knnSettingsMock = mock(KNNSettings.class); - IndexSettings mockIndexSettings = mock(IndexSettings.class); - Settings indexSettingsSettings = Settings.builder().put(ClusterName.CLUSTER_NAME_SETTING.getKey(), TEST_CLUSTER).build(); - when(mockIndexSettings.getSettings()).thenReturn(indexSettingsSettings); + IndexSettings mockIndexSettings = createTestIndexSettings(); + setupTestClusterSettings(); try (MockedStatic knnSettingsStaticMock = Mockito.mockStatic(KNNSettings.class)) { knnSettingsStaticMock.when(KNNSettings::state).thenReturn(knnSettingsMock); + when(knnSettingsMock.getSettingValue(KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING.getKey())).thenReturn(MOCK_ENDPOINT); + KNNSettings.state().setClusterService(clusterService); - List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }); - final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( - vectorValues - ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(FLOAT, randomVectorValues); - - Map algorithmParams = Map.of(METHOD_PARAMETER_EF_CONSTRUCTION, 94, METHOD_PARAMETER_M, 2); - - BuildIndexParams buildIndexParams = BuildIndexParams.builder() - .knnEngine(KNNEngine.FAISS) - .vectorDataType(FLOAT) - .parameters( - Map.of( - KNNConstants.SPACE_TYPE, - SpaceType.HAMMING.getValue(), - KNNConstants.NAME, - KNNConstants.METHOD_HNSW, - PARAMETERS, - algorithmParams - ) - ) - .knnVectorValuesSupplier(() -> knnVectorValues) - .totalLiveDocs(vectorValues.size()) - .build(); + BuildIndexParams buildIndexParams = createTestBuildIndexParams(); RemoteBuildRequest request = RemoteIndexHTTPClient.getInstance() - .constructBuildRequest(mockIndexSettings, buildIndexParams, metadata, "blob"); + .constructBuildRequest(mockIndexSettings, buildIndexParams, metadata, MOCK_BLOB_NAME); assertEquals(S3, request.getRepositoryType()); assertEquals(TEST_BUCKET, request.getContainerName()); - assertEquals(KNNConstants.FAISS_NAME, request.getEngine()); + assertEquals(FAISS_NAME, request.getEngine()); assertEquals(FLOAT.getValue(), request.getDataType()); - assertEquals(BLOB + VECTOR_BLOB_FILE_EXTENSION, request.getVectorPath()); - assertEquals(BLOB + DOC_ID_FILE_EXTENSION, request.getDocIdPath()); + assertEquals(MOCK_BLOB_NAME + VECTOR_BLOB_FILE_EXTENSION, request.getVectorPath()); + assertEquals(MOCK_BLOB_NAME + DOC_ID_FILE_EXTENSION, request.getDocIdPath()); assertEquals(TEST_CLUSTER, request.getTenantId()); - assertEquals(vectorValues.size(), request.getDocCount()); + assertEquals(2, request.getDocCount()); assertEquals(2, request.getDimension()); - assertEquals(request.getIndexParameters().get(METHOD_PARAMETER_SPACE_TYPE), SpaceType.HAMMING.getValue()); - Object algorithmParameters = request.getIndexParameters().get("algorithm_parameters"); - Map algoMap = (Map) algorithmParameters; - assertEquals(2, algoMap.get(METHOD_PARAMETER_M)); - assertEquals(94, algoMap.get(METHOD_PARAMETER_EF_CONSTRUCTION)); + + HTTPRemoteBuildRequest httpRequest = (HTTPRemoteBuildRequest) request; + assertEquals(URI.create(MOCK_ENDPOINT), httpRequest.getEndpoint()); + + String expectedJson = "{" + + "\"repository_type\":\"s3\"," + + "\"container_name\":\"test-bucket\"," + + "\"vector_path\":\"blob.knnvec\"," + + "\"doc_id_path\":\"blob.knndid\"," + + "\"tenant_id\":\"test-cluster\"," + + "\"dimension\":2," + + "\"doc_count\":2," + + "\"data_type\":\"float\"," + + "\"engine\":\"faiss\"," + + "\"index_parameters\":{" + + "\"space_type\":\"l2\"," + + "\"algorithm\":\"hnsw\"," + + "\"algorithm_parameters\":{" + + "\"ef_construction\":94," + + "\"ef_search\":89," + + "\"m\":14" + + "}" + + "}" + + "}"; + assertEquals(mapper.readTree(expectedJson), mapper.readTree(httpRequest.toJson())); } catch (IOException e) { throw new RuntimeException(e); } @@ -219,47 +177,143 @@ public void testSubmitVectorBuild() throws IOException, URISyntaxException { response -> MOCK_JOB_ID_RESPONSE ); - RepositoriesService repositoriesService = mock(RepositoriesService.class); - BlobStoreRepository blobStoreRepository = mock(BlobStoreRepository.class); - RepositoryMetadata metadata = mock(RepositoryMetadata.class); - Settings repoSettings = Settings.builder().put("bucket", TEST_BUCKET).build(); + RepositoryMetadata metadata = createTestRepositoryMetadata(); + KNNSettings knnSettingsMock = mock(KNNSettings.class); + IndexSettings mockIndexSettings = createTestIndexSettings(); + setupTestClusterSettings(); - when(metadata.type()).thenReturn(S3); - when(metadata.settings()).thenReturn(repoSettings); - when(blobStoreRepository.getMetadata()).thenReturn(metadata); - when(repositoriesService.repository("test-repo")).thenReturn(blobStoreRepository); + try (MockedStatic knnSettingsStaticMock = Mockito.mockStatic(KNNSettings.class)) { + knnSettingsStaticMock.when(KNNSettings::state).thenReturn(knnSettingsMock); + when(knnSettingsMock.getSettingValue(KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING.getKey())).thenReturn(MOCK_ENDPOINT); + KNNSettings.state().setClusterService(clusterService); - IndexSettings mockIndexSettings = mock(IndexSettings.class); - Settings indexSettingsSettings = Settings.builder().put(ClusterName.CLUSTER_NAME_SETTING.getKey(), TEST_CLUSTER).build(); - when(mockIndexSettings.getSettings()).thenReturn(indexSettingsSettings); + BuildIndexParams buildIndexParams = createTestBuildIndexParams(); + + RemoteBuildResponse remoteBuildResponse = client.submitVectorBuild( + new HTTPRemoteBuildRequest(mockIndexSettings, buildIndexParams, metadata, MOCK_BLOB_NAME) + ); + assertEquals(MOCK_JOB_ID, remoteBuildResponse.getJobId()); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpPost.class); + verify(mockHttpClient).execute(requestCaptor.capture(), any(HttpClientResponseHandler.class)); + HttpPost capturedRequest = requestCaptor.getValue(); + assertEquals(MOCK_ENDPOINT + BUILD_ENDPOINT, capturedRequest.getUri().toString()); + assert (!capturedRequest.containsHeader(HttpHeaders.AUTHORIZATION)); + } + } + + public void testSecureSettingsReloadAndException() throws IOException { + final MockSecureSettings secureSettings = new MockSecureSettings(); + secureSettings.setString(KNN_REMOTE_BUILD_CLIENT_USERNAME_SETTING.getKey(), USERNAME); + secureSettings.setString(KNN_REMOTE_BUILD_CLIENT_PASSWORD_SETTING.getKey(), PASSWORD); + final Settings settings = Settings.builder().setSecureSettings(secureSettings).build(); + CloseableHttpClient mockHttpClient = mock(CloseableHttpClient.class); + RemoteIndexHTTPClient client = new RemoteIndexHTTPClient(mockHttpClient); + client.reloadAuthHeader(settings); + + when(mockHttpClient.execute(any(HttpPost.class), any(HttpClientResponseHandler.class))).thenAnswer( + response -> MOCK_JOB_ID_RESPONSE + ); + + RepositoryMetadata metadata = createTestRepositoryMetadata(); + KNNSettings knnSettingsMock = mock(KNNSettings.class); + IndexSettings mockIndexSettings = createTestIndexSettings(); + setupTestClusterSettings(); + + try (MockedStatic knnSettingsStaticMock = Mockito.mockStatic(KNNSettings.class)) { + knnSettingsStaticMock.when(KNNSettings::state).thenReturn(knnSettingsMock); + when(knnSettingsMock.getSettingValue(KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING.getKey())).thenReturn(MOCK_ENDPOINT); + KNNSettings.state().setClusterService(clusterService); + + BuildIndexParams buildIndexParams = createTestBuildIndexParams(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpPost.class); + client.submitVectorBuild(new HTTPRemoteBuildRequest(mockIndexSettings, buildIndexParams, metadata, MOCK_BLOB_NAME)); + + final MockSecureSettings emptySettings = new MockSecureSettings(); + final Settings nullPasswordSettings = Settings.builder().setSecureSettings(emptySettings).build(); + + client.reloadAuthHeader(nullPasswordSettings); + client.submitVectorBuild(new HTTPRemoteBuildRequest(mockIndexSettings, buildIndexParams, metadata, MOCK_BLOB_NAME)); + + verify(mockHttpClient, times(2)).execute(requestCaptor.capture(), any(HttpClientResponseHandler.class)); + List capturedRequests = requestCaptor.getAllValues(); + + HttpPost firstRequest = capturedRequests.getFirst(); + assert (firstRequest.containsHeader(HttpHeaders.AUTHORIZATION)); + assertTrue(firstRequest.getHeader(HttpHeaders.AUTHORIZATION).getValue().startsWith(BASIC_PREFIX)); + String authHeader = firstRequest.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(); + byte[] decodedBytes = Base64.decodeBase64(authHeader.substring(6).getBytes(StandardCharsets.ISO_8859_1)); + String decodedCredentials = new String(decodedBytes, StandardCharsets.ISO_8859_1); + assertEquals("username:password", decodedCredentials); + + HttpPost secondRequest = capturedRequests.get(1); + assertFalse(secondRequest.containsHeader(HttpHeaders.AUTHORIZATION)); + + final MockSecureSettings passwordOnlySettings = new MockSecureSettings(); + passwordOnlySettings.setString(KNN_REMOTE_BUILD_CLIENT_PASSWORD_SETTING.getKey(), PASSWORD); + final Settings exceptionSettings = Settings.builder().setSecureSettings(passwordOnlySettings).build(); + + assertThrows(IllegalArgumentException.class, () -> client.reloadAuthHeader(exceptionSettings)); + } catch (ProtocolException e) { + throw new RuntimeException(e); + } + } + + // Utility methods to populate settings for build requests + + private BuildIndexParams createTestBuildIndexParams() { List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }); final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( vectorValues ); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(FLOAT, randomVectorValues); - BuildIndexParams buildIndexParams = BuildIndexParams.builder() + Map encoderParams = new HashMap<>(); + encoderParams.put(NAME, ENCODER_FLAT); + encoderParams.put(PARAMETERS, Map.of()); + + Map algorithmParams = new HashMap<>(); + algorithmParams.put(METHOD_PARAMETER_EF_SEARCH, 89); + algorithmParams.put(METHOD_PARAMETER_EF_CONSTRUCTION, 94); + algorithmParams.put(ENCODER_FLAT, encoderParams); + + Map parameters = new HashMap<>(); + parameters.put(NAME, METHOD_HNSW); + parameters.put(VECTOR_DATA_TYPE_FIELD, FLOAT.getValue()); + parameters.put(INDEX_DESCRIPTION_PARAMETER, "HNSW14,Flat"); + parameters.put(SPACE_TYPE, L2.getValue()); + parameters.put(PARAMETERS, algorithmParams); + + return BuildIndexParams.builder() .knnEngine(KNNEngine.FAISS) .vectorDataType(FLOAT) - .parameters(Map.of(KNNConstants.SPACE_TYPE, L2.getValue(), KNNConstants.NAME, KNNConstants.METHOD_HNSW)) + .parameters(parameters) .knnVectorValuesSupplier(() -> knnVectorValues) .totalLiveDocs(vectorValues.size()) .build(); + } + + private RepositoryMetadata createTestRepositoryMetadata() { + RepositoryMetadata metadata = mock(RepositoryMetadata.class); + Settings repoSettings = Settings.builder().put(BUCKET, TEST_BUCKET).build(); + when(metadata.type()).thenReturn(S3); + when(metadata.settings()).thenReturn(repoSettings); + return metadata; + } + private IndexSettings createTestIndexSettings() { + IndexSettings mockIndexSettings = mock(IndexSettings.class); + Settings indexSettingsSettings = Settings.builder().put(ClusterName.CLUSTER_NAME_SETTING.getKey(), TEST_CLUSTER).build(); + when(mockIndexSettings.getSettings()).thenReturn(indexSettingsSettings); + return mockIndexSettings; + } + + private void setupTestClusterSettings() { ClusterSettings clusterSettings = mock(ClusterSettings.class); when(clusterSettings.get(KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING)).thenReturn(MOCK_ENDPOINT); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); KNNSettings.state().setClusterService(clusterService); - - String jobId = client.submitVectorBuild(mockIndexSettings, buildIndexParams, metadata, MOCK_BLOB_NAME); - // Isolated job_id from expectedResponse - assertEquals(MOCK_JOB_ID, jobId); - - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpPost.class); - Mockito.verify(mockHttpClient).execute(requestCaptor.capture(), any(HttpClientResponseHandler.class)); - HttpPost capturedRequest = requestCaptor.getValue(); - assertEquals(MOCK_ENDPOINT + RemoteIndexHTTPClient.BUILD_ENDPOINT, capturedRequest.getUri().toString()); - assert (!capturedRequest.containsHeader(HttpHeaders.AUTHORIZATION)); } }