Skip to content

Commit

Permalink
Add secure settings, access control logic for http request
Browse files Browse the repository at this point in the history
Signed-off-by: owenhalpert <ohalpert@gmail.com>
  • Loading branch information
owenhalpert committed Feb 27, 2025
1 parent 0dcb7d9 commit a96aea8
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 84 deletions.
30 changes: 28 additions & 2 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Booleans;
import org.opensearch.common.settings.SecureSetting;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.settings.SecureString;
import org.opensearch.core.common.unit.ByteSizeUnit;
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.index.IndexModule;
Expand Down Expand Up @@ -103,6 +105,8 @@ public class KNNSettings {
public static final String KNN_REMOTE_BUILD_SERVICE_ENDPOINT = "knn.remote_index_build.client.endpoint";
public static final String KNN_REMOTE_BUILD_CLIENT_POLL_INTERVAL = "knn.remote_index_build.client.poll_interval";
public static final String KNN_REMOTE_BUILD_CLIENT_TIMEOUT = "knn.remote_index_build.client.timeout";
public static final String KNN_REMOTE_BUILD_CLIENT_USERNAME = "knn.remote_index_build.client.username";
public static final String KNN_REMOTE_BUILD_CLIENT_PASSWORD = "knn.remote_index_build.client.password";

/**
* Default setting values
Expand Down Expand Up @@ -417,7 +421,7 @@ public class KNNSettings {
IndexScope
);
/**
* Remote build service endpoint to be used for remote index build. //TODO we can add String validators on these endpoint settings
* Remote build service endpoint to be used for remote index build.
*/
public static final Setting<String> KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING = Setting.simpleString(
KNN_REMOTE_BUILD_SERVICE_ENDPOINT,
Expand Down Expand Up @@ -445,6 +449,18 @@ public class KNNSettings {
Dynamic
);

/**
* Keystore settings for build service HTTP authorization
*/
public static final Setting<SecureString> KNN_REMOTE_BUILD_CLIENT_USERNAME_SETTING = SecureSetting.secureString(
KNN_REMOTE_BUILD_CLIENT_USERNAME,
null
);
public static final Setting<SecureString> KNN_REMOTE_BUILD_CLIENT_PASSWORD_SETTING = SecureSetting.secureString(
KNN_REMOTE_BUILD_CLIENT_PASSWORD,
null
);

/**
* Dynamic settings
*/
Expand Down Expand Up @@ -648,6 +664,14 @@ private Setting<?> getSetting(String key) {
return KNN_REMOTE_BUILD_CLIENT_POLL_INTERVAL_SETTING;
}

if (KNN_REMOTE_BUILD_CLIENT_USERNAME.equals(key)) {
return KNN_REMOTE_BUILD_CLIENT_USERNAME_SETTING;
}

if (KNN_REMOTE_BUILD_CLIENT_PASSWORD.equals(key)) {
return KNN_REMOTE_BUILD_CLIENT_PASSWORD_SETTING;
}

throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
}

Expand Down Expand Up @@ -679,7 +703,9 @@ public List<Setting<?>> getSettings() {
KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING,
KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING,
KNN_REMOTE_BUILD_CLIENT_TIMEOUT_SETTING,
KNN_REMOTE_BUILD_CLIENT_POLL_INTERVAL_SETTING
KNN_REMOTE_BUILD_CLIENT_POLL_INTERVAL_SETTING,
KNN_REMOTE_BUILD_CLIENT_USERNAME_SETTING,
KNN_REMOTE_BUILD_CLIENT_PASSWORD_SETTING
);
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream()))
.collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import java.io.IOException;

