From cec4609543131c0e66c551deaca6da276882dc98 Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Thu, 10 Oct 2024 19:01:09 +0200 Subject: [PATCH 01/25] [CI] Fix PublishPluginFuncTest (#114511) (#114517) (cherry picked from commit 388d24f2f53e8ba0217f7530c4c9c22f3bf3713a) --- .../internal/PublishPluginFuncTest.groovy | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/PublishPluginFuncTest.groovy b/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/PublishPluginFuncTest.groovy index c0b85ed7450f6..99d451116dbe7 100644 --- a/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/PublishPluginFuncTest.groovy +++ b/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/PublishPluginFuncTest.groovy @@ -70,13 +70,13 @@ class PublishPluginFuncTest extends AbstractGradleFuncTest { repo - Server Side Public License, v 1 - https://www.mongodb.com/licensing/server-side-public-license + GNU Affero General Public License Version 3 + https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt repo - The OSI-approved Open Source license Version 3.0 - https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt + Server Side Public License, v 1 + https://www.mongodb.com/licensing/server-side-public-license repo @@ -150,13 +150,13 @@ class PublishPluginFuncTest extends AbstractGradleFuncTest { repo - Server Side Public License, v 1 - https://www.mongodb.com/licensing/server-side-public-license + GNU Affero General Public License Version 3 + https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt repo - The OSI-approved Open Source license Version 3.0 - https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt + Server Side Public License, v 1 + https://www.mongodb.com/licensing/server-side-public-license repo @@ -239,13 +239,13 @@ class PublishPluginFuncTest extends AbstractGradleFuncTest { repo - Server Side Public License, v 1 - https://www.mongodb.com/licensing/server-side-public-license + GNU Affero General Public License Version 3 + https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt repo - The OSI-approved Open Source license Version 3.0 - https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt + Server Side Public License, v 1 + https://www.mongodb.com/licensing/server-side-public-license repo @@ -337,13 +337,13 @@ class PublishPluginFuncTest extends AbstractGradleFuncTest { repo - Server Side Public License, v 1 - https://www.mongodb.com/licensing/server-side-public-license + GNU Affero General Public License Version 3 + https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt repo - The OSI-approved Open Source license Version 3.0 - https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt + Server Side Public License, v 1 + https://www.mongodb.com/licensing/server-side-public-license repo @@ -415,13 +415,13 @@ class PublishPluginFuncTest extends AbstractGradleFuncTest { repo - Server Side Public License, v 1 - https://www.mongodb.com/licensing/server-side-public-license + GNU Affero General Public License Version 3 + https://raw.githubusercontent.com/elastic/elasticsearch/v2.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt repo - The OSI-approved Open Source license Version 3.0 - https://raw.githubusercontent.com/elastic/elasticsearch/v2.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt + Server Side Public License, v 1 + https://www.mongodb.com/licensing/server-side-public-license repo From 515e5d9f1e6c85b30cf2e72bf76f3bbcc1db9b6c Mon Sep 17 00:00:00 2001 From: Patrick Doyle <810052+prdoyle@users.noreply.github.com> Date: Thu, 10 Oct 2024 13:20:28 -0400 Subject: [PATCH 02/25] Fix max file size check to use getMaxFileSize (#113723) (#114508) * Fix max file size check to use getMaxFileSize * Update docs/changelog/113723.yaml * CURSE YOU SPOTLESS --- docs/changelog/113723.yaml | 6 ++++++ .../java/org/elasticsearch/bootstrap/BootstrapChecks.java | 8 ++++---- .../org/elasticsearch/bootstrap/BootstrapChecksTests.java | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) create mode 100644 docs/changelog/113723.yaml diff --git a/docs/changelog/113723.yaml b/docs/changelog/113723.yaml new file mode 100644 index 0000000000000..2cbcf49102719 --- /dev/null +++ b/docs/changelog/113723.yaml @@ -0,0 +1,6 @@ +pr: 113723 +summary: Fix max file size check to use `getMaxFileSize` +area: Infra/Core +type: bug +issues: + - 113705 diff --git a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java index 566c8001dea56..021ad8127a2d0 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java @@ -412,12 +412,12 @@ static class MaxFileSizeCheck implements BootstrapCheck { @Override public BootstrapCheckResult check(BootstrapContext context) { - final long maxFileSize = getMaxFileSize(); + final long maxFileSize = getProcessLimits().maxFileSize(); if (maxFileSize != Long.MIN_VALUE && maxFileSize != ProcessLimits.UNLIMITED) { final String message = String.format( Locale.ROOT, "max file size [%d] for user [%s] is too low, increase to [unlimited]", - getMaxFileSize(), + maxFileSize, BootstrapInfo.getSystemProperties().get("user.name") ); return BootstrapCheckResult.failure(message); @@ -426,8 +426,8 @@ public BootstrapCheckResult check(BootstrapContext context) { } } - long getMaxFileSize() { - return NativeAccess.instance().getProcessLimits().maxVirtualMemorySize(); + protected ProcessLimits getProcessLimits() { + return NativeAccess.instance().getProcessLimits(); } @Override diff --git a/server/src/test/java/org/elasticsearch/bootstrap/BootstrapChecksTests.java b/server/src/test/java/org/elasticsearch/bootstrap/BootstrapChecksTests.java index 9a51757189f8b..8c3749dbd3a45 100644 --- a/server/src/test/java/org/elasticsearch/bootstrap/BootstrapChecksTests.java +++ b/server/src/test/java/org/elasticsearch/bootstrap/BootstrapChecksTests.java @@ -389,8 +389,8 @@ public void testMaxFileSizeCheck() throws NodeValidationException { final AtomicLong maxFileSize = new AtomicLong(randomIntBetween(0, Integer.MAX_VALUE)); final BootstrapChecks.MaxFileSizeCheck check = new BootstrapChecks.MaxFileSizeCheck() { @Override - long getMaxFileSize() { - return maxFileSize.get(); + protected ProcessLimits getProcessLimits() { + return new ProcessLimits(ProcessLimits.UNKNOWN, ProcessLimits.UNKNOWN, maxFileSize.get()); } }; From bb6470c6256f6535a7fd724f260be76397231d78 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Thu, 10 Oct 2024 19:56:12 +0200 Subject: [PATCH 03/25] Improve performance of LongObjectPagedHashMap#removeAndAdd and ObjectObjectPagedHashMap#removeAndAdd (#114280) (#114518) --- .../common/util/LongObjectPagedHashMap.java | 24 +++++++++++++------ .../common/util/ObjectObjectPagedHashMap.java | 21 ++++++++++++---- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/util/LongObjectPagedHashMap.java b/server/src/main/java/org/elasticsearch/common/util/LongObjectPagedHashMap.java index d955863caa091..86e1bc2d27a09 100644 --- a/server/src/main/java/org/elasticsearch/common/util/LongObjectPagedHashMap.java +++ b/server/src/main/java/org/elasticsearch/common/util/LongObjectPagedHashMap.java @@ -63,6 +63,7 @@ public T get(long key) { * an insertion. */ public T put(long key, T value) { + assert value != null : "Null values are not supported"; if (size >= maxSize) { assert size == maxSize; grow(); @@ -94,9 +95,6 @@ public T remove(long key) { } private T set(long key, T value) { - if (value == null) { - throw new IllegalArgumentException("Null values are not supported"); - } for (long i = slot(hash(key), mask);; i = nextSlot(i, mask)) { final T previous = values.getAndSet(i, value); if (previous == null) { @@ -116,7 +114,7 @@ private T set(long key, T value) { @Override public Iterator> iterator() { - return new Iterator>() { + return new Iterator<>() { boolean cached; final Cursor cursor; @@ -181,9 +179,21 @@ protected boolean used(long bucket) { protected void removeAndAdd(long index) { final long key = keys.get(index); final T value = values.getAndSet(index, null); - --size; - final T removed = set(key, value); - assert removed == null; + reset(key, value); + } + + private void reset(long key, T value) { + final ObjectArray values = this.values; + final long mask = this.mask; + for (long i = slot(hash(key), mask);; i = nextSlot(i, mask)) { + final T previous = values.get(i); + if (previous == null) { + // slot was free + keys.set(i, key); + values.set(i, value); + break; + } + } } public static final class Cursor { diff --git a/server/src/main/java/org/elasticsearch/common/util/ObjectObjectPagedHashMap.java b/server/src/main/java/org/elasticsearch/common/util/ObjectObjectPagedHashMap.java index 298f910d65a9f..a743c535a57d9 100644 --- a/server/src/main/java/org/elasticsearch/common/util/ObjectObjectPagedHashMap.java +++ b/server/src/main/java/org/elasticsearch/common/util/ObjectObjectPagedHashMap.java @@ -67,6 +67,7 @@ public V get(K key) { * an insertion. */ public V put(K key, V value) { + assert value != null : "Null values are not supported"; if (size >= maxSize) { assert size == maxSize; grow(); @@ -100,7 +101,6 @@ public V remove(K key) { private V set(K key, int code, V value) { assert key.hashCode() == code; - assert value != null; assert size < maxSize; final long slot = slot(code, mask); for (long index = slot;; index = nextSlot(index, mask)) { @@ -187,9 +187,22 @@ protected boolean used(long bucket) { protected void removeAndAdd(long index) { final K key = keys.get(index); final V value = values.getAndSet(index, null); - --size; - final V removed = set(key, key.hashCode(), value); - assert removed == null; + reset(key, value); + } + + private void reset(K key, V value) { + final ObjectArray values = this.values; + final long mask = this.mask; + final long slot = slot(key.hashCode(), mask); + for (long index = slot;; index = nextSlot(index, mask)) { + final V previous = values.get(index); + if (previous == null) { + // slot was free + values.set(index, value); + keys.set(index, key); + break; + } + } } public static final class Cursor { From 07d57e24eb44a68f34bc62cc3210a1af934025aa Mon Sep 17 00:00:00 2001 From: Dan Rubinstein Date: Thu, 10 Oct 2024 14:25:31 -0400 Subject: [PATCH 04/25] [8.x] Adding chunking settings to GoogleVertexAiService, AzureAiStudioService, and AlibabaCloudSearchService (#113981) (#114449) * Adding chunking settings to GoogleVertexAiService, AzureAiStudioService, and AlibabaCloudSearchService (#113981) * Adding chunking settings to GoogleVertexAiService, AzureAiStudioService, and AlibabaCloudSearchService * Update docs/changelog/113981.yaml * Updating AlibabaService chunkedInfer to handle sparse embedding task types --------- Co-authored-by: Elastic Machine * Fix enum switch case error in AlibabaSearchService (#114504) Co-authored-by: Elastic Machine --------- Co-authored-by: Elastic Machine --- docs/changelog/113981.yaml | 6 + .../AlibabaCloudSearchService.java | 57 ++- .../AlibabaCloudSearchEmbeddingsModel.java | 6 +- ...aCloudSearchEmbeddingsServiceSettings.java | 14 +- .../sparse/AlibabaCloudSearchSparseModel.java | 6 +- .../azureaistudio/AzureAiStudioService.java | 49 +- .../AzureAiStudioEmbeddingsModel.java | 9 +- .../googlevertexai/GoogleVertexAiService.java | 48 +- .../GoogleVertexAiEmbeddingsModel.java | 6 +- .../AlibabaCloudSearchServiceTests.java | 441 +++++++++++++++--- ...libabaCloudSearchEmbeddingsModelTests.java | 2 + .../AlibabaCloudSearchSparseModelTests.java | 2 + .../AzureAiStudioServiceTests.java | 329 ++++++++++++- .../AzureAiStudioEmbeddingsModelTests.java | 47 ++ .../GoogleVertexAiServiceTests.java | 413 ++++++++++++++++ .../GoogleVertexAiEmbeddingsModelTests.java | 2 +- 16 files changed, 1350 insertions(+), 87 deletions(-) create mode 100644 docs/changelog/113981.yaml diff --git a/docs/changelog/113981.yaml b/docs/changelog/113981.yaml new file mode 100644 index 0000000000000..38f3a6f04ae46 --- /dev/null +++ b/docs/changelog/113981.yaml @@ -0,0 +1,6 @@ +pr: 113981 +summary: "Adding chunking settings to `GoogleVertexAiService,` `AzureAiStudioService,`\ + \ and `AlibabaCloudSearchService`" +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 0bd0eee1aa9a1..c5c88ad978d63 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -15,6 +15,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -24,6 +25,8 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -74,11 +77,19 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + AlibabaCloudSearchModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST @@ -99,6 +110,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage ) { @@ -107,6 +119,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT @@ -118,6 +131,7 @@ private static AlibabaCloudSearchModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context @@ -129,6 +143,7 @@ private static AlibabaCloudSearchModel createModel( NAME, serviceSettings, taskSettings, + chunkingSettings, secretSettings, context ); @@ -138,6 +153,7 @@ private static AlibabaCloudSearchModel createModel( NAME, serviceSettings, taskSettings, + chunkingSettings, secretSettings, context ); @@ -174,11 +190,17 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelWithoutLoggingDeprecations( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -189,11 +211,17 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelWithoutLoggingDeprecations( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -238,17 +266,36 @@ protected void doChunkedInfer( AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model; var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents()); - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + getEmbeddingTypeFromTaskType(alibabaCloudSearchModel.getTaskType()), + alibabaCloudSearchModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + getEmbeddingTypeFromTaskType(alibabaCloudSearchModel.getTaskType()) + ).batchRequestsWithListeners(listener); + } + for (var request : batchedRequests) { var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); } } + private EmbeddingRequestChunker.EmbeddingType getEmbeddingTypeFromTaskType(TaskType taskType) { + return switch (taskType) { + case TEXT_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.FLOAT; + case SPARSE_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.SPARSE; + default -> throw new IllegalArgumentException("Unsupported task type for chunking: " + taskType); + }; + } + /** * For text embedding models get the embedding size and * update the service settings. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java index 87e5e59ae3434..2654ee4d22ce6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -39,6 +40,7 @@ public AlibabaCloudSearchEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context ) { @@ -48,6 +50,7 @@ public AlibabaCloudSearchEmbeddingsModel( service, AlibabaCloudSearchEmbeddingsServiceSettings.fromMap(serviceSettings, context), AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(taskSettings), + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } @@ -59,10 +62,11 @@ public AlibabaCloudSearchEmbeddingsModel( String service, AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings, AlibabaCloudSearchEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secretSettings ) { super( - new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secretSettings), serviceSettings.getCommonSettings() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java index 76dfd01f333da..8896e983d3e7f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; @@ -81,10 +82,21 @@ public SimilarityMeasure getSimilarity() { return similarity; } - public Integer getDimensions() { + @Override + public Integer dimensions() { return dimensions; } + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + public Integer getMaxInputTokens() { return maxInputTokens; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java index b551ba389136b..0155d8fbc1f08 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -39,6 +40,7 @@ public AlibabaCloudSearchSparseModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context ) { @@ -48,6 +50,7 @@ public AlibabaCloudSearchSparseModel( service, AlibabaCloudSearchSparseServiceSettings.fromMap(serviceSettings, context), AlibabaCloudSearchSparseTaskSettings.fromMap(taskSettings), + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } @@ -59,10 +62,11 @@ public AlibabaCloudSearchSparseModel( String service, AlibabaCloudSearchSparseServiceSettings serviceSettings, AlibabaCloudSearchSparseTaskSettings taskSettings, + ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secretSettings ) { super( - new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secretSettings), serviceSettings.getCommonSettings() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 7981fb393a842..c1ca50d41268e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -24,6 +25,8 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -90,11 +93,23 @@ protected void doChunkedInfer( ) { if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) { var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents()); - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT, + baseAzureAiStudioModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT + ).batchRequestsWithListeners(listener); + } + for (var request : batchedRequests) { var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); @@ -115,11 +130,19 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + AzureAiStudioModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST @@ -146,11 +169,17 @@ public AzureAiStudioModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -161,11 +190,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -186,6 +221,7 @@ private static AzureAiStudioModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context @@ -198,6 +234,7 @@ private static AzureAiStudioModel createModel( NAME, serviceSettings, taskSettings, + chunkingSettings, secretSettings, context ); @@ -235,6 +272,7 @@ private AzureAiStudioModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage ) { @@ -243,6 +281,7 @@ private AzureAiStudioModel createModelFromPersistent( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModel.java index a999b9f0312e6..edbefe07cff02 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.azureaistudio.embeddings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; @@ -44,9 +45,13 @@ public AzureAiStudioEmbeddingsModel( String service, AzureAiStudioEmbeddingsServiceSettings serviceSettings, AzureAiStudioEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, DefaultSecretSettings secrets ) { - super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets)); + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), + new ModelSecrets(secrets) + ); } public AzureAiStudioEmbeddingsModel( @@ -55,6 +60,7 @@ public AzureAiStudioEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context ) { @@ -64,6 +70,7 @@ public AzureAiStudioEmbeddingsModel( service, AzureAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context), AzureAiStudioEmbeddingsTaskSettings.fromMap(taskSettings), + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index d9d8850048564..ae9219ba38499 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -23,6 +24,8 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -70,11 +73,19 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + GoogleVertexAiModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST @@ -101,11 +112,17 @@ public Model parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -116,11 +133,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -179,11 +202,22 @@ protected void doChunkedInfer( GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model; var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT, + googleVertexAiModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT + ).batchRequestsWithListeners(listener); + } + for (var request : batchedRequests) { var action = googleVertexAiModel.accept(actionCreator, taskSettings); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); @@ -225,6 +259,7 @@ private static GoogleVertexAiModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage ) { @@ -233,6 +268,7 @@ private static GoogleVertexAiModel createModelFromPersistent( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT @@ -244,6 +280,7 @@ private static GoogleVertexAiModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context @@ -255,6 +292,7 @@ private static GoogleVertexAiModel createModel( NAME, serviceSettings, taskSettings, + chunkingSettings, secretSettings, context ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java index 99110045fc3da..3a5fae09b40ef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java @@ -9,6 +9,7 @@ import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; @@ -35,6 +36,7 @@ public GoogleVertexAiEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secrets, ConfigurationParseContext context ) { @@ -44,6 +46,7 @@ public GoogleVertexAiEmbeddingsModel( service, GoogleVertexAiEmbeddingsServiceSettings.fromMap(serviceSettings, context), GoogleVertexAiEmbeddingsTaskSettings.fromMap(taskSettings), + chunkingSettings, GoogleVertexAiSecretSettings.fromMap(secrets) ); } @@ -59,10 +62,11 @@ public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, Google String service, GoogleVertexAiEmbeddingsServiceSettings serviceSettings, GoogleVertexAiEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, @Nullable GoogleVertexAiSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secrets), serviceSettings ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index e8c34eec96171..7cedc36ffa5f0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -7,12 +7,14 @@ package org.elasticsearch.xpack.inference.services.alibabacloudsearch; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -20,9 +22,12 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -31,28 +36,34 @@ import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests; -import org.hamcrest.CoreMatchers; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Before; import java.io.IOException; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; @@ -99,6 +110,233 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var model = service.parsePersistedConfig( + "id", + TaskType.TEXT_EMBEDDING, + getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap() + ).config() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var model = service.parsePersistedConfig( + "id", + TaskType.TEXT_EMBEDDING, + getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap() + ).config() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var model = service.parsePersistedConfig( + "id", + TaskType.TEXT_EMBEDDING, + getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ).config() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var persistedConfig = getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var persistedConfig = getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var persistedConfig = getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + public void testCheckModelConfig() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -165,49 +403,71 @@ public void doInfer( } } - public void testChunkedInfer_Batches() throws IOException { - var input = List.of("foo", "bar"); + public void testChunkedInfer_TextEmbeddingBatches() throws IOException { + testChunkedInfer(TaskType.TEXT_EMBEDDING, null); + } + public void testChunkedInfer_TextEmbeddingChunkingSettingsSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + testChunkedInfer(TaskType.TEXT_EMBEDDING, ChunkingSettingsTests.createRandomChunkingSettings()); + } + + public void testChunkedInfer_TextEmbeddingChunkingSettingsNotSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + testChunkedInfer(TaskType.TEXT_EMBEDDING, null); + } + + public void testChunkedInfer_SparseEmbeddingBatches() throws IOException { + testChunkedInfer(TaskType.SPARSE_EMBEDDING, null); + } + + public void testChunkedInfer_SparseEmbeddingChunkingSettingsSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + testChunkedInfer(TaskType.SPARSE_EMBEDDING, ChunkingSettingsTests.createRandomChunkingSettings()); + } + + public void testChunkedInfer_SparseEmbeddingChunkingSettingsNotSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + testChunkedInfer(TaskType.SPARSE_EMBEDDING, null); + } + + public void testChunkedInfer_InvalidTaskType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { - Map serviceSettingsMap = new HashMap<>(); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); - serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); + var model = AlibabaCloudSearchCompletionModelTests.createModel( + randomAlphaOfLength(10), + TaskType.COMPLETION, + AlibabaCloudSearchCompletionServiceSettingsTests.createRandom(), + AlibabaCloudSearchCompletionTaskSettingsTests.createRandom(), + null + ); - Map taskSettingsMap = new HashMap<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); + try { + service.chunkedInfer( + model, + null, + List.of("foo", "bar"), + new HashMap<>(), + InputType.INGEST, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + } catch (Exception e) { + assertThat(e, instanceOf(IllegalArgumentException.class)); + } + } + } - Map secretSettingsMap = new HashMap<>(); - secretSettingsMap.put("api_key", "secret"); + private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettings) throws IOException { + var input = List.of("foo", "bar"); - var model = new AlibabaCloudSearchEmbeddingsModel( - "service", - TaskType.TEXT_EMBEDDING, - AlibabaCloudSearchUtils.SERVICE_NAME, - serviceSettingsMap, - taskSettingsMap, - secretSettingsMap, - null - ) { - public ExecutableAction accept( - AlibabaCloudSearchActionVisitor visitor, - Map taskSettings, - InputType inputType - ) { - return (inferenceInputs, timeout, listener) -> { - InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults( - List.of( - new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123f, -0.0123f }), - new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0456f, -0.0456f }) - ) - ); - - listener.onResponse(results); - }; - } - }; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + var model = createModelForTaskType(taskType, chunkingSettings); PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( @@ -222,26 +482,101 @@ public ExecutableAction accept( ); var results = listener.actionGet(TIMEOUT); + assertThat(results, instanceOf(List.class)); assertThat(results, hasSize(2)); + var firstResult = results.get(0); + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + assertThat(firstResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { + assertThat(firstResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + } + } + } + + private AlibabaCloudSearchModel createModelForTaskType(TaskType taskType, ChunkingSettings chunkingSettings) { + Map serviceSettingsMap = new HashMap<>(); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); + serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); + + Map taskSettingsMap = new HashMap<>(); + + Map secretSettingsMap = new HashMap<>(); + + secretSettingsMap.put("api_key", "secret"); + return switch (taskType) { + case TEXT_EMBEDDING -> createEmbeddingsModel(serviceSettingsMap, taskSettingsMap, chunkingSettings, secretSettingsMap); + case SPARSE_EMBEDDING -> createSparseEmbeddingsModel(serviceSettingsMap, taskSettingsMap, chunkingSettings, secretSettingsMap); + default -> throw new IllegalArgumentException("Unsupported task type for chunking: " + taskType); + }; + } - // first result - { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); - assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); + private AlibabaCloudSearchModel createEmbeddingsModel( + Map serviceSettingsMap, + Map taskSettingsMap, + ChunkingSettings chunkingSettings, + Map secretSettingsMap + ) { + return new AlibabaCloudSearchEmbeddingsModel( + "service", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + secretSettingsMap, + null + ) { + public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings, InputType inputType) { + return (inferenceInputs, timeout, listener) -> { + InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123f, -0.0123f }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0456f, -0.0456f }) + ) + ); + + listener.onResponse(results); + }; } + }; + } - // second result - { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); - assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); + private AlibabaCloudSearchModel createSparseEmbeddingsModel( + Map serviceSettingsMap, + Map taskSettingsMap, + ChunkingSettings chunkingSettings, + Map secretSettingsMap + ) { + return new AlibabaCloudSearchSparseModel( + "service", + TaskType.SPARSE_EMBEDDING, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + secretSettingsMap, + null + ) { + public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings, InputType inputType) { + return (inferenceInputs, timeout, listener) -> { + listener.onResponse(SparseEmbeddingResultsTests.createRandomResults(2, 1)); + }; } - } + }; + } + + public Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java index fca0ee11e5c78..957b7149b14f1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java @@ -47,6 +47,7 @@ public static AlibabaCloudSearchEmbeddingsModel createModel( AlibabaCloudSearchUtils.SERVICE_NAME, serviceSettings, taskSettings, + null, secrets, null ); @@ -65,6 +66,7 @@ public static AlibabaCloudSearchEmbeddingsModel createModel( AlibabaCloudSearchUtils.SERVICE_NAME, serviceSettings, taskSettings, + null, secretSettings ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java index 4e9179b66c36f..4a89e1fc924a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java @@ -47,6 +47,7 @@ public static AlibabaCloudSearchSparseModel createModel( AlibabaCloudSearchUtils.SERVICE_NAME, serviceSettings, taskSettings, + null, secrets, null ); @@ -65,6 +66,7 @@ public static AlibabaCloudSearchSparseModel createModel( AlibabaCloudSearchUtils.SERVICE_NAME, serviceSettings, taskSettings, + null, secretSettings ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 1df457b3211ea..683f32710bcb3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -29,6 +30,7 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; @@ -62,6 +64,8 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER; @@ -124,6 +128,90 @@ public void testParseRequestConfig_CreatesAnAzureAiStudioEmbeddingsModel() throw } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", null, null, null, null); + + var config = getRequestConfigMap( + serviceSettings, + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + } + } + + public void testParseRequestConfig_CreatesAnAzureAiStudioEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", null, null, null, null), + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAnAzureAiStudioEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", null, null, null, null), + getEmbeddingsTaskSettingsMap("user"), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + public void testParseRequestConfig_CreatesAnAzureAiStudioChatCompletionModel() throws IOException { try (var service = createService()) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { @@ -461,6 +549,89 @@ public void testParsePersistedConfig_CreatesAnAzureAiStudioEmbeddingsModel() thr } } + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + public void testParsePersistedConfig_CreatesAnAzureAiStudioChatCompletionModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( @@ -651,6 +822,84 @@ public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() thro } } + public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWithoutChunkingSettingsFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + public void testParsePersistedConfig_WithoutSecretsCreatesChatCompletionModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( @@ -843,6 +1092,61 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc } public void testChunkedInfer() throws IOException { + var model = AzureAiStudioEmbeddingsModelTests.createModel( + "id", + getUrl(webServer), + AzureAiStudioProvider.OPENAI, + AzureAiStudioEndpointType.TOKEN, + "apikey", + null, + false, + null, + null, + "user", + null + ); + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var model = AzureAiStudioEmbeddingsModelTests.createModel( + "id", + getUrl(webServer), + AzureAiStudioProvider.OPENAI, + AzureAiStudioEndpointType.TOKEN, + createRandomChunkingSettings(), + "apikey", + null, + false, + null, + null, + "user", + null + ); + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsNotSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var model = AzureAiStudioEmbeddingsModelTests.createModel( + "id", + getUrl(webServer), + AzureAiStudioProvider.OPENAI, + AzureAiStudioEndpointType.TOKEN, + null, + "apikey", + null, + false, + null, + null, + "user", + null + ); + testChunkedInfer(model); + } + + private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { @@ -877,19 +1181,6 @@ public void testChunkedInfer() throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = AzureAiStudioEmbeddingsModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - null, - false, - null, - null, - "user", - null - ); PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, @@ -1020,6 +1311,18 @@ private AzureAiStudioService createService() { return new AzureAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModelTests.java index 5a450f03b4e01..c9b0f905abaa4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModelTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -104,6 +105,51 @@ public static AzureAiStudioEmbeddingsModel createModel( return createModel(inferenceId, target, provider, endpointType, apiKey, null, false, null, null, null, null); } + public static AzureAiStudioEmbeddingsModel createModel( + String inferenceId, + String target, + AzureAiStudioProvider provider, + AzureAiStudioEndpointType endpointType, + ChunkingSettings chunkingSettings, + String apiKey + ) { + return createModel(inferenceId, target, provider, endpointType, chunkingSettings, apiKey, null, false, null, null, null, null); + } + + public static AzureAiStudioEmbeddingsModel createModel( + String inferenceId, + String target, + AzureAiStudioProvider provider, + AzureAiStudioEndpointType endpointType, + ChunkingSettings chunkingSettings, + String apiKey, + @Nullable Integer dimensions, + boolean dimensionsSetByUser, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarity, + @Nullable String user, + RateLimitSettings rateLimitSettings + ) { + return new AzureAiStudioEmbeddingsModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + "azureaistudio", + new AzureAiStudioEmbeddingsServiceSettings( + target, + provider, + endpointType, + dimensions, + dimensionsSetByUser, + maxTokens, + similarity, + rateLimitSettings + ), + new AzureAiStudioEmbeddingsTaskSettings(user), + chunkingSettings, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + public static AzureAiStudioEmbeddingsModel createModel( String inferenceId, String target, @@ -132,6 +178,7 @@ public static AzureAiStudioEmbeddingsModel createModel( rateLimitSettings ), new AzureAiStudioEmbeddingsTaskSettings(user), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 6a96d289a8190..70ec6522c0fcb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -10,12 +10,14 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -37,7 +39,9 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -107,6 +111,130 @@ public void testParseRequestConfig_CreatesGoogleVertexAiEmbeddingsModel() throws } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createGoogleVertexAiService()) { + var config = getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + "model", + GoogleVertexAiServiceFields.LOCATION, + "location", + GoogleVertexAiServiceFields.PROJECT_ID, + "project" + ) + ), + getTaskSettingsMap(true), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("{}") + ); + + var failureListener = ActionListener.wrap(model -> fail("Expected exception, but got model: " + model), exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + }); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + }, e -> fail("Model parsing should succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId + ) + ), + new HashMap<>(Map.of()), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(serviceAccountJson) + ), + modelListener + ); + } + } + + public void testParseRequestConfig_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + }, e -> fail("Model parsing should succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId + ) + ), + new HashMap<>(Map.of()), + getSecretSettingsMap(serviceAccountJson) + ), + modelListener + ); + } + } + public void testParseRequestConfig_CreatesGoogleVertexAiRerankModel() throws IOException { var projectId = "project"; var serviceAccountJson = """ @@ -321,6 +449,161 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiEmbeddingsM } } + public void testParsePersistedConfigWithSecrets_CreatesAGoogleVertexAiEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(serviceAccountJson) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(serviceAccountJson) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + getSecretSettingsMap(serviceAccountJson) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + } + } + public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiRerankModel() throws IOException { var projectId = "project"; var topN = 1; @@ -550,12 +833,142 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } + public void testParsePersistedConfig_CreatesAGoogleVertexAiEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + } + } + + public void testParsePersistedConfig_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + // testInfer tested via end-to-end notebook tests in AppEx repo private GoogleVertexAiService createGoogleVertexAiService() { return new GoogleVertexAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java index ca38bdb6e2c6c..68d03d350d06e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java @@ -79,8 +79,8 @@ public static GoogleVertexAiEmbeddingsModel createModel(String modelId, @Nullabl null ), new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate), + null, new GoogleVertexAiSecretSettings(new SecureString(randomAlphaOfLength(8).toCharArray())) ); } - } From 1dfce84a35dc1cc740e5e82123df41958f65f599 Mon Sep 17 00:00:00 2001 From: David Turner Date: Thu, 10 Oct 2024 19:57:46 +0100 Subject: [PATCH 05/25] Improve handling of failure to create persistent task (#114386) (#114523) Today if creating a persistent task fails with an exception then we submit a cluster state update to fail the task but until that update executes we will retry the failing task creation and cluster state submission on all other cluster state updates that change the persistent tasks metadata. With this commit we register a placeholder task on the executing node to block further attempts to create it until the cluster state update is processed. --- docs/changelog/114386.yaml | 5 + .../PersistentTaskCreationFailureIT.java | 228 ++++++++++++++++++ .../PersistentTasksNodeService.java | 86 +++++-- 3 files changed, 296 insertions(+), 23 deletions(-) create mode 100644 docs/changelog/114386.yaml create mode 100644 server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTaskCreationFailureIT.java diff --git a/docs/changelog/114386.yaml b/docs/changelog/114386.yaml new file mode 100644 index 0000000000000..cf9edda9de21e --- /dev/null +++ b/docs/changelog/114386.yaml @@ -0,0 +1,5 @@ +pr: 114386 +summary: Improve handling of failure to create persistent task +area: Task Management +type: bug +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTaskCreationFailureIT.java b/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTaskCreationFailureIT.java new file mode 100644 index 0000000000000..8a4d1ceda784b --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTaskCreationFailureIT.java @@ -0,0 +1,228 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.persistent; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateUpdateTask; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Priority; +import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.SettingsModule; +import org.elasticsearch.plugins.PersistentTaskPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.PluginsService; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.StreamSupport; + +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class PersistentTaskCreationFailureIT extends ESIntegTestCase { + @Override + protected Collection> nodePlugins() { + return List.of(FailingCreationPersistentTasksPlugin.class); + } + + private static boolean hasPersistentTask(ClusterState clusterState) { + return findTasks(clusterState, FailingCreationPersistentTaskExecutor.TASK_NAME).isEmpty() == false; + } + + public void testPersistentTasksThatFailDuringCreationAreRemovedFromClusterState() { + + final var masterClusterService = internalCluster().getCurrentMasterNodeInstance(ClusterService.class); + final var plugins = StreamSupport.stream(internalCluster().getInstances(PluginsService.class).spliterator(), false) + .flatMap(ps -> ps.filterPlugins(FailingCreationPersistentTasksPlugin.class)) + .toList(); + plugins.forEach(plugin -> plugin.hasFailedToCreateTask.set(false)); + + final var taskCreatedListener = ClusterServiceUtils.addTemporaryStateListener( + masterClusterService, + PersistentTaskCreationFailureIT::hasPersistentTask + ); + + taskCreatedListener.andThenAccept(v -> { + // enqueue some higher-priority cluster state updates to check that they do not cause retries of the failing task creation step + for (int i = 0; i < 5; i++) { + masterClusterService.submitUnbatchedStateUpdateTask("test", new ClusterStateUpdateTask(Priority.IMMEDIATE) { + @Override + public ClusterState execute(ClusterState currentState) { + assertTrue(hasPersistentTask(currentState)); + + assertTrue(waitUntil(() -> { + final var completePersistentTaskPendingTasksCount = masterClusterService.getMasterService() + .pendingTasks() + .stream() + .filter( + pendingClusterTask -> pendingClusterTask.getSource().string().equals("finish persistent task (failed)") + ) + .count(); + assertThat(completePersistentTaskPendingTasksCount, lessThanOrEqualTo(1L)); + return completePersistentTaskPendingTasksCount == 1L; + })); + + return currentState.copyAndUpdateMetadata( + mdb -> mdb.putCustom( + PersistentTasksCustomMetadata.TYPE, + PersistentTasksCustomMetadata.builder( + PersistentTasksCustomMetadata.getPersistentTasksCustomMetadata(currentState) + ) + // create and remove a fake task just to force a change in lastAllocationId so that + // PersistentTasksNodeService checks for changes and potentially retries + .addTask("test", "test", null, PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT) + .removeTask("test") + .build() + ) + ); + } + + @Override + public void onFailure(Exception e) { + fail(e); + } + }); + } + }); + + safeAwait( + l -> internalCluster().getInstance(PersistentTasksService.class) + .sendStartRequest( + UUIDs.base64UUID(), + FailingCreationPersistentTaskExecutor.TASK_NAME, + new FailingCreationTaskParams(), + null, + l.map(ignored -> null) + ) + ); + + safeAwait( + taskCreatedListener.andThen( + (l, v) -> ClusterServiceUtils.addTemporaryStateListener( + masterClusterService, + clusterState -> hasPersistentTask(clusterState) == false + ).addListener(l) + ) + ); + + assertEquals(1L, plugins.stream().filter(plugin -> plugin.hasFailedToCreateTask.get()).count()); + } + + public static class FailingCreationPersistentTasksPlugin extends Plugin implements PersistentTaskPlugin { + + private final AtomicBoolean hasFailedToCreateTask = new AtomicBoolean(); + + @Override + public List> getPersistentTasksExecutor( + ClusterService clusterService, + ThreadPool threadPool, + Client client, + SettingsModule settingsModule, + IndexNameExpressionResolver expressionResolver + ) { + return List.of(new FailingCreationPersistentTaskExecutor(hasFailedToCreateTask)); + } + + @Override + public List getNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry( + PersistentTaskParams.class, + FailingCreationPersistentTaskExecutor.TASK_NAME, + FailingCreationTaskParams::new + ) + ); + } + + @Override + public List getNamedXContent() { + return List.of( + new NamedXContentRegistry.Entry( + PersistentTaskParams.class, + new ParseField(FailingCreationPersistentTaskExecutor.TASK_NAME), + p -> { + p.skipChildren(); + return new FailingCreationTaskParams(); + } + ) + ); + } + } + + public static class FailingCreationTaskParams implements PersistentTaskParams { + public FailingCreationTaskParams() {} + + public FailingCreationTaskParams(StreamInput in) {} + + @Override + public String getWriteableName() { + return FailingCreationPersistentTaskExecutor.TASK_NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException {} + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + } + + static class FailingCreationPersistentTaskExecutor extends PersistentTasksExecutor { + static final String TASK_NAME = "cluster:admin/persistent/test_creation_failure"; + + private final AtomicBoolean hasFailedToCreateTask; + + FailingCreationPersistentTaskExecutor(AtomicBoolean hasFailedToCreateTask) { + super(TASK_NAME, r -> fail("execution is unexpected")); + this.hasFailedToCreateTask = hasFailedToCreateTask; + } + + @Override + protected AllocatedPersistentTask createTask( + long id, + String type, + String action, + TaskId parentTaskId, + PersistentTasksCustomMetadata.PersistentTask taskInProgress, + Map headers + ) { + assertTrue("already failed before", hasFailedToCreateTask.compareAndSet(false, true)); + throw new RuntimeException("simulated"); + } + + @Override + protected void nodeOperation(AllocatedPersistentTask task, FailingCreationTaskParams params, PersistentTaskState state) { + fail("execution is unexpected"); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java index b86292be8e9ee..ff6a0b9018704 100644 --- a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java +++ b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.gateway.GatewayService; import org.elasticsearch.persistent.PersistentTasksCustomMetadata.PersistentTask; import org.elasticsearch.tasks.Task; @@ -32,6 +33,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.concurrent.Executor; import static java.util.Objects.requireNonNull; import static org.elasticsearch.core.Strings.format; @@ -172,33 +174,57 @@ private void startTask(PersistentTask(taskInProgress, executor); + try (var ignored = threadPool.getThreadContext().newTraceContext()) { + doStartTask(taskInProgress, executor, request); + } + } - @Override - public void setParentTask(TaskId taskId) { - throw new UnsupportedOperationException("parent task if for persistent tasks shouldn't change"); - } + /** + * A {@link TaskAwareRequest} which creates the relevant task using a {@link PersistentTasksExecutor}. + */ + private static class PersistentTaskAwareRequest implements TaskAwareRequest { + private final PersistentTask taskInProgress; + private final TaskId parentTaskId; + private final PersistentTasksExecutor executor; + + private PersistentTaskAwareRequest(PersistentTask taskInProgress, PersistentTasksExecutor executor) { + this.taskInProgress = taskInProgress; + this.parentTaskId = new TaskId("cluster", taskInProgress.getAllocationId()); + this.executor = executor; + } - @Override - public void setRequestId(long requestId) { - throw new UnsupportedOperationException("does not have a request ID"); - } + @Override + public void setParentTask(TaskId taskId) { + throw new UnsupportedOperationException("parent task if for persistent tasks shouldn't change"); + } - @Override - public TaskId getParentTask() { - return parentTaskId; - } + @Override + public void setRequestId(long requestId) { + throw new UnsupportedOperationException("does not have a request ID"); + } - @Override - public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { - return executor.createTask(id, type, action, parentTaskId, taskInProgress, headers); - } - }; + @Override + public TaskId getParentTask() { + return parentTaskId; + } - try (var ignored = threadPool.getThreadContext().newTraceContext()) { - doStartTask(taskInProgress, executor, request); + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return executor.createTask(id, type, action, parentTaskId, taskInProgress, headers); + } + } + + /** + * A no-op {@link PersistentTasksExecutor} to create a placeholder task if creating the real task fails for some reason. + */ + private static class PersistentTaskStartupFailureExecutor extends PersistentTasksExecutor { + PersistentTaskStartupFailureExecutor(String taskName, Executor executor) { + super(taskName, executor); } + + @Override + protected void nodeOperation(AllocatedPersistentTask task, Params params, PersistentTaskState state) {} } private void doStartTask( @@ -206,7 +232,7 @@ private void doStartTask( PersistentTasksExecutor executor, TaskAwareRequest request ) { - AllocatedPersistentTask task; + final AllocatedPersistentTask task; try { task = (AllocatedPersistentTask) taskManager.register("persistent", taskInProgress.getTaskName() + "[c]", request); } catch (Exception e) { @@ -220,7 +246,21 @@ private void doStartTask( + "], removing from persistent tasks", e ); - notifyMasterOfFailedTask(taskInProgress, e); + + // create a no-op placeholder task so that we don't keep trying to start this task while we wait for the cluster state update + // which handles the failure + final var placeholderTask = (AllocatedPersistentTask) taskManager.register( + "persistent", + taskInProgress.getTaskName() + "[c]", + new PersistentTaskAwareRequest<>( + taskInProgress, + new PersistentTaskStartupFailureExecutor<>(executor.getTaskName(), EsExecutors.DIRECT_EXECUTOR_SERVICE) + ) + ); + placeholderTask.init(persistentTasksService, taskManager, taskInProgress.getId(), taskInProgress.getAllocationId()); + taskManager.unregister(placeholderTask); + runningTasks.put(taskInProgress.getAllocationId(), placeholderTask); + placeholderTask.markAsFailed(e); return; } From 3b8e33768a3bce50f1a894246099e0e6be362605 Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Thu, 10 Oct 2024 17:24:24 -0400 Subject: [PATCH 06/25] Verify Maxmind database types in the geoip processor (#114527) (#114532) --- docs/changelog/114527.yaml | 5 + ...gDatabasesWhilePerformingGeoLookupsIT.java | 3 +- .../geoip/GeoIpDownloaderTaskExecutor.java | 15 ++- .../ingest/geoip/GeoIpProcessor.java | 91 +++++++++++++------ .../ingest/geoip/IngestGeoIpPlugin.java | 2 +- .../ingest/geoip/IpinfoIpDataLookups.java | 4 + .../ingest/geoip/MaxmindIpDataLookups.java | 6 ++ .../geoip/GeoIpProcessorFactoryTests.java | 81 ++++++++++++----- .../ingest/geoip/GeoIpProcessorTests.java | 28 +++++- 9 files changed, 176 insertions(+), 59 deletions(-) create mode 100644 docs/changelog/114527.yaml diff --git a/docs/changelog/114527.yaml b/docs/changelog/114527.yaml new file mode 100644 index 0000000000000..74d95edcd1a1d --- /dev/null +++ b/docs/changelog/114527.yaml @@ -0,0 +1,5 @@ +pr: 114527 +summary: Verify Maxmind database types in the geoip processor +area: Ingest Node +type: enhancement +issues: [] diff --git a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsIT.java b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsIT.java index b28926673069d..0499b0f94106b 100644 --- a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsIT.java +++ b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/ReloadingDatabasesWhilePerformingGeoLookupsIT.java @@ -32,6 +32,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.ingest.geoip.GeoIpProcessor.GEOIP_TYPE; import static org.elasticsearch.ingest.geoip.GeoIpTestUtils.copyDatabase; import static org.elasticsearch.ingest.geoip.GeoIpTestUtils.copyDefaultDatabases; import static org.hamcrest.Matchers.equalTo; @@ -66,7 +67,7 @@ public void test() throws Exception { ClusterService clusterService = mock(ClusterService.class); when(clusterService.state()).thenReturn(ClusterState.EMPTY_STATE); DatabaseNodeService databaseNodeService = createRegistry(geoIpConfigDir, geoIpTmpDir, clusterService); - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); copyDatabase("GeoLite2-City-Test.mmdb", geoIpTmpDir.resolve("GeoLite2-City.mmdb")); copyDatabase("GeoLite2-City-Test.mmdb", geoIpTmpDir.resolve("GeoLite2-City-Test.mmdb")); databaseNodeService.updateDatabase("GeoLite2-City.mmdb", "md5", geoIpTmpDir.resolve("GeoLite2-City.mmdb")); diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTaskExecutor.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTaskExecutor.java index fbbd4c3e5f8be..ed469628c5f50 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTaskExecutor.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTaskExecutor.java @@ -54,6 +54,7 @@ import static org.elasticsearch.ingest.geoip.GeoIpDownloader.DATABASES_INDEX; import static org.elasticsearch.ingest.geoip.GeoIpDownloader.GEOIP_DOWNLOADER; import static org.elasticsearch.ingest.geoip.GeoIpProcessor.Factory.downloadDatabaseOnPipelineCreation; +import static org.elasticsearch.ingest.geoip.GeoIpProcessor.GEOIP_TYPE; /** * Persistent task executor that is responsible for starting {@link GeoIpDownloader} after task is allocated by master node. @@ -296,9 +297,9 @@ private static boolean hasAtLeastOneGeoipProcessor(Map processor return false; } - if (processor.containsKey(GeoIpProcessor.TYPE)) { - Map processorConfig = (Map) processor.get(GeoIpProcessor.TYPE); - return downloadDatabaseOnPipelineCreation(processorConfig) == downloadDatabaseOnPipelineCreation; + final Map processorConfig = (Map) processor.get(GEOIP_TYPE); + if (processorConfig != null) { + return downloadDatabaseOnPipelineCreation(GEOIP_TYPE, processorConfig, null) == downloadDatabaseOnPipelineCreation; } return isProcessorWithOnFailureGeoIpProcessor(processor, downloadDatabaseOnPipelineCreation) @@ -336,11 +337,9 @@ && hasAtLeastOneGeoipProcessor( */ @SuppressWarnings("unchecked") private static boolean isForeachProcessorWithGeoipProcessor(Map processor, boolean downloadDatabaseOnPipelineCreation) { - return processor.containsKey("foreach") - && hasAtLeastOneGeoipProcessor( - ((Map>) processor.get("foreach")).get("processor"), - downloadDatabaseOnPipelineCreation - ); + final Map processorConfig = (Map) processor.get("foreach"); + return processorConfig != null + && hasAtLeastOneGeoipProcessor((Map) processorConfig.get("processor"), downloadDatabaseOnPipelineCreation); } @UpdateForV9 // use MINUS_ONE once that means no timeout diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpProcessor.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpProcessor.java index e2b516bf5b943..f8ca6d87924a4 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpProcessor.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpProcessor.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.core.Assertions; +import org.elasticsearch.core.Strings; import org.elasticsearch.ingest.AbstractProcessor; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.ingest.Processor; @@ -22,6 +23,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.function.Supplier; @@ -36,9 +38,12 @@ public final class GeoIpProcessor extends AbstractProcessor { private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(GeoIpProcessor.class); static final String DEFAULT_DATABASES_DEPRECATION_MESSAGE = "the [fallback_to_default_databases] has been deprecated, because " + "Elasticsearch no longer includes the default Maxmind geoip databases. This setting will be removed in Elasticsearch 9.0"; + static final String UNSUPPORTED_DATABASE_DEPRECATION_MESSAGE = "the geoip processor will no longer support database type [{}] " + + "in a future version of Elasticsearch"; // TODO add a message about migration? - public static final String TYPE = "geoip"; + public static final String GEOIP_TYPE = "geoip"; + private final String type; private final String field; private final Supplier isValid; private final String targetField; @@ -62,6 +67,7 @@ public final class GeoIpProcessor extends AbstractProcessor { * @param databaseFile the name of the database file being queried; used only for tagging documents if the database is unavailable */ GeoIpProcessor( + final String type, final String tag, final String description, final String field, @@ -74,6 +80,7 @@ public final class GeoIpProcessor extends AbstractProcessor { final String databaseFile ) { super(tag, description); + this.type = type; this.field = field; this.isValid = isValid; this.targetField = targetField; @@ -93,7 +100,7 @@ public IngestDocument execute(IngestDocument document) throws IOException { Object ip = document.getFieldValue(field, Object.class, ignoreMissing); if (isValid.get() == false) { - document.appendFieldValue("tags", "_geoip_expired_database", false); + document.appendFieldValue("tags", "_" + type + "_expired_database", false); return document; } else if (ip == null && ignoreMissing) { return document; @@ -104,7 +111,7 @@ public IngestDocument execute(IngestDocument document) throws IOException { try (IpDatabase ipDatabase = this.supplier.get()) { if (ipDatabase == null) { if (ignoreMissing == false) { - tag(document, databaseFile); + tag(document, type, databaseFile); } return document; } @@ -146,7 +153,7 @@ public IngestDocument execute(IngestDocument document) throws IOException { @Override public String getType() { - return TYPE; + return type; } String getField() { @@ -202,9 +209,11 @@ public IpDatabase get() throws IOException { public static final class Factory implements Processor.Factory { + private final String type; // currently always just "geoip" private final IpDatabaseProvider ipDatabaseProvider; - public Factory(IpDatabaseProvider ipDatabaseProvider) { + public Factory(String type, IpDatabaseProvider ipDatabaseProvider) { + this.type = type; this.ipDatabaseProvider = ipDatabaseProvider; } @@ -215,16 +224,16 @@ public Processor create( final String description, final Map config ) throws IOException { - String ipField = readStringProperty(TYPE, processorTag, config, "field"); - String targetField = readStringProperty(TYPE, processorTag, config, "target_field", "geoip"); - String databaseFile = readStringProperty(TYPE, processorTag, config, "database_file", "GeoLite2-City.mmdb"); - List propertyNames = readOptionalList(TYPE, processorTag, config, "properties"); - boolean ignoreMissing = readBooleanProperty(TYPE, processorTag, config, "ignore_missing", false); - boolean firstOnly = readBooleanProperty(TYPE, processorTag, config, "first_only", true); + String ipField = readStringProperty(type, processorTag, config, "field"); + String targetField = readStringProperty(type, processorTag, config, "target_field", "geoip"); + String databaseFile = readStringProperty(type, processorTag, config, "database_file", "GeoLite2-City.mmdb"); + List propertyNames = readOptionalList(type, processorTag, config, "properties"); + boolean ignoreMissing = readBooleanProperty(type, processorTag, config, "ignore_missing", false); + boolean firstOnly = readBooleanProperty(type, processorTag, config, "first_only", true); // Validating the download_database_on_pipeline_creation even if the result // is not used directly by the factory. - downloadDatabaseOnPipelineCreation(config, processorTag); + downloadDatabaseOnPipelineCreation(type, config, processorTag); // noop, should be removed in 9.0 Object value = config.remove("fallback_to_default_databases"); @@ -239,7 +248,7 @@ public Processor create( // at a later moment, so a processor impl is returned that tags documents instead. If a database cannot be sourced // then the processor will continue to tag documents with a warning until it is remediated by providing a database // or changing the pipeline. - return new DatabaseUnavailableProcessor(processorTag, description, databaseFile); + return new DatabaseUnavailableProcessor(type, processorTag, description, databaseFile); } databaseType = ipDatabase.getDatabaseType(); } @@ -248,17 +257,48 @@ public Processor create( try { factory = IpDataLookupFactories.get(databaseType, databaseFile); } catch (IllegalArgumentException e) { - throw newConfigurationException(TYPE, processorTag, "database_file", e.getMessage()); + throw newConfigurationException(type, processorTag, "database_file", e.getMessage()); + } + + // the "geoip" processor type does additional validation of the database_type + if (GEOIP_TYPE.equals(type)) { + // type sniffing is done with the lowercased type + final String lowerCaseDatabaseType = databaseType.toLowerCase(Locale.ROOT); + + // start with a strict positive rejection check -- as we support addition database providers, + // we should expand these checks when possible + if (lowerCaseDatabaseType.startsWith(IpinfoIpDataLookups.IPINFO_PREFIX)) { + throw newConfigurationException( + type, + processorTag, + "database_file", + Strings.format("Unsupported database type [%s] for file [%s]", databaseType, databaseFile) + ); + } + + // end with a lax negative rejection check -- if we aren't *certain* it's a maxmind database, then we'll warn -- + // it's possible for example that somebody cooked up a custom database of their own that happened to work with + // our preexisting code, they should migrate to the new processor, but we're not going to break them right now + if (lowerCaseDatabaseType.startsWith(MaxmindIpDataLookups.GEOIP2_PREFIX) == false + && lowerCaseDatabaseType.startsWith(MaxmindIpDataLookups.GEOLITE2_PREFIX) == false) { + deprecationLogger.warn( + DeprecationCategory.OTHER, + "unsupported_database_type", + UNSUPPORTED_DATABASE_DEPRECATION_MESSAGE, + databaseType + ); + } } final IpDataLookup ipDataLookup; try { ipDataLookup = factory.create(propertyNames); } catch (IllegalArgumentException e) { - throw newConfigurationException(TYPE, processorTag, "properties", e.getMessage()); + throw newConfigurationException(type, processorTag, "properties", e.getMessage()); } return new GeoIpProcessor( + type, processorTag, description, ipField, @@ -272,34 +312,31 @@ public Processor create( ); } - public static boolean downloadDatabaseOnPipelineCreation(Map config) { - return downloadDatabaseOnPipelineCreation(config, null); - } - - public static boolean downloadDatabaseOnPipelineCreation(Map config, String processorTag) { - return readBooleanProperty(GeoIpProcessor.TYPE, processorTag, config, "download_database_on_pipeline_creation", true); + public static boolean downloadDatabaseOnPipelineCreation(String type, Map config, String processorTag) { + return readBooleanProperty(type, processorTag, config, "download_database_on_pipeline_creation", true); } - } static class DatabaseUnavailableProcessor extends AbstractProcessor { + private final String type; private final String databaseName; - DatabaseUnavailableProcessor(String tag, String description, String databaseName) { + DatabaseUnavailableProcessor(String type, String tag, String description, String databaseName) { super(tag, description); + this.type = type; this.databaseName = databaseName; } @Override public IngestDocument execute(IngestDocument ingestDocument) throws Exception { - tag(ingestDocument, databaseName); + tag(ingestDocument, this.type, databaseName); return ingestDocument; } @Override public String getType() { - return TYPE; + return type; } public String getDatabaseName() { @@ -307,7 +344,7 @@ public String getDatabaseName() { } } - private static void tag(IngestDocument ingestDocument, String databaseName) { - ingestDocument.appendFieldValue("tags", "_geoip_database_unavailable_" + databaseName, true); + private static void tag(IngestDocument ingestDocument, String type, String databaseName) { + ingestDocument.appendFieldValue("tags", "_" + type + "_database_unavailable_" + databaseName, true); } } diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpPlugin.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpPlugin.java index f5ae869841b82..49932f342086e 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpPlugin.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpPlugin.java @@ -129,7 +129,7 @@ public Map getProcessors(Processor.Parameters paramet parameters.ingestService.getClusterService() ); databaseRegistry.set(registry); - return Map.of(GeoIpProcessor.TYPE, new GeoIpProcessor.Factory(registry)); + return Map.of(GeoIpProcessor.GEOIP_TYPE, new GeoIpProcessor.Factory(GeoIpProcessor.GEOIP_TYPE, registry)); } @Override diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookups.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookups.java index 06051879a0745..efc6734b3bd93 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookups.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookups.java @@ -39,6 +39,10 @@ private IpinfoIpDataLookups() { private static final Logger logger = LogManager.getLogger(IpinfoIpDataLookups.class); + // the actual prefix from the metadata is cased like the literal string, and + // prefix dispatch and checks case-insensitive, so that works out nicely + static final String IPINFO_PREFIX = "ipinfo"; + /** * Lax-ly parses a string that (ideally) looks like 'AS123' into a Long like 123L (or null, if such parsing isn't possible). * @param asn a potentially empty (or null) ASN string that is expected to contain 'AS' and then a parsable long diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java index e7c3481938033..5fe2e980d2ab0 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java @@ -34,6 +34,7 @@ import java.net.InetAddress; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; @@ -46,6 +47,11 @@ private MaxmindIpDataLookups() { // utility class } + // the actual prefixes from the metadata are cased like the literal strings, but + // prefix dispatch and checks case-insensitive, so the actual constants are lowercase + static final String GEOIP2_PREFIX = "GeoIP2".toLowerCase(Locale.ROOT); + static final String GEOLITE2_PREFIX = "GeoLite2".toLowerCase(Locale.ROOT); + static class AnonymousIp extends AbstractBase { AnonymousIp(final Set properties) { super( diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java index cfea54d2520bd..bf268e17edccb 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java @@ -43,6 +43,7 @@ import java.util.Map; import java.util.Set; +import static org.elasticsearch.ingest.geoip.GeoIpProcessor.GEOIP_TYPE; import static org.elasticsearch.ingest.geoip.GeoIpTestUtils.DEFAULT_DATABASES; import static org.elasticsearch.ingest.geoip.GeoIpTestUtils.copyDatabase; import static org.elasticsearch.ingest.geoip.GeoIpTestUtils.copyDefaultDatabases; @@ -88,7 +89,7 @@ public void closeDatabaseReaders() throws IOException { } public void testBuildDefaults() throws Exception { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "_field"); @@ -104,7 +105,7 @@ public void testBuildDefaults() throws Exception { } public void testSetIgnoreMissing() throws Exception { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "_field"); @@ -121,7 +122,7 @@ public void testSetIgnoreMissing() throws Exception { } public void testCountryBuildDefaults() throws Exception { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "_field"); @@ -139,7 +140,7 @@ public void testCountryBuildDefaults() throws Exception { } public void testAsnBuildDefaults() throws Exception { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "_field"); @@ -157,7 +158,7 @@ public void testAsnBuildDefaults() throws Exception { } public void testBuildTargetField() throws Exception { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "_field"); config.put("target_field", "_field"); @@ -168,7 +169,7 @@ public void testBuildTargetField() throws Exception { } public void testBuildDbFile() throws Exception { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "_field"); config.put("database_file", "GeoLite2-Country.mmdb"); @@ -181,7 +182,7 @@ public void testBuildDbFile() throws Exception { } public void testBuildWithCountryDbAndAsnFields() { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "_field"); config.put("database_file", "GeoLite2-Country.mmdb"); @@ -201,7 +202,7 @@ public void testBuildWithCountryDbAndAsnFields() { } public void testBuildWithAsnDbAndCityFields() { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "_field"); config.put("database_file", "GeoLite2-ASN.mmdb"); @@ -219,7 +220,7 @@ public void testBuildWithAsnDbAndCityFields() { public void testBuildNonExistingDbFile() throws Exception { copyDatabase("GeoLite2-City-Test.mmdb", geoipTmpDir.resolve("GeoLite2-City.mmdb")); databaseNodeService.updateDatabase("GeoLite2-City.mmdb", "md5", geoipTmpDir.resolve("GeoLite2-City.mmdb")); - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "_field"); @@ -229,7 +230,7 @@ public void testBuildNonExistingDbFile() throws Exception { } public void testBuildBuiltinDatabaseMissing() throws Exception { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); cleanDatabases(geoIpConfigDir, configDatabases); Map config = new HashMap<>(); @@ -240,7 +241,7 @@ public void testBuildBuiltinDatabaseMissing() throws Exception { } public void testBuildFields() throws Exception { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Set properties = new HashSet<>(); List fieldNames = new ArrayList<>(); @@ -264,7 +265,7 @@ public void testBuildFields() throws Exception { } public void testBuildIllegalFieldOption() { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config1 = new HashMap<>(); config1.put("field", "_field"); @@ -293,7 +294,7 @@ public void testBuildUnsupportedDatabase() throws Exception { IpDatabaseProvider provider = mock(IpDatabaseProvider.class); when(provider.getDatabase(anyString())).thenReturn(database); - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(provider); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, provider); Map config1 = new HashMap<>(); config1.put("field", "_field"); @@ -312,7 +313,7 @@ public void testBuildNullDatabase() throws Exception { IpDatabaseProvider provider = mock(IpDatabaseProvider.class); when(provider.getDatabase(anyString())).thenReturn(database); - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(provider); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, provider); Map config1 = new HashMap<>(); config1.put("field", "_field"); @@ -321,6 +322,44 @@ public void testBuildNullDatabase() throws Exception { assertThat(e.getMessage(), equalTo("[database_file] Unsupported database type [null] for file [GeoLite2-City.mmdb]")); } + public void testStrictMaxmindSupport() throws Exception { + IpDatabase database = mock(IpDatabase.class); + when(database.getDatabaseType()).thenReturn("ipinfo some_ipinfo_database.mmdb-City"); + IpDatabaseProvider provider = mock(IpDatabaseProvider.class); + when(provider.getDatabase(anyString())).thenReturn(database); + + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, provider); + + Map config1 = new HashMap<>(); + config1.put("database_file", "some-ipinfo-database.mmdb"); + config1.put("field", "_field"); + config1.put("properties", List.of("ip")); + Exception e = expectThrows(ElasticsearchParseException.class, () -> factory.create(null, null, null, config1)); + assertThat( + e.getMessage(), + equalTo( + "[database_file] Unsupported database type [ipinfo some_ipinfo_database.mmdb-City] " + + "for file [some-ipinfo-database.mmdb]" + ) + ); + } + + public void testLaxMaxmindSupport() throws Exception { + IpDatabase database = mock(IpDatabase.class); + when(database.getDatabaseType()).thenReturn("some_custom_database.mmdb-City"); + IpDatabaseProvider provider = mock(IpDatabaseProvider.class); + when(provider.getDatabase(anyString())).thenReturn(database); + + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, provider); + + Map config1 = new HashMap<>(); + config1.put("database_file", "some-custom-database.mmdb"); + config1.put("field", "_field"); + config1.put("properties", List.of("ip")); + factory.create(null, null, null, config1); + assertWarnings(GeoIpProcessor.UNSUPPORTED_DATABASE_DEPRECATION_MESSAGE.replaceAll("\\{}", "some_custom_database.mmdb-City")); + } + public void testLazyLoading() throws Exception { final Path configDir = createTempDir(); final Path geoIpConfigDir = configDir.resolve("ingest-geoip"); @@ -341,7 +380,7 @@ public void testLazyLoading() throws Exception { Runnable::run, clusterService ); - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); for (DatabaseReaderLazyLoader lazyLoader : configDatabases.getConfigDatabases().values()) { assertNull(lazyLoader.databaseReader.get()); } @@ -410,7 +449,7 @@ public void testLoadingCustomDatabase() throws IOException { clusterService ); databaseNodeService.initialize("nodeId", resourceWatcherService, mock(IngestService.class)); - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); for (DatabaseReaderLazyLoader lazyLoader : configDatabases.getConfigDatabases().values()) { assertNull(lazyLoader.databaseReader.get()); } @@ -433,7 +472,7 @@ public void testLoadingCustomDatabase() throws IOException { } public void testFallbackUsingDefaultDatabases() throws Exception { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "source_field"); config.put("fallback_to_default_databases", randomBoolean()); @@ -442,7 +481,7 @@ public void testFallbackUsingDefaultDatabases() throws Exception { } public void testDownloadDatabaseOnPipelineCreation() throws IOException { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", randomIdentifier()); config.put("download_database_on_pipeline_creation", randomBoolean()); @@ -460,7 +499,7 @@ public void testDefaultDatabaseWithTaskPresent() throws Exception { .metadata(Metadata.builder().putCustom(PersistentTasksCustomMetadata.TYPE, tasks)) .build(); when(clusterService.state()).thenReturn(clusterState); - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "_field"); @@ -472,7 +511,7 @@ public void testDefaultDatabaseWithTaskPresent() throws Exception { } public void testUpdateDatabaseWhileIngesting() throws Exception { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); Map config = new HashMap<>(); config.put("field", "source_field"); GeoIpProcessor processor = (GeoIpProcessor) factory.create(null, null, null, config); @@ -511,7 +550,7 @@ public void testUpdateDatabaseWhileIngesting() throws Exception { } public void testDatabaseNotReadyYet() throws Exception { - GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(databaseNodeService); + GeoIpProcessor.Factory factory = new GeoIpProcessor.Factory(GEOIP_TYPE, databaseNodeService); cleanDatabases(geoIpConfigDir, configDatabases); { diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java index ffc40324bd886..fbceac3b9cce6 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java @@ -28,6 +28,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import static org.elasticsearch.ingest.IngestDocumentMatcher.assertIngestDocument; +import static org.elasticsearch.ingest.geoip.GeoIpProcessor.GEOIP_TYPE; import static org.elasticsearch.ingest.geoip.GeoIpTestUtils.copyDatabase; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; @@ -85,6 +86,7 @@ public void testDatabasePropertyInvariants() { public void testCity() throws Exception { String ip = "8.8.8.8"; GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -119,6 +121,7 @@ public void testCity() throws Exception { public void testNullValueWithIgnoreMissing() throws Exception { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -141,6 +144,7 @@ public void testNullValueWithIgnoreMissing() throws Exception { public void testNonExistentWithIgnoreMissing() throws Exception { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -160,6 +164,7 @@ public void testNonExistentWithIgnoreMissing() throws Exception { public void testNullWithoutIgnoreMissing() { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -182,6 +187,7 @@ public void testNullWithoutIgnoreMissing() { public void testNonExistentWithoutIgnoreMissing() { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -202,6 +208,7 @@ public void testNonExistentWithoutIgnoreMissing() { public void testCity_withIpV6() throws Exception { String ip = "2602:306:33d3:8000::3257:9652"; GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -242,6 +249,7 @@ public void testCity_withIpV6() throws Exception { public void testCityWithMissingLocation() throws Exception { String ip = "80.231.5.0"; GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -270,6 +278,7 @@ public void testCityWithMissingLocation() throws Exception { public void testCountry() throws Exception { String ip = "82.170.213.79"; GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -303,6 +312,7 @@ public void testCountry() throws Exception { public void testCountryWithMissingLocation() throws Exception { String ip = "80.231.5.0"; GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -331,6 +341,7 @@ public void testCountryWithMissingLocation() throws Exception { public void testAsn() throws Exception { String ip = "82.171.64.0"; GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -362,6 +373,7 @@ public void testAsn() throws Exception { public void testAnonymmousIp() throws Exception { String ip = "81.2.69.1"; GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -396,6 +408,7 @@ public void testAnonymmousIp() throws Exception { public void testConnectionType() throws Exception { String ip = "214.78.120.5"; GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -425,6 +438,7 @@ public void testConnectionType() throws Exception { public void testDomain() throws Exception { String ip = "69.219.64.2"; GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -454,6 +468,7 @@ public void testDomain() throws Exception { public void testEnterprise() throws Exception { String ip = "74.209.24.4"; GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -511,6 +526,7 @@ public void testEnterprise() throws Exception { public void testIsp() throws Exception { String ip = "149.101.100.1"; GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -545,6 +561,7 @@ public void testIsp() throws Exception { public void testAddressIsNotInTheDatabase() throws Exception { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -569,6 +586,7 @@ public void testAddressIsNotInTheDatabase() throws Exception { */ public void testInvalid() { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -590,6 +608,7 @@ public void testInvalid() { public void testListAllValid() throws Exception { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -617,6 +636,7 @@ public void testListAllValid() throws Exception { public void testListPartiallyValid() throws Exception { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -644,6 +664,7 @@ public void testListPartiallyValid() throws Exception { public void testListNoMatches() throws Exception { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -667,7 +688,7 @@ public void testListNoMatches() throws Exception { public void testListDatabaseReferenceCounting() throws Exception { AtomicBoolean closeCheck = new AtomicBoolean(false); var loader = loader("GeoLite2-City.mmdb", closeCheck); - GeoIpProcessor processor = new GeoIpProcessor(randomAlphaOfLength(10), null, "source_field", () -> { + GeoIpProcessor processor = new GeoIpProcessor(GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", () -> { loader.preLookup(); return loader; }, () -> true, "target_field", ipDataLookupAll(Database.City), false, false, "filename"); @@ -692,6 +713,7 @@ public void testListDatabaseReferenceCounting() throws Exception { public void testListFirstOnly() throws Exception { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -717,6 +739,7 @@ public void testListFirstOnly() throws Exception { public void testListFirstOnlyNoMatches() throws Exception { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -739,6 +762,7 @@ public void testListFirstOnlyNoMatches() throws Exception { public void testInvalidDatabase() throws Exception { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -762,6 +786,7 @@ public void testInvalidDatabase() throws Exception { public void testNoDatabase() throws Exception { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", @@ -785,6 +810,7 @@ public void testNoDatabase() throws Exception { public void testNoDatabase_ignoreMissing() throws Exception { GeoIpProcessor processor = new GeoIpProcessor( + GEOIP_TYPE, randomAlphaOfLength(10), null, "source_field", From 94bd6878e49655d76b7b7564fb878aa908164399 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Fri, 11 Oct 2024 08:40:07 +1100 Subject: [PATCH 07/25] Mute org.elasticsearch.xpack.inference.InferenceRestIT test {p0=inference/30_semantic_text_inference/Calculates embeddings using the default ELSER 2 endpoint} #114412 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 987174517d6f9..ab1bcee20c6fd 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -292,6 +292,9 @@ tests: issue: https://github.com/elastic/elasticsearch/issues/113874 - class: org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanosTests issue: https://github.com/elastic/elasticsearch/issues/113661 +- class: org.elasticsearch.xpack.inference.InferenceRestIT + method: test {p0=inference/30_semantic_text_inference/Calculates embeddings using the default ELSER 2 endpoint} + issue: https://github.com/elastic/elasticsearch/issues/114412 # Examples: # From 215d7c096412cbddd268ce5e26deed4d34f48cc7 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Fri, 11 Oct 2024 08:40:14 +1100 Subject: [PATCH 08/25] Mute org.elasticsearch.xpack.inference.InferenceRestIT test {p0=inference/40_semantic_text_query/Query a field that uses the default ELSER 2 endpoint} #114376 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index ab1bcee20c6fd..19b684ffbff20 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -295,6 +295,9 @@ tests: - class: org.elasticsearch.xpack.inference.InferenceRestIT method: test {p0=inference/30_semantic_text_inference/Calculates embeddings using the default ELSER 2 endpoint} issue: https://github.com/elastic/elasticsearch/issues/114412 +- class: org.elasticsearch.xpack.inference.InferenceRestIT + method: test {p0=inference/40_semantic_text_query/Query a field that uses the default ELSER 2 endpoint} + issue: https://github.com/elastic/elasticsearch/issues/114376 # Examples: # From b2792a9a78c368cd262b703fcb24661affd1305a Mon Sep 17 00:00:00 2001 From: David Turner Date: Thu, 10 Oct 2024 22:44:06 +0100 Subject: [PATCH 09/25] Remove unnecessary test overrides (#114291) (#114316) These test overrides were introduced so that we had somewhere to hang an `@AwaitsFix` annotation, but now the tests are unmuted again there's no need for the overrides. Relates #108336 --- .../AzureStorageCleanupThirdPartyTests.java | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/modules/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureStorageCleanupThirdPartyTests.java b/modules/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureStorageCleanupThirdPartyTests.java index b5987bf6338bb..b2b6eba4a7c03 100644 --- a/modules/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureStorageCleanupThirdPartyTests.java +++ b/modules/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureStorageCleanupThirdPartyTests.java @@ -59,31 +59,6 @@ public class AzureStorageCleanupThirdPartyTests extends AbstractThirdPartyReposi AzureHttpFixture.sharedKeyForAccountPredicate(AZURE_ACCOUNT) ); - @Override - public void testCreateSnapshot() { - super.testCreateSnapshot(); - } - - @Override - public void testIndexLatest() throws Exception { - super.testIndexLatest(); - } - - @Override - public void testListChildren() { - super.testListChildren(); - } - - @Override - public void testCleanup() throws Exception { - super.testCleanup(); - } - - @Override - public void testReadFromPositionWithLength() { - super.testReadFromPositionWithLength(); - } - @Override protected Collection> getPlugins() { return pluginList(AzureRepositoryPlugin.class); From 90241ac2d34b12631ae427c4b4725bfe63a37d9b Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Thu, 10 Oct 2024 18:18:20 -0400 Subject: [PATCH 10/25] [ML] Stream Azure Completion (#114464) (#114534) Includes both Azure AI Studio and Azure Open AI. Both streaming responses are processed using Open AI's SSE format. --- docs/changelog/114464.yaml | 5 ++ .../AzureOpenAiResponseHandler.java | 5 +- ...eAiStudioChatCompletionRequestManager.java | 10 ++- ...AzureAiStudioEmbeddingsRequestManager.java | 3 +- .../AzureOpenAiCompletionRequestManager.java | 9 ++- .../AzureOpenAiEmbeddingsRequestManager.java | 2 +- .../MistralEmbeddingsRequestManager.java | 3 +- .../AzureAiStudioChatCompletionRequest.java | 12 ++- ...reAiStudioChatCompletionRequestEntity.java | 8 +- .../AzureAiStudioRequestFields.java | 1 + .../AzureOpenAiCompletionRequest.java | 12 ++- .../AzureOpenAiCompletionRequestEntity.java | 8 +- ...eMistralOpenAiExternalResponseHandler.java | 27 ++++++- .../azureaistudio/AzureAiStudioService.java | 6 ++ .../azureopenai/AzureOpenAiService.java | 6 ++ ...tudioChatCompletionRequestEntityTests.java | 65 ++++++++++++--- ...ureAiStudioChatCompletionRequestTests.java | 2 +- ...ureOpenAiCompletionRequestEntityTests.java | 4 +- .../AzureOpenAiCompletionRequestTests.java | 2 +- ...AndOpenAiExternalResponseHandlerTests.java | 3 +- ...udioChatCompletionResponseEntityTests.java | 4 +- .../AzureAiStudioServiceTests.java | 78 ++++++++++++++++++ .../azureopenai/AzureOpenAiServiceTests.java | 81 +++++++++++++++++++ 23 files changed, 318 insertions(+), 38 deletions(-) create mode 100644 docs/changelog/114464.yaml diff --git a/docs/changelog/114464.yaml b/docs/changelog/114464.yaml new file mode 100644 index 0000000000000..5f5ee816aa28d --- /dev/null +++ b/docs/changelog/114464.yaml @@ -0,0 +1,5 @@ +pr: 114464 +summary: Stream Azure Completion +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java index 4b2168a42e3ac..f6f907569ffda 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java @@ -31,8 +31,8 @@ public class AzureOpenAiResponseHandler extends OpenAiResponseHandler { // The remaining number of tokens that are permitted before exhausting the rate limit. static final String REMAINING_TOKENS = "x-ratelimit-remaining-tokens"; - public AzureOpenAiResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, false); + public AzureOpenAiResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) { + super(requestType, parseFunction, canHandleStreamingResponses); } @Override @@ -48,5 +48,4 @@ static String buildRateLimitErrorMessage(HttpResult result) { return RATE_LIMIT + ". " + usageMessage; } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java index c5e5a5251f7db..21cec68b14a49 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.inference.external.response.azureaistudio.AzureAiStudioChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel; -import java.util.List; import java.util.function.Supplier; public class AzureAiStudioChatCompletionRequestManager extends AzureAiStudioRequestManager { @@ -42,8 +41,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); - AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, docsInput); + var docsOnly = DocumentsOnlyInput.of(inferenceInputs); + var docsInput = docsOnly.getInputs(); + var stream = docsOnly.stream(); + AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, docsInput, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } @@ -52,7 +53,8 @@ private static ResponseHandler createCompletionHandler() { return new AzureMistralOpenAiExternalResponseHandler( "azure ai studio completion", new AzureAiStudioChatCompletionResponseEntity(), - ErrorMessageResponseEntity::fromResponse + ErrorMessageResponseEntity::fromResponse, + true ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java index c610a7f31f7ba..5f4984fabab69 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java @@ -56,7 +56,8 @@ private static ResponseHandler createEmbeddingsHandler() { return new AzureMistralOpenAiExternalResponseHandler( "azure ai studio text embedding", new AzureAiStudioEmbeddingsResponseEntity(), - ErrorMessageResponseEntity::fromResponse + ErrorMessageResponseEntity::fromResponse, + false ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java index 8c9b848f78e3c..d036559ec3dcb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.inference.external.response.azureopenai.AzureOpenAiCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel; -import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -32,7 +31,7 @@ public class AzureOpenAiCompletionRequestManager extends AzureOpenAiRequestManag private final AzureOpenAiCompletionModel model; private static ResponseHandler createCompletionHandler() { - return new AzureOpenAiResponseHandler("azure openai completion", AzureOpenAiCompletionResponseEntity::fromResponse); + return new AzureOpenAiResponseHandler("azure openai completion", AzureOpenAiCompletionResponseEntity::fromResponse, true); } public AzureOpenAiCompletionRequestManager(AzureOpenAiCompletionModel model, ThreadPool threadPool) { @@ -47,8 +46,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); - AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(docsInput, model); + var docsOnly = DocumentsOnlyInput.of(inferenceInputs); + var docsInput = docsOnly.getInputs(); + var stream = docsOnly.stream(); + AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(docsInput, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java index 8d4162858b36f..fc39bd4af96d4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java @@ -33,7 +33,7 @@ public class AzureOpenAiEmbeddingsRequestManager extends AzureOpenAiRequestManag private static final ResponseHandler HANDLER = createEmbeddingsHandler(); private static ResponseHandler createEmbeddingsHandler() { - return new AzureOpenAiResponseHandler("azure openai text embedding", OpenAiEmbeddingsResponseEntity::fromResponse); + return new AzureOpenAiResponseHandler("azure openai text embedding", OpenAiEmbeddingsResponseEntity::fromResponse, false); } public static AzureOpenAiEmbeddingsRequestManager of(AzureOpenAiEmbeddingsModel model, Truncator truncator, ThreadPool threadPool) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java index d550749cc2348..d18c3227ed444 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java @@ -38,7 +38,8 @@ private static ResponseHandler createEmbeddingsHandler() { return new AzureMistralOpenAiExternalResponseHandler( "mistral text embedding", new MistralEmbeddingsResponseEntity(), - ErrorMessageResponseEntity::fromResponse + ErrorMessageResponseEntity::fromResponse, + false ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequest.java index b913f79e39202..377afee12f394 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequest.java @@ -23,11 +23,13 @@ public class AzureAiStudioChatCompletionRequest extends AzureAiStudioRequest { private final List input; private final AzureAiStudioChatCompletionModel completionModel; + private final boolean stream; - public AzureAiStudioChatCompletionRequest(AzureAiStudioChatCompletionModel model, List input) { + public AzureAiStudioChatCompletionRequest(AzureAiStudioChatCompletionModel model, List input, boolean stream) { super(model); this.input = Objects.requireNonNull(input); this.completionModel = Objects.requireNonNull(model); + this.stream = stream; } public boolean isRealtimeEndpoint() { @@ -59,6 +61,11 @@ public boolean[] getTruncationInfo() { return null; } + @Override + public boolean isStreaming() { + return stream; + } + private AzureAiStudioChatCompletionRequestEntity createRequestEntity() { var taskSettings = completionModel.getTaskSettings(); var serviceSettings = completionModel.getServiceSettings(); @@ -68,7 +75,8 @@ private AzureAiStudioChatCompletionRequestEntity createRequestEntity() { taskSettings.temperature(), taskSettings.topP(), taskSettings.doSample(), - taskSettings.maxNewTokens() + taskSettings.maxNewTokens(), + isStreaming() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestEntity.java index a4f685530f942..ce9c8c662bfb4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestEntity.java @@ -22,6 +22,7 @@ import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.MESSAGE_CONTENT; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.PARAMETERS_OBJECT; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.ROLE; +import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.STREAM; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.USER_ROLE; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.DO_SAMPLE_FIELD; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.MAX_NEW_TOKENS_FIELD; @@ -34,7 +35,8 @@ public record AzureAiStudioChatCompletionRequestEntity( @Nullable Double temperature, @Nullable Double topP, @Nullable Boolean doSample, - @Nullable Integer maxNewTokens + @Nullable Integer maxNewTokens, + boolean stream ) implements ToXContentObject { public AzureAiStudioChatCompletionRequestEntity { @@ -52,6 +54,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws createRealtimeRequest(builder, params); } + if (stream) { + builder.field(STREAM, true); + } + builder.endObject(); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioRequestFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioRequestFields.java index ad10410792867..56c26775bc01d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioRequestFields.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioRequestFields.java @@ -16,6 +16,7 @@ public final class AzureAiStudioRequestFields { public static final String MESSAGE_CONTENT = "content"; public static final String ROLE = "role"; public static final String USER_ROLE = "user"; + public static final String STREAM = "stream"; private AzureAiStudioRequestFields() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequest.java index 8854dc7950365..41f05b500efa8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequest.java @@ -27,16 +27,19 @@ public class AzureOpenAiCompletionRequest implements AzureOpenAiRequest { private final AzureOpenAiCompletionModel model; - public AzureOpenAiCompletionRequest(List input, AzureOpenAiCompletionModel model) { + private final boolean stream; + + public AzureOpenAiCompletionRequest(List input, AzureOpenAiCompletionModel model, boolean stream) { this.input = input; this.model = Objects.requireNonNull(model); this.uri = model.getUri(); + this.stream = stream; } @Override public HttpRequest createHttpRequest() { var httpPost = new HttpPost(uri); - var requestEntity = Strings.toString(new AzureOpenAiCompletionRequestEntity(input, model.getTaskSettings().user())); + var requestEntity = Strings.toString(new AzureOpenAiCompletionRequestEntity(input, model.getTaskSettings().user(), isStreaming())); ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); httpPost.setEntity(byteEntity); @@ -56,6 +59,11 @@ public String getInferenceEntityId() { return model.getInferenceEntityId(); } + @Override + public boolean isStreaming() { + return stream; + } + @Override public Request truncate() { // No truncation for Azure OpenAI completion diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequestEntity.java index 86614ef32855f..725e51c06c494 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequestEntity.java @@ -16,7 +16,7 @@ import java.util.List; import java.util.Objects; -public record AzureOpenAiCompletionRequestEntity(List messages, @Nullable String user) implements ToXContentObject { +public record AzureOpenAiCompletionRequestEntity(List messages, @Nullable String user, boolean stream) implements ToXContentObject { private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; @@ -28,6 +28,8 @@ public record AzureOpenAiCompletionRequestEntity(List messages, @Nullabl private static final String USER_FIELD = "user"; + private static final String STREAM_FIELD = "stream"; + public AzureOpenAiCompletionRequestEntity { Objects.requireNonNull(messages); } @@ -58,6 +60,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(USER_FIELD, user); } + if (stream) { + builder.field(STREAM_FIELD, true); + } + builder.endObject(); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiExternalResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiExternalResponseHandler.java index e4e96ca644c7f..01b463d4bc8bd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiExternalResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiExternalResponseHandler.java @@ -9,15 +9,21 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.ContentTooLargeException; import org.elasticsearch.xpack.inference.external.http.retry.ErrorMessage; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.openai.OpenAiStreamingProcessor; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import java.util.concurrent.Flow; import java.util.function.Function; import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; @@ -43,12 +49,16 @@ public class AzureMistralOpenAiExternalResponseHandler extends BaseResponseHandl static final String CONTENT_TOO_LARGE_MESSAGE = "Please reduce your prompt; or completion length."; static final String SERVER_BUSY_ERROR = "Received a server busy error status code"; + private final boolean canHandleStreamingResponses; + public AzureMistralOpenAiExternalResponseHandler( String requestType, ResponseParser parseFunction, - Function errorParseFunction + Function errorParseFunction, + boolean canHandleStreamingResponses ) { super(requestType, parseFunction, errorParseFunction); + this.canHandleStreamingResponses = canHandleStreamingResponses; } @Override @@ -58,6 +68,21 @@ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, R checkForEmptyBody(throttlerManager, logger, request, result); } + @Override + public boolean canHandleStreamingResponses() { + return canHandleStreamingResponses; + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); + var openAiProcessor = new OpenAiStreamingProcessor(); + + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(openAiProcessor); + return new StreamingChatCompletionResults(openAiProcessor); + } + public void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { int statusCode = result.response().getStatusLine().getStatusCode(); if (statusCode >= 200 && statusCode < 300) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index c1ca50d41268e..ba36febc3c162 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -43,6 +43,7 @@ import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -216,6 +217,11 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO; } + @Override + public Set supportedStreamingTasks() { + return COMPLETION_ONLY; + } + private static AzureAiStudioModel createModel( String inferenceEntityId, TaskType taskType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 07708ee072099..9f7bcfc256117 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -42,6 +42,7 @@ import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -325,4 +326,9 @@ private AzureOpenAiEmbeddingsModel updateModelWithEmbeddingDetails(AzureOpenAiEm public TransportVersion getMinimalSupportedVersion() { return TransportVersions.ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED; } + + @Override + public Set supportedStreamingTasks() { + return COMPLETION_ONLY; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestEntityTests.java index 3b086f4d3b900..59bcb82171349 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestEntityTests.java @@ -23,35 +23,75 @@ public class AzureAiStudioChatCompletionRequestEntityTests extends ESTestCase { public void testToXContent_WhenTokenEndpoint_NoParameters() throws IOException { - var entity = new AzureAiStudioChatCompletionRequestEntity(List.of("abc"), AzureAiStudioEndpointType.TOKEN, null, null, null, null); + var entity = new AzureAiStudioChatCompletionRequestEntity( + List.of("abc"), + AzureAiStudioEndpointType.TOKEN, + null, + null, + null, + null, + false + ); var request = getXContentAsString(entity); var expectedRequest = getExpectedTokenEndpointRequest(List.of("abc"), null, null, null, null); assertThat(request, is(expectedRequest)); } public void testToXContent_WhenTokenEndpoint_WithTemperatureParam() throws IOException { - var entity = new AzureAiStudioChatCompletionRequestEntity(List.of("abc"), AzureAiStudioEndpointType.TOKEN, 1.0, null, null, null); + var entity = new AzureAiStudioChatCompletionRequestEntity( + List.of("abc"), + AzureAiStudioEndpointType.TOKEN, + 1.0, + null, + null, + null, + false + ); var request = getXContentAsString(entity); var expectedRequest = getExpectedTokenEndpointRequest(List.of("abc"), 1.0, null, null, null); assertThat(request, is(expectedRequest)); } public void testToXContent_WhenTokenEndpoint_WithTopPParam() throws IOException { - var entity = new AzureAiStudioChatCompletionRequestEntity(List.of("abc"), AzureAiStudioEndpointType.TOKEN, null, 2.0, null, null); + var entity = new AzureAiStudioChatCompletionRequestEntity( + List.of("abc"), + AzureAiStudioEndpointType.TOKEN, + null, + 2.0, + null, + null, + false + ); var request = getXContentAsString(entity); var expectedRequest = getExpectedTokenEndpointRequest(List.of("abc"), null, 2.0, null, null); assertThat(request, is(expectedRequest)); } public void testToXContent_WhenTokenEndpoint_WithDoSampleParam() throws IOException { - var entity = new AzureAiStudioChatCompletionRequestEntity(List.of("abc"), AzureAiStudioEndpointType.TOKEN, null, null, true, null); + var entity = new AzureAiStudioChatCompletionRequestEntity( + List.of("abc"), + AzureAiStudioEndpointType.TOKEN, + null, + null, + true, + null, + false + ); var request = getXContentAsString(entity); var expectedRequest = getExpectedTokenEndpointRequest(List.of("abc"), null, null, true, null); assertThat(request, is(expectedRequest)); } public void testToXContent_WhenTokenEndpoint_WithMaxNewTokensParam() throws IOException { - var entity = new AzureAiStudioChatCompletionRequestEntity(List.of("abc"), AzureAiStudioEndpointType.TOKEN, null, null, null, 512); + var entity = new AzureAiStudioChatCompletionRequestEntity( + List.of("abc"), + AzureAiStudioEndpointType.TOKEN, + null, + null, + null, + 512, + false + ); var request = getXContentAsString(entity); var expectedRequest = getExpectedTokenEndpointRequest(List.of("abc"), null, null, null, 512); assertThat(request, is(expectedRequest)); @@ -64,7 +104,8 @@ public void testToXContent_WhenRealtimeEndpoint_NoParameters() throws IOExceptio null, null, null, - null + null, + false ); var request = getXContentAsString(entity); var expectedRequest = getExpectedRealtimeEndpointRequest(List.of("abc"), null, null, null, null); @@ -78,7 +119,8 @@ public void testToXContent_WhenRealtimeEndpoint_WithTemperatureParam() throws IO 1.0, null, null, - null + null, + false ); var request = getXContentAsString(entity); var expectedRequest = getExpectedRealtimeEndpointRequest(List.of("abc"), 1.0, null, null, null); @@ -92,7 +134,8 @@ public void testToXContent_WhenRealtimeEndpoint_WithTopPParam() throws IOExcepti null, 2.0, null, - null + null, + false ); var request = getXContentAsString(entity); var expectedRequest = getExpectedRealtimeEndpointRequest(List.of("abc"), null, 2.0, null, null); @@ -106,7 +149,8 @@ public void testToXContent_WhenRealtimeEndpoint_WithDoSampleParam() throws IOExc null, null, true, - null + null, + false ); var request = getXContentAsString(entity); var expectedRequest = getExpectedRealtimeEndpointRequest(List.of("abc"), null, null, true, null); @@ -120,7 +164,8 @@ public void testToXContent_WhenRealtimeEndpoint_WithMaxNewTokensParam() throws I null, null, null, - 512 + 512, + false ); var request = getXContentAsString(entity); var expectedRequest = getExpectedRealtimeEndpointRequest(List.of("abc"), null, null, null, 512); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestTests.java index f3ddf7f9299d9..71d4c0cd42351 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestTests.java @@ -460,6 +460,6 @@ public static AzureAiStudioChatCompletionRequest createRequest( maxNewTokens, null ); - return new AzureAiStudioChatCompletionRequest(model, List.of(input)); + return new AzureAiStudioChatCompletionRequest(model, List.of(input), false); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestEntityTests.java index 7647a4983f4be..6942f62756c50 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestEntityTests.java @@ -22,7 +22,7 @@ public class AzureOpenAiCompletionRequestEntityTests extends ESTestCase { public void testXContent_WritesSingleMessage_DoesNotWriteUserWhenItIsNull() throws IOException { - var entity = new AzureOpenAiCompletionRequestEntity(List.of("input"), null); + var entity = new AzureOpenAiCompletionRequestEntity(List.of("input"), null, false); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -33,7 +33,7 @@ public void testXContent_WritesSingleMessage_DoesNotWriteUserWhenItIsNull() thro } public void testXContent_WritesSingleMessage_WriteUserWhenItIsNull() throws IOException { - var entity = new AzureOpenAiCompletionRequestEntity(List.of("input"), "user"); + var entity = new AzureOpenAiCompletionRequestEntity(List.of("input"), "user", false); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestTests.java index 048d4ea16d56f..d2761bf007927 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestTests.java @@ -94,7 +94,7 @@ protected AzureOpenAiCompletionRequest createRequest( "id" ); - return new AzureOpenAiCompletionRequest(List.of(input), completionModel); + return new AzureOpenAiCompletionRequest(List.of(input), completionModel, false); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java index 53bb38943d35b..d816d7a7e6274 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java @@ -45,7 +45,8 @@ public void testCheckForFailureStatusCode() { var handler = new AzureMistralOpenAiExternalResponseHandler( "", (request, result) -> null, - ErrorMessageResponseEntity::fromResponse + ErrorMessageResponseEntity::fromResponse, + false ); // 200 ok diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioChatCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioChatCompletionResponseEntityTests.java index 7d5aafa181b19..bb518b1860ee8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioChatCompletionResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioChatCompletionResponseEntityTests.java @@ -35,7 +35,7 @@ public void testCompletionResponse_FromTokenEndpoint() throws IOException { AzureAiStudioEndpointType.TOKEN, "apikey" ); - var request = new AzureAiStudioChatCompletionRequest(model, List.of("test input")); + var request = new AzureAiStudioChatCompletionRequest(model, List.of("test input"), false); var result = (ChatCompletionResults) entity.apply( request, new HttpResult(mock(HttpResponse.class), testTokenResponseJson.getBytes(StandardCharsets.UTF_8)) @@ -54,7 +54,7 @@ public void testCompletionResponse_FromRealtimeEndpoint() throws IOException { AzureAiStudioEndpointType.REALTIME, "apikey" ); - var request = new AzureAiStudioChatCompletionRequest(model, List.of("test input")); + var request = new AzureAiStudioChatCompletionRequest(model, List.of("test input"), false); var result = (ChatCompletionResults) entity.apply( request, new HttpResult(mock(HttpResponse.class), testRealtimeResponseJson.getBytes(StandardCharsets.UTF_8)) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 683f32710bcb3..37a8c25461045 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettingsTests; @@ -55,6 +56,7 @@ import org.junit.Before; import java.io.IOException; +import java.net.URISyntaxException; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1305,6 +1307,82 @@ public void testInfer_UnauthorisedResponse() throws IOException { } } + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":null\ + }\ + ]\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var result = streamChatCompletion(); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello, world"}]}"""); + } + + private InferenceServiceResults streamChatCompletion() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + var model = AzureAiStudioChatCompletionModelTests.createModel( + "id", + getUrl(webServer), + AzureAiStudioProvider.OPENAI, + AzureAiStudioEndpointType.TOKEN, + "apikey" + ); + var listener = new PlainActionFuture(); + service.infer( + model, + null, + List.of("abc"), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return listener.actionGet(TIMEOUT); + } + } + + public void testInfer_StreamRequest_ErrorResponse() throws Exception { + String responseJson = """ + { + "error": { + "message": "You didn't provide an API key...", + "type": "invalid_request_error", + "param": null, + "code": null + } + }"""; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var result = streamChatCompletion(); + + InferenceEventsAssertion.assertThat(result) + .hasFinishedStream() + .hasNoEvents() + .hasErrorWithStatusCode(401) + .hasErrorContaining("You didn't provide an API key..."); + } + // ---------------------------------------------------------------- private AzureAiStudioService createService() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 17a00a9eb829c..0d1c6befa6219 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -38,6 +38,8 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests; import org.hamcrest.CoreMatchers; @@ -1422,6 +1424,85 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti } } + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":null\ + }\ + ]\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var result = streamChatCompletion(); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello, world"}]}"""); + } + + private InferenceServiceResults streamChatCompletion() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + var model = AzureOpenAiCompletionModelTests.createCompletionModel( + "resource", + "deployment", + "apiversion", + "user", + "apikey", + null, + "id" + ); + model.setUri(new URI(getUrl(webServer))); + var listener = new PlainActionFuture(); + service.infer( + model, + null, + List.of("abc"), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return listener.actionGet(TIMEOUT); + } + } + + public void testInfer_StreamRequest_ErrorResponse() throws Exception { + String responseJson = """ + { + "error": { + "message": "You didn't provide an API key...", + "type": "invalid_request_error", + "param": null, + "code": null + } + }"""; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var result = streamChatCompletion(); + + InferenceEventsAssertion.assertThat(result) + .hasFinishedStream() + .hasNoEvents() + .hasErrorWithStatusCode(401) + .hasErrorContaining("You didn't provide an API key..."); + } + private AzureOpenAiService createAzureOpenAiService() { return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } From 695fce42e3ef889056d1a598fc187fb8c715619f Mon Sep 17 00:00:00 2001 From: Keith Massey Date: Thu, 10 Oct 2024 17:31:07 -0500 Subject: [PATCH 11/25] Mute org.elasticsearch.search.retriever.RankDocsRetrieverBuilderTests testRewrite #114467 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 19b684ffbff20..e63499aef6c18 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -298,6 +298,9 @@ tests: - class: org.elasticsearch.xpack.inference.InferenceRestIT method: test {p0=inference/40_semantic_text_query/Query a field that uses the default ELSER 2 endpoint} issue: https://github.com/elastic/elasticsearch/issues/114376 +- class: org.elasticsearch.search.retriever.RankDocsRetrieverBuilderTests + method: testRewrite + issue: https://github.com/elastic/elasticsearch/issues/114467 # Examples: # From 3adce49fdb8c513facbacd30dfb98e9d886bcfdd Mon Sep 17 00:00:00 2001 From: Keith Massey Date: Thu, 10 Oct 2024 18:41:24 -0500 Subject: [PATCH 12/25] Adding support for registered country fields for maxmind geoip databases (#114521) (#114543) Co-authored-by: Joe Gallo --- .../elasticsearch/ingest/geoip/Database.java | 24 +++++++-- .../ingest/geoip/MaxmindIpDataLookups.java | 51 +++++++++++++++++++ .../geoip/GeoIpProcessorFactoryTests.java | 6 ++- .../ingest/geoip/GeoIpProcessorTests.java | 20 ++++++-- .../geoip/IpinfoIpDataLookupsTests.java | 6 ++- .../ingest/geoip/MaxMindSupportTests.java | 23 +++++---- 6 files changed, 108 insertions(+), 22 deletions(-) diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java index 61ec1e74b40a4..fd88e2e71f0c9 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java @@ -43,7 +43,10 @@ enum Database { Property.TIMEZONE, Property.LOCATION, Property.POSTAL_CODE, - Property.ACCURACY_RADIUS + Property.ACCURACY_RADIUS, + Property.REGISTERED_COUNTRY_IN_EUROPEAN_UNION, + Property.REGISTERED_COUNTRY_ISO_CODE, + Property.REGISTERED_COUNTRY_NAME ), Set.of( Property.COUNTRY_ISO_CODE, @@ -62,7 +65,10 @@ enum Database { Property.CONTINENT_NAME, Property.COUNTRY_NAME, Property.COUNTRY_IN_EUROPEAN_UNION, - Property.COUNTRY_ISO_CODE + Property.COUNTRY_ISO_CODE, + Property.REGISTERED_COUNTRY_IN_EUROPEAN_UNION, + Property.REGISTERED_COUNTRY_ISO_CODE, + Property.REGISTERED_COUNTRY_NAME ), Set.of(Property.CONTINENT_NAME, Property.COUNTRY_NAME, Property.COUNTRY_ISO_CODE) ), @@ -124,7 +130,10 @@ enum Database { Property.CONNECTION_TYPE, Property.POSTAL_CODE, Property.POSTAL_CONFIDENCE, - Property.ACCURACY_RADIUS + Property.ACCURACY_RADIUS, + Property.REGISTERED_COUNTRY_IN_EUROPEAN_UNION, + Property.REGISTERED_COUNTRY_ISO_CODE, + Property.REGISTERED_COUNTRY_NAME ), Set.of( Property.COUNTRY_ISO_CODE, @@ -182,6 +191,10 @@ enum Database { ), Set.of(Property.COUNTRY_ISO_CODE, Property.REGION_NAME, Property.CITY_NAME, Property.LOCATION) ), + CountryV2( + Set.of(Property.IP, Property.CONTINENT_CODE, Property.CONTINENT_NAME, Property.COUNTRY_NAME, Property.COUNTRY_ISO_CODE), + Set.of(Property.CONTINENT_NAME, Property.COUNTRY_NAME, Property.COUNTRY_ISO_CODE) + ), PrivacyDetection( Set.of(Property.IP, Property.HOSTING, Property.PROXY, Property.RELAY, Property.TOR, Property.VPN, Property.SERVICE), Set.of(Property.HOSTING, Property.PROXY, Property.RELAY, Property.TOR, Property.VPN, Property.SERVICE) @@ -272,7 +285,10 @@ enum Property { PROXY, RELAY, VPN, - SERVICE; + SERVICE, + REGISTERED_COUNTRY_IN_EUROPEAN_UNION, + REGISTERED_COUNTRY_ISO_CODE, + REGISTERED_COUNTRY_NAME; /** * Parses a string representation of a property into an actual Property instance. Not all properties that exist are diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java index 5fe2e980d2ab0..4297413073e52 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java @@ -142,6 +142,7 @@ static class City extends AbstractBase { @Override protected Map transform(final CityResponse response) { com.maxmind.geoip2.record.Country country = response.getCountry(); + com.maxmind.geoip2.record.Country registeredCountry = response.getRegisteredCountry(); com.maxmind.geoip2.record.City city = response.getCity(); Location location = response.getLocation(); Continent continent = response.getContinent(); @@ -231,6 +232,22 @@ protected Map transform(final CityResponse response) { data.put("postal_code", postal.getCode()); } } + case REGISTERED_COUNTRY_IN_EUROPEAN_UNION -> { + if (registeredCountry.getIsoCode() != null) { + // isInEuropeanUnion is a boolean so it can't be null. But it really only makes sense if we have a country + data.put("registered_country_in_european_union", registeredCountry.isInEuropeanUnion()); + } + } + case REGISTERED_COUNTRY_ISO_CODE -> { + if (registeredCountry.getIsoCode() != null) { + data.put("registered_country_iso_code", registeredCountry.getIsoCode()); + } + } + case REGISTERED_COUNTRY_NAME -> { + if (registeredCountry.getName() != null) { + data.put("registered_country_name", registeredCountry.getName()); + } + } } } return data; @@ -273,6 +290,7 @@ static class Country extends AbstractBase { @Override protected Map transform(final CountryResponse response) { com.maxmind.geoip2.record.Country country = response.getCountry(); + com.maxmind.geoip2.record.Country registeredCountry = response.getRegisteredCountry(); Continent continent = response.getContinent(); Map data = new HashMap<>(); @@ -309,6 +327,22 @@ protected Map transform(final CountryResponse response) { data.put("continent_name", continentName); } } + case REGISTERED_COUNTRY_IN_EUROPEAN_UNION -> { + if (registeredCountry.getIsoCode() != null) { + // isInEuropeanUnion is a boolean so it can't be null. But it really only makes sense if we have a country + data.put("registered_country_in_european_union", registeredCountry.isInEuropeanUnion()); + } + } + case REGISTERED_COUNTRY_ISO_CODE -> { + if (registeredCountry.getIsoCode() != null) { + data.put("registered_country_iso_code", registeredCountry.getIsoCode()); + } + } + case REGISTERED_COUNTRY_NAME -> { + if (registeredCountry.getName() != null) { + data.put("registered_country_name", registeredCountry.getName()); + } + } } } return data; @@ -351,6 +385,7 @@ static class Enterprise extends AbstractBase { @Override protected Map transform(final EnterpriseResponse response) { com.maxmind.geoip2.record.Country country = response.getCountry(); + com.maxmind.geoip2.record.Country registeredCountry = response.getRegisteredCountry(); com.maxmind.geoip2.record.City city = response.getCity(); Location location = response.getLocation(); Continent continent = response.getContinent(); @@ -548,6 +583,22 @@ protected Map transform(final EnterpriseResponse response) { data.put("connection_type", connectionType.toString()); } } + case REGISTERED_COUNTRY_IN_EUROPEAN_UNION -> { + if (registeredCountry.getIsoCode() != null) { + // isInEuropeanUnion is a boolean so it can't be null. But it really only makes sense if we have a country + data.put("registered_country_in_european_union", registeredCountry.isInEuropeanUnion()); + } + } + case REGISTERED_COUNTRY_ISO_CODE -> { + if (registeredCountry.getIsoCode() != null) { + data.put("registered_country_iso_code", registeredCountry.getIsoCode()); + } + } + case REGISTERED_COUNTRY_NAME -> { + if (registeredCountry.getName() != null) { + data.put("registered_country_name", registeredCountry.getName()); + } + } } } return data; diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java index bf268e17edccb..5ac0c76054d33 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java @@ -196,7 +196,8 @@ public void testBuildWithCountryDbAndAsnFields() { equalTo( "[properties] illegal property value [" + asnProperty - + "]. valid values are [IP, COUNTRY_IN_EUROPEAN_UNION, COUNTRY_ISO_CODE, COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME]" + + "]. valid values are [IP, COUNTRY_IN_EUROPEAN_UNION, COUNTRY_ISO_CODE, COUNTRY_NAME, CONTINENT_CODE, " + + "CONTINENT_NAME, REGISTERED_COUNTRY_IN_EUROPEAN_UNION, REGISTERED_COUNTRY_ISO_CODE, REGISTERED_COUNTRY_NAME]" ) ); } @@ -276,7 +277,8 @@ public void testBuildIllegalFieldOption() { equalTo( "[properties] illegal property value [invalid]. valid values are [IP, COUNTRY_IN_EUROPEAN_UNION, COUNTRY_ISO_CODE, " + "COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME, REGION_ISO_CODE, REGION_NAME, CITY_NAME, TIMEZONE, " - + "LOCATION, POSTAL_CODE, ACCURACY_RADIUS]" + + "LOCATION, POSTAL_CODE, ACCURACY_RADIUS, REGISTERED_COUNTRY_IN_EUROPEAN_UNION, REGISTERED_COUNTRY_ISO_CODE, " + + "REGISTERED_COUNTRY_NAME]" ) ); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java index fbceac3b9cce6..50b59c26749fc 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java @@ -108,7 +108,7 @@ public void testCity() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(9)); + assertThat(geoData.size(), equalTo(12)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("country_in_european_union"), equalTo(false)); assertThat(geoData.get("country_iso_code"), equalTo("US")); @@ -117,6 +117,9 @@ public void testCity() throws Exception { assertThat(geoData.get("continent_name"), equalTo("North America")); assertThat(geoData.get("timezone"), equalTo("America/Chicago")); assertThat(geoData.get("location"), equalTo(Map.of("lat", 37.751d, "lon", -97.822d))); + assertThat(geoData.get("registered_country_in_european_union"), equalTo(false)); + assertThat(geoData.get("registered_country_iso_code"), equalTo("US")); + assertThat(geoData.get("registered_country_name"), equalTo("United States")); } public void testNullValueWithIgnoreMissing() throws Exception { @@ -230,7 +233,7 @@ public void testCity_withIpV6() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(13)); + assertThat(geoData.size(), equalTo(16)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("country_in_european_union"), equalTo(false)); assertThat(geoData.get("country_iso_code"), equalTo("US")); @@ -244,6 +247,9 @@ public void testCity_withIpV6() throws Exception { assertThat(geoData.get("location"), equalTo(Map.of("lat", 25.4573d, "lon", -80.4572d))); assertThat(geoData.get("accuracy_radius"), equalTo(50)); assertThat(geoData.get("postal_code"), equalTo("33035")); + assertThat(geoData.get("registered_country_in_european_union"), equalTo(false)); + assertThat(geoData.get("registered_country_iso_code"), equalTo("US")); + assertThat(geoData.get("registered_country_name"), equalTo("United States")); } public void testCityWithMissingLocation() throws Exception { @@ -300,13 +306,16 @@ public void testCountry() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(6)); + assertThat(geoData.size(), equalTo(9)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("country_in_european_union"), equalTo(true)); assertThat(geoData.get("country_iso_code"), equalTo("NL")); assertThat(geoData.get("country_name"), equalTo("Netherlands")); assertThat(geoData.get("continent_code"), equalTo("EU")); assertThat(geoData.get("continent_name"), equalTo("Europe")); + assertThat(geoData.get("registered_country_in_european_union"), equalTo(true)); + assertThat(geoData.get("registered_country_iso_code"), equalTo("NL")); + assertThat(geoData.get("registered_country_name"), equalTo("Netherlands")); } public void testCountryWithMissingLocation() throws Exception { @@ -490,7 +499,7 @@ public void testEnterprise() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(30)); + assertThat(geoData.size(), equalTo(33)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("country_confidence"), equalTo(99)); assertThat(geoData.get("country_in_european_union"), equalTo(false)); @@ -521,6 +530,9 @@ public void testEnterprise() throws Exception { assertThat(geoData.get("isp_organization_name"), equalTo("Fairpoint Communications")); assertThat(geoData.get("user_type"), equalTo("residential")); assertThat(geoData.get("connection_type"), equalTo("Cable/DSL")); + assertThat(geoData.get("registered_country_in_european_union"), equalTo(false)); + assertThat(geoData.get("registered_country_iso_code"), equalTo("US")); + assertThat(geoData.get("registered_country_name"), equalTo("United States")); } public void testIsp() throws Exception { diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookupsTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookupsTests.java index 039c826337caa..4ecf3056db738 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookupsTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookupsTests.java @@ -79,6 +79,10 @@ public void testDatabasePropertyInvariants() { // the second City variant database is like a version of the ordinary City database but lacking many fields assertThat(Sets.difference(Database.CityV2.properties(), Database.City.properties()), is(empty())); assertThat(Sets.difference(Database.CityV2.defaultProperties(), Database.City.defaultProperties()), is(empty())); + + // the second Country variant database is like a version of the ordinary Country database but lacking come fields + assertThat(Sets.difference(Database.CountryV2.properties(), Database.CountryV2.properties()), is(empty())); + assertThat(Database.CountryV2.defaultProperties(), equalTo(Database.Country.defaultProperties())); } public void testParseAsn() { @@ -219,7 +223,7 @@ public void testCountry() throws IOException { // this is the 'free' Country database (sample) try (DatabaseReaderLazyLoader loader = configDatabases.getDatabase("ip_country_sample.mmdb")) { - IpDataLookup lookup = new IpinfoIpDataLookups.Country(Database.Country.properties()); + IpDataLookup lookup = new IpinfoIpDataLookups.Country(Database.CountryV2.properties()); Map data = lookup.getData(loader, "4.221.143.168"); assertThat( data, diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java index 068867deeea3c..79a4190af284a 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java @@ -87,7 +87,10 @@ public class MaxMindSupportTests extends ESTestCase { "location.timeZone", "mostSpecificSubdivision.isoCode", "mostSpecificSubdivision.name", - "postal.code" + "postal.code", + "registeredCountry.inEuropeanUnion", + "registeredCountry.isoCode", + "registeredCountry.name" ); private static final Set CITY_UNSUPPORTED_FIELDS = Set.of( "city.confidence", @@ -113,9 +116,6 @@ public class MaxMindSupportTests extends ESTestCase { "postal.confidence", "registeredCountry.confidence", "registeredCountry.geoNameId", - "registeredCountry.inEuropeanUnion", - "registeredCountry.isoCode", - "registeredCountry.name", "registeredCountry.names", "representedCountry.confidence", "representedCountry.geoNameId", @@ -162,7 +162,10 @@ public class MaxMindSupportTests extends ESTestCase { "country.inEuropeanUnion", "country.isoCode", "continent.code", - "country.name" + "country.name", + "registeredCountry.inEuropeanUnion", + "registeredCountry.isoCode", + "registeredCountry.name" ); private static final Set COUNTRY_UNSUPPORTED_FIELDS = Set.of( "continent.geoNameId", @@ -173,9 +176,6 @@ public class MaxMindSupportTests extends ESTestCase { "maxMind", "registeredCountry.confidence", "registeredCountry.geoNameId", - "registeredCountry.inEuropeanUnion", - "registeredCountry.isoCode", - "registeredCountry.name", "registeredCountry.names", "representedCountry.confidence", "representedCountry.geoNameId", @@ -229,6 +229,9 @@ public class MaxMindSupportTests extends ESTestCase { "mostSpecificSubdivision.name", "postal.code", "postal.confidence", + "registeredCountry.inEuropeanUnion", + "registeredCountry.isoCode", + "registeredCountry.name", "traits.anonymous", "traits.anonymousVpn", "traits.autonomousSystemNumber", @@ -267,9 +270,6 @@ public class MaxMindSupportTests extends ESTestCase { "mostSpecificSubdivision.names", "registeredCountry.confidence", "registeredCountry.geoNameId", - "registeredCountry.inEuropeanUnion", - "registeredCountry.isoCode", - "registeredCountry.name", "registeredCountry.names", "representedCountry.confidence", "representedCountry.geoNameId", @@ -364,6 +364,7 @@ public class MaxMindSupportTests extends ESTestCase { private static final Set KNOWN_UNSUPPORTED_DATABASE_VARIANTS = Set.of( Database.AsnV2, Database.CityV2, + Database.CountryV2, Database.PrivacyDetection ); From a575b1ee440f65bccfbaee6b47868e8034dba33d Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Thu, 10 Oct 2024 21:19:30 -0400 Subject: [PATCH 13/25] [ML] Mute tests using mock web server for streaming (#114542) (#114545) Relates #114385 --- .../inference/services/anthropic/AnthropicServiceTests.java | 2 ++ .../services/azureaistudio/AzureAiStudioServiceTests.java | 2 ++ .../inference/services/azureopenai/AzureOpenAiServiceTests.java | 2 ++ .../xpack/inference/services/cohere/CohereServiceTests.java | 2 ++ .../xpack/inference/services/openai/OpenAiServiceTests.java | 2 ++ 5 files changed, 10 insertions(+) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 48277112d9306..8adf75b4c0a81 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -532,6 +532,7 @@ public void testInfer_SendsCompletionRequest() throws IOException { } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114385") public void testInfer_StreamRequest() throws Exception { String responseJson = """ data: {"type": "message_start", "message": {"model": "claude, probably"}} @@ -577,6 +578,7 @@ private InferenceServiceResults streamChatCompletion() throws IOException { } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114385") public void testInfer_StreamRequest_ErrorResponse() throws Exception { String responseJson = """ data: {"type": "error", "error": {"type": "request_too_large", "message": "blah"}} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 37a8c25461045..192f97aa01e98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -1307,6 +1307,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114385") public void testInfer_StreamRequest() throws Exception { String responseJson = """ data: {\ @@ -1362,6 +1363,7 @@ private InferenceServiceResults streamChatCompletion() throws IOException, URISy } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114385") public void testInfer_StreamRequest_ErrorResponse() throws Exception { String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 0d1c6befa6219..d0c7841125187 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -1424,6 +1424,7 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114385") public void testInfer_StreamRequest() throws Exception { String responseJson = """ data: {\ @@ -1482,6 +1483,7 @@ private InferenceServiceResults streamChatCompletion() throws IOException, URISy } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114385") public void testInfer_StreamRequest_ErrorResponse() throws Exception { String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 8ca49a47b943a..9df9dd841bdb9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -1633,6 +1633,7 @@ public void testDefaultSimilarity() { assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity()); } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114385") public void testInfer_StreamRequest() throws Exception { String responseJson = """ {"event_type":"text-generation", "text":"hello"} @@ -1666,6 +1667,7 @@ private InferenceServiceResults streamChatCompletion() throws IOException { } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114385") public void testInfer_StreamRequest_ErrorResponse() throws Exception { String responseJson = """ { "event_type":"stream-end", "finish_reason":"ERROR", "response":{ "text": "how dare you" } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index cd6846747135a..0a56dd87ec02d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -1007,6 +1007,7 @@ public void testInfer_SendsRequest() throws IOException { } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114385") public void testInfer_StreamRequest() throws Exception { String responseJson = """ data: {\ @@ -1056,6 +1057,7 @@ private InferenceServiceResults streamChatCompletion() throws IOException { } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114385") public void testInfer_StreamRequest_ErrorResponse() throws Exception { String responseJson = """ { From 43716eb962cda0e132ef7f928a3e60c2d53ca0ce Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Thu, 10 Oct 2024 23:50:28 -0400 Subject: [PATCH 14/25] [ML] Upgrade to AWS SDK v2 (#114309) (#114550) - Replaced AWS 1.12.740 with 2.28.13 - Removed `aws-java-sdk*` and its transitive dependencies. - Added `awssdk:bedrockruntime` as an `implementations`, all transitive dependencies are added as `api` matching their marked `Compile` in maven. - Added `awssdk:netty-nio-client` as our client implementation, since our v1 integration is using the respective Async client. - Added netty packages as `runtimeOnly` since they are only used during runtime. - Replaced AWS's use of SLF4J-1.7 with our declaration of SLF4J-2.x, since SLF4J includes backwards-compatible bindings. - Migrated all references from the v1 package (`com.amazonaws`) to the v2 package (`software.amazon.awssdk`). Notable changes in the SDK: - *Result objects are renamed to *Response objects. - Objects are now immutable and require Builders to set fields. - Getters no longer have the `get*` prefix, e.g. `getModelId()` is now `modelId()`. - `Future` has been replaced with `CompletableFuture`. - There is no longer a need to invoke the `IdleConnectionReaper`, this is now done when the client is closed. - Builders have a consumer mutation pattern for modifying many fields at once. Security changes: - The underlying Builder objects always check to see if the `.aws/credentials` and `.aws/config` files exist, even if they are not used, so our `plugin-security.policy` now allows reading these files. - The Builder always checks for the `http.proxyHost` property before defaulting to the hardcoded Bedrock URL. Resolve #110590 --- docs/changelog/114309.yaml | 6 + gradle/verification-metadata.xml | 140 +++++++++ x-pack/plugin/inference/build.gradle | 215 ++++++++++++-- .../licenses/aws-java-sdk-LICENSE.txt | 63 ---- .../licenses/aws-java-sdk-NOTICE.txt | 15 - .../inference/licenses/aws-sdk-2-LICENSE.txt | 206 +++++++++++++ .../inference/licenses/aws-sdk-2-NOTICE.txt | 26 ++ ...me-LICENSE.txt => eventstream-LICENSE.txt} | 0 .../inference/licenses/eventstream-NOTICE.txt | 2 + .../inference/licenses/jaxb-LICENSE.txt | 274 ------------------ .../plugin/inference/licenses/jaxb-NOTICE.txt | 1 - .../inference/licenses/joda-time-NOTICE.txt | 5 - .../inference/licenses/netty-LICENSE.txt | 202 +++++++++++++ .../inference/licenses/netty-NOTICE.txt | 116 ++++++++ .../licenses/reactive-streams-LICENSE.txt | 7 + .../licenses/reactive-streams-NOTICE.txt | 0 .../inference/licenses/slf4j-LICENSE.txt | 23 ++ .../inference/licenses/slf4j-NOTICE.txt | 0 .../inference/src/main/java/module-info.java | 13 +- .../amazonbedrock/AmazonBedrockClient.java | 12 +- .../AmazonBedrockInferenceClient.java | 122 ++++---- .../AmazonBedrockInferenceClientCache.java | 6 - .../AmazonBedrockJsonWriter.java | 20 -- ...edrockAI21LabsCompletionRequestEntity.java | 35 +-- ...drockAnthropicCompletionRequestEntity.java | 37 ++- .../AmazonBedrockChatCompletionRequest.java | 6 +- ...nBedrockCohereCompletionRequestEntity.java | 37 ++- .../AmazonBedrockConverseRequestEntity.java | 8 +- .../AmazonBedrockConverseUtils.java | 18 +- ...zonBedrockMetaCompletionRequestEntity.java | 35 +-- ...BedrockMistralCompletionRequestEntity.java | 37 ++- ...onBedrockTitanCompletionRequestEntity.java | 35 +-- .../AmazonBedrockEmbeddingsRequest.java | 15 +- .../AmazonBedrockChatCompletionResponse.java | 25 +- ...nBedrockChatCompletionResponseHandler.java | 6 +- ...BedrockChatCompletionResponseListener.java | 6 +- .../AmazonBedrockEmbeddingsResponse.java | 10 +- ...mazonBedrockEmbeddingsResponseHandler.java | 6 +- ...azonBedrockEmbeddingsResponseListener.java | 8 +- .../plugin-metadata/plugin-security.policy | 5 + .../AmazonBedrockExecutorTests.java | 35 +-- .../AmazonBedrockMockClientCache.java | 30 +- ...AmazonBedrockMockExecuteRequestSender.java | 12 +- .../AmazonBedrockMockInferenceClient.java | 112 ++----- ...kAI21LabsCompletionRequestEntityTests.java | 10 +- ...AnthropicCompletionRequestEntityTests.java | 12 +- ...ockCohereCompletionRequestEntityTests.java | 12 +- .../AmazonBedrockConverseRequestUtils.java | 56 ++-- ...drockMetaCompletionRequestEntityTests.java | 10 +- ...ckMistralCompletionRequestEntityTests.java | 12 +- ...rockTitanCompletionRequestEntityTests.java | 10 +- 51 files changed, 1284 insertions(+), 830 deletions(-) create mode 100644 docs/changelog/114309.yaml delete mode 100644 x-pack/plugin/inference/licenses/aws-java-sdk-LICENSE.txt delete mode 100644 x-pack/plugin/inference/licenses/aws-java-sdk-NOTICE.txt create mode 100644 x-pack/plugin/inference/licenses/aws-sdk-2-LICENSE.txt create mode 100644 x-pack/plugin/inference/licenses/aws-sdk-2-NOTICE.txt rename x-pack/plugin/inference/licenses/{joda-time-LICENSE.txt => eventstream-LICENSE.txt} (100%) create mode 100644 x-pack/plugin/inference/licenses/eventstream-NOTICE.txt delete mode 100644 x-pack/plugin/inference/licenses/jaxb-LICENSE.txt delete mode 100644 x-pack/plugin/inference/licenses/jaxb-NOTICE.txt delete mode 100644 x-pack/plugin/inference/licenses/joda-time-NOTICE.txt create mode 100644 x-pack/plugin/inference/licenses/netty-LICENSE.txt create mode 100644 x-pack/plugin/inference/licenses/netty-NOTICE.txt create mode 100644 x-pack/plugin/inference/licenses/reactive-streams-LICENSE.txt create mode 100644 x-pack/plugin/inference/licenses/reactive-streams-NOTICE.txt create mode 100644 x-pack/plugin/inference/licenses/slf4j-LICENSE.txt create mode 100644 x-pack/plugin/inference/licenses/slf4j-NOTICE.txt delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonWriter.java diff --git a/docs/changelog/114309.yaml b/docs/changelog/114309.yaml new file mode 100644 index 0000000000000..bcd1262062943 --- /dev/null +++ b/docs/changelog/114309.yaml @@ -0,0 +1,6 @@ +pr: 114309 +summary: Upgrade to AWS SDK v2 +area: Machine Learning +type: enhancement +issues: + - 110590 diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 53a65e217ed18..443417e6a5b92 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -1456,6 +1456,11 @@ + + + + + @@ -4222,6 +4227,11 @@ + + + + + @@ -4435,6 +4445,136 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 211b99343340d..28e1405cf7b97 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -28,7 +28,7 @@ base { } versions << [ - 'awsbedrockruntime': '1.12.740' + 'aws2': '2.28.13' ] dependencies { @@ -57,19 +57,58 @@ dependencies { implementation 'com.google.http-client:google-http-client-appengine:1.42.3' implementation 'com.google.http-client:google-http-client-jackson2:1.42.3' implementation "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" - implementation "com.fasterxml.jackson.core:jackson-databind:${versions.jackson}" - implementation "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" - implementation "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${versions.jackson}" - implementation "com.fasterxml.jackson:jackson-bom:${versions.jackson}" implementation 'com.google.api:gax-httpjson:0.105.1' implementation 'io.grpc:grpc-context:1.49.2' implementation 'io.opencensus:opencensus-api:0.31.1' implementation 'io.opencensus:opencensus-contrib-http-util:0.31.1' - implementation "com.amazonaws:aws-java-sdk-bedrockruntime:${versions.awsbedrockruntime}" - implementation "com.amazonaws:aws-java-sdk-core:${versions.aws}" - implementation "com.amazonaws:jmespath-java:${versions.aws}" - implementation "joda-time:joda-time:2.10.10" - implementation 'javax.xml.bind:jaxb-api:2.2.2' + + /* AWS SDK v2 */ + implementation ("software.amazon.awssdk:bedrockruntime:${versions.aws2}") + api "software.amazon.awssdk:protocol-core:${versions.aws2}" + api "software.amazon.awssdk:aws-json-protocol:${versions.aws2}" + api "software.amazon.awssdk:third-party-jackson-core:${versions.aws2}" + api "software.amazon.awssdk:http-auth-aws:${versions.aws2}" + api "software.amazon.awssdk:checksums-spi:${versions.aws2}" + api "software.amazon.awssdk:checksums:${versions.aws2}" + api "software.amazon.awssdk:sdk-core:${versions.aws2}" + api "org.reactivestreams:reactive-streams:1.0.4" + api "org.reactivestreams:reactive-streams-tck:1.0.4" + api "software.amazon.awssdk:profiles:${versions.aws2}" + api "software.amazon.awssdk:retries:${versions.aws2}" + api "software.amazon.awssdk:auth:${versions.aws2}" + api "software.amazon.awssdk:http-auth-aws-eventstream:${versions.aws2}" + api "software.amazon.eventstream:eventstream:1.0.1" + api "software.amazon.awssdk:http-auth-spi:${versions.aws2}" + api "software.amazon.awssdk:http-auth:${versions.aws2}" + api "software.amazon.awssdk:identity-spi:${versions.aws2}" + api "software.amazon.awssdk:http-client-spi:${versions.aws2}" + api "software.amazon.awssdk:regions:${versions.aws2}" + api "software.amazon.awssdk:annotations:${versions.aws2}" + api "software.amazon.awssdk:utils:${versions.aws2}" + api "software.amazon.awssdk:aws-core:${versions.aws2}" + api "software.amazon.awssdk:metrics-spi:${versions.aws2}" + api "software.amazon.awssdk:json-utils:${versions.aws2}" + api "software.amazon.awssdk:endpoints-spi:${versions.aws2}" + api "software.amazon.awssdk:retries-spi:${versions.aws2}" + + /* Netty (via AWS SDKv2) */ + implementation "software.amazon.awssdk:netty-nio-client:${versions.aws2}" + runtimeOnly "io.netty:netty-buffer:${versions.netty}" + runtimeOnly "io.netty:netty-codec-dns:${versions.netty}" + runtimeOnly "io.netty:netty-codec-http2:${versions.netty}" + runtimeOnly "io.netty:netty-codec-http:${versions.netty}" + runtimeOnly "io.netty:netty-codec:${versions.netty}" + runtimeOnly "io.netty:netty-common:${versions.netty}" + runtimeOnly "io.netty:netty-handler:${versions.netty}" + runtimeOnly "io.netty:netty-resolver-dns:${versions.netty}" + runtimeOnly "io.netty:netty-resolver:${versions.netty}" + runtimeOnly "io.netty:netty-transport-classes-epoll:${versions.netty}" + runtimeOnly "io.netty:netty-transport-native-unix-common:${versions.netty}" + runtimeOnly "io.netty:netty-transport:${versions.netty}" + + /* SLF4J (via AWS SDKv2) */ + api "org.slf4j:slf4j-api:${versions.slf4j}" + runtimeOnly "org.slf4j:slf4j-nop:${versions.slf4j}" } tasks.named("dependencyLicenses").configure { @@ -79,9 +118,46 @@ tasks.named("dependencyLicenses").configure { mapping from: /protobuf.*/, to: 'protobuf' mapping from: /proto-google.*/, to: 'proto-google' mapping from: /jackson.*/, to: 'jackson' - mapping from: /aws-java-sdk-.*/, to: 'aws-java-sdk' - mapping from: /jmespath-java.*/, to: 'aws-java-sdk' - mapping from: /jaxb-.*/, to: 'jaxb' + mapping from: /reactive-streams.*/, to: 'reactive-streams' + mapping from: /eventstream.*/, to: 'eventstream' + mapping from: /slf4j.*/, to: 'slf4j' + mapping from: /protocol-core.*/, to: 'aws-sdk-2' + mapping from: /aws-json-protocol.*/, to: 'aws-sdk-2' + mapping from: /third-party-jackson-core.*/, to: 'aws-sdk-2' + mapping from: /checksums-spi.*/, to: 'aws-sdk-2' + mapping from: /checksums.*/, to: 'aws-sdk-2' + mapping from: /sdk-core.*/, to: 'aws-sdk-2' + mapping from: /profiles.*/, to: 'aws-sdk-2' + mapping from: /retries.*/, to: 'aws-sdk-2' + mapping from: /auth.*/, to: 'aws-sdk-2' + mapping from: /http-auth-aws-eventstream.*/, to: 'aws-sdk-2' + mapping from: /http-auth-spi.*/, to: 'aws-sdk-2' + mapping from: /http-auth.*/, to: 'aws-sdk-2' + mapping from: /http-auth-aws.*/, to: 'aws-sdk-2' + mapping from: /identity-spi.*/, to: 'aws-sdk-2' + mapping from: /http-client-spi.*/, to: 'aws-sdk-2' + mapping from: /regions.*/, to: 'aws-sdk-2' + mapping from: /annotations.*/, to: 'aws-sdk-2' + mapping from: /utils.*/, to: 'aws-sdk-2' + mapping from: /aws-core.*/, to: 'aws-sdk-2' + mapping from: /metrics-spi.*/, to: 'aws-sdk-2' + mapping from: /json-utils.*/, to: 'aws-sdk-2' + mapping from: /endpoints-spi.*/, to: 'aws-sdk-2' + mapping from: /bedrockruntime.*/, to: 'aws-sdk-2' + mapping from: /netty-nio-client/, to: 'aws-sdk-2' + /* Cannot use REGEX to match netty-* because netty-nio-client is an AWS package */ + mapping from: /netty-buffer/, to: 'netty' + mapping from: /netty-codec-dns/, to: 'netty' + mapping from: /netty-codec-http2/, to: 'netty' + mapping from: /netty-codec-http/, to: 'netty' + mapping from: /netty-codec/, to: 'netty' + mapping from: /netty-common/, to: 'netty' + mapping from: /netty-handler/, to: 'netty' + mapping from: /netty-resolver-dns/, to: 'netty' + mapping from: /netty-resolver/, to: 'netty' + mapping from: /netty-transport-classes-epoll/, to: 'netty' + mapping from: /netty-transport-native-unix-common/, to: 'netty' + mapping from: /netty-transport/, to: 'netty' } tasks.named("thirdPartyAudit").configure { @@ -108,6 +184,29 @@ tasks.named("thirdPartyAudit").configure { 'com.google.common.hash.LittleEndianByteArray$UnsafeByteArray', 'com.google.common.primitives.UnsignedBytes$LexicographicalComparatorHolder$UnsafeComparator', 'com.google.common.primitives.UnsignedBytes$LexicographicalComparatorHolder$UnsafeComparator$1', + 'io.netty.handler.ssl.util.OpenJdkSelfSignedCertGenerator', + 'io.netty.handler.ssl.util.OpenJdkSelfSignedCertGenerator$1', + 'io.netty.handler.ssl.util.OpenJdkSelfSignedCertGenerator$2', + 'io.netty.handler.ssl.util.OpenJdkSelfSignedCertGenerator$3', + 'io.netty.handler.ssl.util.OpenJdkSelfSignedCertGenerator$4', + 'io.netty.handler.ssl.util.OpenJdkSelfSignedCertGenerator$5', + 'io.netty.util.internal.PlatformDependent0', + 'io.netty.util.internal.PlatformDependent0$1', + 'io.netty.util.internal.PlatformDependent0$2', + 'io.netty.util.internal.PlatformDependent0$3', + 'io.netty.util.internal.PlatformDependent0$4', + 'io.netty.util.internal.PlatformDependent0$6', + 'io.netty.util.internal.shaded.org.jctools.queues.BaseLinkedQueueConsumerNodeRef', + 'io.netty.util.internal.shaded.org.jctools.queues.BaseLinkedQueueProducerNodeRef', + 'io.netty.util.internal.shaded.org.jctools.queues.BaseMpscLinkedArrayQueueColdProducerFields', + 'io.netty.util.internal.shaded.org.jctools.queues.BaseMpscLinkedArrayQueueConsumerFields', + 'io.netty.util.internal.shaded.org.jctools.queues.BaseMpscLinkedArrayQueueProducerFields', + 'io.netty.util.internal.shaded.org.jctools.queues.LinkedQueueNode', + 'io.netty.util.internal.shaded.org.jctools.queues.MpscArrayQueueConsumerIndexField', + 'io.netty.util.internal.shaded.org.jctools.queues.MpscArrayQueueProducerIndexField', + 'io.netty.util.internal.shaded.org.jctools.queues.MpscArrayQueueProducerLimitField', + 'io.netty.util.internal.shaded.org.jctools.util.UnsafeAccess', + 'io.netty.util.internal.shaded.org.jctools.util.UnsafeRefArrayAccess', ) ignoreMissingClasses( @@ -215,17 +314,87 @@ tasks.named("thirdPartyAudit").configure { 'com.google.appengine.api.urlfetch.HTTPRequest', 'com.google.appengine.api.urlfetch.HTTPResponse', 'com.google.appengine.api.urlfetch.URLFetchService', + 'com.aayushatharva.brotli4j.Brotli4jLoader', + 'com.aayushatharva.brotli4j.decoder.DecoderJNI$Status', + 'com.aayushatharva.brotli4j.decoder.DecoderJNI$Wrapper', + 'com.aayushatharva.brotli4j.encoder.BrotliEncoderChannel', + 'com.aayushatharva.brotli4j.encoder.Encoder$Mode', + 'com.aayushatharva.brotli4j.encoder.Encoder$Parameters', + 'com.github.luben.zstd.BaseZstdBufferDecompressingStreamNoFinalizer', + 'com.github.luben.zstd.Zstd', + 'com.github.luben.zstd.ZstdBufferDecompressingStreamNoFinalizer', + 'com.github.luben.zstd.ZstdDirectBufferDecompressingStreamNoFinalizer', 'com.google.appengine.api.urlfetch.URLFetchServiceFactory', - 'software.amazon.ion.IonReader', - 'software.amazon.ion.IonSystem', - 'software.amazon.ion.IonType', - 'software.amazon.ion.IonWriter', - 'software.amazon.ion.Timestamp', - 'software.amazon.ion.system.IonBinaryWriterBuilder', - 'software.amazon.ion.system.IonSystemBuilder', - 'software.amazon.ion.system.IonTextWriterBuilder', - 'software.amazon.ion.system.IonWriterBuilder', - 'javax.activation.DataHandler' + 'com.google.protobuf.nano.CodedOutputByteBufferNano', + 'com.google.protobuf.nano.MessageNano', + 'com.jcraft.jzlib.Deflater', + 'com.jcraft.jzlib.Inflater', + 'com.jcraft.jzlib.JZlib', + 'com.jcraft.jzlib.JZlib$WrapperType', + 'com.ning.compress.BufferRecycler', + 'com.ning.compress.lzf.ChunkDecoder', + 'com.ning.compress.lzf.ChunkEncoder', + 'com.ning.compress.lzf.LZFChunk', + 'com.ning.compress.lzf.LZFEncoder', + 'com.ning.compress.lzf.util.ChunkDecoderFactory', + 'com.ning.compress.lzf.util.ChunkEncoderFactory', + 'io.netty.internal.tcnative.AsyncSSLPrivateKeyMethod', + 'io.netty.internal.tcnative.AsyncTask', + 'io.netty.internal.tcnative.Buffer', + 'io.netty.internal.tcnative.CertificateCallback', + 'io.netty.internal.tcnative.CertificateCompressionAlgo', + 'io.netty.internal.tcnative.CertificateVerifier', + 'io.netty.internal.tcnative.Library', + 'io.netty.internal.tcnative.ResultCallback', + 'io.netty.internal.tcnative.SSL', + 'io.netty.internal.tcnative.SSLContext', + 'io.netty.internal.tcnative.SSLPrivateKeyMethod', + 'io.netty.internal.tcnative.SSLSession', + 'io.netty.internal.tcnative.SSLSessionCache', + 'io.netty.internal.tcnative.SessionTicketKey', + 'io.netty.internal.tcnative.SniHostNameMatcher', + 'lzma.sdk.lzma.Encoder', + 'org.bouncycastle.cert.X509v3CertificateBuilder', + 'org.bouncycastle.cert.jcajce.JcaX509CertificateConverter', + 'org.bouncycastle.openssl.PEMEncryptedKeyPair', + 'org.bouncycastle.openssl.PEMParser', + 'org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter', + 'org.bouncycastle.openssl.jcajce.JceOpenSSLPKCS8DecryptorProviderBuilder', + 'org.bouncycastle.openssl.jcajce.JcePEMDecryptorProviderBuilder', + 'org.bouncycastle.operator.jcajce.JcaContentSignerBuilder', + 'org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo', + 'org.conscrypt.AllocatedBuffer', + 'org.conscrypt.BufferAllocator', + 'org.conscrypt.Conscrypt', + 'org.conscrypt.HandshakeListener', + 'org.eclipse.jetty.alpn.ALPN', + 'org.eclipse.jetty.alpn.ALPN$ClientProvider', + 'org.eclipse.jetty.alpn.ALPN$ServerProvider', + 'org.eclipse.jetty.npn.NextProtoNego', + 'org.eclipse.jetty.npn.NextProtoNego$ClientProvider', + 'org.eclipse.jetty.npn.NextProtoNego$ServerProvider', + 'org.jboss.marshalling.ByteInput', + 'org.jboss.marshalling.ByteOutput', + 'org.jboss.marshalling.Marshaller', + 'org.jboss.marshalling.MarshallerFactory', + 'org.jboss.marshalling.MarshallingConfiguration', + 'org.jboss.marshalling.Unmarshaller', + 'org.reactivestreams.example.unicast.AsyncIterablePublisher', + 'org.testng.Assert', + 'reactor.blockhound.BlockHound$Builder', + 'reactor.blockhound.integration.BlockHoundIntegration', + 'software.amazon.awssdk.crt.auth.credentials.Credentials', + 'software.amazon.awssdk.crt.auth.signing.AwsSigner', + 'software.amazon.awssdk.crt.auth.signing.AwsSigningConfig', + 'software.amazon.awssdk.crt.auth.signing.AwsSigningConfig$AwsSignatureType', + 'software.amazon.awssdk.crt.auth.signing.AwsSigningConfig$AwsSignedBodyHeaderType', + 'software.amazon.awssdk.crt.auth.signing.AwsSigningConfig$AwsSigningAlgorithm', + 'software.amazon.awssdk.crt.auth.signing.AwsSigningResult', + 'software.amazon.awssdk.crt.checksums.CRC32', + 'software.amazon.awssdk.crt.checksums.CRC32C', + 'software.amazon.awssdk.crt.http.HttpHeader', + 'software.amazon.awssdk.crt.http.HttpRequest', + 'software.amazon.awssdk.crt.http.HttpRequestBodyStream', ) } diff --git a/x-pack/plugin/inference/licenses/aws-java-sdk-LICENSE.txt b/x-pack/plugin/inference/licenses/aws-java-sdk-LICENSE.txt deleted file mode 100644 index 98d1f9319f374..0000000000000 --- a/x-pack/plugin/inference/licenses/aws-java-sdk-LICENSE.txt +++ /dev/null @@ -1,63 +0,0 @@ -Apache License -Version 2.0, January 2004 - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - -"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. - -"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. - -"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. - -"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. - -"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. - -"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. - -"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). - -"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. - -"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." - -"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: - - 1. You must give any other recipients of the Work or Derivative Works a copy of this License; and - 2. You must cause any modified files to carry prominent notices stating that You changed the files; and - 3. You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and - 4. If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. - -You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -Note: Other license terms may apply to certain, identified software files contained within or distributed with the accompanying software if such terms are included in the directory containing the accompanying software. Such other license terms will then apply in lieu of the terms of the software license above. - -JSON processing code subject to the JSON License from JSON.org: - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -The Software shall be used for Good, not Evil. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/x-pack/plugin/inference/licenses/aws-java-sdk-NOTICE.txt b/x-pack/plugin/inference/licenses/aws-java-sdk-NOTICE.txt deleted file mode 100644 index 565bd6085c71a..0000000000000 --- a/x-pack/plugin/inference/licenses/aws-java-sdk-NOTICE.txt +++ /dev/null @@ -1,15 +0,0 @@ -AWS SDK for Java -Copyright 2010-2014 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -This product includes software developed by -Amazon Technologies, Inc (http://www.amazon.com/). - -********************** -THIRD PARTY COMPONENTS -********************** -This software includes third party software subject to the following copyrights: -- XML parsing and utility functions from JetS3t - Copyright 2006-2009 James Murty. -- JSON parsing and utility functions from JSON.org - Copyright 2002 JSON.org. -- PKCS#1 PEM encoded private key parsing and utility functions from oauth.googlecode.com - Copyright 1998-2010 AOL Inc. - -The licenses for these third party components are included in LICENSE.txt diff --git a/x-pack/plugin/inference/licenses/aws-sdk-2-LICENSE.txt b/x-pack/plugin/inference/licenses/aws-sdk-2-LICENSE.txt new file mode 100644 index 0000000000000..1eef70a9b9f42 --- /dev/null +++ b/x-pack/plugin/inference/licenses/aws-sdk-2-LICENSE.txt @@ -0,0 +1,206 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Note: Other license terms may apply to certain, identified software files contained within or distributed + with the accompanying software if such terms are included in the directory containing the accompanying software. + Such other license terms will then apply in lieu of the terms of the software license above. diff --git a/x-pack/plugin/inference/licenses/aws-sdk-2-NOTICE.txt b/x-pack/plugin/inference/licenses/aws-sdk-2-NOTICE.txt new file mode 100644 index 0000000000000..f3c4db7d1724e --- /dev/null +++ b/x-pack/plugin/inference/licenses/aws-sdk-2-NOTICE.txt @@ -0,0 +1,26 @@ +AWS SDK for Java 2.0 +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +This product includes software developed by +Amazon Technologies, Inc (http://www.amazon.com/). + +********************** +THIRD PARTY COMPONENTS +********************** +This software includes third party software subject to the following copyrights: +- XML parsing and utility functions from JetS3t - Copyright 2006-2009 James Murty. +- PKCS#1 PEM encoded private key parsing and utility functions from oauth.googlecode.com - Copyright 1998-2010 AOL Inc. +- Apache Commons Lang - https://github.com/apache/commons-lang +- Netty Reactive Streams - https://github.com/playframework/netty-reactive-streams +- Jackson-core - https://github.com/FasterXML/jackson-core +- Jackson-dataformat-cbor - https://github.com/FasterXML/jackson-dataformats-binary + +The licenses for these third party components are included in LICENSE.txt + +- For Apache Commons Lang see also this required NOTICE: + Apache Commons Lang + Copyright 2001-2020 The Apache Software Foundation + + This product includes software developed at + The Apache Software Foundation (https://www.apache.org/). + diff --git a/x-pack/plugin/inference/licenses/joda-time-LICENSE.txt b/x-pack/plugin/inference/licenses/eventstream-LICENSE.txt similarity index 100% rename from x-pack/plugin/inference/licenses/joda-time-LICENSE.txt rename to x-pack/plugin/inference/licenses/eventstream-LICENSE.txt diff --git a/x-pack/plugin/inference/licenses/eventstream-NOTICE.txt b/x-pack/plugin/inference/licenses/eventstream-NOTICE.txt new file mode 100644 index 0000000000000..1a066ac2925f7 --- /dev/null +++ b/x-pack/plugin/inference/licenses/eventstream-NOTICE.txt @@ -0,0 +1,2 @@ +AWS EventStream for Java +Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/x-pack/plugin/inference/licenses/jaxb-LICENSE.txt b/x-pack/plugin/inference/licenses/jaxb-LICENSE.txt deleted file mode 100644 index 833a843cfeee1..0000000000000 --- a/x-pack/plugin/inference/licenses/jaxb-LICENSE.txt +++ /dev/null @@ -1,274 +0,0 @@ -COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL)Version 1.1 - -1. Definitions. - - 1.1. "Contributor" means each individual or entity that creates or contributes to the creation of Modifications. - - 1.2. "Contributor Version" means the combination of the Original Software, prior Modifications used by a Contributor (if any), and the Modifications made by that particular Contributor. - - 1.3. "Covered Software" means (a) the Original Software, or (b) Modifications, or (c) the combination of files containing Original Software with files containing Modifications, in each case including portions thereof. - - 1.4. "Executable" means the Covered Software in any form other than Source Code. - - 1.5. "Initial Developer" means the individual or entity that first makes Original Software available under this License. - - 1.6. "Larger Work" means a work which combines Covered Software or portions thereof with code not governed by the terms of this License. - - 1.7. "License" means this document. - - 1.8. "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently acquired, any and all of the rights conveyed herein. - - 1.9. "Modifications" means the Source Code and Executable form of any of the following: - - A. Any file that results from an addition to, deletion from or modification of the contents of a file containing Original Software or previous Modifications; - - B. Any new file that contains any part of the Original Software or previous Modification; or - - C. Any new file that is contributed or otherwise made available under the terms of this License. - - 1.10. "Original Software" means the Source Code and Executable form of computer software code that is originally released under this License. - - 1.11. "Patent Claims" means any patent claim(s), now owned or hereafter acquired, including without limitation, method, process, and apparatus claims, in any patent Licensable by grantor. - - 1.12. "Source Code" means (a) the common form of computer software code in which modifications are made and (b) associated documentation included in or with such code. - - 1.13. "You" (or "Your") means an individual or a legal entity exercising rights under, and complying with all of the terms of, this License. For legal entities, "You" includes any entity which controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. - -2. License Grants. - - 2.1. The Initial Developer Grant. - - Conditioned upon Your compliance with Section 3.1 below and subject to third party intellectual property claims, the Initial Developer hereby grants You a world-wide, royalty-free, non-exclusive license: - - (a) under intellectual property rights (other than patent or trademark) Licensable by Initial Developer, to use, reproduce, modify, display, perform, sublicense and distribute the Original Software (or portions thereof), with or without Modifications, and/or as part of a Larger Work; and - - (b) under Patent Claims infringed by the making, using or selling of Original Software, to make, have made, use, practice, sell, and offer for sale, and/or otherwise dispose of the Original Software (or portions thereof). - - (c) The licenses granted in Sections 2.1(a) and (b) are effective on the date Initial Developer first distributes or otherwise makes the Original Software available to a third party under the terms of this License. - - (d) Notwithstanding Section 2.1(b) above, no patent license is granted: (1) for code that You delete from the Original Software, or (2) for infringements caused by: (i) the modification of the Original Software, or (ii) the combination of the Original Software with other software or devices. - - 2.2. Contributor Grant. - - Conditioned upon Your compliance with Section 3.1 below and subject to third party intellectual property claims, each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: - - (a) under intellectual property rights (other than patent or trademark) Licensable by Contributor to use, reproduce, modify, display, perform, sublicense and distribute the Modifications created by such Contributor (or portions thereof), either on an unmodified basis, with other Modifications, as Covered Software and/or as part of a Larger Work; and - - (b) under Patent Claims infringed by the making, using, or selling of Modifications made by that Contributor either alone and/or in combination with its Contributor Version (or portions of such combination), to make, use, sell, offer for sale, have made, and/or otherwise dispose of: (1) Modifications made by that Contributor (or portions thereof); and (2) the combination of Modifications made by that Contributor with its Contributor Version (or portions of such combination). - - (c) The licenses granted in Sections 2.2(a) and 2.2(b) are effective on the date Contributor first distributes or otherwise makes the Modifications available to a third party. - - (d) Notwithstanding Section 2.2(b) above, no patent license is granted: (1) for any code that Contributor has deleted from the Contributor Version; (2) for infringements caused by: (i) third party modifications of Contributor Version, or (ii) the combination of Modifications made by that Contributor with other software (except as part of the Contributor Version) or other devices; or (3) under Patent Claims infringed by Covered Software in the absence of Modifications made by that Contributor. - -3. Distribution Obligations. - - 3.1. Availability of Source Code. - - Any Covered Software that You distribute or otherwise make available in Executable form must also be made available in Source Code form and that Source Code form must be distributed only under the terms of this License. You must include a copy of this License with every copy of the Source Code form of the Covered Software You distribute or otherwise make available. You must inform recipients of any such Covered Software in Executable form as to how they can obtain such Covered Software in Source Code form in a reasonable manner on or through a medium customarily used for software exchange. - - 3.2. Modifications. - - The Modifications that You create or to which You contribute are governed by the terms of this License. You represent that You believe Your Modifications are Your original creation(s) and/or You have sufficient rights to grant the rights conveyed by this License. - - 3.3. Required Notices. - - You must include a notice in each of Your Modifications that identifies You as the Contributor of the Modification. You may not remove or alter any copyright, patent or trademark notices contained within the Covered Software, or any notices of licensing or any descriptive text giving attribution to any Contributor or the Initial Developer. - - 3.4. Application of Additional Terms. - - You may not offer or impose any terms on any Covered Software in Source Code form that alters or restricts the applicable version of this License or the recipients' rights hereunder. You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, you may do so only on Your own behalf, and not on behalf of the Initial Developer or any Contributor. You must make it absolutely clear that any such warranty, support, indemnity or liability obligation is offered by You alone, and You hereby agree to indemnify the Initial Developer and every Contributor for any liability incurred by the Initial Developer or such Contributor as a result of warranty, support, indemnity or liability terms You offer. - - 3.5. Distribution of Executable Versions. - - You may distribute the Executable form of the Covered Software under the terms of this License or under the terms of a license of Your choice, which may contain terms different from this License, provided that You are in compliance with the terms of this License and that the license for the Executable form does not attempt to limit or alter the recipient's rights in the Source Code form from the rights set forth in this License. If You distribute the Covered Software in Executable form under a different license, You must make it absolutely clear that any terms which differ from this License are offered by You alone, not by the Initial Developer or Contributor. You hereby agree to indemnify the Initial Developer and every Contributor for any liability incurred by the Initial Developer or such Contributor as a result of any such terms You offer. - - 3.6. Larger Works. - - You may create a Larger Work by combining Covered Software with other code not governed by the terms of this License and distribute the Larger Work as a single product. In such a case, You must make sure the requirements of this License are fulfilled for the Covered Software. - -4. Versions of the License. - - 4.1. New Versions. - - Oracle is the initial license steward and may publish revised and/or new versions of this License from time to time. Each version will be given a distinguishing version number. Except as provided in Section 4.3, no one other than the license steward has the right to modify this License. - - 4.2. Effect of New Versions. - - You may always continue to use, distribute or otherwise make the Covered Software available under the terms of the version of the License under which You originally received the Covered Software. If the Initial Developer includes a notice in the Original Software prohibiting it from being distributed or otherwise made available under any subsequent version of the License, You must distribute and make the Covered Software available under the terms of the version of the License under which You originally received the Covered Software. Otherwise, You may also choose to use, distribute or otherwise make the Covered Software available under the terms of any subsequent version of the License published by the license steward. - - 4.3. Modified Versions. - - When You are an Initial Developer and You want to create a new license for Your Original Software, You may create and use a modified version of this License if You: (a) rename the license and remove any references to the name of the license steward (except to note that the license differs from this License); and (b) otherwise make it clear that the license contains terms which differ from this License. - -5. DISCLAIMER OF WARRANTY. - - COVERED SOFTWARE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS" BASIS, WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, WITHOUT LIMITATION, WARRANTIES THAT THE COVERED SOFTWARE IS FREE OF DEFECTS, MERCHANTABLE, FIT FOR A PARTICULAR PURPOSE OR NON-INFRINGING. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE COVERED SOFTWARE IS WITH YOU. SHOULD ANY COVERED SOFTWARE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL DEVELOPER OR ANY OTHER CONTRIBUTOR) ASSUME THE COST OF ANY NECESSARY SERVICING, REPAIR OR CORRECTION. THIS DISCLAIMER OF WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS LICENSE. NO USE OF ANY COVERED SOFTWARE IS AUTHORIZED HEREUNDER EXCEPT UNDER THIS DISCLAIMER. - -6. TERMINATION. - - 6.1. This License and the rights granted hereunder will terminate automatically if You fail to comply with terms herein and fail to cure such breach within 30 days of becoming aware of the breach. Provisions which, by their nature, must remain in effect beyond the termination of this License shall survive. - - 6.2. If You assert a patent infringement claim (excluding declaratory judgment actions) against Initial Developer or a Contributor (the Initial Developer or Contributor against whom You assert such claim is referred to as "Participant") alleging that the Participant Software (meaning the Contributor Version where the Participant is a Contributor or the Original Software where the Participant is the Initial Developer) directly or indirectly infringes any patent, then any and all rights granted directly or indirectly to You by such Participant, the Initial Developer (if the Initial Developer is not the Participant) and all Contributors under Sections 2.1 and/or 2.2 of this License shall, upon 60 days notice from Participant terminate prospectively and automatically at the expiration of such 60 day notice period, unless if within such 60 day period You withdraw Your claim with respect to the Participant Software against such Participant either unilaterally or pursuant to a written agreement with Participant. - - 6.3. If You assert a patent infringement claim against Participant alleging that the Participant Software directly or indirectly infringes any patent where such claim is resolved (such as by license or settlement) prior to the initiation of patent infringement litigation, then the reasonable value of the licenses granted by such Participant under Sections 2.1 or 2.2 shall be taken into account in determining the amount or value of any payment or license. - - 6.4. In the event of termination under Sections 6.1 or 6.2 above, all end user licenses that have been validly granted by You or any distributor hereunder prior to termination (excluding licenses granted to You by any distributor) shall survive termination. - -7. LIMITATION OF LIABILITY. - - UNDER NO CIRCUMSTANCES AND UNDER NO LEGAL THEORY, WHETHER TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE, SHALL YOU, THE INITIAL DEVELOPER, ANY OTHER CONTRIBUTOR, OR ANY DISTRIBUTOR OF COVERED SOFTWARE, OR ANY SUPPLIER OF ANY OF SUCH PARTIES, BE LIABLE TO ANY PERSON FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER INCLUDING, WITHOUT LIMITATION, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER FAILURE OR MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES, EVEN IF SUCH PARTY SHALL HAVE BEEN INFORMED OF THE POSSIBILITY OF SUCH DAMAGES. THIS LIMITATION OF LIABILITY SHALL NOT APPLY TO LIABILITY FOR DEATH OR PERSONAL INJURY RESULTING FROM SUCH PARTY'S NEGLIGENCE TO THE EXTENT APPLICABLE LAW PROHIBITS SUCH LIMITATION. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OR LIMITATION OF INCIDENTAL OR CONSEQUENTIAL DAMAGES, SO THIS EXCLUSION AND LIMITATION MAY NOT APPLY TO YOU. - -8. U.S. GOVERNMENT END USERS. - - The Covered Software is a "commercial item," as that term is defined in 48 C.F.R. 2.101 (Oct. 1995), consisting of "commercial computer software" (as that term is defined at 48 C.F.R. ? 252.227-7014(a)(1)) and "commercial computer software documentation" as such terms are used in 48 C.F.R. 12.212 (Sept. 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users acquire Covered Software with only those rights set forth herein. This U.S. Government Rights clause is in lieu of, and supersedes, any other FAR, DFAR, or other clause or provision that addresses Government rights in computer software under this License. - -9. MISCELLANEOUS. - - This License represents the complete agreement concerning subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. This License shall be governed by the law of the jurisdiction specified in a notice contained within the Original Software (except to the extent applicable law, if any, provides otherwise), excluding such jurisdiction's conflict-of-law provisions. Any litigation relating to this License shall be subject to the jurisdiction of the courts located in the jurisdiction and venue specified in a notice contained within the Original Software, with the losing party responsible for costs, including, without limitation, court costs and reasonable attorneys' fees and expenses. The application of the United Nations Convention on Contracts for the International Sale of Goods is expressly excluded. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not apply to this License. You agree that You alone are responsible for compliance with the United States export administration regulations (and the export control laws and regulation of any other countries) when You use, distribute or otherwise make available any Covered Software. - -10. RESPONSIBILITY FOR CLAIMS. - - As between Initial Developer and the Contributors, each party is responsible for claims and damages arising, directly or indirectly, out of its utilization of rights under this License and You agree to work with Initial Developer and Contributors to distribute such responsibility on an equitable basis. Nothing herein is intended or shall be deemed to constitute any admission of liability. - ----------- -NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) -The code released under the CDDL shall be governed by the laws of the State of California (excluding conflict-of-law provisions). Any litigation relating to this License shall be subject to the jurisdiction of the Federal Courts of the Northern District of California and the state courts of the State of California, with venue lying in Santa Clara County, California. - - - - -The GNU General Public License (GPL) Version 2, June 1991 - - -Copyright (C) 1989, 1991 Free Software Foundation, Inc. 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - -Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. - -Preamble - -The licenses for most software are designed to take away your freedom to share and change it. By contrast, the GNU General Public License is intended to guarantee your freedom to share and change free software--to make sure the software is free for all its users. This General Public License applies to most of the Free Software Foundation's software and to any other program whose authors commit to using it. (Some other Free Software Foundation software is covered by the GNU Library General Public License instead.) You can apply it to your programs, too. - -When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for this service if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs; and that you know you can do these things. - -To protect your rights, we need to make restrictions that forbid anyone to deny you these rights or to ask you to surrender the rights. These restrictions translate to certain responsibilities for you if you distribute copies of the software, or if you modify it. - -For example, if you distribute copies of such a program, whether gratis or for a fee, you must give the recipients all the rights that you have. You must make sure that they, too, receive or can get the source code. And you must show them these terms so they know their rights. - -We protect your rights with two steps: (1) copyright the software, and (2) offer you this license which gives you legal permission to copy, distribute and/or modify the software. - -Also, for each author's protection and ours, we want to make certain that everyone understands that there is no warranty for this free software. If the software is modified by someone else and passed on, we want its recipients to know that what they have is not the original, so that any problems introduced by others will not reflect on the original authors' reputations. - -Finally, any free program is threatened constantly by software patents. We wish to avoid the danger that redistributors of a free program will individually obtain patent licenses, in effect making the program proprietary. To prevent this, we have made it clear that any patent must be licensed for everyone's free use or not licensed at all. - -The precise terms and conditions for copying, distribution and modification follow. - - -TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION - -0. This License applies to any program or other work which contains a notice placed by the copyright holder saying it may be distributed under the terms of this General Public License. The "Program", below, refers to any such program or work, and a "work based on the Program" means either the Program or any derivative work under copyright law: that is to say, a work containing the Program or a portion of it, either verbatim or with modifications and/or translated into another language. (Hereinafter, translation is included without limitation in the term "modification".) Each licensee is addressed as "you". - -Activities other than copying, distribution and modification are not covered by this License; they are outside its scope. The act of running the Program is not restricted, and the output from the Program is covered only if its contents constitute a work based on the Program (independent of having been made by running the Program). Whether that is true depends on what the Program does. - -1. You may copy and distribute verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice and disclaimer of warranty; keep intact all the notices that refer to this License and to the absence of any warranty; and give any other recipients of the Program a copy of this License along with the Program. - -You may charge a fee for the physical act of transferring a copy, and you may at your option offer warranty protection in exchange for a fee. - -2. You may modify your copy or copies of the Program or any portion of it, thus forming a work based on the Program, and copy and distribute such modifications or work under the terms of Section 1 above, provided that you also meet all of these conditions: - - a) You must cause the modified files to carry prominent notices stating that you changed the files and the date of any change. - - b) You must cause any work that you distribute or publish, that in whole or in part contains or is derived from the Program or any part thereof, to be licensed as a whole at no charge to all third parties under the terms of this License. - - c) If the modified program normally reads commands interactively when run, you must cause it, when started running for such interactive use in the most ordinary way, to print or display an announcement including an appropriate copyright notice and a notice that there is no warranty (or else, saying that you provide a warranty) and that users may redistribute the program under these conditions, and telling the user how to view a copy of this License. (Exception: if the Program itself is interactive but does not normally print such an announcement, your work based on the Program is not required to print an announcement.) - -These requirements apply to the modified work as a whole. If identifiable sections of that work are not derived from the Program, and can be reasonably considered independent and separate works in themselves, then this License, and its terms, do not apply to those sections when you distribute them as separate works. But when you distribute the same sections as part of a whole which is a work based on the Program, the distribution of the whole must be on the terms of this License, whose permissions for other licensees extend to the entire whole, and thus to each and every part regardless of who wrote it. - -Thus, it is not the intent of this section to claim rights or contest your rights to work written entirely by you; rather, the intent is to exercise the right to control the distribution of derivative or collective works based on the Program. - -In addition, mere aggregation of another work not based on the Program with the Program (or with a work based on the Program) on a volume of a storage or distribution medium does not bring the other work under the scope of this License. - -3. You may copy and distribute the Program (or a work based on it, under Section 2) in object code or executable form under the terms of Sections 1 and 2 above provided that you also do one of the following: - - a) Accompany it with the complete corresponding machine-readable source code, which must be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange; or, - - b) Accompany it with a written offer, valid for at least three years, to give any third party, for a charge no more than your cost of physically performing source distribution, a complete machine-readable copy of the corresponding source code, to be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange; or, - - c) Accompany it with the information you received as to the offer to distribute corresponding source code. (This alternative is allowed only for noncommercial distribution and only if you received the program in object code or executable form with such an offer, in accord with Subsection b above.) - -The source code for a work means the preferred form of the work for making modifications to it. For an executable work, complete source code means all the source code for all modules it contains, plus any associated interface definition files, plus the scripts used to control compilation and installation of the executable. However, as a special exception, the source code distributed need not include anything that is normally distributed (in either source or binary form) with the major components (compiler, kernel, and so on) of the operating system on which the executable runs, unless that component itself accompanies the executable. - -If distribution of executable or object code is made by offering access to copy from a designated place, then offering equivalent access to copy the source code from the same place counts as distribution of the source code, even though third parties are not compelled to copy the source along with the object code. - -4. You may not copy, modify, sublicense, or distribute the Program except as expressly provided under this License. Any attempt otherwise to copy, modify, sublicense or distribute the Program is void, and will automatically terminate your rights under this License. However, parties who have received copies, or rights, from you under this License will not have their licenses terminated so long as such parties remain in full compliance. - -5. You are not required to accept this License, since you have not signed it. However, nothing else grants you permission to modify or distribute the Program or its derivative works. These actions are prohibited by law if you do not accept this License. Therefore, by modifying or distributing the Program (or any work based on the Program), you indicate your acceptance of this License to do so, and all its terms and conditions for copying, distributing or modifying the Program or works based on it. - -6. Each time you redistribute the Program (or any work based on the Program), the recipient automatically receives a license from the original licensor to copy, distribute or modify the Program subject to these terms and conditions. You may not impose any further restrictions on the recipients' exercise of the rights granted herein. You are not responsible for enforcing compliance by third parties to this License. - -7. If, as a consequence of a court judgment or allegation of patent infringement or for any other reason (not limited to patent issues), conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot distribute so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not distribute the Program at all. For example, if a patent license would not permit royalty-free redistribution of the Program by all those who receive copies directly or indirectly through you, then the only way you could satisfy both it and this License would be to refrain entirely from distribution of the Program. - -If any portion of this section is held invalid or unenforceable under any particular circumstance, the balance of the section is intended to apply and the section as a whole is intended to apply in other circumstances. - -It is not the purpose of this section to induce you to infringe any patents or other property right claims or to contest validity of any such claims; this section has the sole purpose of protecting the integrity of the free software distribution system, which is implemented by public license practices. Many people have made generous contributions to the wide range of software distributed through that system in reliance on consistent application of that system; it is up to the author/donor to decide if he or she is willing to distribute software through any other system and a licensee cannot impose that choice. - -This section is intended to make thoroughly clear what is believed to be a consequence of the rest of this License. - -8. If the distribution and/or use of the Program is restricted in certain countries either by patents or by copyrighted interfaces, the original copyright holder who places the Program under this License may add an explicit geographical distribution limitation excluding those countries, so that distribution is permitted only in or among countries not thus excluded. In such case, this License incorporates the limitation as if written in the body of this License. - -9. The Free Software Foundation may publish revised and/or new versions of the General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. - -Each version is given a distinguishing version number. If the Program specifies a version number of this License which applies to it and "any later version", you have the option of following the terms and conditions either of that version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of this License, you may choose any version ever published by the Free Software Foundation. - -10. If you wish to incorporate parts of the Program into other free programs whose distribution conditions are different, write to the author to ask for permission. For software which is copyrighted by the Free Software Foundation, write to the Free Software Foundation; we sometimes make exceptions for this. Our decision will be guided by the two goals of preserving the free status of all derivatives of our free software and of promoting the sharing and reuse of software generally. - -NO WARRANTY - -11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. - -12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. - -END OF TERMS AND CONDITIONS - - -How to Apply These Terms to Your New Programs - -If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. - -To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively convey the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. - - One line to give the program's name and a brief idea of what it does. - - Copyright (C) - - This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. - - This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. - - You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - -Also add information on how to contact you by electronic and paper mail. - -If the program is interactive, make it output a short notice like this when it starts in an interactive mode: - - Gnomovision version 69, Copyright (C) year name of author - Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. This is free software, and you are welcome to redistribute it under certain conditions; type `show c' for details. - -The hypothetical commands `show w' and `show c' should show the appropriate parts of the General Public License. Of course, the commands you use may be called something other than `show w' and `show c'; they could even be mouse-clicks or menu items--whatever suits your program. - -You should also get your employer (if you work as a programmer) or your school, if any, to sign a "copyright disclaimer" for the program, if necessary. Here is a sample; alter the names: - - Yoyodyne, Inc., hereby disclaims all copyright interest in the program `Gnomovision' (which makes passes at compilers) written by James Hacker. - - signature of Ty Coon, 1 April 1989 - Ty Coon, President of Vice - -This General Public License does not permit incorporating your program into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Library General Public License instead of this License. - - -"CLASSPATH" EXCEPTION TO THE GPL VERSION 2 - -Certain source files distributed by Oracle are subject to the following clarification and special exception to the GPL Version 2, but only where Oracle has expressly included in the particular source file's header the words "Oracle designates this particular file as subject to the "Classpath" exception as provided by Oracle in the License file that accompanied this code." - -Linking this library statically or dynamically with other modules is making a combined work based on this library. Thus, the terms and conditions of the GNU General Public License Version 2 cover the whole combination. - -As a special exception, the copyright holders of this library give you permission to link this library with independent modules to produce an executable, regardless of the license terms of these independent modules, and to copy and distribute the resulting executable under terms of your choice, provided that you also meet, for each linked independent module, the terms and conditions of the license of that module. An independent module is a module which is not derived from or based on this library. If you modify this library, you may extend this exception to your version of the library, but you are not obligated to do so. If you do not wish to do so, delete this exception statement from your version. diff --git a/x-pack/plugin/inference/licenses/jaxb-NOTICE.txt b/x-pack/plugin/inference/licenses/jaxb-NOTICE.txt deleted file mode 100644 index 8d1c8b69c3fce..0000000000000 --- a/x-pack/plugin/inference/licenses/jaxb-NOTICE.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/x-pack/plugin/inference/licenses/joda-time-NOTICE.txt b/x-pack/plugin/inference/licenses/joda-time-NOTICE.txt deleted file mode 100644 index dffbcf31cacf6..0000000000000 --- a/x-pack/plugin/inference/licenses/joda-time-NOTICE.txt +++ /dev/null @@ -1,5 +0,0 @@ -============================================================================= -= NOTICE file corresponding to section 4d of the Apache License Version 2.0 = -============================================================================= -This product includes software developed by -Joda.org (http://www.joda.org/). diff --git a/x-pack/plugin/inference/licenses/netty-LICENSE.txt b/x-pack/plugin/inference/licenses/netty-LICENSE.txt new file mode 100644 index 0000000000000..d645695673349 --- /dev/null +++ b/x-pack/plugin/inference/licenses/netty-LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-pack/plugin/inference/licenses/netty-NOTICE.txt b/x-pack/plugin/inference/licenses/netty-NOTICE.txt new file mode 100644 index 0000000000000..5bbf91a14de23 --- /dev/null +++ b/x-pack/plugin/inference/licenses/netty-NOTICE.txt @@ -0,0 +1,116 @@ + + The Netty Project + ================= + +Please visit the Netty web site for more information: + + * http://netty.io/ + +Copyright 2011 The Netty Project + +The Netty Project licenses this file to you under the Apache License, +version 2.0 (the "License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at: + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +License for the specific language governing permissions and limitations +under the License. + +Also, please refer to each LICENSE..txt file, which is located in +the 'license' directory of the distribution file, for the license terms of the +components that this product depends on. + +------------------------------------------------------------------------------- +This product contains the extensions to Java Collections Framework which has +been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: + + * LICENSE: + * license/LICENSE.jsr166y.txt (Public Domain) + * HOMEPAGE: + * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ + * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ + +This product contains a modified version of Robert Harder's Public Domain +Base64 Encoder and Decoder, which can be obtained at: + + * LICENSE: + * license/LICENSE.base64.txt (Public Domain) + * HOMEPAGE: + * http://iharder.sourceforge.net/current/java/base64/ + +This product contains a modified version of 'JZlib', a re-implementation of +zlib in pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD Style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product contains a modified version of 'Webbit', a Java event based +WebSocket and HTTP server: + + * LICENSE: + * license/LICENSE.webbit.txt (BSD License) + * HOMEPAGE: + * https://github.com/joewalnes/webbit + +This product optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + + * LICENSE: + * license/LICENSE.protobuf.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/protobuf/ + +This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +a temporary self-signed X.509 certificate when the JVM does not provide the +equivalent functionality. It can be obtained at: + + * LICENSE: + * license/LICENSE.bouncycastle.txt (MIT License) + * HOMEPAGE: + * http://www.bouncycastle.org/ + +This product optionally depends on 'SLF4J', a simple logging facade for Java, +which can be obtained at: + + * LICENSE: + * license/LICENSE.slf4j.txt (MIT License) + * HOMEPAGE: + * http://www.slf4j.org/ + +This product optionally depends on 'Apache Commons Logging', a logging +framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-logging.txt (Apache License 2.0) + * HOMEPAGE: + * http://commons.apache.org/logging/ + +This product optionally depends on 'Apache Log4J', a logging framework, +which can be obtained at: + + * LICENSE: + * license/LICENSE.log4j.txt (Apache License 2.0) + * HOMEPAGE: + * http://logging.apache.org/log4j/ + +This product optionally depends on 'JBoss Logging', a logging framework, +which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://anonsvn.jboss.org/repos/common/common-logging-spi/ + +This product optionally depends on 'Apache Felix', an open source OSGi +framework implementation, which can be obtained at: + + * LICENSE: + * license/LICENSE.felix.txt (Apache License 2.0) + * HOMEPAGE: + * http://felix.apache.org/ diff --git a/x-pack/plugin/inference/licenses/reactive-streams-LICENSE.txt b/x-pack/plugin/inference/licenses/reactive-streams-LICENSE.txt new file mode 100644 index 0000000000000..1e141c13ddba2 --- /dev/null +++ b/x-pack/plugin/inference/licenses/reactive-streams-LICENSE.txt @@ -0,0 +1,7 @@ +MIT No Attribution + +Copyright 2014 Reactive Streams + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/x-pack/plugin/inference/licenses/reactive-streams-NOTICE.txt b/x-pack/plugin/inference/licenses/reactive-streams-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/licenses/slf4j-LICENSE.txt b/x-pack/plugin/inference/licenses/slf4j-LICENSE.txt new file mode 100644 index 0000000000000..a51675a21c10f --- /dev/null +++ b/x-pack/plugin/inference/licenses/slf4j-LICENSE.txt @@ -0,0 +1,23 @@ +Copyright (c) 2004-2022 QOS.ch Sarl (Switzerland) +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + diff --git a/x-pack/plugin/inference/licenses/slf4j-NOTICE.txt b/x-pack/plugin/inference/licenses/slf4j-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 49abe14dbf302..53cb6ac154ced 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -23,10 +23,15 @@ requires com.google.auth; requires com.google.api.client; requires com.google.gson; - requires aws.java.sdk.bedrockruntime; - requires aws.java.sdk.core; - requires com.fasterxml.jackson.databind; - requires org.joda.time; + requires software.amazon.awssdk.services.bedrockruntime; + requires software.amazon.awssdk.utils; + requires software.amazon.awssdk.core; + requires software.amazon.awssdk.auth; + requires software.amazon.awssdk.regions; + requires software.amazon.awssdk.http.nio.netty; + requires software.amazon.awssdk.profiles; + requires org.slf4j; + requires software.amazon.awssdk.retries.api; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java index 812e76129c420..23b6884ddc33a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java @@ -7,10 +7,10 @@ package org.elasticsearch.xpack.inference.external.amazonbedrock; -import com.amazonaws.services.bedrockruntime.model.ConverseRequest; -import com.amazonaws.services.bedrockruntime.model.ConverseResult; -import com.amazonaws.services.bedrockruntime.model.InvokeModelRequest; -import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; @@ -18,9 +18,9 @@ import java.time.Instant; public interface AmazonBedrockClient { - void converse(ConverseRequest converseRequest, ActionListener responseListener) throws ElasticsearchException; + void converse(ConverseRequest converseRequest, ActionListener responseListener) throws ElasticsearchException; - void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener) + void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener) throws ElasticsearchException; boolean isExpired(Instant currentTimestampMs); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java index c3d458925268c..b1486f4995b84 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java @@ -7,16 +7,18 @@ package org.elasticsearch.xpack.inference.external.amazonbedrock; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.services.bedrockruntime.AmazonBedrockRuntimeAsync; -import com.amazonaws.services.bedrockruntime.AmazonBedrockRuntimeAsyncClientBuilder; -import com.amazonaws.services.bedrockruntime.model.AmazonBedrockRuntimeException; -import com.amazonaws.services.bedrockruntime.model.ConverseRequest; -import com.amazonaws.services.bedrockruntime.model.ConverseResult; -import com.amazonaws.services.bedrockruntime.model.InvokeModelRequest; -import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.profiles.ProfileFile; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.SpecialPermission; @@ -24,25 +26,33 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.xpack.core.common.socket.SocketAccess; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; +import org.slf4j.LoggerFactory; import java.security.AccessController; import java.security.PrivilegedExceptionAction; import java.time.Duration; import java.time.Instant; import java.util.Objects; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; /** * Not marking this as "final" so we can subclass it for mocking */ public class AmazonBedrockInferenceClient extends AmazonBedrockBaseClient { + static { + // we need to load SLF4j on this classloader to pick up the imported SLF4j-2.x + // otherwise, software.amazon.awssdk:netty-nio-client loads on ExtendedPluginsClassLoader and fails to find the classes + LoggerFactory.getLogger(AmazonBedrockInferenceClient.class); + } + // package-private for testing static final int CLIENT_CACHE_EXPIRY_MINUTES = 5; - private static final int DEFAULT_CLIENT_TIMEOUT_MS = 10000; + private static final Duration DEFAULT_CLIENT_TIMEOUT_MS = Duration.ofMillis(10000); - private final AmazonBedrockRuntimeAsync internalClient; + private final BedrockRuntimeAsyncClient internalClient; private volatile Instant expiryTimestamp; public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout) { @@ -60,68 +70,76 @@ protected AmazonBedrockInferenceClient(AmazonBedrockModel model, @Nullable TimeV } @Override - public void converse(ConverseRequest converseRequest, ActionListener responseListener) throws ElasticsearchException { + public void converse(ConverseRequest converseRequest, ActionListener responseListener) throws ElasticsearchException { try { - var responseFuture = internalClient.converseAsync(converseRequest); + var responseFuture = internalClient.converse(converseRequest); responseListener.onResponse(responseFuture.get()); - } catch (AmazonBedrockRuntimeException amazonBedrockRuntimeException) { - responseListener.onFailure( + } catch (Exception e) { + onFailure(responseListener, e, "converse"); + } + } + + private void onFailure(ActionListener listener, Throwable t, String method) { + var unwrappedException = t; + if (t instanceof CompletionException || t instanceof ExecutionException) { + unwrappedException = t.getCause() != null ? t.getCause() : t; + } + + if (unwrappedException instanceof BedrockRuntimeException amazonBedrockRuntimeException) { + listener.onFailure( new ElasticsearchException( - Strings.format("AmazonBedrock converse failure: [%s]", amazonBedrockRuntimeException.getMessage()), + Strings.format("AmazonBedrock %s failure: [%s]", method, amazonBedrockRuntimeException.getMessage()), amazonBedrockRuntimeException ) ); - } catch (ElasticsearchException elasticsearchException) { - // just throw the exception if we have one - responseListener.onFailure(elasticsearchException); - } catch (Exception e) { - responseListener.onFailure(new ElasticsearchException("Amazon Bedrock client converse call failed", e)); + } else if (unwrappedException instanceof ElasticsearchException elasticsearchException) { + listener.onFailure(elasticsearchException); + } else { + listener.onFailure(new ElasticsearchException(Strings.format("Amazon Bedrock %s call failed", method), unwrappedException)); } } @Override - public void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener) + public void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener) throws ElasticsearchException { try { - var responseFuture = internalClient.invokeModelAsync(invokeModelRequest); + var responseFuture = internalClient.invokeModel(invokeModelRequest); responseListener.onResponse(responseFuture.get()); - } catch (AmazonBedrockRuntimeException amazonBedrockRuntimeException) { - responseListener.onFailure( - new ElasticsearchException( - Strings.format("AmazonBedrock invoke model failure: [%s]", amazonBedrockRuntimeException.getMessage()), - amazonBedrockRuntimeException - ) - ); - } catch (ElasticsearchException elasticsearchException) { - // just throw the exception if we have one - responseListener.onFailure(elasticsearchException); } catch (Exception e) { - responseListener.onFailure(new ElasticsearchException(e)); + onFailure(responseListener, e, "invoke model"); } } // allow this to be overridden for test mocks - protected AmazonBedrockRuntimeAsync createAmazonBedrockClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + protected BedrockRuntimeAsyncClient createAmazonBedrockClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { var secretSettings = model.getSecretSettings(); - var credentials = new BasicAWSCredentials(secretSettings.accessKey.toString(), secretSettings.secretKey.toString()); - var credentialsProvider = new AWSStaticCredentialsProvider(credentials); - var clientConfig = timeout == null - ? new ClientConfiguration().withConnectionTimeout(DEFAULT_CLIENT_TIMEOUT_MS) - : new ClientConfiguration().withConnectionTimeout((int) timeout.millis()); var serviceSettings = model.getServiceSettings(); try { SpecialPermission.check(); - AmazonBedrockRuntimeAsyncClientBuilder builder = AccessController.doPrivileged( - (PrivilegedExceptionAction) () -> AmazonBedrockRuntimeAsyncClientBuilder.standard() - .withCredentials(credentialsProvider) - .withRegion(serviceSettings.region()) - .withClientConfiguration(clientConfig) - ); - - return SocketAccess.doPrivileged(builder::build); - } catch (AmazonBedrockRuntimeException amazonBedrockRuntimeException) { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + var credentials = AwsBasicCredentials.create(secretSettings.accessKey.toString(), secretSettings.secretKey.toString()); + var credentialsProvider = StaticCredentialsProvider.create(credentials); + var clientConfig = timeout == null + ? NettyNioAsyncHttpClient.builder().connectionTimeout(DEFAULT_CLIENT_TIMEOUT_MS) + : NettyNioAsyncHttpClient.builder().connectionTimeout(Duration.ofMillis(timeout.millis())); + var override = ClientOverrideConfiguration.builder() + // disable profileFile, user credentials will always come from the configured Model Secrets + .defaultProfileFileSupplier(ProfileFile.aggregator()::build) + .defaultProfileFile(ProfileFile.aggregator().build()) + // each model request retries at most once, limit the impact a request can have on other request's availability + .retryPolicy(retryPolicy -> retryPolicy.numRetries(1)) + .retryStrategy(retryStrategy -> retryStrategy.maxAttempts(1)) + .build(); + return BedrockRuntimeAsyncClient.builder() + .credentialsProvider(credentialsProvider) + .region(Region.of(serviceSettings.region())) + .httpClientBuilder(clientConfig) + .overrideConfiguration(override) + .build(); + }); + } catch (BedrockRuntimeException amazonBedrockRuntimeException) { throw new ElasticsearchException( Strings.format("failed to create AmazonBedrockRuntime client: [%s]", amazonBedrockRuntimeException.getMessage()), amazonBedrockRuntimeException @@ -161,6 +179,6 @@ public int hashCode() { // make this package-private so only the cache can close it @Override void close() { - internalClient.shutdown(); + internalClient.close(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java index e245365c214af..21e5cfaf211e5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.inference.external.amazonbedrock; -import com.amazonaws.http.IdleConnectionReaper; - import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; @@ -114,10 +112,6 @@ private void releaseCachedClients() { } finally { cacheLock.writeLock().unlock(); } - - // shutdown IdleConnectionReaper background thread - // it will be restarted on new client usage - IdleConnectionReaper.shutdown(); } // used for testing diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonWriter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonWriter.java deleted file mode 100644 index 83ebcb4563a8c..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonWriter.java +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock; - -import com.fasterxml.jackson.core.JsonGenerator; - -import java.io.IOException; - -/** - * This is needed as the input for the Amazon Bedrock SDK does not like - * the formatting of XContent JSON output - */ -public interface AmazonBedrockJsonWriter { - JsonGenerator writeJson(JsonGenerator generator) throws IOException; -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java index 6e2f2f6702005..aff01316838f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java @@ -7,8 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ConverseRequest; -import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import org.elasticsearch.core.Nullable; @@ -29,35 +28,33 @@ public record AmazonBedrockAI21LabsCompletionRequestEntity( } @Override - public ConverseRequest addMessages(ConverseRequest request) { - return request.withMessages(getConverseMessageList(messages)); + public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { + return request.messages(getConverseMessageList(messages)); } @Override - public ConverseRequest addInferenceConfig(ConverseRequest request) { + public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { if (temperature == null && topP == null && maxTokenCount == null) { return request; } - InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + return request.inferenceConfig(config -> { + if (temperature != null) { + config.temperature(temperature.floatValue()); + } - if (temperature != null) { - inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); - } - - if (topP != null) { - inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); - } - - if (maxTokenCount != null) { - inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); - } + if (topP != null) { + config.topP(topP.floatValue()); + } - return request.withInferenceConfig(inferenceConfig); + if (maxTokenCount != null) { + config.maxTokens(maxTokenCount); + } + }); } @Override - public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { return request; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java index a8b0032af09c5..540012c221192 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java @@ -7,8 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ConverseRequest; -import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; @@ -31,40 +30,38 @@ public record AmazonBedrockAnthropicCompletionRequestEntity( } @Override - public ConverseRequest addMessages(ConverseRequest request) { - return request.withMessages(getConverseMessageList(messages)); + public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { + return request.messages(getConverseMessageList(messages)); } @Override - public ConverseRequest addInferenceConfig(ConverseRequest request) { + public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { if (temperature == null && topP == null && maxTokenCount == null) { return request; } - InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + return request.inferenceConfig(config -> { + if (temperature != null) { + config.temperature(temperature.floatValue()); + } - if (temperature != null) { - inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); - } - - if (topP != null) { - inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); - } - - if (maxTokenCount != null) { - inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); - } + if (topP != null) { + config.topP(topP.floatValue()); + } - return request.withInferenceConfig(inferenceConfig); + if (maxTokenCount != null) { + config.maxTokens(maxTokenCount); + } + }); } @Override - public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { if (topK == null) { return request; } String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); - return request.withAdditionalModelResponseFieldPaths(topKField); + return request.additionalModelResponseFieldPaths(topKField); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java index f02f05f2d3b17..61e0504732462 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; @@ -52,11 +52,11 @@ public TaskType taskType() { } private ConverseRequest getConverseRequest() { - var converseRequest = new ConverseRequest().withModelId(amazonBedrockModel.model()); + var converseRequest = ConverseRequest.builder().modelId(amazonBedrockModel.model()); converseRequest = requestEntity.addMessages(converseRequest); converseRequest = requestEntity.addInferenceConfig(converseRequest); converseRequest = requestEntity.addAdditionalModelFields(converseRequest); - return converseRequest; + return converseRequest.build(); } public void executeChatCompletionRequest( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java index 17a264ef820ff..f1ae04ad39516 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java @@ -7,8 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ConverseRequest; -import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; @@ -31,40 +30,38 @@ public record AmazonBedrockCohereCompletionRequestEntity( } @Override - public ConverseRequest addMessages(ConverseRequest request) { - return request.withMessages(getConverseMessageList(messages)); + public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { + return request.messages(getConverseMessageList(messages)); } @Override - public ConverseRequest addInferenceConfig(ConverseRequest request) { + public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { if (temperature == null && topP == null && maxTokenCount == null) { return request; } - InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + return request.inferenceConfig(config -> { + if (temperature != null) { + config.temperature(temperature.floatValue()); + } - if (temperature != null) { - inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); - } - - if (topP != null) { - inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); - } - - if (maxTokenCount != null) { - inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); - } + if (topP != null) { + config.topP(topP.floatValue()); + } - return request.withInferenceConfig(inferenceConfig); + if (maxTokenCount != null) { + config.maxTokens(maxTokenCount); + } + }); } @Override - public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { if (topK == null) { return request; } String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); - return request.withAdditionalModelResponseFieldPaths(topKField); + return request.additionalModelResponseFieldPaths(topKField); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java index fbd55e76e509b..d8e9fa43797cd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java @@ -7,12 +7,12 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; public interface AmazonBedrockConverseRequestEntity { - ConverseRequest addMessages(ConverseRequest request); + ConverseRequest.Builder addMessages(ConverseRequest.Builder request); - ConverseRequest addInferenceConfig(ConverseRequest request); + ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request); - ConverseRequest addAdditionalModelFields(ConverseRequest request); + ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java index 2cfb56a94b319..22e0d26a315a7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java @@ -7,23 +7,19 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ContentBlock; -import com.amazonaws.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.Message; -import java.util.ArrayList; import java.util.List; import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest.USER_ROLE; public final class AmazonBedrockConverseUtils { - public static List getConverseMessageList(List messages) { - List messageList = new ArrayList<>(); - for (String message : messages) { - var messageContent = new ContentBlock().withText(message); - var returnMessage = (new Message()).withRole(USER_ROLE).withContent(messageContent); - messageList.add(returnMessage); - } - return messageList; + public static List getConverseMessageList(List texts) { + return texts.stream() + .map(text -> ContentBlock.builder().text(text).build()) + .map(content -> Message.builder().role(USER_ROLE).content(content).build()) + .toList(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java index cdabdd4cbebff..c21791ced02cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java @@ -7,8 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ConverseRequest; -import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import org.elasticsearch.core.Nullable; @@ -29,35 +28,33 @@ public record AmazonBedrockMetaCompletionRequestEntity( } @Override - public ConverseRequest addMessages(ConverseRequest request) { - return request.withMessages(getConverseMessageList(messages)); + public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { + return request.messages(getConverseMessageList(messages)); } @Override - public ConverseRequest addInferenceConfig(ConverseRequest request) { + public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { if (temperature == null && topP == null && maxTokenCount == null) { return request; } - InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + return request.inferenceConfig(config -> { + if (temperature != null) { + config.temperature(temperature.floatValue()); + } - if (temperature != null) { - inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); - } - - if (topP != null) { - inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); - } - - if (maxTokenCount != null) { - inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); - } + if (topP != null) { + config.topP(topP.floatValue()); + } - return request.withInferenceConfig(inferenceConfig); + if (maxTokenCount != null) { + config.maxTokens(maxTokenCount); + } + }); } @Override - public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { return request; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java index c68eaa1b81f54..15931674cbabb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java @@ -7,8 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ConverseRequest; -import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; @@ -31,40 +30,38 @@ public record AmazonBedrockMistralCompletionRequestEntity( } @Override - public ConverseRequest addMessages(ConverseRequest request) { - return request.withMessages(getConverseMessageList(messages)); + public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { + return request.messages(getConverseMessageList(messages)); } @Override - public ConverseRequest addInferenceConfig(ConverseRequest request) { + public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { if (temperature == null && topP == null && maxTokenCount == null) { return request; } - InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + return request.inferenceConfig(config -> { + if (temperature != null) { + config.temperature(temperature.floatValue()); + } - if (temperature != null) { - inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); - } - - if (topP != null) { - inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); - } - - if (maxTokenCount != null) { - inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); - } + if (topP != null) { + config.topP(topP.floatValue()); + } - return request.withInferenceConfig(inferenceConfig); + if (maxTokenCount != null) { + config.maxTokens(maxTokenCount); + } + }); } @Override - public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { if (topK == null) { return request; } String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); - return request.withAdditionalModelResponseFieldPaths(topKField); + return request.additionalModelResponseFieldPaths(topKField); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java index d56035b80e9ef..e267592dfd0ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java @@ -7,8 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ConverseRequest; -import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import org.elasticsearch.core.Nullable; @@ -29,35 +28,33 @@ public record AmazonBedrockTitanCompletionRequestEntity( } @Override - public ConverseRequest addMessages(ConverseRequest request) { - return request.withMessages(getConverseMessageList(messages)); + public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { + return request.messages(getConverseMessageList(messages)); } @Override - public ConverseRequest addInferenceConfig(ConverseRequest request) { + public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { if (temperature == null && topP == null && maxTokenCount == null) { return request; } - InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + return request.inferenceConfig(config -> { + if (temperature != null) { + config.temperature(temperature.floatValue()); + } - if (temperature != null) { - inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); - } - - if (topP != null) { - inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); - } - - if (maxTokenCount != null) { - inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); - } + if (topP != null) { + config.topP(topP.floatValue()); + } - return request.withInferenceConfig(inferenceConfig); + if (maxTokenCount != null) { + config.maxTokens(maxTokenCount); + } + }); } @Override - public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { return request; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java index 96d3b3a3cc057..7d7cba53b5b41 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java @@ -7,8 +7,9 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; -import com.amazonaws.services.bedrockruntime.model.InvokeModelRequest; -import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.Nullable; @@ -35,7 +36,7 @@ public class AmazonBedrockEmbeddingsRequest extends AmazonBedrockRequest { private final Truncator truncator; private final Truncator.TruncationResult truncationResult; private final AmazonBedrockProvider provider; - private ActionListener listener = null; + private ActionListener listener = null; public AmazonBedrockEmbeddingsRequest( Truncator truncator, @@ -62,10 +63,10 @@ protected void executeRequest(AmazonBedrockBaseClient client) { var jsonBuilder = new AmazonBedrockJsonBuilder(requestEntity); var bodyAsString = jsonBuilder.getStringContent(); - var charset = StandardCharsets.UTF_8; - var bodyBuffer = charset.encode(bodyAsString); - - var invokeModelRequest = new InvokeModelRequest().withModelId(embeddingsModel.model()).withBody(bodyBuffer); + var invokeModelRequest = InvokeModelRequest.builder() + .modelId(embeddingsModel.model()) + .body(SdkBytes.fromString(bodyAsString, StandardCharsets.UTF_8)) + .build(); SocketAccess.doPrivileged(() -> client.invokeModel(invokeModelRequest, listener)); } catch (IOException e) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponse.java index 5b3872e2c416a..19bd14dfa6f88 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponse.java @@ -7,7 +7,8 @@ package org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.inference.InferenceServiceResults; @@ -16,13 +17,11 @@ import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponse; -import java.util.ArrayList; - public class AmazonBedrockChatCompletionResponse extends AmazonBedrockResponse { - private final ConverseResult result; + private final ConverseResponse result; - public AmazonBedrockChatCompletionResponse(ConverseResult responseResult) { + public AmazonBedrockChatCompletionResponse(ConverseResponse responseResult) { this.result = responseResult; } @@ -35,14 +34,14 @@ public InferenceServiceResults accept(AmazonBedrockRequest request) { throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]"); } - public static ChatCompletionResults fromResponse(ConverseResult response) { - var responseMessage = response.getOutput().getMessage(); - - var messageContents = responseMessage.getContent(); - var resultTexts = new ArrayList(); - for (var messageContent : messageContents) { - resultTexts.add(new ChatCompletionResults.Result(messageContent.getText())); - } + public static ChatCompletionResults fromResponse(ConverseResponse response) { + var resultTexts = response.output() + .message() + .content() + .stream() + .map(ContentBlock::text) + .map(ChatCompletionResults.Result::new) + .toList(); return new ChatCompletionResults(resultTexts); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseHandler.java index a24f54c50eef3..bf4279fe3cc73 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseHandler.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; @@ -18,7 +18,7 @@ public class AmazonBedrockChatCompletionResponseHandler extends AmazonBedrockResponseHandler { - private ConverseResult responseResult; + private ConverseResponse responseResult; public AmazonBedrockChatCompletionResponseHandler() {} @@ -33,7 +33,7 @@ public String getRequestType() { return "Amazon Bedrock Chat Completion"; } - public void acceptChatCompletionResponseObject(ConverseResult response) { + public void acceptChatCompletionResponseObject(ConverseResponse response) { this.responseResult = response; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java index dfdf764871ccf..a56b67e80d616 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; @@ -17,7 +17,7 @@ import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseListener; -public class AmazonBedrockChatCompletionResponseListener extends AmazonBedrockResponseListener implements ActionListener { +public class AmazonBedrockChatCompletionResponseListener extends AmazonBedrockResponseListener implements ActionListener { public AmazonBedrockChatCompletionResponseListener( AmazonBedrockChatCompletionRequest request, @@ -28,7 +28,7 @@ public AmazonBedrockChatCompletionResponseListener( } @Override - public void onResponse(ConverseResult result) { + public void onResponse(ConverseResponse result) { ((AmazonBedrockChatCompletionResponseHandler) responseHandler).acceptChatCompletionResponseObject(result); inferenceResultsListener.onResponse(responseHandler.parseResult(request, (HttpResult) null)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponse.java index 83fa790acbe68..1848e082dec46 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponse.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings; -import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; @@ -33,9 +33,9 @@ public class AmazonBedrockEmbeddingsResponse extends AmazonBedrockResponse { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Amazon Bedrock embeddings response"; - private final InvokeModelResult result; + private final InvokeModelResponse result; - public AmazonBedrockEmbeddingsResponse(InvokeModelResult invokeModelResult) { + public AmazonBedrockEmbeddingsResponse(InvokeModelResponse invokeModelResult) { this.result = invokeModelResult; } @@ -48,9 +48,9 @@ public InferenceServiceResults accept(AmazonBedrockRequest request) { throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]"); } - public static InferenceTextEmbeddingFloatResults fromResponse(InvokeModelResult response, AmazonBedrockProvider provider) { + public static InferenceTextEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) { var charset = StandardCharsets.UTF_8; - var bodyText = String.valueOf(charset.decode(response.getBody())); + var bodyText = String.valueOf(charset.decode(response.body().asByteBuffer())); var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseHandler.java index a3fb68ee23486..8bbf09c968607 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseHandler.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings; -import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; @@ -18,7 +18,7 @@ public class AmazonBedrockEmbeddingsResponseHandler extends AmazonBedrockResponseHandler { - private InvokeModelResult invokeModelResult; + private InvokeModelResponse invokeModelResult; @Override public InferenceServiceResults parseResult(Request request, HttpResult result) throws RetryException { @@ -31,7 +31,7 @@ public String getRequestType() { return "Amazon Bedrock Embeddings"; } - public void acceptEmbeddingsResult(InvokeModelResult result) { + public void acceptEmbeddingsResult(InvokeModelResponse result) { this.invokeModelResult = result; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java index d93e8fd7bb132..415f1ae54b00d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings; -import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; @@ -16,7 +16,7 @@ import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseListener; -public class AmazonBedrockEmbeddingsResponseListener extends AmazonBedrockResponseListener implements ActionListener { +public class AmazonBedrockEmbeddingsResponseListener extends AmazonBedrockResponseListener implements ActionListener { public AmazonBedrockEmbeddingsResponseListener( AmazonBedrockEmbeddingsRequest request, @@ -27,8 +27,8 @@ public AmazonBedrockEmbeddingsResponseListener( } @Override - public void onResponse(InvokeModelResult result) { - ((AmazonBedrockEmbeddingsResponseHandler) responseHandler).acceptEmbeddingsResult(result); + public void onResponse(InvokeModelResponse response) { + ((AmazonBedrockEmbeddingsResponseHandler) responseHandler).acceptEmbeddingsResult(response); inferenceResultsListener.onResponse(responseHandler.parseResult(request, (HttpResult) null)); } diff --git a/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy index a56e5401df4a0..8ec8ff9ad4ddc 100644 --- a/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy +++ b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy @@ -21,4 +21,9 @@ grant { // gcs client opens socket connections for to access repository // also, AWS Bedrock client opens socket connections and needs resolve for to access to resources permission java.net.SocketPermission "*", "connect,resolve"; + + // AWS Clients always try to access the credentials and config files, even if we configure otherwise + permission java.io.FilePermission "${user.home}/.aws/credentials", "read"; + permission java.io.FilePermission "${user.home}/.aws/config", "read"; + permission java.util.PropertyPermission "http.proxyHost", "read"; }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java index 9326d39cb657c..8f09c53c99366 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java @@ -7,11 +7,12 @@ package org.elasticsearch.xpack.inference.external.amazonbedrock; -import com.amazonaws.services.bedrockruntime.model.ContentBlock; -import com.amazonaws.services.bedrockruntime.model.ConverseOutput; -import com.amazonaws.services.bedrockruntime.model.ConverseResult; -import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; -import com.amazonaws.services.bedrockruntime.model.Message; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseOutput; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.Message; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.support.PlainActionFuture; @@ -28,9 +29,8 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; -import java.nio.CharBuffer; import java.nio.charset.CharacterCodingException; -import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.util.List; import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator; @@ -139,18 +139,19 @@ public void testExecute_FailsProperly_WithElasticsearchException() { assertThat(exceptionThrown.getCause().getMessage(), containsString("test exception")); } - public static ConverseResult getTestConverseResult(String resultText) { - var message = new Message().withContent(new ContentBlock().withText(resultText)); - var converseOutput = new ConverseOutput().withMessage(message); - return new ConverseResult().withOutput(converseOutput); + public static ConverseResponse getTestConverseResult(String resultText) { + return ConverseResponse.builder() + .output( + ConverseOutput.builder().message(Message.builder().content(ContentBlock.builder().text(resultText).build()).build()).build() + ) + .build(); } - public static InvokeModelResult getTestInvokeResult(String resultJson) throws CharacterCodingException { - var result = new InvokeModelResult(); - result.setContentType("application/json"); - var encoder = Charset.forName("UTF-8").newEncoder(); - result.setBody(encoder.encode(CharBuffer.wrap(resultJson))); - return result; + public static InvokeModelResponse getTestInvokeResult(String resultJson) { + return InvokeModelResponse.builder() + .contentType("application/json") + .body(SdkBytes.fromString(resultJson, StandardCharsets.UTF_8)) + .build(); } public static final String TEST_AMAZON_TITAN_EMBEDDINGS_RESULT = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockClientCache.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockClientCache.java index 912967a9012d7..d8164b3552da2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockClientCache.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockClientCache.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.inference.external.amazonbedrock; -import com.amazonaws.services.bedrockruntime.model.ConverseResult; -import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.core.Nullable; @@ -18,27 +18,27 @@ import java.io.IOException; public class AmazonBedrockMockClientCache implements AmazonBedrockClientCache { - private ConverseResult converseResult = null; - private InvokeModelResult invokeModelResult = null; + private ConverseResponse converseResponse = null; + private InvokeModelResponse invokeModelResponse = null; private ElasticsearchException exceptionToThrow = null; public AmazonBedrockMockClientCache() {} public AmazonBedrockMockClientCache( - @Nullable ConverseResult converseResult, - @Nullable InvokeModelResult invokeModelResult, + @Nullable ConverseResponse converseResponse, + @Nullable InvokeModelResponse invokeModelResponse, @Nullable ElasticsearchException exceptionToThrow ) { - this.converseResult = converseResult; - this.invokeModelResult = invokeModelResult; + this.converseResponse = converseResponse; + this.invokeModelResponse = invokeModelResponse; this.exceptionToThrow = exceptionToThrow; } @Override public AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, TimeValue timeout) { - var client = (AmazonBedrockMockInferenceClient) AmazonBedrockMockInferenceClient.create(model, timeout); - client.setConverseResult(converseResult); - client.setInvokeModelResult(invokeModelResult); + var client = AmazonBedrockMockInferenceClient.create(model, timeout); + client.setConverseResponse(converseResponse); + client.setInvokeModelResponse(invokeModelResponse); client.setExceptionToThrow(exceptionToThrow); return client; } @@ -48,12 +48,12 @@ public void close() throws IOException { // nothing to do } - public void setConverseResult(ConverseResult converseResult) { - this.converseResult = converseResult; + public void setConverseResponse(ConverseResponse converseResponse) { + this.converseResponse = converseResponse; } - public void setInvokeModelResult(InvokeModelResult invokeModelResult) { - this.invokeModelResult = invokeModelResult; + public void setInvokeModelResponse(InvokeModelResponse invokeModelResponse) { + this.invokeModelResponse = invokeModelResponse; } public void setExceptionToThrow(ElasticsearchException exceptionToThrow) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockExecuteRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockExecuteRequestSender.java index b0df8a40e2551..f4fbd5eb725ea 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockExecuteRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockExecuteRequestSender.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.inference.external.amazonbedrock; -import com.amazonaws.services.bedrockruntime.model.ConverseResult; -import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; @@ -60,13 +60,13 @@ protected AmazonBedrockExecutor createExecutor( private void setCacheResult() { var mockCache = (AmazonBedrockMockClientCache) this.clientCache; var result = results.remove(); - if (result instanceof ConverseResult converseResult) { - mockCache.setConverseResult(converseResult); + if (result instanceof ConverseResponse converseResponse) { + mockCache.setConverseResponse(converseResponse); return; } - if (result instanceof InvokeModelResult invokeModelResult) { - mockCache.setInvokeModelResult(invokeModelResult); + if (result instanceof InvokeModelResponse invokeModelResponse) { + mockCache.setInvokeModelResponse(invokeModelResponse); return; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java index dcbf8dfcbff01..5584e90b3264d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java @@ -7,33 +7,28 @@ package org.elasticsearch.xpack.inference.external.amazonbedrock; -import com.amazonaws.services.bedrockruntime.AmazonBedrockRuntimeAsync; -import com.amazonaws.services.bedrockruntime.model.ConverseResult; -import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import java.util.concurrent.CompletableFuture; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; public class AmazonBedrockMockInferenceClient extends AmazonBedrockInferenceClient { - private ConverseResult converseResult = null; - private InvokeModelResult invokeModelResult = null; - private ElasticsearchException exceptionToThrow = null; + private CompletableFuture converseResponseFuture = CompletableFuture.completedFuture(null); + private CompletableFuture invokeModelResponseFuture = CompletableFuture.completedFuture(null); - private Future converseResultFuture = new MockConverseResultFuture(); - private Future invokeModelResultFuture = new MockInvokeResultFuture(); - - public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout) { + public static AmazonBedrockMockInferenceClient create(AmazonBedrockModel model, @Nullable TimeValue timeout) { return new AmazonBedrockMockInferenceClient(model, timeout); } @@ -42,92 +37,31 @@ protected AmazonBedrockMockInferenceClient(AmazonBedrockModel model, @Nullable T } public void setExceptionToThrow(ElasticsearchException exceptionToThrow) { - this.exceptionToThrow = exceptionToThrow; + if (exceptionToThrow != null) { + this.converseResponseFuture = new CompletableFuture<>(); + this.converseResponseFuture.completeExceptionally(exceptionToThrow); + this.invokeModelResponseFuture = new CompletableFuture<>(); + this.invokeModelResponseFuture.completeExceptionally(exceptionToThrow); + } } - public void setConverseResult(ConverseResult result) { - this.converseResult = result; + public void setConverseResponse(ConverseResponse result) { + this.converseResponseFuture = CompletableFuture.completedFuture(result); } - public void setInvokeModelResult(InvokeModelResult result) { - this.invokeModelResult = result; + public void setInvokeModelResponse(InvokeModelResponse result) { + this.invokeModelResponseFuture = CompletableFuture.completedFuture(result); } @Override - protected AmazonBedrockRuntimeAsync createAmazonBedrockClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { - var runtimeClient = mock(AmazonBedrockRuntimeAsync.class); - doAnswer(invocation -> invokeModelResultFuture).when(runtimeClient).invokeModelAsync(any()); - doAnswer(invocation -> converseResultFuture).when(runtimeClient).converseAsync(any()); + protected BedrockRuntimeAsyncClient createAmazonBedrockClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + var runtimeClient = mock(BedrockRuntimeAsyncClient.class); + doAnswer(invocation -> invokeModelResponseFuture).when(runtimeClient).invokeModel(any(InvokeModelRequest.class)); + doAnswer(invocation -> converseResponseFuture).when(runtimeClient).converse(any(ConverseRequest.class)); return runtimeClient; } @Override void close() {} - - private class MockConverseResultFuture implements Future { - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - return false; - } - - @Override - public boolean isCancelled() { - return false; - } - - @Override - public boolean isDone() { - return false; - } - - @Override - public ConverseResult get() throws InterruptedException, ExecutionException { - if (exceptionToThrow != null) { - throw exceptionToThrow; - } - return converseResult; - } - - @Override - public ConverseResult get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { - if (exceptionToThrow != null) { - throw exceptionToThrow; - } - return converseResult; - } - } - - private class MockInvokeResultFuture implements Future { - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - return false; - } - - @Override - public boolean isCancelled() { - return false; - } - - @Override - public boolean isDone() { - return false; - } - - @Override - public InvokeModelResult get() throws InterruptedException, ExecutionException { - if (exceptionToThrow != null) { - throw exceptionToThrow; - } - return invokeModelResult; - } - - @Override - public InvokeModelResult get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { - if (exceptionToThrow != null) { - throw exceptionToThrow; - } - return invokeModelResult; - } - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java index b91aab5410048..10c8943c75f6c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java @@ -26,9 +26,9 @@ public class AmazonBedrockAI21LabsCompletionRequestEntityTests extends ESTestCas public void testRequestEntity_CreatesProperRequest() { var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); @@ -38,7 +38,7 @@ public void testRequestEntity_CreatesProperRequest() { public void testRequestEntity_CreatesProperRequest_WithTemperature() { var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), 1.0, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); @@ -49,7 +49,7 @@ public void testRequestEntity_CreatesProperRequest_WithTemperature() { public void testRequestEntity_CreatesProperRequest_WithTopP() { var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, 1.0, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); @@ -60,7 +60,7 @@ public void testRequestEntity_CreatesProperRequest_WithTopP() { public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, null, 128); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java index 89d5fec7efba6..e8a3440a37294 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java @@ -27,9 +27,9 @@ public class AmazonBedrockAnthropicCompletionRequestEntityTests extends ESTestCa public void testRequestEntity_CreatesProperRequest() { var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); @@ -39,7 +39,7 @@ public void testRequestEntity_CreatesProperRequest() { public void testRequestEntity_CreatesProperRequest_WithTemperature() { var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); @@ -50,7 +50,7 @@ public void testRequestEntity_CreatesProperRequest_WithTemperature() { public void testRequestEntity_CreatesProperRequest_WithTopP() { var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); @@ -61,7 +61,7 @@ public void testRequestEntity_CreatesProperRequest_WithTopP() { public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, null, 128); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); @@ -72,7 +72,7 @@ public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { public void testRequestEntity_CreatesProperRequest_WithTopK() { var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java index 8df5c7f32e529..c8e844d000240 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java @@ -27,9 +27,9 @@ public class AmazonBedrockCohereCompletionRequestEntityTests extends ESTestCase public void testRequestEntity_CreatesProperRequest() { var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); @@ -39,7 +39,7 @@ public void testRequestEntity_CreatesProperRequest() { public void testRequestEntity_CreatesProperRequest_WithTemperature() { var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); @@ -50,7 +50,7 @@ public void testRequestEntity_CreatesProperRequest_WithTemperature() { public void testRequestEntity_CreatesProperRequest_WithTopP() { var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); @@ -61,7 +61,7 @@ public void testRequestEntity_CreatesProperRequest_WithTopP() { public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, null, 128); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); @@ -72,7 +72,7 @@ public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { public void testRequestEntity_CreatesProperRequest_WithTopK() { var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java index cbbe3c5554967..17c3b4488bae4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java @@ -7,70 +7,70 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; -import com.amazonaws.services.bedrockruntime.model.ContentBlock; -import com.amazonaws.services.bedrockruntime.model.ConverseRequest; -import com.amazonaws.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.Message; import org.elasticsearch.core.Strings; +import java.util.Collection; + public final class AmazonBedrockConverseRequestUtils { public static ConverseRequest getConverseRequest(String modelId, AmazonBedrockConverseRequestEntity requestEntity) { - var converseRequest = new ConverseRequest().withModelId(modelId); + var converseRequest = ConverseRequest.builder().modelId(modelId); converseRequest = requestEntity.addMessages(converseRequest); converseRequest = requestEntity.addInferenceConfig(converseRequest); converseRequest = requestEntity.addAdditionalModelFields(converseRequest); - return converseRequest; + return converseRequest.build(); } public static boolean doesConverseRequestHasMessage(ConverseRequest converseRequest, String expectedMessage) { - for (Message message : converseRequest.getMessages()) { - var content = message.getContent(); - for (ContentBlock contentBlock : content) { - if (contentBlock.getText().equals(expectedMessage)) { - return true; - } - } + if (expectedMessage == null) { + return false; } - return false; + return converseRequest.messages() + .stream() + .map(Message::content) + .flatMap(Collection::stream) + .map(ContentBlock::text) + .anyMatch(expectedMessage::equals); } public static boolean doesConverseRequestHaveAnyTemperatureInput(ConverseRequest converseRequest) { - return converseRequest.getInferenceConfig() != null - && converseRequest.getInferenceConfig().getTemperature() != null - && (converseRequest.getInferenceConfig().getTemperature().isNaN() == false); + return converseRequest.inferenceConfig() != null + && converseRequest.inferenceConfig().temperature() != null + && (converseRequest.inferenceConfig().temperature().isNaN() == false); } public static boolean doesConverseRequestHaveAnyTopPInput(ConverseRequest converseRequest) { - return converseRequest.getInferenceConfig() != null - && converseRequest.getInferenceConfig().getTopP() != null - && (converseRequest.getInferenceConfig().getTopP().isNaN() == false); + return converseRequest.inferenceConfig() != null + && converseRequest.inferenceConfig().topP() != null + && (converseRequest.inferenceConfig().topP().isNaN() == false); } public static boolean doesConverseRequestHaveAnyMaxTokensInput(ConverseRequest converseRequest) { - return converseRequest.getInferenceConfig() != null && converseRequest.getInferenceConfig().getMaxTokens() != null; + return converseRequest.inferenceConfig() != null && converseRequest.inferenceConfig().maxTokens() != null; } public static boolean doesConverseRequestHaveTemperatureInput(ConverseRequest converseRequest, Double temperature) { return doesConverseRequestHaveAnyTemperatureInput(converseRequest) - && converseRequest.getInferenceConfig().getTemperature().equals(temperature.floatValue()); + && converseRequest.inferenceConfig().temperature().equals(temperature.floatValue()); } public static boolean doesConverseRequestHaveTopPInput(ConverseRequest converseRequest, Double topP) { - return doesConverseRequestHaveAnyTopPInput(converseRequest) - && converseRequest.getInferenceConfig().getTopP().equals(topP.floatValue()); + return doesConverseRequestHaveAnyTopPInput(converseRequest) && converseRequest.inferenceConfig().topP().equals(topP.floatValue()); } public static boolean doesConverseRequestHaveMaxTokensInput(ConverseRequest converseRequest, Integer maxTokens) { - return doesConverseRequestHaveAnyMaxTokensInput(converseRequest) - && converseRequest.getInferenceConfig().getMaxTokens().equals(maxTokens); + return doesConverseRequestHaveAnyMaxTokensInput(converseRequest) && converseRequest.inferenceConfig().maxTokens().equals(maxTokens); } public static boolean doesConverseRequestHaveAnyTopKInput(ConverseRequest converseRequest) { - if (converseRequest.getAdditionalModelResponseFieldPaths() == null) { + if (converseRequest.additionalModelResponseFieldPaths() == null) { return false; } - for (String fieldPath : converseRequest.getAdditionalModelResponseFieldPaths()) { + for (String fieldPath : converseRequest.additionalModelResponseFieldPaths()) { if (fieldPath.contains("{\"top_k\":")) { return true; } @@ -84,7 +84,7 @@ public static boolean doesConverseRequestHaveTopKInput(ConverseRequest converseR } var checkString = Strings.format("{\"top_k\":%f}", topK.floatValue()); - for (String fieldPath : converseRequest.getAdditionalModelResponseFieldPaths()) { + for (String fieldPath : converseRequest.additionalModelResponseFieldPaths()) { if (fieldPath.contains(checkString)) { return true; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java index fa482669a0bb2..25700f7c7aee1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java @@ -26,9 +26,9 @@ public class AmazonBedrockMetaCompletionRequestEntityTests extends ESTestCase { public void testRequestEntity_CreatesProperRequest() { var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); @@ -38,7 +38,7 @@ public void testRequestEntity_CreatesProperRequest() { public void testRequestEntity_CreatesProperRequest_WithTemperature() { var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), 1.0, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); @@ -49,7 +49,7 @@ public void testRequestEntity_CreatesProperRequest_WithTemperature() { public void testRequestEntity_CreatesProperRequest_WithTopP() { var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, 1.0, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); @@ -60,7 +60,7 @@ public void testRequestEntity_CreatesProperRequest_WithTopP() { public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, null, 128); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java index 788625d3702b8..8e321b0cb33a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java @@ -27,9 +27,9 @@ public class AmazonBedrockMistralCompletionRequestEntityTests extends ESTestCase public void testRequestEntity_CreatesProperRequest() { var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); @@ -39,7 +39,7 @@ public void testRequestEntity_CreatesProperRequest() { public void testRequestEntity_CreatesProperRequest_WithTemperature() { var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); @@ -50,7 +50,7 @@ public void testRequestEntity_CreatesProperRequest_WithTemperature() { public void testRequestEntity_CreatesProperRequest_WithTopP() { var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); @@ -61,7 +61,7 @@ public void testRequestEntity_CreatesProperRequest_WithTopP() { public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, null, 128); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); @@ -72,7 +72,7 @@ public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { public void testRequestEntity_CreatesProperRequest_WithTopK() { var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java index 79fa387876c8b..8d1c15499bfb6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java @@ -26,9 +26,9 @@ public class AmazonBedrockTitanCompletionRequestEntityTests extends ESTestCase { public void testRequestEntity_CreatesProperRequest() { var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); @@ -38,7 +38,7 @@ public void testRequestEntity_CreatesProperRequest() { public void testRequestEntity_CreatesProperRequest_WithTemperature() { var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), 1.0, null, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); @@ -49,7 +49,7 @@ public void testRequestEntity_CreatesProperRequest_WithTemperature() { public void testRequestEntity_CreatesProperRequest_WithTopP() { var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, 1.0, null); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); @@ -60,7 +60,7 @@ public void testRequestEntity_CreatesProperRequest_WithTopP() { public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, null, 128); var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(builtRequest.modelId(), is("testmodel")); assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); From a37344481f206612e3392acc1ffa4655d00a98ea Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:38:30 +1100 Subject: [PATCH 15/25] Mute org.elasticsearch.kibana.KibanaThreadPoolIT testBlockedThreadPoolsRejectUserRequests #113939 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index e63499aef6c18..4e3c58b50e2fb 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -301,6 +301,9 @@ tests: - class: org.elasticsearch.search.retriever.RankDocsRetrieverBuilderTests method: testRewrite issue: https://github.com/elastic/elasticsearch/issues/114467 +- class: org.elasticsearch.kibana.KibanaThreadPoolIT + method: testBlockedThreadPoolsRejectUserRequests + issue: https://github.com/elastic/elasticsearch/issues/113939 # Examples: # From 4bfbc4100a2bf124c23ef449027f174240e74527 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:41:32 +1100 Subject: [PATCH 16/25] Mute org.elasticsearch.ingest.geoip.DatabaseNodeServiceIT testGzippedDatabase #113752 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 4e3c58b50e2fb..284e83adcc597 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -304,6 +304,9 @@ tests: - class: org.elasticsearch.kibana.KibanaThreadPoolIT method: testBlockedThreadPoolsRejectUserRequests issue: https://github.com/elastic/elasticsearch/issues/113939 +- class: org.elasticsearch.ingest.geoip.DatabaseNodeServiceIT + method: testGzippedDatabase + issue: https://github.com/elastic/elasticsearch/issues/113752 # Examples: # From 1b8224aa24217194a59bf8060e9641cb3f701a35 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:41:33 +1100 Subject: [PATCH 17/25] Mute org.elasticsearch.xpack.rank.rrf.RRFRankClientYamlTestSuiteIT test {yaml=rrf/700_rrf_retriever_search_api_compatibility/rrf retriever with top-level collapse} #114331 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 284e83adcc597..33a965d5c7631 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -307,6 +307,9 @@ tests: - class: org.elasticsearch.ingest.geoip.DatabaseNodeServiceIT method: testGzippedDatabase issue: https://github.com/elastic/elasticsearch/issues/113752 +- class: org.elasticsearch.xpack.rank.rrf.RRFRankClientYamlTestSuiteIT + method: test {yaml=rrf/700_rrf_retriever_search_api_compatibility/rrf retriever with top-level collapse} + issue: https://github.com/elastic/elasticsearch/issues/114331 # Examples: # From 79d832d24a3b37df47d4864c202f25f54d6bcf70 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:42:56 +1100 Subject: [PATCH 18/25] Mute org.elasticsearch.threadpool.SimpleThreadPoolIT testThreadPoolMetrics #108320 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 33a965d5c7631..9819e424d242f 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -310,6 +310,9 @@ tests: - class: org.elasticsearch.xpack.rank.rrf.RRFRankClientYamlTestSuiteIT method: test {yaml=rrf/700_rrf_retriever_search_api_compatibility/rrf retriever with top-level collapse} issue: https://github.com/elastic/elasticsearch/issues/114331 +- class: org.elasticsearch.threadpool.SimpleThreadPoolIT + method: testThreadPoolMetrics + issue: https://github.com/elastic/elasticsearch/issues/108320 # Examples: # From d16d7b3c797f3cae9818ded82bbe4493a9ee9d9f Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Fri, 11 Oct 2024 16:27:38 +1100 Subject: [PATCH 19/25] Mute org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT test {p0=indices.split/40_routing_partition_size/nested} #113842 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 9819e424d242f..1884b914de359 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -313,6 +313,9 @@ tests: - class: org.elasticsearch.threadpool.SimpleThreadPoolIT method: testThreadPoolMetrics issue: https://github.com/elastic/elasticsearch/issues/108320 +- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT + method: test {p0=indices.split/40_routing_partition_size/nested} + issue: https://github.com/elastic/elasticsearch/issues/113842 # Examples: # From ee0f8ec26f9aafc7a068aa18e62498992365ecda Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Fri, 11 Oct 2024 16:27:48 +1100 Subject: [PATCH 20/25] Mute org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT test {p0=indices.split/40_routing_partition_size/more than 1} #113841 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 1884b914de359..a9be9d9164e1c 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -316,6 +316,9 @@ tests: - class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT method: test {p0=indices.split/40_routing_partition_size/nested} issue: https://github.com/elastic/elasticsearch/issues/113842 +- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT + method: test {p0=indices.split/40_routing_partition_size/more than 1} + issue: https://github.com/elastic/elasticsearch/issues/113841 # Examples: # From 49829975247a7b946512ad2961af00e586c014ee Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Fri, 11 Oct 2024 16:33:47 +1100 Subject: [PATCH 21/25] Mute org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT test {categorize.Categorize ASYNC} #113721 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index a9be9d9164e1c..65c321834a1e2 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -319,6 +319,9 @@ tests: - class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT method: test {p0=indices.split/40_routing_partition_size/more than 1} issue: https://github.com/elastic/elasticsearch/issues/113841 +- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT + method: test {categorize.Categorize ASYNC} + issue: https://github.com/elastic/elasticsearch/issues/113721 # Examples: # From 14ba77f21d776466e28028c6d882b1676bc35270 Mon Sep 17 00:00:00 2001 From: Kostas Krikellas <131142368+kkrik-es@users.noreply.github.com> Date: Fri, 11 Oct 2024 09:14:08 +0300 Subject: [PATCH 22/25] [TEST] Add coverage for field caps and ES|QL to LogsDB QA testing (#114505) (#114562) * Add coverage for field caps and ES|QL to LogsDB QA testing * address comments * address comments * address comments (cherry picked from commit 34da953571ff70c7da680e4e08e4d4730092539c) --- .../logsdb/qa/AbstractChallengeRestTest.java | 27 +++++++ ...ardVersusLogsIndexModeChallengeRestIT.java | 80 +++++++++++++++++++ 2 files changed, 107 insertions(+) diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/AbstractChallengeRestTest.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/AbstractChallengeRestTest.java index 88a33d502633b..6464b4e966823 100644 --- a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/AbstractChallengeRestTest.java +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/AbstractChallengeRestTest.java @@ -276,6 +276,33 @@ private Response query(final SearchSourceBuilder search, final Supplier return client.performRequest(request); } + public Response esqlBaseline(final String query) throws IOException { + return esql(query, this::getBaselineDataStreamName); + } + + public Response esqlContender(final String query) throws IOException { + return esql(query, this::getContenderDataStreamName); + } + + private Response esql(final String query, final Supplier dataStreamNameSupplier) throws IOException { + final Request request = new Request("POST", "/_query"); + request.setJsonEntity("{\"query\": \"" + query.replace("$index", dataStreamNameSupplier.get()) + "\"}"); + return client.performRequest(request); + } + + public Response fieldCapsBaseline() throws IOException { + return fieldCaps(this::getBaselineDataStreamName); + } + + public Response fieldCapsContender() throws IOException { + return fieldCaps(this::getContenderDataStreamName); + } + + private Response fieldCaps(final Supplier dataStreamNameSupplier) throws IOException { + final Request request = new Request("GET", "/" + dataStreamNameSupplier.get() + "/_field_caps?fields=*"); + return client.performRequest(request); + } + public String getBaselineDataStreamName() { return baselineDataStreamName; } diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeChallengeRestIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeChallengeRestIT.java index 0b6cc38aff37a..43efdbdcf8b1c 100644 --- a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeChallengeRestIT.java +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeChallengeRestIT.java @@ -42,6 +42,7 @@ import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.TreeMap; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -271,6 +272,51 @@ public void testDateHistogramAggregation() throws IOException { assertTrue(matchResult.getMessage(), matchResult.isMatch()); } + public void testEsqlSource() throws IOException { + int numberOfDocuments = ESTestCase.randomIntBetween(100, 200); + final List documents = generateDocuments(numberOfDocuments); + + indexDocuments(documents); + + final String query = "FROM $index METADATA _source, _id | KEEP _source, _id | LIMIT " + numberOfDocuments; + final MatchResult matchResult = Matcher.matchSource() + .mappings(getContenderMappings(), getBaselineMappings()) + .settings(getContenderSettings(), getBaselineSettings()) + .expected(getEsqlSourceResults(esqlBaseline(query))) + .ignoringSort(true) + .isEqualTo(getEsqlSourceResults(esqlContender(query))); + assertTrue(matchResult.getMessage(), matchResult.isMatch()); + } + + public void testEsqlTermsAggregation() throws IOException { + int numberOfDocuments = ESTestCase.randomIntBetween(100, 200); + final List documents = generateDocuments(numberOfDocuments); + + indexDocuments(documents); + + final String query = "FROM $index | STATS count(*) BY host.name | SORT host.name | LIMIT " + numberOfDocuments; + final MatchResult matchResult = Matcher.mappings(getContenderMappings(), getBaselineMappings()) + .settings(getContenderSettings(), getBaselineSettings()) + .expected(getEsqlStatsResults(esqlBaseline(query))) + .ignoringSort(true) + .isEqualTo(getEsqlStatsResults(esqlContender(query))); + assertTrue(matchResult.getMessage(), matchResult.isMatch()); + } + + public void testFieldCaps() throws IOException { + int numberOfDocuments = ESTestCase.randomIntBetween(20, 50); + final List documents = generateDocuments(numberOfDocuments); + + indexDocuments(documents); + + final MatchResult matchResult = Matcher.mappings(getContenderMappings(), getBaselineMappings()) + .settings(getContenderSettings(), getBaselineSettings()) + .expected(getFields(fieldCapsBaseline())) + .ignoringSort(true) + .isEqualTo(getFields(fieldCapsContender())); + assertTrue(matchResult.getMessage(), matchResult.isMatch()); + } + @Override public Response indexBaselineDocuments(CheckedSupplier, IOException> documentsSupplier) throws IOException { var response = super.indexBaselineDocuments(documentsSupplier); @@ -329,6 +375,40 @@ private static List> getQueryHits(final Response response) t .toList(); } + @SuppressWarnings("unchecked") + private static Map getFields(final Response response) throws IOException { + final Map map = XContentHelper.convertToMap(XContentType.JSON.xContent(), response.getEntity().getContent(), true); + final Map fields = (Map) map.get("fields"); + assertThat(fields.size(), greaterThan(0)); + return new TreeMap<>(fields); + } + + @SuppressWarnings("unchecked") + private static List> getEsqlSourceResults(final Response response) throws IOException { + final Map map = XContentHelper.convertToMap(XContentType.JSON.xContent(), response.getEntity().getContent(), true); + final List> values = (List>) map.get("values"); + assertThat(values.size(), greaterThan(0)); + + // Results contain a list of [source, id] lists. + return values.stream() + .sorted(Comparator.comparingInt((List value) -> Integer.parseInt((String) value.get(1)))) + .map(value -> (Map) value.get(0)) + .toList(); + } + + @SuppressWarnings("unchecked") + private static List> getEsqlStatsResults(final Response response) throws IOException { + final Map map = XContentHelper.convertToMap(XContentType.JSON.xContent(), response.getEntity().getContent(), true); + final List> values = (List>) map.get("values"); + assertThat(values.size(), greaterThan(0)); + + // Results contain a list of [agg value, group name] lists. + return values.stream() + .sorted(Comparator.comparing((List value) -> (String) value.get(1))) + .map(value -> Map.of((String) value.get(1), value.get(0))) + .toList(); + } + @SuppressWarnings("unchecked") private static List> getAggregationBuckets(final Response response, final String aggName) throws IOException { final Map map = XContentHelper.convertToMap(XContentType.JSON.xContent(), response.getEntity().getContent(), true); From 979cbd08db973840de05c60758d320c4612ed77f Mon Sep 17 00:00:00 2001 From: Salvatore Campagna <93581129+salvatore-campagna@users.noreply.github.com> Date: Fri, 11 Oct 2024 08:45:00 +0200 Subject: [PATCH 23/25] LogsDB `host` and `timestamp` mappings tests (#114001) (#114455) Here we are testing mappings of `host` and `timestamp` fields as they are used as default fields to sort on when using LogsDB. LogsDB uses a `host.name` field mapped as a `keyword` and a `@timestamp` field, required by data streams. Some mappings throw errors as a result of incompatibilities when trying to merge object fields. Such errors are expected. (cherry picked from commit c4815b3416279f7d952a2904e3efb23fdc0ba755) --- x-pack/plugin/logsdb/build.gradle | 6 + .../test/10_logsdb_default_mapping.yml | 254 ++++++++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 x-pack/plugin/logsdb/src/yamlRestTest/resources/rest-api-spec/test/10_logsdb_default_mapping.yml diff --git a/x-pack/plugin/logsdb/build.gradle b/x-pack/plugin/logsdb/build.gradle index 466cf69243c8e..929d7dad2f5e6 100644 --- a/x-pack/plugin/logsdb/build.gradle +++ b/x-pack/plugin/logsdb/build.gradle @@ -23,6 +23,12 @@ base { archivesName = 'x-pack-logsdb' } +restResources { + restApi { + include 'bulk', 'search', '_common', 'indices', 'index', 'cluster', 'data_stream', 'ingest', 'cat', 'capabilities' + } +} + dependencies { compileOnly project(path: xpackModule('core')) testImplementation(testArtifact(project(xpackModule('core')))) diff --git a/x-pack/plugin/logsdb/src/yamlRestTest/resources/rest-api-spec/test/10_logsdb_default_mapping.yml b/x-pack/plugin/logsdb/src/yamlRestTest/resources/rest-api-spec/test/10_logsdb_default_mapping.yml new file mode 100644 index 0000000000000..8346221c01066 --- /dev/null +++ b/x-pack/plugin/logsdb/src/yamlRestTest/resources/rest-api-spec/test/10_logsdb_default_mapping.yml @@ -0,0 +1,254 @@ +--- +create logsdb data stream with host.name as keyword: + - requires: + cluster_features: [ "mapper.keyword_normalizer_synthetic_source" ] + reason: support for normalizer on keyword fields + + - do: + cluster.put_component_template: + name: "logsdb-mappings" + body: + template: + settings: + mode: "logsdb" + mappings: + properties: + host.name: + type: "keyword" + + - do: + indices.put_index_template: + name: "logsdb-index-template" + body: + index_patterns: ["logsdb"] + data_stream: {} + composed_of: ["logsdb-mappings"] + + - do: + indices.create_data_stream: + name: "logsdb" + + - is_true: acknowledged + +--- +create logsdb data stream with host.name as keyword and timestamp as date: + - requires: + cluster_features: [ "mapper.keyword_normalizer_synthetic_source" ] + reason: support for normalizer on keyword fields + + - do: + cluster.put_component_template: + name: "logsdb-mappings" + body: + template: + settings: + mode: "logsdb" + mappings: + properties: + host.name: + type: "keyword" + "@timestamp": + type: "date" + + - do: + indices.put_index_template: + name: "logsdb-index-template" + body: + index_patterns: ["logsdb"] + data_stream: {} + composed_of: ["logsdb-mappings"] + + - do: + indices.create_data_stream: + name: "logsdb" + + - is_true: acknowledged + +--- +create logsdb data stream with host as keyword: + - requires: + cluster_features: [ "mapper.keyword_normalizer_synthetic_source" ] + reason: support for normalizer on keyword fields + + - do: + cluster.put_component_template: + name: "logsdb-mappings" + body: + template: + settings: + mode: "logsdb" + mappings: + properties: + host: + type: "keyword" + + - do: + indices.put_index_template: + name: "logsdb-index-template" + body: + index_patterns: ["logsdb"] + data_stream: {} + composed_of: ["logsdb-mappings"] + + - do: + catch: bad_request + indices.create_data_stream: + name: "logsdb" + + - match: { error.type: "mapper_parsing_exception" } + - match: { error.reason: "Failed to parse mapping: can't merge a non object mapping [host] with an object mapping" } + +--- +create logsdb data stream with host as text and multi fields: + - requires: + cluster_features: [ "mapper.keyword_normalizer_synthetic_source" ] + reason: support for normalizer on keyword fields + + - do: + cluster.put_component_template: + name: "logsdb-mappings" + body: + template: + settings: + mode: "logsdb" + mappings: + properties: + host: + type: "text" + fields: + keyword: + ignore_above: 256 + type: "keyword" + "@timestamp": + type: "date" + format: "strict_date_optional_time" + + - do: + indices.put_index_template: + name: "logsdb-index-template" + body: + index_patterns: ["logsdb"] + data_stream: {} + composed_of: ["logsdb-mappings"] + + - do: + catch: bad_request + indices.create_data_stream: + name: "logsdb" + + - match: { error.type: "mapper_parsing_exception" } + - match: { error.reason: "Failed to parse mapping: can't merge a non object mapping [host] with an object mapping" } + +--- +create logsdb data stream with host as text: + - requires: + cluster_features: ["mapper.keyword_normalizer_synthetic_source"] + reason: "Support for normalizer on keyword fields" + + - do: + cluster.put_component_template: + name: "logsdb-mappings" + body: + template: + settings: + mode: "logsdb" + mappings: + properties: + host: + type: "text" + "@timestamp": + type: "date" + format: "strict_date_optional_time" + + - do: + indices.put_index_template: + name: "logsdb-index-template" + body: + index_patterns: ["logsdb"] + data_stream: {} + composed_of: ["logsdb-mappings"] + + - do: + catch: bad_request + indices.create_data_stream: + name: "logsdb" + + - match: { error.type: "mapper_parsing_exception" } + - match: { error.reason: "Failed to parse mapping: can't merge a non object mapping [host] with an object mapping" } + +--- +create logsdb data stream with host as text and name as double: + - requires: + cluster_features: ["mapper.keyword_normalizer_synthetic_source"] + reason: "Support for normalizer on keyword fields" + + - do: + cluster.put_component_template: + name: "logsdb-mappings" + body: + template: + settings: + mode: "logsdb" + mappings: + properties: + host: + type: "text" + fields: + name: + type: "double" + "@timestamp": + type: "date" + format: "strict_date_optional_time" + + - do: + indices.put_index_template: + name: "logsdb-index-template" + body: + index_patterns: ["logsdb"] + data_stream: {} + composed_of: ["logsdb-mappings"] + + - do: + catch: bad_request + indices.create_data_stream: + name: "logsdb" + + - match: { error.type: "mapper_parsing_exception" } + - match: { error.reason: "Failed to parse mapping: can't merge a non object mapping [host] with an object mapping" } + +--- +create logsdb data stream with timestamp object mapping: + - requires: + cluster_features: ["mapper.keyword_normalizer_synthetic_source"] + reason: "Support for normalizer on keyword fields" + + - do: + cluster.put_component_template: + name: "logsdb-mappings" + body: + template: + settings: + mode: "logsdb" + mappings: + properties: + host: + properties: + name: + type: "keyword" + "@timestamp": + properties: + date: + type: "date" + format: "strict_date_optional_time" + + - do: + catch: bad_request + indices.put_index_template: + name: "logsdb-index-template" + body: + index_patterns: ["logsdb"] + data_stream: {} + composed_of: ["logsdb-mappings"] + + - match: { error.type: "illegal_argument_exception" } + - match: { error.reason: "composable template [logsdb-index-template] template after composition with component templates [logsdb-mappings] is invalid" } From 62d0765fa44758cf75419dfc1607f1f1ca28204a Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Fri, 11 Oct 2024 11:07:41 +0300 Subject: [PATCH 24/25] Backporting 114502 to 8.x branch (#114568) --- muted-tests.yml | 3 - .../RankDocsRetrieverBuilderTests.java | 75 +++++++++---------- 2 files changed, 37 insertions(+), 41 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 65c321834a1e2..f39e37bd202fb 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -298,9 +298,6 @@ tests: - class: org.elasticsearch.xpack.inference.InferenceRestIT method: test {p0=inference/40_semantic_text_query/Query a field that uses the default ELSER 2 endpoint} issue: https://github.com/elastic/elasticsearch/issues/114376 -- class: org.elasticsearch.search.retriever.RankDocsRetrieverBuilderTests - method: testRewrite - issue: https://github.com/elastic/elasticsearch/issues/114467 - class: org.elasticsearch.kibana.KibanaThreadPoolIT method: testBlockedThreadPoolsRejectUserRequests issue: https://github.com/elastic/elasticsearch/issues/113939 diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java index bcb93b100ea48..384564ac01e2a 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java @@ -25,13 +25,11 @@ import java.util.List; import java.util.function.Supplier; -import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.search.vectors.KnnSearchBuilderTests.randomVector; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.Mockito.mock; public class RankDocsRetrieverBuilderTests extends ESTestCase { @@ -48,7 +46,7 @@ private Supplier rankDocsSupplier() { return () -> rankDocs; } - private List innerRetrievers() { + private List innerRetrievers(QueryRewriteContext queryRewriteContext) throws IOException { List retrievers = new ArrayList<>(); int numRetrievers = randomIntBetween(1, 10); for (int i = 0; i < numRetrievers; i++) { @@ -56,9 +54,14 @@ private List innerRetrievers() { StandardRetrieverBuilder standardRetrieverBuilder = new StandardRetrieverBuilder(); standardRetrieverBuilder.queryBuilder = RandomQueryBuilder.createQuery(random()); if (randomBoolean()) { - standardRetrieverBuilder.preFilterQueryBuilders = preFilters(); + standardRetrieverBuilder.preFilterQueryBuilders = preFilters(queryRewriteContext); } - retrievers.add(standardRetrieverBuilder); + // RankDocsRetrieverBuilder assumes that the inner retrievers are already rewritten + StandardRetrieverBuilder rewritten = (StandardRetrieverBuilder) Rewriteable.rewrite( + standardRetrieverBuilder, + queryRewriteContext + ); + retrievers.add(rewritten); } else { KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( randomAlphaOfLength(10), @@ -69,30 +72,40 @@ private List innerRetrievers() { randomFloat() ); if (randomBoolean()) { - knnRetrieverBuilder.preFilterQueryBuilders = preFilters(); + knnRetrieverBuilder.preFilterQueryBuilders = preFilters(queryRewriteContext); } knnRetrieverBuilder.rankDocs = rankDocsSupplier().get(); - retrievers.add(knnRetrieverBuilder); + // RankDocsRetrieverBuilder assumes that the inner retrievers are already rewritten + KnnRetrieverBuilder rewritten = (KnnRetrieverBuilder) Rewriteable.rewrite(knnRetrieverBuilder, queryRewriteContext); + retrievers.add(rewritten); } } return retrievers; } - private List preFilters() { + private List preFilters(QueryRewriteContext queryRewriteContext) throws IOException { List preFilters = new ArrayList<>(); int numPreFilters = randomInt(10); for (int i = 0; i < numPreFilters; i++) { - preFilters.add(RandomQueryBuilder.createQuery(random())); + QueryBuilder filter = RandomQueryBuilder.createQuery(random()); + QueryBuilder rewritten = Rewriteable.rewrite(filter, queryRewriteContext); + preFilters.add(rewritten); } return preFilters; } - private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder() { - return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(), rankDocsSupplier(), preFilters()); + private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException { + return new RankDocsRetrieverBuilder( + randomIntBetween(1, 100), + innerRetrievers(queryRewriteContext), + rankDocsSupplier(), + preFilters(queryRewriteContext) + ); } - public void testExtractToSearchSourceBuilder() { - RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder(); + public void testExtractToSearchSourceBuilder() throws IOException { + QueryRewriteContext queryRewriteContext = new QueryRewriteContext(parserConfig(), null, () -> 0L); + RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder(queryRewriteContext); SearchSourceBuilder source = new SearchSourceBuilder(); if (randomBoolean()) { source.aggregation(new TermsAggregationBuilder("name").field("field")); @@ -115,8 +128,9 @@ public void testExtractToSearchSourceBuilder() { assertNull(source.postFilter()); } - public void testTopDocsQuery() { - RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder(); + public void testTopDocsQuery() throws IOException { + QueryRewriteContext queryRewriteContext = new QueryRewriteContext(parserConfig(), null, () -> 0L); + RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder(queryRewriteContext); QueryBuilder topDocs = retriever.topDocsQuery(); assertNotNull(topDocs); assertThat(topDocs, instanceOf(BoolQueryBuilder.class)); @@ -124,7 +138,8 @@ public void testTopDocsQuery() { } public void testRewrite() throws IOException { - RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder(); + QueryRewriteContext queryRewriteContext = new QueryRewriteContext(parserConfig(), null, () -> 0L); + RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder(queryRewriteContext); boolean compoundAdded = false; if (randomBoolean()) { compoundAdded = true; @@ -136,29 +151,13 @@ public boolean isCompound() { }); } SearchSourceBuilder source = new SearchSourceBuilder().retriever(retriever); - QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); - int size = source.size() < 0 ? DEFAULT_SIZE : source.size(); - if (retriever.rankWindowSize < size) { - if (compoundAdded) { - expectThrows(AssertionError.class, () -> Rewriteable.rewrite(source, queryRewriteContext)); - } + if (compoundAdded) { + expectThrows(AssertionError.class, () -> Rewriteable.rewrite(source, queryRewriteContext)); } else { - if (compoundAdded) { - expectThrows(AssertionError.class, () -> Rewriteable.rewrite(source, queryRewriteContext)); - } else { - SearchSourceBuilder rewrittenSource = Rewriteable.rewrite(source, queryRewriteContext); - assertNull(rewrittenSource.retriever()); - assertTrue(rewrittenSource.knnSearch().isEmpty()); - assertThat(rewrittenSource.query(), instanceOf(RankDocsQueryBuilder.class)); - if (rewrittenSource.query() instanceof BoolQueryBuilder) { - BoolQueryBuilder bq = (BoolQueryBuilder) rewrittenSource.query(); - assertThat(bq.filter().size(), equalTo(retriever.preFilterQueryBuilders.size())); - assertThat(bq.must().size(), equalTo(1)); - assertThat(bq.must().get(0), instanceOf(BoolQueryBuilder.class)); - assertThat(bq.should().size(), equalTo(1)); - assertThat(bq.should().get(0), instanceOf(RankDocsQueryBuilder.class)); - } - } + SearchSourceBuilder rewrittenSource = Rewriteable.rewrite(source, queryRewriteContext); + assertNull(rewrittenSource.retriever()); + assertTrue(rewrittenSource.knnSearch().isEmpty()); + assertThat(rewrittenSource.query(), instanceOf(RankDocsQueryBuilder.class)); } } } From a0c5af3f078c753c5eef5a20c1a9fdf3d58db3f4 Mon Sep 17 00:00:00 2001 From: David Turner Date: Fri, 11 Oct 2024 09:13:46 +0100 Subject: [PATCH 25/25] Ensure clean thread context in `MasterService` (#114512) (#114569) `ThreadContext#stashContext` doesn't guarantee to give a clean thread context, but it's important we don't allow the callers' thread contexts to leak into the cluster state update. This commit captures the desired thread context at startup rather than using `stashContext` when forking the processor. --- docs/changelog/114512.yaml | 5 +++++ .../cluster/service/MasterService.java | 15 +++++++++++++-- .../org/elasticsearch/node/NodeConstruction.java | 9 ++++++++- .../rollover/MetadataRolloverServiceTests.java | 3 +++ .../ingest/ReservedPipelineActionTests.java | 2 ++ .../elasticsearch/cluster/ClusterModuleTests.java | 5 ++--- .../MetadataIndexTemplateServiceTests.java | 3 +++ .../cluster/service/MasterServiceTests.java | 13 +++++++++++-- .../test/AbstractBuilderTestCase.java | 6 +++--- .../xpack/ml/MlSingleNodeTestCase.java | 2 ++ ...TransportGetTrainedModelsStatsActionTests.java | 2 ++ .../ingest/InferenceProcessorFactoryTests.java | 2 ++ .../job/persistence/JobResultsPersisterTests.java | 2 ++ .../task/OpenJobPersistentTasksExecutorTests.java | 2 ++ .../xpack/ml/support/BaseMlIntegTestCase.java | 2 ++ .../persistence/ResultsPersisterServiceTests.java | 2 ++ .../TransportOpenIdConnectLogoutActionTests.java | 7 ++++++- ...TransportSamlInvalidateSessionActionTests.java | 7 ++++++- .../saml/TransportSamlLogoutActionTests.java | 7 ++++++- .../xpack/security/authc/TokenServiceTests.java | 8 +++++++- .../TransformCheckpointServiceNodeTests.java | 2 +- .../TransportGetCheckpointNodeActionTests.java | 6 +++--- .../TransformPersistentTasksExecutorTests.java | 5 ++--- .../transform/transforms/TransformTaskTests.java | 6 ++++-- 24 files changed, 99 insertions(+), 24 deletions(-) create mode 100644 docs/changelog/114512.yaml diff --git a/docs/changelog/114512.yaml b/docs/changelog/114512.yaml new file mode 100644 index 0000000000000..10dea3a2cbac1 --- /dev/null +++ b/docs/changelog/114512.yaml @@ -0,0 +1,5 @@ +pr: 114512 +summary: Ensure clean thread context in `MasterService` +area: Cluster Coordination +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java index e3bb675f94b3a..f756933567683 100644 --- a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java +++ b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java @@ -109,6 +109,7 @@ public class MasterService extends AbstractLifecycleComponent { protected final ThreadPool threadPool; private final TaskManager taskManager; + private final ThreadContext.StoredContext clusterStateUpdateContext; private volatile ExecutorService threadPoolExecutor; private final AtomicInteger totalQueueSize = new AtomicInteger(); @@ -129,6 +130,7 @@ public MasterService(Settings settings, ClusterSettings clusterSettings, ThreadP this.threadPool = threadPool; this.taskManager = taskManager; + this.clusterStateUpdateContext = getClusterStateUpdateContext(threadPool.getThreadContext()); final var queuesByPriorityBuilder = new EnumMap(Priority.class); for (final var priority : Priority.values()) { @@ -138,6 +140,15 @@ public MasterService(Settings settings, ClusterSettings clusterSettings, ThreadP this.unbatchedExecutor = new UnbatchedExecutor(); } + private static ThreadContext.StoredContext getClusterStateUpdateContext(ThreadContext threadContext) { + try (var ignored = threadContext.newStoredContext()) { + // capture the context in which to run all cluster state updates here where we know it to be very clean + assert threadContext.isDefaultContext() : "must only create MasterService in a clean ThreadContext"; + threadContext.markAsSystemContext(); + return threadContext.newStoredContext(); + } + } + private void setSlowTaskLoggingThreshold(TimeValue slowTaskLoggingThreshold) { this.slowTaskLoggingThreshold = slowTaskLoggingThreshold; } @@ -1330,8 +1341,8 @@ private void forkQueueProcessor() { assert totalQueueSize.get() > 0; final var threadContext = threadPool.getThreadContext(); - try (var ignored = threadContext.stashContext()) { - threadContext.markAsSystemContext(); + try (var ignored = threadContext.newStoredContext()) { + clusterStateUpdateContext.restore(); threadPoolExecutor.execute(queuesProcessor); } } diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 64afb06a2aed5..8ee51044e5f88 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -268,7 +268,14 @@ static NodeConstruction prepareConstruction( constructor.loadLoggingDataProviders(); TelemetryProvider telemetryProvider = constructor.createTelemetryProvider(settings); ThreadPool threadPool = constructor.createThreadPool(settings, telemetryProvider.getMeterRegistry()); - SettingsModule settingsModule = constructor.validateSettings(initialEnvironment.settings(), settings, threadPool); + + final SettingsModule settingsModule; + try (var ignored = threadPool.getThreadContext().newStoredContext()) { + // If any deprecated settings are in use then we add warnings to the thread context response headers, but we're not + // computing a response here so these headers aren't relevant and eventually just get dropped after possibly leaking into + // places they shouldn't. Best to explicitly drop them now to protect against such leakage. + settingsModule = constructor.validateSettings(initialEnvironment.settings(), settings, threadPool); + } SearchModule searchModule = constructor.createSearchModule(settingsModule.getSettings(), threadPool, telemetryProvider); constructor.createClientAndRegistries(settingsModule.getSettings(), threadPool, searchModule); diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java index 2d3c9eec00d29..848e46f2b3366 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java @@ -34,6 +34,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; @@ -63,6 +64,7 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class MetadataRolloverServiceTests extends ESTestCase { @@ -833,6 +835,7 @@ public void testRolloverClusterStateForDataStreamNoTemplate() throws Exception { final TestTelemetryPlugin telemetryPlugin = new TestTelemetryPlugin(); ThreadPool testThreadPool = mock(ThreadPool.class); + when(testThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); MetadataRolloverService rolloverService = DataStreamTestHelper.getMetadataRolloverService( dataStream, testThreadPool, diff --git a/server/src/test/java/org/elasticsearch/action/ingest/ReservedPipelineActionTests.java b/server/src/test/java/org/elasticsearch/action/ingest/ReservedPipelineActionTests.java index f8dfdcfae57d1..9729b653ae3d2 100644 --- a/server/src/test/java/org/elasticsearch/action/ingest/ReservedPipelineActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/ingest/ReservedPipelineActionTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.ingest.FakeProcessor; import org.elasticsearch.ingest.IngestInfo; @@ -81,6 +82,7 @@ public void setup() { threadPool = mock(ThreadPool.class); when(threadPool.generic()).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); Client client = mock(Client.class); ingestService = new IngestService( diff --git a/server/src/test/java/org/elasticsearch/cluster/ClusterModuleTests.java b/server/src/test/java/org/elasticsearch/cluster/ClusterModuleTests.java index 58d4650fe6628..52be7004209bb 100644 --- a/server/src/test/java/org/elasticsearch/cluster/ClusterModuleTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/ClusterModuleTests.java @@ -46,7 +46,6 @@ import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.injection.guice.ModuleTestCase; import org.elasticsearch.plugins.ClusterPlugin; -import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.telemetry.TelemetryProvider; import org.elasticsearch.test.gateway.TestGatewayAllocator; import org.elasticsearch.threadpool.TestThreadPool; @@ -88,8 +87,8 @@ public void setUp() throws Exception { clusterService = new ClusterService( Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), - null, - (TaskManager) null + threadPool, + null ); } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateServiceTests.java index 873b185e6be28..ac3f08e6e29fb 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateServiceTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.IndexScopedSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.env.Environment; import org.elasticsearch.health.node.selection.HealthNodeTaskExecutor; @@ -77,6 +78,7 @@ import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.sameInstance; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class MetadataIndexTemplateServiceTests extends ESSingleNodeTestCase { @@ -2473,6 +2475,7 @@ public void testAddIndexTemplateWithDeprecatedComponentTemplate() throws Excepti private static List putTemplate(NamedXContentRegistry xContentRegistry, PutRequest request) { ThreadPool testThreadPool = mock(ThreadPool.class); + when(testThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); ClusterService clusterService = ClusterServiceUtils.createClusterService(testThreadPool); MetadataCreateIndexService createIndexService = new MetadataCreateIndexService( Settings.EMPTY, diff --git a/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java index 98f17e7958b82..16e07f4b9c85e 100644 --- a/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java @@ -68,6 +68,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.EnumMap; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; @@ -262,7 +263,15 @@ public void testThreadContext() throws InterruptedException { final CountDownLatch latch = new CountDownLatch(1); try (ThreadContext.StoredContext ignored = threadPool.getThreadContext().stashContext()) { - final Map expectedHeaders = Collections.singletonMap("test", "test"); + + final var expectedHeaders = new HashMap(); + expectedHeaders.put(randomIdentifier(), randomIdentifier()); + for (final var copiedHeader : Task.HEADERS_TO_COPY) { + if (randomBoolean()) { + expectedHeaders.put(copiedHeader, randomIdentifier()); + } + } + final Map> expectedResponseHeaders = Collections.singletonMap( "testResponse", Collections.singletonList("testResponse") @@ -1342,7 +1351,6 @@ public void testAcking() { .build(); final var deterministicTaskQueue = new DeterministicTaskQueue(); final var threadPool = deterministicTaskQueue.getThreadPool(); - threadPool.getThreadContext().markAsSystemContext(); try ( var masterService = createMasterService( true, @@ -1351,6 +1359,7 @@ public void testAcking() { new StoppableExecutorServiceWrapper(threadPool.generic()) ) ) { + threadPool.getThreadContext().markAsSystemContext(); final var responseHeaderName = "test-response-header"; diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java index 02ef43de29af8..77ff194e2681d 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java @@ -36,6 +36,7 @@ import org.elasticsearch.common.settings.IndexScopedSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsModule; +import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.core.IOUtils; import org.elasticsearch.env.Environment; @@ -81,7 +82,6 @@ import org.elasticsearch.script.ScriptModule; import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.SearchModule; -import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.transport.RemoteClusterAware; @@ -432,8 +432,8 @@ private static class ServiceHolder implements Closeable { ClusterService clusterService = new ClusterService( Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), - null, - (TaskManager) null + new DeterministicTaskQueue().getThreadPool(), + null ); client = (Client) Proxy.newProxyInstance( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java index c6e573fb3ea9c..8898cac495706 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.datastreams.DataStreamsPlugin; import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin; @@ -159,6 +160,7 @@ protected T blockingCall(Consumer> function) throws Except protected static ThreadPool mockThreadPool() { ThreadPool tp = mock(ThreadPool.class); + when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); ExecutorService executor = mock(ExecutorService.class); doAnswer(invocationOnMock -> { ((Runnable) invocationOnMock.getArguments()[0]).run(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java index 9232d32e40a97..7e88cad88dcec 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.ingest.IngestStats; @@ -112,6 +113,7 @@ public Map getProcessors(Processor.Parameters paramet public void setUpVariables() { ThreadPool tp = mock(ThreadPool.class); when(tp.generic()).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); + when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); client = mock(Client.class); Settings settings = Settings.builder().put("node.name", "InferenceProcessorFactoryTests_node").build(); ClusterSettings clusterSettings = new ClusterSettings( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index 9adbb3b3dd89a..637a9f73cbcbb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Tuple; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.PipelineConfiguration; @@ -88,6 +89,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { public void setUpVariables() { ThreadPool tp = mock(ThreadPool.class); when(tp.generic()).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); + when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); client = mock(Client.class); Settings settings = Settings.builder().put("node.name", "InferenceProcessorFactoryTests_node").build(); ClusterSettings clusterSettings = new ClusterSettings( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersisterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersisterTests.java index 7b0d9d3051dcc..2190d8af01f4d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersisterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersisterTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.cluster.service.MasterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; @@ -427,6 +428,7 @@ private static Answer withResponse(Response response) { private ResultsPersisterService buildResultsPersisterService(OriginSettingClient client) { ThreadPool tp = mock(ThreadPool.class); + when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, new HashSet<>( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java index 0440a66bdbcaa..eb2e21d5fda6c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.shard.ShardId; @@ -90,6 +91,7 @@ public class OpenJobPersistentTasksExecutorTests extends ESTestCase { public void setUpMocks() { ThreadPool tp = mock(ThreadPool.class); when(tp.generic()).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); + when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); Settings settings = Settings.builder().put("node.name", "OpenJobPersistentTasksExecutorTests").build(); ClusterSettings clusterSettings = new ClusterSettings( settings, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java index 4ac028ec3af21..aeebfabdce704 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java @@ -30,6 +30,7 @@ import org.elasticsearch.cluster.routing.UnassignedInfo; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.datastreams.DataStreamsPlugin; import org.elasticsearch.health.node.selection.HealthNode; @@ -284,6 +285,7 @@ public void cleanup() throws Exception { protected static ThreadPool mockThreadPool() { ThreadPool tp = mock(ThreadPool.class); + when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); ExecutorService executor = mock(ExecutorService.class); doAnswer(invocationOnMock -> { ((Runnable) invocationOnMock.getArguments()[0]).run(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java index e109f2995d215..7a513f12bf302 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.cluster.service.MasterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; import org.elasticsearch.index.shard.ShardId; @@ -428,6 +429,7 @@ private static Answer withFailure(Exception failure) { public static ResultsPersisterService buildResultsPersisterService(OriginSettingClient client) { ThreadPool tp = mock(ThreadPool.class); + when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, new HashSet<>( diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java index 95e818dc20c96..52fe73ab552d5 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java @@ -102,6 +102,7 @@ public void setup() throws Exception { .put("path.home", createTempDir()) .build(); final ThreadContext threadContext = new ThreadContext(settings); + final var defaultContext = threadContext.newStoredContext(); final ThreadPool threadPool = mock(ThreadPool.class); when(threadPool.getThreadContext()).thenReturn(threadContext); AuthenticationTestHelper.builder() @@ -174,7 +175,11 @@ public void setup() throws Exception { when(securityIndex.isAvailable(SecurityIndexManager.Availability.SEARCH_SHARDS)).thenReturn(true); when(securityIndex.defensiveCopy()).thenReturn(securityIndex); - final ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool); + final ClusterService clusterService; + try (var ignored = threadContext.newStoredContext()) { + defaultContext.restore(); + clusterService = ClusterServiceUtils.createClusterService(threadPool); + } final MockLicenseState licenseState = mock(MockLicenseState.class); when(licenseState.isAllowed(Security.TOKEN_SERVICE_FEATURE)).thenReturn(true); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java index 540a0758db43a..269f8cb0471fc 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java @@ -151,6 +151,7 @@ public void setup() throws Exception { this.threadPool = new TestThreadPool("saml test thread pool", settings); final ThreadContext threadContext = threadPool.getThreadContext(); + final var defaultContext = threadContext.newStoredContext(); AuthenticationTestHelper.builder() .user(new User("kibana")) .realmRef(new RealmRef("realm", "type", "node")) @@ -278,7 +279,11 @@ protected void final MockLicenseState licenseState = mock(MockLicenseState.class); when(licenseState.isAllowed(Security.TOKEN_SERVICE_FEATURE)).thenReturn(true); - final ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool); + final ClusterService clusterService; + try (var ignored = threadContext.newStoredContext()) { + defaultContext.restore(); + clusterService = ClusterServiceUtils.createClusterService(threadPool); + } final SecurityContext securityContext = new SecurityContext(settings, threadContext); tokenService = new TokenService( settings, diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java index 490704eff216a..855f96e30ffa0 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java @@ -120,6 +120,7 @@ public void setup() throws Exception { this.threadPool = new TestThreadPool("saml logout test thread pool", settings); final ThreadContext threadContext = this.threadPool.getThreadContext(); + final var defaultContext = threadContext.newStoredContext(); AuthenticationTestHelper.builder() .user(new User("kibana")) .realmRef(new Authentication.RealmRef("realm", "type", "node")) @@ -207,7 +208,11 @@ public void setup() throws Exception { final MockLicenseState licenseState = mock(MockLicenseState.class); when(licenseState.isAllowed(Security.TOKEN_SERVICE_FEATURE)).thenReturn(true); - final ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool); + final ClusterService clusterService; + try (var ignored = threadContext.newStoredContext()) { + defaultContext.restore(); + clusterService = ClusterServiceUtils.createClusterService(threadPool); + } final SecurityContext securityContext = new SecurityContext(settings, threadContext); tokenService = new TokenService( settings, diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java index e53fa83b89617..75c2507a1dc5f 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java @@ -138,6 +138,7 @@ public class TokenServiceTests extends ESTestCase { private static ThreadPool threadPool; + private static ThreadContext.StoredContext defaultContext; private static final Settings settings = Settings.builder() .put(Node.NODE_NAME_SETTING.getKey(), "TokenServiceTests") .put(XPackSettings.TOKEN_SERVICE_ENABLED_SETTING.getKey(), true) @@ -217,7 +218,11 @@ public void setupClient() { // setup lifecycle service this.securityMainIndex = SecurityMocks.mockSecurityIndexManager(); this.securityTokensIndex = SecurityMocks.mockSecurityIndexManager(); - this.clusterService = ClusterServiceUtils.createClusterService(threadPool); + + try (var ignored = threadPool.getThreadContext().newStoredContext()) { + defaultContext.restore(); + this.clusterService = ClusterServiceUtils.createClusterService(threadPool); + } // License state (enabled by default) licenseState = mock(MockLicenseState.class); @@ -282,6 +287,7 @@ public static void startThreadPool() throws IOException { EsExecutors.TaskTrackingConfig.DO_NOT_TRACK ) ); + defaultContext = threadPool.getThreadContext().newStoredContext(); AuthenticationTestHelper.builder() .user(new User("foo")) .realmRef(new RealmRef("realm", "type", "node")) diff --git a/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCheckpointServiceNodeTests.java b/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCheckpointServiceNodeTests.java index 366f3e6f917bf..745ccae86816c 100644 --- a/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCheckpointServiceNodeTests.java +++ b/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCheckpointServiceNodeTests.java @@ -180,7 +180,7 @@ public void createComponents() { new ClusterService( Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), - null, + threadPool, mock(TaskManager.class) ), transformsConfigManager, diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransportGetCheckpointNodeActionTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransportGetCheckpointNodeActionTests.java index 950e593165f01..6b82c93a61752 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransportGetCheckpointNodeActionTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/action/TransportGetCheckpointNodeActionTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexNotFoundException; @@ -25,7 +26,6 @@ import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskCancelHelper; import org.elasticsearch.tasks.TaskId; -import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.transform.action.GetCheckpointNodeAction; import org.elasticsearch.xpack.transform.transforms.scheduling.FakeClock; @@ -68,8 +68,8 @@ public void setUp() throws Exception { ClusterService clusterService = new ClusterService( Settings.builder().put("node.name", NODE_NAME).build(), new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), - null, - (TaskManager) null + new DeterministicTaskQueue().getThreadPool(), + null ); indicesService = mock(IndicesService.class); diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java index 7d572aaef2dee..70ded62b4c7d4 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformPersistentTasksExecutorTests.java @@ -34,7 +34,6 @@ import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.persistent.PersistentTasksCustomMetadata.Assignment; import org.elasticsearch.tasks.TaskId; -import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -593,8 +592,8 @@ private TransformServices transformServices(TransformConfigManager configManager new ClusterService( Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), - null, - (TaskManager) null + threadPool, + null ), configManager, mockAuditor diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformTaskTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformTaskTests.java index 31bd365250e3c..67ce09c74e98c 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformTaskTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/TransformTaskTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; import org.elasticsearch.health.HealthStatus; import org.elasticsearch.persistent.PersistentTaskParams; @@ -107,6 +108,7 @@ public void tearDownClient() { public void testStopOnFailedTaskWithStoppedIndexer() { Clock clock = Clock.systemUTC(); ThreadPool threadPool = mock(ThreadPool.class); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); when(threadPool.executor("generic")).thenReturn(mock(ExecutorService.class)); TransformConfig transformConfig = TransformConfigTests.randomTransformConfigWithoutHeaders(); @@ -193,8 +195,8 @@ private TransformServices transformServices(Clock clock, TransformAuditor audito new ClusterService( Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), - null, - (TaskManager) null + threadPool, + null ), transformsConfigManager, auditor