diff --git a/src/main/scala/org/apache/spark/shuffle/ConcurrentObjectMap.scala b/src/main/scala/org/apache/spark/shuffle/ConcurrentObjectMap.scala index f7101f5..1f09a50 100644 --- a/src/main/scala/org/apache/spark/shuffle/ConcurrentObjectMap.scala +++ b/src/main/scala/org/apache/spark/shuffle/ConcurrentObjectMap.scala @@ -20,13 +20,17 @@ class ConcurrentObjectMap[K, V] { } def getOrElsePut(key: K, op: K => V): V = { - val l = valueLocks.get(key).getOrElse({ - lock.synchronized { - valueLocks.getOrElseUpdate(key, { - new Object() - }) - } - }) + val l = valueLocks + .get(key) + .getOrElse({ + lock.synchronized { + valueLocks.getOrElseUpdate( + key, { + new Object() + } + ) + } + }) l.synchronized { return map.getOrElseUpdate(key, op(key)) } diff --git a/src/main/scala/org/apache/spark/shuffle/S3MeasureOutputStream.scala b/src/main/scala/org/apache/spark/shuffle/S3MeasureOutputStream.scala index ecd211f..db9b28e 100644 --- a/src/main/scala/org/apache/spark/shuffle/S3MeasureOutputStream.scala +++ b/src/main/scala/org/apache/spark/shuffle/S3MeasureOutputStream.scala @@ -11,7 +11,6 @@ class S3MeasureOutputStream(var out: OutputStream, label: String = "") extends O private var timings: Long = 0 private var bytes: Long = 0 - private def checkOpen(): Unit = { if (!isOpen) { throw new IOException("The stream is already closed!") @@ -58,7 +57,9 @@ class S3MeasureOutputStream(var out: OutputStream, label: String = "") extends O val sAt = tc.stageAttemptNumber() val t = timings / 1000000 val bw = bytes.toDouble / (t.toDouble / 1000) / (1024 * 1024) - logInfo(s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " + - s"Writing ${label} ${bytes} took ${t} ms (${bw} MiB/s)") + logInfo( + s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " + + s"Writing ${label} ${bytes} took ${t} ms (${bw} MiB/s)" + ) } } diff --git a/src/main/scala/org/apache/spark/shuffle/S3ShuffleDataIO.scala b/src/main/scala/org/apache/spark/shuffle/S3ShuffleDataIO.scala index b1ad17d..9e7cd94 100644 --- a/src/main/scala/org/apache/spark/shuffle/S3ShuffleDataIO.scala +++ b/src/main/scala/org/apache/spark/shuffle/S3ShuffleDataIO.scala @@ -36,9 +36,9 @@ class S3ShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { } override def createSingleFileMapOutputWriter( - shuffleId: Int, - mapId: Long - ): Optional[SingleSpillShuffleMapOutputWriter] = { + shuffleId: Int, + mapId: Long + ): Optional[SingleSpillShuffleMapOutputWriter] = { Optional.of(new S3SingleSpillShuffleMapOutputWriter(shuffleId, mapId)) } } @@ -67,4 +67,3 @@ class S3ShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { } } } - diff --git a/src/main/scala/org/apache/spark/shuffle/S3ShuffleMapOutputWriter.scala b/src/main/scala/org/apache/spark/shuffle/S3ShuffleMapOutputWriter.scala index 6a5ff25..920d908 100644 --- a/src/main/scala/org/apache/spark/shuffle/S3ShuffleMapOutputWriter.scala +++ b/src/main/scala/org/apache/spark/shuffle/S3ShuffleMapOutputWriter.scala @@ -19,19 +19,18 @@ import java.nio.ByteBuffer import java.nio.channels.{Channels, WritableByteChannel} import java.util.Optional -/** - * Implements the ShuffleMapOutputWriter interface. It stores the shuffle output in one - * shuffle block. - * - * This file is based on Spark "LocalDiskShuffleMapOutputWriter.java". - */ +/** Implements the ShuffleMapOutputWriter interface. It stores the shuffle output in one shuffle block. + * + * This file is based on Spark "LocalDiskShuffleMapOutputWriter.java". + */ class S3ShuffleMapOutputWriter( - conf: SparkConf, - shuffleId: Int, - mapId: Long, - numPartitions: Int, - ) extends ShuffleMapOutputWriter with Logging { + conf: SparkConf, + shuffleId: Int, + mapId: Long, + numPartitions: Int +) extends ShuffleMapOutputWriter + with Logging { val dispatcher = S3ShuffleDispatcher.get /* Target block for writing */ @@ -44,7 +43,8 @@ class S3ShuffleMapOutputWriter( def initStream(): Unit = { if (stream == null) { stream = dispatcher.createBlock(shuffleBlock) - bufferedStream = new S3MeasureOutputStream(new BufferedOutputStream(stream, dispatcher.bufferSize), shuffleBlock.name) + bufferedStream = + new S3MeasureOutputStream(new BufferedOutputStream(stream, dispatcher.bufferSize), shuffleBlock.name) } } @@ -59,10 +59,11 @@ class S3ShuffleMapOutputWriter( private var totalBytesWritten: Long = 0 private var lastPartitionWriterId: Int = -1 - /** - * @param reducePartitionId Monotonically increasing, as per contract in ShuffleMapOutputWriter. - * @return An instance of the ShufflePartitionWriter exposing the single output stream. - */ + /** @param reducePartitionId + * Monotonically increasing, as per contract in ShuffleMapOutputWriter. + * @return + * An instance of the ShufflePartitionWriter exposing the single output stream. + */ override def getPartitionWriter(reducePartitionId: Int): ShufflePartitionWriter = { if (reducePartitionId <= lastPartitionWriterId) { throw new RuntimeException("Precondition: Expect a monotonically increasing reducePartitionId.") @@ -81,19 +82,21 @@ class S3ShuffleMapOutputWriter( new S3ShufflePartitionWriter(reducePartitionId) } - /** - * Close all writers and the shuffle block. - * - * @param checksums Ignored. - * @return - */ + /** Close all writers and the shuffle block. + * + * @param checksums + * Ignored. + * @return + */ override def commitAllPartitions(checksums: Array[Long]): MapOutputCommitMessage = { if (bufferedStream != null) { bufferedStream.flush() } if (stream != null) { if (stream.getPos != totalBytesWritten) { - throw new RuntimeException(f"S3ShuffleMapOutputWriter: Unexpected output length ${stream.getPos}, expected: ${totalBytesWritten}.") + throw new RuntimeException( + f"S3ShuffleMapOutputWriter: Unexpected output length ${stream.getPos}, expected: ${totalBytesWritten}." + ) } } if (bufferedStreamAsChannel != null) { @@ -198,8 +201,7 @@ class S3ShuffleMapOutputWriter( } } - private class S3ShufflePartitionWriterChannel(reduceId: Int) - extends WritableByteChannelWrapper { + private class S3ShufflePartitionWriterChannel(reduceId: Int) extends WritableByteChannelWrapper { private val partChannel = new S3PartitionWritableByteChannel(bufferedStreamAsChannel) override def channel(): WritableByteChannel = { @@ -216,8 +218,7 @@ class S3ShuffleMapOutputWriter( } } - private class S3PartitionWritableByteChannel(channel: WritableByteChannel) - extends WritableByteChannel { + private class S3PartitionWritableByteChannel(channel: WritableByteChannel) extends WritableByteChannel { private var count: Long = 0 @@ -229,8 +230,7 @@ class S3ShuffleMapOutputWriter( channel.isOpen() } - override def close(): Unit = { - } + override def close(): Unit = {} override def write(x: ByteBuffer): Int = { var c = 0 diff --git a/src/main/scala/org/apache/spark/shuffle/S3ShuffleWriter.scala b/src/main/scala/org/apache/spark/shuffle/S3ShuffleWriter.scala index 14d3224..8d0543e 100644 --- a/src/main/scala/org/apache/spark/shuffle/S3ShuffleWriter.scala +++ b/src/main/scala/org/apache/spark/shuffle/S3ShuffleWriter.scala @@ -19,4 +19,3 @@ class S3ShuffleWriter[K, V](writer: ShuffleWriter[K, V]) extends ShuffleWriter[K override def getPartitionLengths(): Array[Long] = writer.getPartitionLengths() } - diff --git a/src/main/scala/org/apache/spark/shuffle/S3SingleSpillShuffleMapOutputWriter.scala b/src/main/scala/org/apache/spark/shuffle/S3SingleSpillShuffleMapOutputWriter.scala index 09592bc..49738b3 100644 --- a/src/main/scala/org/apache/spark/shuffle/S3SingleSpillShuffleMapOutputWriter.scala +++ b/src/main/scala/org/apache/spark/shuffle/S3SingleSpillShuffleMapOutputWriter.scala @@ -15,15 +15,17 @@ import org.apache.spark.util.Utils import java.io.{File, FileInputStream} import java.nio.file.{Files, Path} -class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) extends SingleSpillShuffleMapOutputWriter with Logging { +class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) + extends SingleSpillShuffleMapOutputWriter + with Logging { private lazy val dispatcher = S3ShuffleDispatcher.get override def transferMapSpillFile( - mapSpillFile: File, - partitionLengths: Array[Long], - checksums: Array[Long] - ): Unit = { + mapSpillFile: File, + partitionLengths: Array[Long], + checksums: Array[Long] + ): Unit = { val block = ShuffleDataBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) if (dispatcher.rootIsLocal) { @@ -44,8 +46,10 @@ class S3SingleSpillShuffleMapOutputWriter(shuffleId: Int, mapId: Long) extends S val sAt = tc.stageAttemptNumber() val t = timings / 1000000 val bw = bytes.toDouble / (t.toDouble / 1000) / (1024 * 1024) - logInfo(s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " + - s"Writing ${block.name} ${bytes} took ${t} ms (${bw} MiB/s)") + logInfo( + s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " + + s"Writing ${block.name} ${bytes} took ${t} ms (${bw} MiB/s)" + ) } else { // Copy using a stream. val in = new FileInputStream(mapSpillFile) diff --git a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala index 85d8ada..1a934e8 100644 --- a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala +++ b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala @@ -20,9 +20,8 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration.Duration import scala.concurrent.{Await, Future} -/** - * Helper class that configures Hadoop FS. - */ +/** Helper class that configures Hadoop FS. + */ class S3ShuffleDispatcher extends Logging { val executorId: String = SparkEnv.get.executorId val conf: SparkConf = SparkEnv.get.conf @@ -40,11 +39,15 @@ class S3ShuffleDispatcher extends Logging { val useSparkShuffleFetch: Boolean = conf.getBoolean("spark.shuffle.s3.useSparkShuffleFetch", defaultValue = false) private val fallbackStoragePath_ = conf.get(STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH) val fallbackStoragePath = if (fallbackStoragePath_.isEmpty && useSparkShuffleFetch) { - throw new SparkException(s"spark.shuffle.s3.useSparkShuffleFetch is set, but no ${STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH}") + throw new SparkException( + s"spark.shuffle.s3.useSparkShuffleFetch is set, but no ${STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH}" + ) } else { fallbackStoragePath_.getOrElse(s"${STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH} is not set.") } - private val rootDir_ = if (useSparkShuffleFetch) fallbackStoragePath else conf.get("spark.shuffle.s3.rootDir", defaultValue = "sparkS3shuffle/") + private val rootDir_ = + if (useSparkShuffleFetch) fallbackStoragePath + else conf.get("spark.shuffle.s3.rootDir", defaultValue = "sparkS3shuffle/") val rootDir: String = if (rootDir_.endsWith("/")) rootDir_ else rootDir_ + "/" val rootIsLocal: Boolean = URI.create(rootDir).getScheme == "file" @@ -66,9 +69,11 @@ class S3ShuffleDispatcher extends Logging { val checksumAlgorithm: String = SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) val checksumEnabled: Boolean = SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED) - val fs: FileSystem = FileSystem.get(URI.create(rootDir), { - SparkHadoopUtil.newConfiguration(conf) - }) + val fs: FileSystem = FileSystem.get( + URI.create(rootDir), { + SparkHadoopUtil.newConfiguration(conf) + } + ) val canSetReadahead = fs.hasPathCapability(new Path(rootDir), StreamCapabilities.READAHEAD) @@ -97,16 +102,18 @@ class S3ShuffleDispatcher extends Logging { logInfo(s"- ${config.SHUFFLE_CHECKSUM_ENABLED.key}=${checksumEnabled}") def removeRoot(): Boolean = { - Range(0, folderPrefixes).map(idx => { - Future { - val prefix = f"${rootDir}${idx}/${appId}" - try { - fs.delete(new Path(prefix), true) - } catch { - case _: IOException => logDebug(s"Unable to delete prefix ${prefix}") + Range(0, folderPrefixes) + .map(idx => { + Future { + val prefix = f"${rootDir}${idx}/${appId}" + try { + fs.delete(new Path(prefix), true) + } catch { + case _: IOException => logDebug(s"Unable to delete prefix ${prefix}") + } } - } - }).map(Await.result(_, Duration.Inf)) + }) + .map(Await.result(_, Duration.Inf)) true } @@ -124,10 +131,10 @@ class S3ShuffleDispatcher extends Logging { } if (useSparkShuffleFetch) { blockId match { - case ShuffleDataBlockId(_, _, _) => - case ShuffleIndexBlockId(_, _, _) => + case ShuffleDataBlockId(_, _, _) => + case ShuffleIndexBlockId(_, _, _) => case ShuffleChecksumBlockId(_, _, _) => - case _ => throw new SparkException(s"Unsupported block id type: ${blockId.name}") + case _ => throw new SparkException(s"Unsupported block id type: ${blockId.name}") } val hash = JavaUtils.nonNegativeHash(blockId.name) return new Path(f"${rootDir}${appId}/${shuffleId}/${hash}/${blockId.name}") @@ -146,35 +153,40 @@ class S3ShuffleDispatcher extends Logging { name.endsWith(".index") } } - Range(0, folderPrefixes).map(idx => { - Future { - val path = new Path(f"${rootDir}${idx}/${appId}/${shuffleId}/") - try { - fs.listStatus(path, shuffleIndexFilter).map(v => { - BlockId.apply(v.getPath.getName).asInstanceOf[ShuffleIndexBlockId] - }) - } catch { - case _: IOException => Array.empty[ShuffleIndexBlockId] + Range(0, folderPrefixes) + .map(idx => { + Future { + val path = new Path(f"${rootDir}${idx}/${appId}/${shuffleId}/") + try { + fs.listStatus(path, shuffleIndexFilter) + .map(v => { + BlockId.apply(v.getPath.getName).asInstanceOf[ShuffleIndexBlockId] + }) + } catch { + case _: IOException => Array.empty[ShuffleIndexBlockId] + } } - } - }).flatMap(Await.result(_, Duration.Inf)).toArray + }) + .flatMap(Await.result(_, Duration.Inf)) + .toArray } def removeShuffle(shuffleId: Int): Unit = { - Range(0, folderPrefixes).map(idx => { - val path = new Path(f"${rootDir}${idx}/${appId}/${shuffleId}/") - Future { - fs.delete(path, true) - } - }).foreach(Await.result(_, Duration.Inf)) + Range(0, folderPrefixes) + .map(idx => { + val path = new Path(f"${rootDir}${idx}/${appId}/${shuffleId}/") + Future { + fs.delete(path, true) + } + }) + .foreach(Await.result(_, Duration.Inf)) } - /** - * Open a block for reading. - * - * @param blockId - * @return - */ + /** Open a block for reading. + * + * @param blockId + * @return + */ def openBlock(blockId: BlockId): FSDataInputStream = { val status = getFileStatusCached(blockId) val builder = fs.openFile(status.getPath).withFileStatus(status) @@ -188,41 +200,44 @@ class S3ShuffleDispatcher extends Logging { private val cachedFileStatus = new ConcurrentObjectMap[BlockId, FileStatus]() def getFileStatusCached(blockId: BlockId): FileStatus = { - cachedFileStatus.getOrElsePut(blockId, (value: BlockId) => { - fs.getFileStatus(getPath(value)) - }) + cachedFileStatus.getOrElsePut( + blockId, + (value: BlockId) => { + fs.getFileStatus(getPath(value)) + } + ) } def closeCachedBlocks(shuffleIndex: Int): Unit = { - val filter = (blockId: BlockId) => blockId match { - case RDDBlockId(_, _) => false - case ShuffleBlockId(shuffleId, _, _) => shuffleId == shuffleIndex - case ShuffleBlockBatchId(shuffleId, _, _, _) => shuffleId == shuffleIndex - case ShuffleBlockChunkId(shuffleId, _, _, _) => shuffleId == shuffleIndex - case ShuffleDataBlockId(shuffleId, _, _) => shuffleId == shuffleIndex - case ShuffleIndexBlockId(shuffleId, _, _) => shuffleId == shuffleIndex - case ShuffleChecksumBlockId(shuffleId, _, _) => shuffleId == shuffleIndex - case ShufflePushBlockId(shuffleId, _, _, _) => shuffleId == shuffleIndex - case ShuffleMergedBlockId(shuffleId, _, _) => shuffleId == shuffleIndex - case ShuffleMergedDataBlockId(_, shuffleId, _, _) => shuffleId == shuffleIndex - case ShuffleMergedIndexBlockId(_, shuffleId, _, _) => shuffleId == shuffleIndex - case ShuffleMergedMetaBlockId(_, shuffleId, _, _) => shuffleId == shuffleIndex - case BroadcastBlockId(_, _) => false - case TaskResultBlockId(_) => false - case StreamBlockId(_, _) => false - case TempLocalBlockId(_) => false - case TempShuffleBlockId(_) => false - case TestBlockId(_) => false - } + val filter = (blockId: BlockId) => + blockId match { + case RDDBlockId(_, _) => false + case ShuffleBlockId(shuffleId, _, _) => shuffleId == shuffleIndex + case ShuffleBlockBatchId(shuffleId, _, _, _) => shuffleId == shuffleIndex + case ShuffleBlockChunkId(shuffleId, _, _, _) => shuffleId == shuffleIndex + case ShuffleDataBlockId(shuffleId, _, _) => shuffleId == shuffleIndex + case ShuffleIndexBlockId(shuffleId, _, _) => shuffleId == shuffleIndex + case ShuffleChecksumBlockId(shuffleId, _, _) => shuffleId == shuffleIndex + case ShufflePushBlockId(shuffleId, _, _, _) => shuffleId == shuffleIndex + case ShuffleMergedBlockId(shuffleId, _, _) => shuffleId == shuffleIndex + case ShuffleMergedDataBlockId(_, shuffleId, _, _) => shuffleId == shuffleIndex + case ShuffleMergedIndexBlockId(_, shuffleId, _, _) => shuffleId == shuffleIndex + case ShuffleMergedMetaBlockId(_, shuffleId, _, _) => shuffleId == shuffleIndex + case BroadcastBlockId(_, _) => false + case TaskResultBlockId(_) => false + case StreamBlockId(_, _) => false + case TempLocalBlockId(_) => false + case TempShuffleBlockId(_) => false + case TestBlockId(_) => false + } cachedFileStatus.remove(filter, _) } - /** - * Open a block for writing. - * - * @param blockId - * @return - */ + /** Open a block for writing. + * + * @param blockId + * @return + */ def createBlock(blockId: BlockId): FSDataOutputStream = { fs.create(getPath(blockId)) } diff --git a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleHelper.scala b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleHelper.scala index aa9777d..8cb38a0 100644 --- a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleHelper.scala +++ b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleHelper.scala @@ -15,11 +15,10 @@ object S3ShuffleHelper extends Logging { private val cachedChecksums = new ConcurrentObjectMap[ShuffleChecksumBlockId, Array[Long]]() private val cachedArrayLengths = new ConcurrentObjectMap[ShuffleIndexBlockId, Array[Long]]() - /** - * Purge cached shuffle indices. - * - * @param shuffleIndex - */ + /** Purge cached shuffle indices. + * + * @param shuffleIndex + */ def purgeCachedDataForShuffle(shuffleIndex: Int): Unit = { if (dispatcher.cachePartitionLengths) { val filter = (block: ShuffleIndexBlockId) => block.shuffleId == shuffleIndex @@ -36,13 +35,12 @@ object S3ShuffleHelper extends Logging { cachedArrayLengths.clear() } - /** - * Write partitionLengths for block with shuffleId and mapId at 0. - * - * @param shuffleId - * @param mapId - * @param partitionLengths - */ + /** Write partitionLengths for block with shuffleId and mapId at 0. + * + * @param shuffleId + * @param mapId + * @param partitionLengths + */ def writePartitionLengths(shuffleId: Int, mapId: Long, partitionLengths: Array[Long]): Unit = { val accumulated = Array[Long](0) ++ partitionLengths.tail.scan(partitionLengths.head)(_ + _) writeArrayAsBlock(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID), accumulated) @@ -60,23 +58,21 @@ object S3ShuffleHelper extends Logging { out.close() } - /** - * Get the cached partition length for shuffle index at shuffleId and mapId - * - * @param shuffleId - * @param mapId - * @return - */ + /** Get the cached partition length for shuffle index at shuffleId and mapId + * + * @param shuffleId + * @param mapId + * @return + */ def getPartitionLengths(shuffleId: Int, mapId: Long): Array[Long] = { getPartitionLengths(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } - /** - * Get the cached partition length for the shuffleIndex block. - * - * @param blockId - * @return - */ + /** Get the cached partition length for the shuffleIndex block. + * + * @param blockId + * @return + */ def getPartitionLengths(blockId: ShuffleIndexBlockId): Array[Long] = { if (dispatcher.cachePartitionLengths) { return cachedArrayLengths.getOrElsePut(blockId, readBlockAsArray) @@ -109,7 +105,9 @@ object S3ShuffleHelper extends Logging { private def readBlockAsArray(blockId: BlockId): Array[Long] = { val stat = dispatcher.getFileStatusCached(blockId) val fileLength = stat.getLen.toInt - val input = new DataInputStream(new BufferedInputStream(dispatcher.openBlock(blockId), math.min(fileLength, dispatcher.bufferSize))) + val input = new DataInputStream( + new BufferedInputStream(dispatcher.openBlock(blockId), math.min(fileLength, dispatcher.bufferSize)) + ) val count = fileLength / 8 if (fileLength % 8 != 0) { throw new SparkException(s"Unexpected file length when reading ${blockId.name}") diff --git a/src/main/scala/org/apache/spark/shuffle/sort/S3ShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/sort/S3ShuffleManager.scala index 8f36539..28d744d 100644 --- a/src/main/scala/org/apache/spark/shuffle/sort/S3ShuffleManager.scala +++ b/src/main/scala/org/apache/spark/shuffle/sort/S3ShuffleManager.scala @@ -38,28 +38,23 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration.Duration import scala.concurrent.{Await, Future} - -/** - * This class was adapted from Apache Spark: SortShuffleManager.scala - */ +/** This class was adapted from Apache Spark: SortShuffleManager.scala + */ private[spark] class S3ShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { val versionString = s"${SparkS3ShuffleBuild.name}-${SparkS3ShuffleBuild.version} " + s"for ${SparkS3ShuffleBuild.sparkVersion}_${SparkS3ShuffleBuild.scalaVersion}" logInfo(s"Configured S3ShuffleManager (${versionString}).") private lazy val dispatcher = S3ShuffleDispatcher.get private lazy val shuffleExecutorComponents = S3ShuffleManager.loadShuffleExecutorComponents(conf) - /** - * A mapping from shuffle ids to the task ids of mappers producing output for those shuffles. - */ + + /** A mapping from shuffle ids to the task ids of mappers producing output for those shuffles. + */ override lazy val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) private val registeredShuffleIds = new mutable.HashSet[Int]() - /** - * Obtains a [[ShuffleHandle]] to pass to tasks. - */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + /** Obtains a [[ShuffleHandle]] to pass to tasks. + */ + override def registerShuffle[K, V, C](shuffleId: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { registeredShuffleIds.add(shuffleId) if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't @@ -68,13 +63,11 @@ private[spark] class S3ShuffleManager(conf: SparkConf) extends ShuffleManager wi // together the spilled files, which would happen with the normal code path. The downside is // having multiple files open at a time and thus more memory allocated to buffers. logInfo(f"Using BypassMergeSortShuffleWriter for ${shuffleId}") - new BypassMergeSortShuffleHandle[K, V]( - shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + new BypassMergeSortShuffleHandle[K, V](shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: logInfo(f"Using UnsafeShuffleWriter for ${shuffleId}") - new SerializedShuffleHandle[K, V]( - shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + new SerializedShuffleHandle[K, V](shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else { logInfo(f"Using SortShuffleWriter for ${shuffleId}") // Otherwise, buffer map outputs in a deserialized form: @@ -83,40 +76,55 @@ private[spark] class S3ShuffleManager(conf: SparkConf) extends ShuffleManager wi } override def getReader[K, C]( - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int, - context: TaskContext, - metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter + ): ShuffleReader[K, C] = { if (dispatcher.useSparkShuffleFetch) { val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( - handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + handle.shuffleId, + startMapIndex, + endMapIndex, + startPartition, + endPartition + ) val canEnableBatchFetch = true return new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, shouldBatchFetch = - canEnableBatchFetch && SortShuffleManager.canUseBatchFetch(startPartition, endPartition, context)) + canEnableBatchFetch && SortShuffleManager.canUseBatchFetch(startPartition, endPartition, context) + ) } new S3ShuffleReader( conf, handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - context, metrics, - startMapIndex, endMapIndex, - startPartition, endPartition, - shouldBatchFetch = SortShuffleManager.canUseBatchFetch(startPartition, endPartition, context)) + context, + metrics, + startMapIndex, + endMapIndex, + startPartition, + endPartition, + shouldBatchFetch = SortShuffleManager.canUseBatchFetch(startPartition, endPartition, context) + ) } /** Get a writer for a given partition. Called on executors by map tasks. */ override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter + ): ShuffleWriter[K, V] = { val env = SparkEnv.get val writer = handle match { - case unsafeShuffleHandle: SerializedShuffleHandle[K@unchecked, V@unchecked] => + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( env.blockManager, context.taskMemoryManager(), @@ -125,8 +133,9 @@ private[spark] class S3ShuffleManager(conf: SparkConf) extends ShuffleManager wi context, env.conf, metrics, - shuffleExecutorComponents) - case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K@unchecked, V@unchecked] => + shuffleExecutorComponents + ) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, bypassMergeSortHandle, @@ -134,8 +143,8 @@ private[spark] class S3ShuffleManager(conf: SparkConf) extends ShuffleManager wi env.conf, metrics, shuffleExecutorComponents - ) - case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] => + ) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) } new S3ShuffleWriter[K, V](writer) @@ -166,11 +175,10 @@ private[spark] class S3ShuffleManager(conf: SparkConf) extends ShuffleManager wi /** Shut down this ShuffleManager. */ override def stop(): Unit = { val cleanupRequired = registeredShuffleIds.nonEmpty - registeredShuffleIds.foreach( - shuffleId => { - purgeCaches(shuffleId) - registeredShuffleIds.remove(shuffleId) - }) + registeredShuffleIds.foreach(shuffleId => { + purgeCaches(shuffleId) + registeredShuffleIds.remove(shuffleId) + }) if (cleanupRequired) { if (dispatcher.cleanupShuffleFiles) { logInfo(f"Cleaning up shuffle files in ${dispatcher.rootDir}.") @@ -186,15 +194,13 @@ private[spark] class S3ShuffleManager(conf: SparkConf) extends ShuffleManager wi private[spark] object S3ShuffleManager { private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { if (conf.get("spark.shuffle.sort.io.plugin.class") != "org.apache.spark.shuffle.S3ShuffleDataIO") { - throw new RuntimeException("\"spark.shuffle.sort.io.plugin.class\" needs to be set to \"org.apache.spark.shuffle.S3ShuffleDataIO\" in order for this plugin to work!") + throw new RuntimeException( + "\"spark.shuffle.sort.io.plugin.class\" needs to be set to \"org.apache.spark.shuffle.S3ShuffleDataIO\" in order for this plugin to work!" + ) } val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() - val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX) - .toMap - executorComponents.initializeExecutor( - conf.getAppId, - SparkEnv.get.executorId, - extraConfigs.asJava) + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap + executorComponents.initializeExecutor(conf.getAppId, SparkEnv.get.executorId, extraConfigs.asJava) executorComponents } } diff --git a/src/main/scala/org/apache/spark/storage/S3BufferedInputStreamAdaptor.scala b/src/main/scala/org/apache/spark/storage/S3BufferedInputStreamAdaptor.scala index 3059312..fc83b18 100644 --- a/src/main/scala/org/apache/spark/storage/S3BufferedInputStreamAdaptor.scala +++ b/src/main/scala/org/apache/spark/storage/S3BufferedInputStreamAdaptor.scala @@ -4,7 +4,9 @@ import org.apache.spark.internal.Logging import java.io.{BufferedInputStream, EOFException, InputStream} -class S3BufferedInputStreamAdaptor(inputStream: InputStream, bufferSize: Int, onClose: (Int) => Unit) extends InputStream with Logging { +class S3BufferedInputStreamAdaptor(inputStream: InputStream, bufferSize: Int, onClose: (Int) => Unit) + extends InputStream + with Logging { private var bufferedStream = new BufferedInputStream(inputStream, bufferSize) @@ -24,7 +26,6 @@ class S3BufferedInputStreamAdaptor(inputStream: InputStream, bufferSize: Int, on } } - def read(): Int = synchronized { checkOpen() bufferedStream.read() diff --git a/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala b/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala index d205030..dfd9aed 100644 --- a/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala +++ b/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala @@ -13,7 +13,9 @@ import java.io.{BufferedInputStream, InputStream} import java.util import java.util.concurrent.atomic.AtomicLong -class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)], maxBufferSize: Long) extends Iterator[(BlockId, InputStream)] with Logging { +class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)], maxBufferSize: Long) + extends Iterator[(BlockId, InputStream)] + with Logging { private val startTime = System.nanoTime() @volatile private var memoryUsage: Long = 0 @@ -123,8 +125,7 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)] if (memoryUsage + bsize > maxBufferSize) { try { wait() - } - catch { + } catch { case _: InterruptedException => } } else { @@ -173,10 +174,12 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)] val bs = bR / r // Threads val ta = desiredActiveThreads.get() - logInfo(s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " + - s"${bR} bytes, ${tW} ms waiting (${atW} avg), " + - s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s). " + - s"Total: ${tR} ms - ${wPer}% waiting. ${ta} active threads.") + logInfo( + s"Statistics: Stage ${sId}.${sAt} TID ${tc.taskAttemptId()} -- " + + s"${bR} bytes, ${tW} ms waiting (${atW} avg), " + + s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s). " + + s"Total: ${tR} ms - ${wPer}% waiting. ${ta} active threads." + ) } catch { case e: Exception => logError(f"Unable to print statistics: ${e.getMessage}.") } diff --git a/src/main/scala/org/apache/spark/storage/S3ChecksumValidationStream.scala b/src/main/scala/org/apache/spark/storage/S3ChecksumValidationStream.scala index 4983f10..39ac769 100644 --- a/src/main/scala/org/apache/spark/storage/S3ChecksumValidationStream.scala +++ b/src/main/scala/org/apache/spark/storage/S3ChecksumValidationStream.scala @@ -12,17 +12,16 @@ import org.apache.spark.shuffle.helper.S3ShuffleHelper import java.io.InputStream import java.util.zip.Checksum -/** - * Validates checksum stored for blockId on stream with checksumAlgorithm. - */ -class S3ChecksumValidationStream( - blockId: BlockId, - stream: InputStream, - checksumAlgorithm: String) extends InputStream with Logging { +/** Validates checksum stored for blockId on stream with checksumAlgorithm. + */ +class S3ChecksumValidationStream(blockId: BlockId, stream: InputStream, checksumAlgorithm: String) + extends InputStream + with Logging { private val (shuffleId: Int, mapId: Long, startReduceId: Int, endReduceId: Int) = blockId match { case ShuffleBlockId(shuffleId, mapId, reduceId) => (shuffleId, mapId, reduceId, reduceId + 1) - case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, endReduceId) => (shuffleId, mapId, startReduceId, endReduceId) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, endReduceId) => + (shuffleId, mapId, startReduceId, endReduceId) case _ => throw new SparkException(s"S3ChecksumValidationStream does not support block type ${blockId}") } @@ -33,7 +32,7 @@ class S3ChecksumValidationStream( private var pos: Long = 0 private var reduceId: Int = startReduceId - private var blockLength: Long = lengths(reduceId+1) - lengths(reduceId) + private var blockLength: Long = lengths(reduceId + 1) - lengths(reduceId) private def eof(): Boolean = reduceId > endReduceId @@ -77,7 +76,7 @@ class S3ChecksumValidationStream( pos = 0 reduceId += 1 if (reduceId < endReduceId) { - blockLength = lengths(reduceId+1) - lengths(reduceId) + blockLength = lengths(reduceId + 1) - lengths(reduceId) if (blockLength == 0) { validateChecksum() } diff --git a/src/main/scala/org/apache/spark/storage/S3ShuffleBlockIterator.scala b/src/main/scala/org/apache/spark/storage/S3ShuffleBlockIterator.scala index 8029652..fd9f4ab 100644 --- a/src/main/scala/org/apache/spark/storage/S3ShuffleBlockIterator.scala +++ b/src/main/scala/org/apache/spark/storage/S3ShuffleBlockIterator.scala @@ -8,8 +8,8 @@ package org.apache.spark.storage import org.apache.spark.shuffle.helper.{S3ShuffleDispatcher, S3ShuffleHelper} class S3ShuffleBlockIterator( - shuffleBlocks: Iterator[BlockId], - ) extends Iterator[(BlockId, S3ShuffleBlockStream)] { + shuffleBlocks: Iterator[BlockId] +) extends Iterator[(BlockId, S3ShuffleBlockStream)] { private val dispatcher = S3ShuffleDispatcher.get @@ -30,9 +30,8 @@ class S3ShuffleBlockIterator( do { val nextBlock = shuffleBlocks.next() - /** - * Ignore missing index files if `alwaysCreateIndex` is configured. - */ + /** Ignore missing index files if `alwaysCreateIndex` is configured. + */ try { val stream = nextBlock match { case ShuffleBlockId(shuffleId, mapId, reduceId) => diff --git a/src/main/scala/org/apache/spark/storage/S3ShuffleBlockStream.scala b/src/main/scala/org/apache/spark/storage/S3ShuffleBlockStream.scala index be50b2b..cdad34f 100644 --- a/src/main/scala/org/apache/spark/storage/S3ShuffleBlockStream.scala +++ b/src/main/scala/org/apache/spark/storage/S3ShuffleBlockStream.scala @@ -11,16 +11,16 @@ import org.apache.spark.shuffle.helper.S3ShuffleDispatcher import java.io.{IOException, InputStream} -/** - * InputStream that reads data from a shuffleBlock, mapId and exposes an InputStream from startReduceId to endReduceId. - */ +/** InputStream that reads data from a shuffleBlock, mapId and exposes an InputStream from startReduceId to endReduceId. + */ class S3ShuffleBlockStream( - shuffleId: Int, - mapId: Long, - startReduceId: Int, - endReduceId: Int, - accumulatedPositions: Array[Long], - ) extends InputStream with Logging { + shuffleId: Int, + mapId: Long, + startReduceId: Int, + endReduceId: Int, + accumulatedPositions: Array[Long] +) extends InputStream + with Logging { private lazy val dispatcher = S3ShuffleDispatcher.get private lazy val blockId = ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID) private lazy val stream = { diff --git a/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala b/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala index 2bc0300..9c1fbba 100644 --- a/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala +++ b/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala @@ -32,21 +32,21 @@ import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.{InterruptibleIterator, SparkConf, SparkEnv, TaskContext} -/** - * This class was adapted from Apache Spark: BlockStoreShuffleReader. - */ +/** This class was adapted from Apache Spark: BlockStoreShuffleReader. + */ class S3ShuffleReader[K, C]( - conf: SparkConf, - handle: BaseShuffleHandle[K, _, C], - context: TaskContext, - readMetrics: ShuffleReadMetricsReporter, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int, - serializerManager: SerializerManager = SparkEnv.get.serializerManager, - shouldBatchFetch: Boolean - ) extends ShuffleReader[K, C] with Logging { + conf: SparkConf, + handle: BaseShuffleHandle[K, _, C], + context: TaskContext, + readMetrics: ShuffleReadMetricsReporter, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + shouldBatchFetch: Boolean +) extends ShuffleReader[K, C] + with Logging { private val dispatcher = S3ShuffleDispatcher.get private val dep = handle.dependency @@ -64,29 +64,37 @@ class S3ShuffleReader[K, C]( val doBatchFetch = shouldBatchFetch && serializerRelocatable && (!compressed || codecConcatenation) && !ioEncryption if (shouldBatchFetch && !doBatchFetch) { - logDebug("The feature tag of continuous shuffle block fetching is set to true, but " + - "we can not enable the feature because other conditions are not satisfied. " + - s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " + - s"codec concatenation: $codecConcatenation, io encryption: $ioEncryption.") + logDebug( + "The feature tag of continuous shuffle block fetching is set to true, but " + + "we can not enable the feature because other conditions are not satisfied. " + + s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " + + s"codec concatenation: $codecConcatenation, io encryption: $ioEncryption." + ) } doBatchFetch } override def read(): Iterator[Product2[K, C]] = { val serializerInstance = dep.serializer.newInstance() - val blocks = computeShuffleBlocks(handle.shuffleId, - startMapIndex, endMapIndex, - startPartition, endPartition, - doBatchFetch = fetchContinousBlocksInBatch, - useBlockManager = dispatcher.useBlockManager) + val blocks = computeShuffleBlocks( + handle.shuffleId, + startMapIndex, + endMapIndex, + startPartition, + endPartition, + doBatchFetch = fetchContinousBlocksInBatch, + useBlockManager = dispatcher.useBlockManager + ) val wrappedStreams = new S3ShuffleBlockIterator(blocks) - val filteredStream = wrappedStreams.filterNot(_._2.maxBytes == 0).map(f => { - readMetrics.incRemoteBytesRead(f._2.maxBytes) // increase byte count. - readMetrics.incRemoteBlocksFetched(1) - f - }) + val filteredStream = wrappedStreams + .filterNot(_._2.maxBytes == 0) + .map(f => { + readMetrics.incRemoteBytesRead(f._2.maxBytes) // increase byte count. + readMetrics.incRemoteBlocksFetched(1) + f + }) val recordIter = new S3BufferedPrefetchIterator(filteredStream, maxBufferSizeTask) .flatMap(s => { val stream = s._2 @@ -107,7 +115,8 @@ class S3ShuffleReader[K, C]( readMetrics.incRecordsRead(1) record }, - context.taskMetrics().mergeShuffleReadMetrics()) + context.taskMetrics().mergeShuffleReadMetrics() + ) // An interruptible iterator must be used here in order to support task cancellation val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) @@ -141,7 +150,7 @@ class S3ShuffleReader[K, C]( resultIter match { case _: InterruptibleIterator[Product2[K, C]] => resultIter - case _ => + case _ => // Use another interruptible iterator here to support task cancellation as aggregator // or(and) sorter may have consumed previous interruptible iterator. new InterruptibleIterator[Product2[K, C]](context, resultIter) @@ -149,27 +158,40 @@ class S3ShuffleReader[K, C]( } private def computeShuffleBlocks( - shuffleId: Int, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int, - doBatchFetch: Boolean, - useBlockManager: Boolean - ): Iterator[BlockId] = { + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + doBatchFetch: Boolean, + useBlockManager: Boolean + ): Iterator[BlockId] = { if (useBlockManager) { val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( - handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) - blocksByAddress.map(f => f._2.map(info => FetchBlockInfo(info._1, info._2, info._3))) - .flatMap(info => ShuffleBlockFetcherIterator.mergeContinuousShuffleBlockIdsIfNeeded(info, doBatchFetch)) - .map(_.blockId) + handle.shuffleId, + startMapIndex, + endMapIndex, + startPartition, + endPartition + ) + blocksByAddress + .map(f => f._2.map(info => FetchBlockInfo(info._1, info._2, info._3))) + .flatMap(info => ShuffleBlockFetcherIterator.mergeContinuousShuffleBlockIdsIfNeeded(info, doBatchFetch)) + .map(_.blockId) } else { - val indices = dispatcher.listShuffleIndices(shuffleId).filter( - block => block.mapId >= startMapIndex && block.mapId < endMapIndex) + val indices = dispatcher + .listShuffleIndices(shuffleId) + .filter(block => block.mapId >= startMapIndex && block.mapId < endMapIndex) if (doBatchFetch || dispatcher.forceBatchFetch) { indices.map(block => ShuffleBlockBatchId(block.shuffleId, block.mapId, startPartition, endPartition)).toIterator } else { - indices.flatMap(block => Range(startPartition, endPartition).map(partition => ShuffleBlockId(block.shuffleId, block.mapId, partition))).toIterator + indices + .flatMap(block => + Range(startPartition, endPartition).map(partition => + ShuffleBlockId(block.shuffleId, block.mapId, partition) + ) + ) + .toIterator } } } diff --git a/src/test/scala-2.12/org/apache/spark/shuffle/S3ShuffleManagerTest.scala b/src/test/scala-2.12/org/apache/spark/shuffle/S3ShuffleManagerTest.scala index ec9e532..75752fa 100644 --- a/src/test/scala-2.12/org/apache/spark/shuffle/S3ShuffleManagerTest.scala +++ b/src/test/scala-2.12/org/apache/spark/shuffle/S3ShuffleManagerTest.scala @@ -62,9 +62,10 @@ class S3ShuffleManagerTest { val sc = new SparkContext(conf) try { // Test copied from: src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala - val rdd = sc.parallelize(1 to 5, 4) - .map(key => (KeyClass(), ValueClass())) - .groupByKey() + val rdd = sc + .parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .groupByKey() val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] assert(!dep.mapSideCombine, "Test requires that no map-side aggregator is defined") assert(dep.keyClassName == classOf[KeyClass].getName) @@ -83,12 +84,13 @@ class S3ShuffleManagerTest { val numValues = 10000 val numMaps = 3 - val rdd = sc.parallelize(0 until numValues, numMaps) - .map(t => { - val rand = scala.util.Random - (t) -> rand.nextInt(numValues) - }) - .sortBy(_._2, ascending = true) + val rdd = sc + .parallelize(0 until numValues, numMaps) + .map(t => { + val rand = scala.util.Random + (t) -> rand.nextInt(numValues) + }) + .sortBy(_._2, ascending = true) val result = rdd.collect() var previous = result(0)._2 @@ -109,13 +111,12 @@ class S3ShuffleManagerTest { try { val numValuesPerPartition = 100000 val numPartitions = 20 - val dataset = sc.parallelize(0 until numPartitions, numPartitions).mapPartitionsWithIndex { - case (index, _) => - Iterator.tabulate(numValuesPerPartition) { offset => - val key = offset - val value = offset*index - (key, value*2) - } + val dataset = sc.parallelize(0 until numPartitions, numPartitions).mapPartitionsWithIndex { case (index, _) => + Iterator.tabulate(numValuesPerPartition) { offset => + val key = offset + val value = offset * index + (key, value * 2) + } } def convert_value(v: Int) = { @@ -156,14 +157,13 @@ class S3ShuffleManagerTest { val numValuesPerPartition = 10000 val numPartitions = 5 - val dataset = sc.parallelize(0 until numPartitions).mapPartitionsWithIndex { - case (index, _) => - val rand = scala.util.Random - Iterator.tabulate(numValuesPerPartition) { offset => - val key = rand.nextInt() - val value = rand.nextInt() - (key, value) - } + val dataset = sc.parallelize(0 until numPartitions).mapPartitionsWithIndex { case (index, _) => + val rand = scala.util.Random + Iterator.tabulate(numValuesPerPartition) { offset => + val key = rand.nextInt() + val value = rand.nextInt() + (key, value) + } } val sorted = dataset.sortByKey(true, numPartitions - 1) val result = sorted.collect() @@ -195,12 +195,14 @@ class S3ShuffleManagerTest { val metrics = stageMetrics.aggregateStageMetrics(s"spark_measure_test_${timestamp}") // get all of the stats val (runTime, bytesRead, recordsRead, bytesWritten, recordsWritten) = - metrics.select("elapsedTime", "bytesRead", - "recordsRead", "bytesWritten", "recordsWritten") - .take(1) - .map(r => (r.getLong(0), r.getLong(1), r.getLong(2), r.getLong(3), - r.getLong(4))).head - println(f"Elapsed: ${runTime}, bytesRead: ${bytesRead}, recordsRead: ${recordsRead}, bytesWritten ${bytesWritten}, recordsWritten: ${recordsWritten}") + metrics + .select("elapsedTime", "bytesRead", "recordsRead", "bytesWritten", "recordsWritten") + .take(1) + .map(r => (r.getLong(0), r.getLong(1), r.getLong(2), r.getLong(3), r.getLong(4))) + .head + println( + f"Elapsed: ${runTime}, bytesRead: ${bytesRead}, recordsRead: ${recordsRead}, bytesWritten ${bytesWritten}, recordsWritten: ${recordsWritten}" + ) spark.stop() spark.close() } @@ -213,9 +215,10 @@ class S3ShuffleManagerTest { val numMaps = 3 val numPartitions = 5 - val rdd = sc.parallelize(0 until numValues, numMaps) - .map(t => ((t / 2) -> (t * 2).longValue())) - .foldByKey(0, numPartitions)((v1, v2) => v1 + v2) + val rdd = sc + .parallelize(0 until numValues, numMaps) + .map(t => ((t / 2) -> (t * 2).longValue())) + .foldByKey(0, numPartitions)((v1, v2) => v1 + v2) val result = rdd.collect() assert(result.length === numValues / 2)