/**
* Interface for a Remote Index Client. This will support future implementations for protocols such as gRPC.
* Interface which dictates how we interact with a remote index build service.
*/
public interface RemoteIndexClient {
interface RemoteIndexClient {
/**
* Submit an index build request to the build service endpoint.
* @param indexSettings IndexSettings for the index being built
Expand Down
180 changes: 112 additions & 68 deletions src/main/java/org/opensearch/knn/index/remote/RemoteIndexHTTPClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.NotImplementedException;
import org.apache.hc.client5.http.classic.methods.HttpPost;
import org.apache.hc.client5.http.classic.methods.HttpUriRequestBase;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.client5.http.utils.Base64;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.io.entity.StringEntity;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.metadata.RepositoryMetadata;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.settings.SecureString;
import org.opensearch.index.IndexSettings;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
Expand All @@ -28,10 +31,15 @@

import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
Expand All @@ -42,10 +50,14 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_DEFAULT;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION;
import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH;
import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M;
import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE;
import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_BUILD_CLIENT_PASSWORD_SETTING;
import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_BUILD_CLIENT_USERNAME_SETTING;
import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_BUILD_SERVICE_ENDPOINT_SETTING;

/**
Expand All @@ -59,20 +71,19 @@ public class RemoteIndexHTTPClient implements RemoteIndexClient {
protected static final int MAX_RETRIES = 1; // 2 total attempts
protected static final long BASE_DELAY_MS = 1000;
protected static final String BUILD_ENDPOINT = "/_build";
protected static final String FP32 = "fp32";
protected static final String FP16 = "fp16";

private static final ObjectMapper objectMapper = new ObjectMapper();

RemoteIndexHTTPClient() {
this.httpClient = createHttpClient();
}
private String authHeader = null;

/**
* Return the Singleton instance of the node's RemoteIndexClient
* @return RemoteIndexClient instance
*/
public static synchronized RemoteIndexHTTPClient getInstance() {
if (INSTANCE == null) {
INSTANCE = new RemoteIndexHTTPClient();
INSTANCE = new RemoteIndexHTTPClient(createHttpClient());
}
return INSTANCE;
}
Expand All @@ -81,10 +92,14 @@ public static synchronized RemoteIndexHTTPClient getInstance() {
* Initialize the httpClient to be used
* @return The HTTP Client
*/
private CloseableHttpClient createHttpClient() {
private static CloseableHttpClient createHttpClient() {
return HttpClients.custom().setRetryStrategy(new RemoteIndexClientRetryStrategy()).build();
}

RemoteIndexHTTPClient(CloseableHttpClient httpClient) {
this.httpClient = httpClient;
}

/**
* Submit a build to the Remote Vector Build Service endpoint.
* @return job_id from the server response used to track the job
Expand All @@ -101,19 +116,28 @@ public String submitVectorBuild(
HttpPost buildRequest = new HttpPost(endpoint + BUILD_ENDPOINT);
buildRequest.setHeader("Content-Type", "application/json");
buildRequest.setEntity(new StringEntity(request.toJson()));
if (authHeader != null) {
buildRequest.setHeader(HttpHeaders.AUTHORIZATION, authHeader);
}

String response = httpClient.execute(buildRequest, body -> {
if (body.getCode() != 200) {
throw new IOException("Failed to submit build request, got status code: " + body.getCode());
try {
String response = AccessController.doPrivileged(
(PrivilegedExceptionAction<String>) () -> httpClient.execute(buildRequest, body -> {
if (body.getCode() != 200) {
throw new IOException("Failed to submit build request, got status code: " + body.getCode());
}
return EntityUtils.toString(body.getEntity());
})
);

if (response == null) {
throw new IOException("Received 200 status code but response is null.");
}
return EntityUtils.toString(body.getEntity());
});

if (response == null) {
throw new IOException("Received 200 status code but response is null.");
return getValueFromResponse(response, "job_id");
} catch (PrivilegedActionException e) {
throw new IOException("Failed to execute HTTP request", e.getException());
}

return getValueFromResponse(response, "job_id");
}

/**
Expand All @@ -126,15 +150,6 @@ public String awaitVectorBuild(String jobId) {
throw new NotImplementedException();
}

/**
* Helper method to directly get the status response for a given job ID
* @param jobId to check
* @return The entire response for the status request
*/
private String getBuildStatus(String jobId) throws IOException {
throw new NotImplementedException();
}

/**
* Given a JSON response string, get a value for a specific key. Converts json {@literal <null>} to Java null.
* @param responseBody The response to read
Expand All @@ -153,15 +168,6 @@ static String getValueFromResponse(String responseBody, String key) throws JsonP
throw new IllegalArgumentException("Key " + key + " not found in response");
}

/**
* Authenticate the HTTP request by manually setting the auth header iff the credentials are configured.
* This is favored over setting a global auth scheme to allow for dynamic credential updates.
* @param request to be authenticated
*/
private void maybeAddAuthHeader(HttpUriRequestBase request) {
throw new NotImplementedException();
}

/**
* Construct the RemoteBuildRequest object for the index build request
* @param indexSettings Index settings
Expand All @@ -187,10 +193,13 @@ RemoteBuildRequest constructBuildRequest(
VectorDataType vectorDataType = indexInfo.getVectorDataType();
String exactDataType;
switch (vectorDataType) {
case FLOAT -> exactDataType = resolveFloatDataType();
case FLOAT -> exactDataType = resolveFloatDataType(indexInfo);
default -> exactDataType = vectorDataType.getValue();
}
Map<String, Object> indexParameters = constructIndexParams(indexInfo);
Map<String, Object> indexParameters = null;
if (indexInfo.getParameters() != null) {
indexParameters = constructIndexParams(indexInfo);
}
KNNVectorValues<?> vectorValues = indexInfo.getKnnVectorValuesSupplier().get();
KNNCodecUtil.initializeVectorValues(vectorValues);
assert (vectorValues.dimension() > 0);
Expand All @@ -216,49 +225,64 @@ RemoteBuildRequest constructBuildRequest(
*/
private Map<String, Object> constructIndexParams(BuildIndexParams indexInfo) {
Map<String, Object> indexParameters = new HashMap<>();
indexParameters.put("algorithm", indexInfo.getParameters().get("name"));
String methodName = (String) indexInfo.getParameters().get(NAME);
indexParameters.put("algorithm", methodName);
indexParameters.put(
METHOD_PARAMETER_SPACE_TYPE,
indexInfo.getParameters().getOrDefault(METHOD_PARAMETER_SPACE_TYPE, INDEX_KNN_DEFAULT_SPACE_TYPE)
);

String methodName = (String) indexInfo.getParameters().get("name");
Map<String, Object> algorithmParams = new HashMap<>(); // TODO add other method/engine combos and their params
switch (methodName) {
case METHOD_HNSW -> {
algorithmParams.put(
METHOD_PARAMETER_EF_CONSTRUCTION,
indexInfo.getParameters().getOrDefault(METHOD_PARAMETER_EF_CONSTRUCTION, INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION)
);
algorithmParams.put(
METHOD_PARAMETER_M,
indexInfo.getParameters().getOrDefault(METHOD_PARAMETER_M, INDEX_KNN_DEFAULT_ALGO_PARAM_M)
);
if (indexInfo.getKnnEngine().getName().equals(FAISS_NAME)) {
algorithmParams.put(
METHOD_PARAMETER_EF_SEARCH,
indexInfo.getParameters().getOrDefault(METHOD_PARAMETER_EF_SEARCH, INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH)
);
//TODO solidify defaults being sent
if (indexInfo.getParameters().containsKey(PARAMETERS)) {
Object innerParams = indexInfo.getParameters().get(PARAMETERS);
if (innerParams instanceof Map) {
Map<String, Object> algorithmParams = new HashMap<>();
Map<String, Object> innerMap = (Map<String, Object>) innerParams;
switch (methodName) {
case METHOD_HNSW -> {
algorithmParams.put(
METHOD_PARAMETER_EF_CONSTRUCTION,
innerMap.getOrDefault(METHOD_PARAMETER_EF_CONSTRUCTION, INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION)
);
algorithmParams.put(METHOD_PARAMETER_M, innerMap.getOrDefault(METHOD_PARAMETER_M, INDEX_KNN_DEFAULT_ALGO_PARAM_M));
if (indexInfo.getKnnEngine().getName().equals(FAISS_NAME)) {
algorithmParams.put(
METHOD_PARAMETER_EF_SEARCH,
innerMap.getOrDefault(METHOD_PARAMETER_EF_SEARCH, INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH)
);
}
}
case METHOD_IVF -> {
algorithmParams.put(
METHOD_PARAMETER_NLIST,
innerMap.getOrDefault(METHOD_PARAMETER_NLIST, METHOD_PARAMETER_NLIST_DEFAULT)
);
algorithmParams.put(
METHOD_PARAMETER_NPROBES,
innerMap.getOrDefault(METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NPROBES_DEFAULT)
);
}
}
}
case METHOD_IVF -> {
algorithmParams.put(
METHOD_PARAMETER_NLIST,
indexInfo.getParameters().getOrDefault(METHOD_PARAMETER_NLIST, METHOD_PARAMETER_NLIST_DEFAULT)
);
algorithmParams.put(
METHOD_PARAMETER_NPROBES,
indexInfo.getParameters().getOrDefault(METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NPROBES_DEFAULT)
);
indexParameters.put("algorithm_parameters", algorithmParams);
}
}
indexParameters.put("algorithm_parameters", algorithmParams);

return indexParameters;
}

private String resolveFloatDataType() {
return "fp32"; // TODO fetch and use encoder to determine fp16 vs fp32
/**
* Use the index description in the index mappings to determine whether the float type is specifically fp16 or 32.
* @param indexInfo Index parameters
* @return fp16 or fp32 concrete type
*/
private String resolveFloatDataType(BuildIndexParams indexInfo) {
String dataType = FP32;
if (indexInfo.getParameters().containsKey(INDEX_DESCRIPTION_PARAMETER)) {
String indexDescription = (String) indexInfo.getParameters().get(INDEX_DESCRIPTION_PARAMETER);
if (indexDescription.contains(FP16)) {
dataType = FP16;
}
}
return dataType;
}

/**
Expand All @@ -269,4 +293,24 @@ public void close() throws IOException {
httpClient.close();
}
}

/**
* Rebuild the httpClient with the new credentials
* @param settings Settings to use to get the new credentials
*/
public void reloadSecureSettings(Settings settings) {
SecureString username = KNN_REMOTE_BUILD_CLIENT_USERNAME_SETTING.get(settings);
SecureString password = KNN_REMOTE_BUILD_CLIENT_PASSWORD_SETTING.get(settings);

if (password != null && !password.isEmpty()) {
if (username == null || username.isEmpty()) {
throw new IllegalArgumentException("Username must be set if password is set");
}
final String auth = username + ":" + password.clone();
final byte[] encodedAuth = Base64.encodeBase64(auth.getBytes(StandardCharsets.ISO_8859_1));
this.authHeader = "Basic " + new String(encodedAuth);
} else {
this.authHeader = null;
}
}
}
Loading

0 comments on commit a96aea8

Please sign in to comment.