diff --git a/docs/index.md b/docs/index.md index ac131c6a1..7155396cc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -543,6 +543,8 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i - `spark.datasource.flint.read.scroll_size`: default value is 100. - `spark.datasource.flint.read.scroll_duration`: default value is 5 minutes. scroll context keep alive duration. - `spark.datasource.flint.retry.max_retries`: max retries on failed HTTP request. default value is 3. Use 0 to disable retry. +- `spark.datasource.flint.retry.bulk.max_retries`: max retries on failed bulk request. default value is 10. Use 0 to disable retry. +- `spark.datasource.flint.retry.bulk.initial_backoff`: initial backoff in seconds for bulk request retry, default is 4. - `spark.datasource.flint.retry.http_status_codes`: retryable HTTP response status code list. default value is "429,502" (429 Too Many Request and 502 Bad Gateway). - `spark.datasource.flint.retry.exception_class_names`: retryable exception class name list. by default no retry on any exception thrown. - `spark.datasource.flint.read.support_shard`: default is true. set to false if index does not support shard (AWS OpenSearch Serverless collection). Do not use in production, this setting will be removed in later version. diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java index 597f441ec..6acfbe782 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java @@ -15,9 +15,12 @@ import dev.failsafe.event.ExecutionAttemptedEvent; import dev.failsafe.function.CheckedPredicate; import java.time.Duration; +import java.util.Arrays; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.logging.Logger; +import java.util.stream.Collectors; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.flint.core.http.handler.ExceptionClassNameFailurePredicate; import org.opensearch.flint.core.http.handler.HttpAOSSResultPredicate; @@ -41,8 +44,12 @@ public class FlintRetryOptions implements Serializable { */ public static final int DEFAULT_MAX_RETRIES = 3; public static final String MAX_RETRIES = "retry.max_retries"; + public static final int DEFAULT_BULK_MAX_RETRIES = 10; + public static final String BULK_MAX_RETRIES = "retry.bulk.max_retries"; + public static final int DEFAULT_BULK_INITIAL_BACKOFF = 4; + public static final String BULK_INITIAL_BACKOFF = "retry.bulk.initial_backoff"; - public static final String DEFAULT_RETRYABLE_HTTP_STATUS_CODES = "429,502"; + public static final String DEFAULT_RETRYABLE_HTTP_STATUS_CODES = "429,500,502"; public static final String RETRYABLE_HTTP_STATUS_CODES = "retry.http_status_codes"; /** @@ -90,9 +97,9 @@ public RetryPolicy getRetryPolicy() { public RetryPolicy getBulkRetryPolicy(CheckedPredicate resultPredicate) { return RetryPolicy.builder() // Using higher initial backoff to mitigate throttling quickly - .withBackoff(4, 30, SECONDS) + .withBackoff(getBulkInitialBackoff(), 30, SECONDS) .withJitter(Duration.ofMillis(100)) - .withMaxRetries(getMaxRetries()) + .withMaxRetries(getBulkMaxRetries()) // Do not retry on exception (will be handled by the other retry policy .handleIf((ex) -> false) .handleResultIf(resultPredicate) @@ -122,10 +129,27 @@ public int getMaxRetries() { } /** - * @return retryable HTTP status code list + * @return bulk maximum retry option value + */ + public int getBulkMaxRetries() { + return Integer.parseInt( + options.getOrDefault(BULK_MAX_RETRIES, String.valueOf(DEFAULT_BULK_MAX_RETRIES))); + } + + /** + * @return maximum retry option value */ - public String getRetryableHttpStatusCodes() { - return options.getOrDefault(RETRYABLE_HTTP_STATUS_CODES, DEFAULT_RETRYABLE_HTTP_STATUS_CODES); + public int getBulkInitialBackoff() { + return Integer.parseInt( + options.getOrDefault(BULK_INITIAL_BACKOFF, String.valueOf(DEFAULT_BULK_INITIAL_BACKOFF))); + } + + public Set getRetryableHttpStatusCodes() { + String statusCodes = options.getOrDefault(RETRYABLE_HTTP_STATUS_CODES, DEFAULT_RETRYABLE_HTTP_STATUS_CODES); + return Arrays.stream(statusCodes.split(",")) + .map(String::trim) + .map(Integer::valueOf) + .collect(Collectors.toSet()); } /** diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/HttpStatusCodeResultPredicate.java b/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/HttpStatusCodeResultPredicate.java index fa82e3655..436a5cfad 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/HttpStatusCodeResultPredicate.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/HttpStatusCodeResultPredicate.java @@ -26,12 +26,8 @@ public class HttpStatusCodeResultPredicate implements CheckedPredicate { */ private final Set retryableStatusCodes; - public HttpStatusCodeResultPredicate(String httpStatusCodes) { - this.retryableStatusCodes = - Arrays.stream(httpStatusCodes.split(",")) - .map(String::trim) - .map(Integer::valueOf) - .collect(Collectors.toSet()); + public HttpStatusCodeResultPredicate(Set httpStatusCodes) { + this.retryableStatusCodes = httpStatusCodes; } @Override diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterImpl.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterImpl.java index 3dec19558..62a60d766 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterImpl.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/BulkRequestRateLimiterImpl.java @@ -19,12 +19,14 @@ public class BulkRequestRateLimiterImpl implements BulkRequestRateLimiter { private final long maxRate; private final long increaseStep; private final double decreaseRatio; + private final RequestRateMeter requestRateMeter; public BulkRequestRateLimiterImpl(FlintOptions flintOptions) { minRate = flintOptions.getBulkRequestMinRateLimitPerNode(); maxRate = flintOptions.getBulkRequestMaxRateLimitPerNode(); increaseStep = flintOptions.getBulkRequestRateLimitPerNodeIncreaseStep(); decreaseRatio = flintOptions.getBulkRequestRateLimitPerNodeDecreaseRatio(); + requestRateMeter = new RequestRateMeter(); LOG.info("Setting rate limit for bulk request to " + minRate + " documents/sec"); this.rateLimiter = RateLimiter.create(minRate); @@ -42,6 +44,7 @@ public void acquirePermit() { public void acquirePermit(int permits) { this.rateLimiter.acquire(permits); LOG.info("Acquired " + permits + " permits"); + requestRateMeter.addDataPoint(System.currentTimeMillis(), permits); } /** @@ -49,7 +52,18 @@ public void acquirePermit(int permits) { */ @Override public void increaseRate() { - setRate(getRate() + increaseStep); + if (isEstimatedCurrentRateCloseToLimit()) { + setRate(getRate() + increaseStep); + } else { + LOG.info("Rate increase was blocked."); + } + LOG.info("Current rate limit for bulk request is " + getRate() + " documents/sec"); + } + + private boolean isEstimatedCurrentRateCloseToLimit() { + long currentEstimatedRate = requestRateMeter.getCurrentEstimatedRate(); + LOG.info("Current estimated rate is " + currentEstimatedRate + " documents/sec"); + return getRate() * 0.8 < currentEstimatedRate; } /** diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchBulkWrapper.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchBulkWrapper.java index 935894a31..bbeb4fa4f 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchBulkWrapper.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchBulkWrapper.java @@ -11,6 +11,7 @@ import dev.failsafe.function.CheckedPredicate; import java.util.Arrays; import java.util.List; +import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Logger; @@ -34,10 +35,12 @@ public class OpenSearchBulkWrapper { private final RetryPolicy retryPolicy; private final BulkRequestRateLimiter rateLimiter; + private final Set retryableStatusCodes; public OpenSearchBulkWrapper(FlintRetryOptions retryOptions, BulkRequestRateLimiter rateLimiter) { this.retryPolicy = retryOptions.getBulkRetryPolicy(bulkItemRetryableResultPredicate); this.rateLimiter = rateLimiter; + this.retryableStatusCodes = retryOptions.getRetryableHttpStatusCodes(); } /** @@ -50,7 +53,6 @@ public OpenSearchBulkWrapper(FlintRetryOptions retryOptions, BulkRequestRateLimi * @return Last result */ public BulkResponse bulk(RestHighLevelClient client, BulkRequest bulkRequest, RequestOptions options) { - rateLimiter.acquirePermit(bulkRequest.requests().size()); return bulkWithPartialRetry(client, bulkRequest, options); } @@ -69,11 +71,13 @@ private BulkResponse bulkWithPartialRetry(RestHighLevelClient client, BulkReques }) .get(() -> { requestCount.incrementAndGet(); + rateLimiter.acquirePermit(nextRequest.get().requests().size()); BulkResponse response = client.bulk(nextRequest.get(), options); if (!bulkItemRetryableResultPredicate.test(response)) { rateLimiter.increaseRate(); } else { + LOG.info("Bulk request failed. attempt = " + (requestCount.get() - 1)); rateLimiter.decreaseRate(); if (retryPolicy.getConfig().allowsRetries()) { nextRequest.set(getRetryableRequest(nextRequest.get(), response)); @@ -118,10 +122,10 @@ private static void verifyIdMatch(DocWriteRequest request, BulkItemResponse r /** * A predicate to decide if a BulkResponse is retryable or not. */ - private static final CheckedPredicate bulkItemRetryableResultPredicate = bulkResponse -> + private final CheckedPredicate bulkItemRetryableResultPredicate = bulkResponse -> bulkResponse.hasFailures() && isRetryable(bulkResponse); - private static boolean isRetryable(BulkResponse bulkResponse) { + private boolean isRetryable(BulkResponse bulkResponse) { if (Arrays.stream(bulkResponse.getItems()) .anyMatch(itemResp -> isItemRetryable(itemResp))) { LOG.info("Found retryable failure in the bulk response"); @@ -130,12 +134,23 @@ private static boolean isRetryable(BulkResponse bulkResponse) { return false; } - private static boolean isItemRetryable(BulkItemResponse itemResponse) { - return itemResponse.isFailed() && !isCreateConflict(itemResponse); + private boolean isItemRetryable(BulkItemResponse itemResponse) { + return itemResponse.isFailed() && !isCreateConflict(itemResponse) + && isFailureStatusRetryable(itemResponse); } private static boolean isCreateConflict(BulkItemResponse itemResp) { return itemResp.getOpType() == DocWriteRequest.OpType.CREATE && itemResp.getFailure().getStatus() == RestStatus.CONFLICT; } + + private boolean isFailureStatusRetryable(BulkItemResponse itemResp) { + if (retryableStatusCodes.contains(itemResp.getFailure().getStatus().getStatus())) { + return true; + } else { + LOG.info("Found non-retryable failure in bulk response: " + itemResp.getFailure().getStatus() + + ", " + itemResp.getFailure().toString()); + return false; + } + } } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/RequestRateMeter.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/RequestRateMeter.java new file mode 100644 index 000000000..fddca5565 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/RequestRateMeter.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.storage; + +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; + +/** + * Track the current request rate based on the past requests within ESTIMATE_RANGE_DURATION_MSEC + * milliseconds period. + */ +public class RequestRateMeter { + private static final long ESTIMATE_RANGE_DURATION_MSEC = 3000; + + private static class DataPoint { + long timestamp; + long requestCount; + public DataPoint(long timestamp, long requestCount) { + this.timestamp = timestamp; + this.requestCount = requestCount; + } + } + + private Queue dataPoints = new LinkedList<>(); + private long currentSum = 0; + + public synchronized void addDataPoint(long timestamp, long requestCount) { + dataPoints.add(new DataPoint(timestamp, requestCount)); + currentSum += requestCount; + removeOldDataPoints(); + } + + public synchronized long getCurrentEstimatedRate() { + removeOldDataPoints(); + return currentSum * 1000 / ESTIMATE_RANGE_DURATION_MSEC; + } + + private synchronized void removeOldDataPoints() { + long curr = System.currentTimeMillis(); + while (!dataPoints.isEmpty() && dataPoints.peek().timestamp < curr - ESTIMATE_RANGE_DURATION_MSEC) { + currentSum -= dataPoints.remove().requestCount; + } + } +} diff --git a/flint-core/src/test/java/org/opensearch/flint/core/storage/RequestRateMeterTest.java b/flint-core/src/test/java/org/opensearch/flint/core/storage/RequestRateMeterTest.java new file mode 100644 index 000000000..b8f13a12d --- /dev/null +++ b/flint-core/src/test/java/org/opensearch/flint/core/storage/RequestRateMeterTest.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.storage; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +class RequestRateMeterTest { + + private RequestRateMeter requestRateMeter; + + @BeforeEach + void setUp() { + requestRateMeter = new RequestRateMeter(); + } + + @Test + void testAddDataPoint() { + long timestamp = System.currentTimeMillis(); + requestRateMeter.addDataPoint(timestamp, 30); + assertEquals(10, requestRateMeter.getCurrentEstimatedRate()); + } + + @Test + void testAddDataPointRemoveOldDataPoint() { + long timestamp = System.currentTimeMillis(); + requestRateMeter.addDataPoint(timestamp - 4000, 30); + requestRateMeter.addDataPoint(timestamp, 90); + assertEquals(90 / 3, requestRateMeter.getCurrentEstimatedRate()); + } + + @Test + void testRemoveOldDataPoints() { + long currentTime = System.currentTimeMillis(); + requestRateMeter.addDataPoint(currentTime - 4000, 30); + requestRateMeter.addDataPoint(currentTime - 2000, 60); + requestRateMeter.addDataPoint(currentTime, 90); + + assertEquals((60 + 90)/3, requestRateMeter.getCurrentEstimatedRate()); + } + + @Test + void testGetCurrentEstimatedRate() { + long currentTime = System.currentTimeMillis(); + requestRateMeter.addDataPoint(currentTime - 2500, 30); + requestRateMeter.addDataPoint(currentTime - 1500, 60); + requestRateMeter.addDataPoint(currentTime - 500, 90); + + assertEquals((30 + 60 + 90)/3, requestRateMeter.getCurrentEstimatedRate()); + } + + @Test + void testEmptyRateMeter() { + assertEquals(0, requestRateMeter.getCurrentEstimatedRate()); + } + + @Test + void testSingleDataPoint() { + requestRateMeter.addDataPoint(System.currentTimeMillis(), 30); + assertEquals(30 / 3, requestRateMeter.getCurrentEstimatedRate()); + } +} \ No newline at end of file diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index faba2135f..8b83f16e7 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -136,6 +136,18 @@ object FlintSparkConf { .doc("max retries on failed HTTP request, 0 means retry is disabled, default is 3") .createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_MAX_RETRIES)) + val BULK_MAX_RETRIES = + FlintConfig(s"spark.datasource.flint.${FlintRetryOptions.BULK_MAX_RETRIES}") + .datasourceOption() + .doc("max retries on failed HTTP request, 0 means retry is disabled, default is 10") + .createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_BULK_MAX_RETRIES)) + + val BULK_INITIAL_BACKOFF = + FlintConfig(s"spark.datasource.flint.${FlintRetryOptions.BULK_INITIAL_BACKOFF}") + .datasourceOption() + .doc("initial backoff in seconds for bulk request retry, default is 4s") + .createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_BULK_INITIAL_BACKOFF)) + val BULK_REQUEST_RATE_LIMIT_PER_NODE_ENABLED = FlintConfig( s"spark.datasource.flint.${FlintOptions.BULK_REQUEST_RATE_LIMIT_PER_NODE_ENABLED}") @@ -368,6 +380,8 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable SCHEME, AUTH, MAX_RETRIES, + BULK_MAX_RETRIES, + BULK_INITIAL_BACKOFF, RETRYABLE_HTTP_STATUS_CODES, BULK_REQUEST_RATE_LIMIT_PER_NODE_ENABLED, BULK_REQUEST_MIN_RATE_LIMIT_PER_NODE, diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala index ac9346562..8ea84aecc 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala @@ -11,6 +11,7 @@ import scala.collection.JavaConverters._ import org.opensearch.flint.core.FlintOptions import org.opensearch.flint.core.http.FlintRetryOptions._ +import org.scalatest.matchers.must.Matchers.contain import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.FlintSuite @@ -46,7 +47,7 @@ class FlintSparkConfSuite extends FlintSuite { test("test retry options default values") { val retryOptions = FlintSparkConf().flintOptions().getRetryOptions retryOptions.getMaxRetries shouldBe DEFAULT_MAX_RETRIES - retryOptions.getRetryableHttpStatusCodes shouldBe DEFAULT_RETRYABLE_HTTP_STATUS_CODES + retryOptions.getRetryableHttpStatusCodes should contain theSameElementsAs Set(429, 500, 502) retryOptions.getRetryableExceptionClassNames shouldBe Optional.empty } @@ -60,7 +61,11 @@ class FlintSparkConfSuite extends FlintSuite { .getRetryOptions retryOptions.getMaxRetries shouldBe 5 - retryOptions.getRetryableHttpStatusCodes shouldBe "429,502,503,504" + retryOptions.getRetryableHttpStatusCodes should contain theSameElementsAs Set( + 429, + 502, + 503, + 504) retryOptions.getRetryableExceptionClassNames.get() shouldBe "java.net.ConnectException" }