diff --git a/README.md b/README.md index 6daa6b2..4c38526 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,8 @@ These are optional configuration values that control how s3-shuffle behaves. - `spark.shuffle.s3.forceBatchFetch`: Force batch fetch for Shuffle Blocks (default: `false`) - `spark.shuffle.s3.supportsUnbuffer`: Streams can be unbuffered instead of closed (default: `true`, if Storage-backend is S3A, `false` otherwise). +- `spark.shuffle.s3.prefetchBatchSize`: Prefetch batch size (default: `10`). +- `spark.shuffle.s3.prefetchThreadPoolSize`: Prefetch thread pool size (default: `40`). ## Testing diff --git a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala index 41c5910..3b5f6fc 100644 --- a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala +++ b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala @@ -34,6 +34,8 @@ class S3ShuffleDispatcher extends Logging { val alwaysCreateIndex: Boolean = conf.getBoolean("spark.shuffle.s3.alwaysCreateIndex", defaultValue = false) val useBlockManager: Boolean = conf.getBoolean("spark.shuffle.s3.useBlockManager", defaultValue = true) val forceBatchFetch: Boolean = conf.getBoolean("spark.shuffle.s3.forceBatchFetch", defaultValue = false) + val prefetchBatchSize: Int = conf.getInt("spark.shuffle.s3.prefetchBatchSize", defaultValue = 25) + val prefetchThreadPoolSize: Int = conf.getInt("spark.shuffle.s3.prefetchThreadPoolSize", defaultValue = 100) val appDir = f"/${startTime}-${appId}/" val fs: FileSystem = FileSystem.get(URI.create(rootDir), { @@ -46,6 +48,8 @@ class S3ShuffleDispatcher extends Logging { logInfo(s"- spark.shuffle.s3.alwaysCreateIndex=${alwaysCreateIndex}") logInfo(s"- spark.shuffle.s3.useBlockManager=${useBlockManager}") logInfo(s"- spark.shuffle.s3.forceBatchFetch=${forceBatchFetch}") + logInfo(s"- spark.shuffle.s3.prefetchBlockSize=${prefetchBatchSize}") + logInfo(s"- spark.shuffle.s3.prefetchThreadPoolSize=${prefetchThreadPoolSize}") def removeRoot(): Boolean = { Range(0, 10).map(idx => { diff --git a/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala b/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala index c0e6d2c..58d9b02 100644 --- a/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala +++ b/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala @@ -119,7 +119,7 @@ class S3ShuffleReader[K, C]( }(S3ShuffleReader.asyncExecutionContext) } - val recordIter = slidingPrefetchIterator(recordIterPromise, 25).flatten + val recordIter = slidingPrefetchIterator(recordIterPromise, dispatcher.prefetchBatchSize).flatten // Update the context task metrics for each record read. val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( @@ -196,6 +196,6 @@ class S3ShuffleReader[K, C]( } object S3ShuffleReader { - private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("s3-shuffle-reader-async-thread-pool", 100) - private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) + private lazy val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("s3-shuffle-reader-async-thread-pool", S3ShuffleDispatcher.get.prefetchThreadPoolSize) + private lazy implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) }