diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md
index 25bba0dbd90..0d4a3267e30 100644
--- a/docs/additional-functionality/advanced_configs.md
+++ b/docs/additional-functionality/advanced_configs.md
@@ -60,7 +60,6 @@ Name | Description | Default Value | Applicable at
spark.rapids.shuffle.ucx.activeMessages.forceRndv|Set to true to force 'rndv' mode for all UCX Active Messages. This should only be required with UCX 1.10.x. UCX 1.11.x deployments should set to false.|false|Startup
spark.rapids.shuffle.ucx.managementServerHost|The host to be used to start the management server|null|Startup
spark.rapids.shuffle.ucx.useWakeup|When set to true, use UCX's event-based progress (epoll) in order to wake up the progress thread when needed, instead of a hot loop.|true|Startup
-spark.rapids.sql.agg.fallbackAlgorithm|When agg cannot be done in a single pass, use sort-based fallback or repartition-based fallback.|sort|Runtime
spark.rapids.sql.agg.skipAggPassReductionRatio|In non-final aggregation stages, if the previous pass has a row reduction ratio greater than this value, the next aggregation pass will be skipped.Setting this to 1 essentially disables this feature.|1.0|Runtime
spark.rapids.sql.allowMultipleJars|Allow multiple rapids-4-spark, spark-rapids-jni, and cudf jars on the classpath. Spark will take the first one it finds, so the version may not be expected. Possisble values are ALWAYS: allow all jars, SAME_REVISION: only allow jars with the same revision, NEVER: do not allow multiple jars at all.|SAME_REVISION|Startup
spark.rapids.sql.castDecimalToFloat.enabled|Casting from decimal to floating point types on the GPU returns results that have tiny difference compared to results returned from CPU.|true|Runtime
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala
index 96254b9f38d..926f770a683 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala
@@ -16,7 +16,7 @@
package com.nvidia.spark.rapids
import scala.collection.mutable
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+import scala.collection.mutable.ArrayBuffer
import scala.util.control.ControlThrowable
import com.nvidia.spark.rapids.RapidsPluginImplicits._
@@ -43,8 +43,7 @@ object Arm extends ArmScalaSpecificImpl {
}
/** Executes the provided code block and then closes the sequence of resources */
- def withResource[T <: AutoCloseable, V](r: scala.collection.Seq[T])
- (block: scala.collection.Seq[T] => V): V = {
+ def withResource[T <: AutoCloseable, V](r: Seq[T])(block: Seq[T] => V): V = {
try {
block(r)
} finally {
@@ -135,20 +134,6 @@ object Arm extends ArmScalaSpecificImpl {
}
}
- /** Executes the provided code block, closing the resources only if an exception occurs */
- def closeOnExcept[T <: AutoCloseable, V](r: ListBuffer[T])(block: ListBuffer[T] => V): V = {
- try {
- block(r)
- } catch {
- case t: ControlThrowable =>
- // Don't close for these cases..
- throw t
- case t: Throwable =>
- r.safeClose(t)
- throw t
- }
- }
-
/** Executes the provided code block, closing the resources only if an exception occurs */
def closeOnExcept[T <: AutoCloseable, V](r: mutable.Queue[T])(block: mutable.Queue[T] => V): V = {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala
index b28101f3442..7e6a1056d01 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala
@@ -16,9 +16,11 @@
package com.nvidia.spark.rapids
+import java.util
+
import scala.annotation.tailrec
+import scala.collection.JavaConverters.collectionAsScalaIterableConverter
import scala.collection.mutable
-import scala.collection.mutable.ListBuffer
import ai.rapids.cudf
import ai.rapids.cudf.{NvtxColor, NvtxRange}
@@ -44,11 +46,11 @@ import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{ExplainUtils, SortExec, SparkPlan}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.rapids.aggregate.{CpuToGpuAggregateBufferConverter, CudfAggregate, GpuAggregateExpression, GpuToCpuAggregateBufferConverter}
-import org.apache.spark.sql.rapids.execution.{GpuBatchSubPartitioner, GpuShuffleMeta, TrampolineUtil}
+import org.apache.spark.sql.rapids.execution.{GpuShuffleMeta, TrampolineUtil}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
-object AggregateUtils extends Logging {
+object AggregateUtils {
private val aggs = List("min", "max", "avg", "sum", "count", "first", "last")
@@ -84,10 +86,9 @@ object AggregateUtils extends Logging {
/**
* Computes a target input batch size based on the assumption that computation can consume up to
* 4X the configured batch size.
- *
- * @param confTargetSize user-configured maximum desired batch size
- * @param inputTypes input batch schema
- * @param outputTypes output batch schema
+ * @param confTargetSize user-configured maximum desired batch size
+ * @param inputTypes input batch schema
+ * @param outputTypes output batch schema
* @param isReductionOnly true if this is a reduction-only aggregation without grouping
* @return maximum target batch size to keep computation under the 4X configured batch limit
*/
@@ -98,7 +99,6 @@ object AggregateUtils extends Logging {
isReductionOnly: Boolean): Long = {
def typesToSize(types: Seq[DataType]): Long =
types.map(GpuBatchUtils.estimateGpuMemory(_, nullable = false, rowCount = 1)).sum
-
val inputRowSize = typesToSize(inputTypes)
val outputRowSize = typesToSize(outputTypes)
// The cudf hash table implementation allocates four 32-bit integers per input row.
@@ -124,129 +124,6 @@ object AggregateUtils extends Logging {
// Finally compute the input target batching size taking into account the cudf row limits
Math.min(inputRowSize * maxRows, Int.MaxValue)
}
-
-
- /**
- * Concatenate batches together and perform a merge aggregation on the result. The input batches
- * will be closed as part of this operation.
- *
- * @param batches batches to concatenate and merge aggregate
- * @return lazy spillable batch which has NOT been marked spillable
- */
- private def concatenateAndMerge(
- batches: mutable.Buffer[SpillableColumnarBatch],
- metrics: GpuHashAggregateMetrics,
- concatAndMergeHelper: AggHelper): SpillableColumnarBatch = {
- // TODO: concatenateAndMerge (and calling code) could output a sequence
- // of batches for the partial aggregate case. This would be done in case
- // a retry failed a certain number of times.
- val concatBatch = withResource(batches) { _ =>
- val concatSpillable = concatenateBatches(metrics, batches.toSeq)
- withResource(concatSpillable) {
- _.getColumnarBatch()
- }
- }
- computeAggregateAndClose(metrics, concatBatch, concatAndMergeHelper)
- }
-
- /**
- * Perform a single pass over the aggregated batches attempting to merge adjacent batches.
- *
- * @return true if at least one merge operation occurred
- */
- private def mergePass(
- aggregatedBatches: mutable.Buffer[SpillableColumnarBatch],
- targetMergeBatchSize: Long,
- helper: AggHelper,
- metrics: GpuHashAggregateMetrics
- ): Boolean = {
- val batchesToConcat: mutable.ArrayBuffer[SpillableColumnarBatch] = mutable.ArrayBuffer.empty
- var wasBatchMerged = false
- // Current size in bytes of the batches targeted for the next concatenation
- var concatSize: Long = 0L
- var batchesLeftInPass = aggregatedBatches.size
-
- while (batchesLeftInPass > 0) {
- closeOnExcept(batchesToConcat) { _ =>
- var isConcatSearchFinished = false
- // Old batches are picked up at the front of the queue and freshly merged batches are
- // appended to the back of the queue. Although tempting to allow the pass to "wrap around"
- // and pick up batches freshly merged in this pass, it's avoided to prevent changing the
- // order of aggregated batches.
- while (batchesLeftInPass > 0 && !isConcatSearchFinished) {
- val candidate = aggregatedBatches.head
- val potentialSize = concatSize + candidate.sizeInBytes
- isConcatSearchFinished = concatSize > 0 && potentialSize > targetMergeBatchSize
- if (!isConcatSearchFinished) {
- batchesLeftInPass -= 1
- batchesToConcat += aggregatedBatches.remove(0)
- concatSize = potentialSize
- }
- }
- }
-
- val mergedBatch = if (batchesToConcat.length > 1) {
- wasBatchMerged = true
- concatenateAndMerge(batchesToConcat, metrics, helper)
- } else {
- // Unable to find a neighboring buffer to produce a valid merge in this pass,
- // so simply put this buffer back on the queue for other passes.
- batchesToConcat.remove(0)
- }
-
- // Add the merged batch to the end of the aggregated batch queue. Only a single pass over
- // the batches is being performed due to the batch count check above, so the single-pass
- // loop will terminate before picking up this new batch.
- aggregatedBatches += mergedBatch
- batchesToConcat.clear()
- concatSize = 0
- }
-
- wasBatchMerged
- }
-
-
- /**
- * Attempt to merge adjacent batches in the aggregatedBatches queue until either there is only
- * one batch or merging adjacent batches would exceed the target batch size.
- */
- def tryMergeAggregatedBatches(
- aggregatedBatches: mutable.Buffer[SpillableColumnarBatch],
- isReductionOnly: Boolean,
- metrics: GpuHashAggregateMetrics,
- targetMergeBatchSize: Long,
- helper: AggHelper
- ): Unit = {
- while (aggregatedBatches.size > 1) {
- val concatTime = metrics.concatTime
- val opTime = metrics.opTime
- withResource(new NvtxWithMetrics("agg merge pass", NvtxColor.BLUE, concatTime,
- opTime)) { _ =>
- // continue merging as long as some batches are able to be combined
- if (!mergePass(aggregatedBatches, targetMergeBatchSize, helper, metrics))
- if (aggregatedBatches.size > 1 && isReductionOnly) {
- // We were unable to merge the aggregated batches within the target batch size limit,
- // which means normally we would fallback to a sort-based approach. However for
- // reduction-only aggregation there are no keys to use for a sort. The only way this
- // can work is if all batches are merged. This will exceed the target batch size limit,
- // but at this point it is either risk an OOM/cudf error and potentially work or
- // not work at all.
- logWarning(s"Unable to merge reduction-only aggregated batches within " +
- s"target batch limit of $targetMergeBatchSize, attempting to merge remaining " +
- s"${aggregatedBatches.size} batches beyond limit")
- withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { batchesToConcat =>
- aggregatedBatches.foreach(b => batchesToConcat += b)
- aggregatedBatches.clear()
- val batch = concatenateAndMerge(batchesToConcat, metrics, helper)
- // batch does not need to be marked spillable since it is the last and only batch
- // and will be immediately retrieved on the next() call.
- aggregatedBatches += batch
- }
- }
- return
- }
- }
- }
}
/** Utility class to hold all of the metrics related to hash aggregation */
@@ -258,7 +135,6 @@ case class GpuHashAggregateMetrics(
computeAggTime: GpuMetric,
concatTime: GpuMetric,
sortTime: GpuMetric,
- repartitionTime: GpuMetric,
numAggOps: GpuMetric,
numPreSplits: GpuMetric,
singlePassTasks: GpuMetric,
@@ -835,8 +711,6 @@ object GpuAggFinalPassIterator {
* @param useTieredProject user-specified option to enable tiered projections
* @param allowNonFullyAggregatedOutput if allowed to skip third pass Agg
* @param skipAggPassReductionRatio skip if the ratio of rows after a pass is bigger than this value
- * @param aggFallbackAlgorithm use sort-based fallback or repartition-based fallback
- * for oversize agg
* @param localInputRowsCount metric to track the number of input rows processed locally
*/
class GpuMergeAggregateIterator(
@@ -852,17 +726,15 @@ class GpuMergeAggregateIterator(
useTieredProject: Boolean,
allowNonFullyAggregatedOutput: Boolean,
skipAggPassReductionRatio: Double,
- aggFallbackAlgorithm: String,
localInputRowsCount: LocalGpuMetric)
extends Iterator[ColumnarBatch] with AutoCloseable with Logging {
private[this] val isReductionOnly = groupingExpressions.isEmpty
private[this] val targetMergeBatchSize = computeTargetMergeBatchSize(configuredTargetBatchSize)
- private[this] val aggregatedBatches = ListBuffer.empty[SpillableColumnarBatch]
+ private[this] val aggregatedBatches = new util.ArrayDeque[SpillableColumnarBatch]
private[this] var outOfCoreIter: Option[GpuOutOfCoreSortIterator] = None
- private[this] var repartitionIter: Option[RepartitionAggregateIterator] = None
/** Iterator for fetching aggregated batches either if:
- * 1. a sort-based/repartition-based fallback has occurred
+ * 1. a sort-based fallback has occurred
* 2. skip third pass agg has occurred
**/
private[this] var fallbackIter: Option[Iterator[ColumnarBatch]] = None
@@ -880,7 +752,7 @@ class GpuMergeAggregateIterator(
override def hasNext: Boolean = {
fallbackIter.map(_.hasNext).getOrElse {
// reductions produce a result even if the input is empty
- hasReductionOnlyBatch || aggregatedBatches.nonEmpty || firstPassIter.hasNext
+ hasReductionOnlyBatch || !aggregatedBatches.isEmpty || firstPassIter.hasNext
}
}
@@ -897,11 +769,9 @@ class GpuMergeAggregateIterator(
if (isReductionOnly ||
skipAggPassReductionRatio * localInputRowsCount.value >= rowsAfterFirstPassAgg) {
// second pass agg
- AggregateUtils.tryMergeAggregatedBatches(
- aggregatedBatches, isReductionOnly,
- metrics, targetMergeBatchSize, concatAndMergeHelper)
+ tryMergeAggregatedBatches()
- val rowsAfterSecondPassAgg = aggregatedBatches.foldLeft(0L) {
+ val rowsAfterSecondPassAgg = aggregatedBatches.asScala.foldLeft(0L) {
(totalRows, batch) => totalRows + batch.numRows()
}
shouldSkipThirdPassAgg =
@@ -914,7 +784,7 @@ class GpuMergeAggregateIterator(
}
}
- if (aggregatedBatches.size > 1) {
+ if (aggregatedBatches.size() > 1) {
// Unable to merge to a single output, so must fall back
if (allowNonFullyAggregatedOutput && shouldSkipThirdPassAgg) {
// skip third pass agg, return the aggregated batches directly
@@ -922,23 +792,17 @@ class GpuMergeAggregateIterator(
s"${skipAggPassReductionRatio * 100}% of " +
s"rows after first pass, skip the third pass agg")
fallbackIter = Some(new Iterator[ColumnarBatch] {
- override def hasNext: Boolean = aggregatedBatches.nonEmpty
+ override def hasNext: Boolean = !aggregatedBatches.isEmpty
override def next(): ColumnarBatch = {
- withResource(aggregatedBatches.remove(0)) { spillableBatch =>
+ withResource(aggregatedBatches.pop()) { spillableBatch =>
spillableBatch.getColumnarBatch()
}
}
})
} else {
// fallback to sort agg, this is the third pass agg
- aggFallbackAlgorithm.toLowerCase match {
- case "repartition" =>
- fallbackIter = Some(buildRepartitionFallbackIterator())
- case "sort" => fallbackIter = Some(buildSortFallbackIterator())
- case _ => throw new IllegalArgumentException(
- s"Unsupported aggregation fallback algorithm: $aggFallbackAlgorithm")
- }
+ fallbackIter = Some(buildSortFallbackIterator())
}
fallbackIter.get.next()
} else if (aggregatedBatches.isEmpty) {
@@ -951,7 +815,7 @@ class GpuMergeAggregateIterator(
} else {
// this will be the last batch
hasReductionOnlyBatch = false
- withResource(aggregatedBatches.remove(0)) { spillableBatch =>
+ withResource(aggregatedBatches.pop()) { spillableBatch =>
spillableBatch.getColumnarBatch()
}
}
@@ -959,12 +823,10 @@ class GpuMergeAggregateIterator(
}
override def close(): Unit = {
- aggregatedBatches.foreach(_.safeClose())
+ aggregatedBatches.forEach(_.safeClose())
aggregatedBatches.clear()
outOfCoreIter.foreach(_.close())
outOfCoreIter = None
- repartitionIter.foreach(_.close())
- repartitionIter = None
fallbackIter = None
hasReductionOnlyBatch = false
}
@@ -981,161 +843,133 @@ class GpuMergeAggregateIterator(
while (firstPassIter.hasNext) {
val batch = firstPassIter.next()
rowsAfter += batch.numRows()
- aggregatedBatches += batch
+ aggregatedBatches.add(batch)
}
rowsAfter
}
- private lazy val concatAndMergeHelper =
- new AggHelper(inputAttributes, groupingExpressions, aggregateExpressions,
- forceMerge = true, useTieredProject = useTieredProject)
-
- private def cbIteratorStealingFromBuffer(input: ListBuffer[SpillableColumnarBatch]) = {
- val aggregatedBatchIter = new Iterator[ColumnarBatch] {
- override def hasNext: Boolean = input.nonEmpty
-
- override def next(): ColumnarBatch = {
- withResource(input.remove(0)) { spillable =>
- spillable.getColumnarBatch()
+ /**
+ * Attempt to merge adjacent batches in the aggregatedBatches queue until either there is only
+ * one batch or merging adjacent batches would exceed the target batch size.
+ */
+ private def tryMergeAggregatedBatches(): Unit = {
+ while (aggregatedBatches.size() > 1) {
+ val concatTime = metrics.concatTime
+ val opTime = metrics.opTime
+ withResource(new NvtxWithMetrics("agg merge pass", NvtxColor.BLUE, concatTime,
+ opTime)) { _ =>
+ // continue merging as long as some batches are able to be combined
+ if (!mergePass()) {
+ if (aggregatedBatches.size() > 1 && isReductionOnly) {
+ // We were unable to merge the aggregated batches within the target batch size limit,
+ // which means normally we would fallback to a sort-based approach. However for
+ // reduction-only aggregation there are no keys to use for a sort. The only way this
+ // can work is if all batches are merged. This will exceed the target batch size limit,
+ // but at this point it is either risk an OOM/cudf error and potentially work or
+ // not work at all.
+ logWarning(s"Unable to merge reduction-only aggregated batches within " +
+ s"target batch limit of $targetMergeBatchSize, attempting to merge remaining " +
+ s"${aggregatedBatches.size()} batches beyond limit")
+ withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { batchesToConcat =>
+ aggregatedBatches.forEach(b => batchesToConcat += b)
+ aggregatedBatches.clear()
+ val batch = concatenateAndMerge(batchesToConcat)
+ // batch does not need to be marked spillable since it is the last and only batch
+ // and will be immediately retrieved on the next() call.
+ aggregatedBatches.add(batch)
+ }
+ }
+ return
}
}
}
- aggregatedBatchIter
}
- private case class RepartitionAggregateIterator(
- inputBatches: ListBuffer[SpillableColumnarBatch],
- hashKeys: Seq[GpuExpression],
- targetSize: Long,
- opTime: GpuMetric,
- repartitionTime: GpuMetric) extends Iterator[ColumnarBatch]
- with AutoCloseable {
-
- case class AggregatePartition(batches: ListBuffer[SpillableColumnarBatch], seed: Int)
- extends AutoCloseable {
- override def close(): Unit = {
- batches.safeClose()
- }
-
- def totalRows(): Long = batches.map(_.numRows()).sum
-
- def totalSize(): Long = batches.map(_.sizeInBytes).sum
+ /**
+ * Perform a single pass over the aggregated batches attempting to merge adjacent batches.
+ * @return true if at least one merge operation occurred
+ */
+ private def mergePass(): Boolean = {
+ val batchesToConcat: mutable.ArrayBuffer[SpillableColumnarBatch] = mutable.ArrayBuffer.empty
+ var wasBatchMerged = false
+ // Current size in bytes of the batches targeted for the next concatenation
+ var concatSize: Long = 0L
+ var batchesLeftInPass = aggregatedBatches.size()
- def split(): ListBuffer[AggregatePartition] = {
- withResource(new NvtxWithMetrics("agg repartition", NvtxColor.CYAN, repartitionTime)) { _ =>
- if (seed > hashSeed + 20) {
- throw new IllegalStateException("At most repartition 3 times for a partition")
- }
- val totalSize = batches.map(_.sizeInBytes).sum
- val newSeed = seed + 10
- val iter = cbIteratorStealingFromBuffer(batches)
- withResource(new GpuBatchSubPartitioner(
- iter, hashKeys, computeNumPartitions(totalSize), newSeed, "aggRepartition")) {
- partitioner =>
- closeOnExcept(ListBuffer.empty[AggregatePartition]) { partitions =>
- preparePartitions(newSeed, partitioner, partitions)
- partitions
- }
+ while (batchesLeftInPass > 0) {
+ closeOnExcept(batchesToConcat) { _ =>
+ var isConcatSearchFinished = false
+ // Old batches are picked up at the front of the queue and freshly merged batches are
+ // appended to the back of the queue. Although tempting to allow the pass to "wrap around"
+ // and pick up batches freshly merged in this pass, it's avoided to prevent changing the
+ // order of aggregated batches.
+ while (batchesLeftInPass > 0 && !isConcatSearchFinished) {
+ val candidate = aggregatedBatches.getFirst
+ val potentialSize = concatSize + candidate.sizeInBytes
+ isConcatSearchFinished = concatSize > 0 && potentialSize > targetMergeBatchSize
+ if (!isConcatSearchFinished) {
+ batchesLeftInPass -= 1
+ batchesToConcat += aggregatedBatches.removeFirst()
+ concatSize = potentialSize
}
}
}
- }
- private def preparePartitions(
- newSeed: Int,
- partitioner: GpuBatchSubPartitioner,
- partitions: ListBuffer[AggregatePartition]): Unit = {
- (0 until partitioner.partitionsCount).foreach { id =>
- val buffer = ListBuffer.empty[SpillableColumnarBatch]
- buffer ++= partitioner.releaseBatchesByPartition(id)
- val newPart = AggregatePartition.apply(buffer, newSeed)
- if (newPart.totalRows() > 0) {
- partitions += newPart
- } else {
- newPart.safeClose()
- }
+ val mergedBatch = if (batchesToConcat.length > 1) {
+ wasBatchMerged = true
+ concatenateAndMerge(batchesToConcat)
+ } else {
+ // Unable to find a neighboring buffer to produce a valid merge in this pass,
+ // so simply put this buffer back on the queue for other passes.
+ batchesToConcat.remove(0)
}
- }
- private[this] def computeNumPartitions(totalSize: Long): Int = {
- Math.floorDiv(totalSize, targetMergeBatchSize).toInt + 1
- }
-
- private val hashSeed = 100
- private val aggPartitions = ListBuffer.empty[AggregatePartition]
- private val deferredAggPartitions = ListBuffer.empty[AggregatePartition]
- deferredAggPartitions += AggregatePartition.apply(inputBatches, hashSeed)
-
- override def hasNext: Boolean = aggPartitions.nonEmpty || deferredAggPartitions.nonEmpty
-
- override def next(): ColumnarBatch = {
- withResource(new NvtxWithMetrics("RepartitionAggregateIterator.next",
- NvtxColor.BLUE, opTime)) { _ =>
- if (aggPartitions.isEmpty && deferredAggPartitions.nonEmpty) {
- val headDeferredPartition = deferredAggPartitions.remove(0)
- withResource(headDeferredPartition) { _ =>
- aggPartitions ++= headDeferredPartition.split()
- }
- return next()
- }
-
- val headPartition = aggPartitions.remove(0)
- if (headPartition.totalSize() > targetMergeBatchSize) {
- deferredAggPartitions += headPartition
- return next()
- }
-
- withResource(headPartition) { _ =>
- val batchSizeBeforeMerge = headPartition.batches.size
- AggregateUtils.tryMergeAggregatedBatches(
- headPartition.batches, isReductionOnly, metrics,
- targetMergeBatchSize, concatAndMergeHelper)
- if (headPartition.batches.size != 1) {
- throw new IllegalStateException(
- "Expected a single batch after tryMergeAggregatedBatches, but got " +
- s"${headPartition.batches.size} batches. Before merge, there were " +
- s"$batchSizeBeforeMerge batches.")
- }
- headPartition.batches.head.getColumnarBatch()
- }
- }
+ // Add the merged batch to the end of the aggregated batch queue. Only a single pass over
+ // the batches is being performed due to the batch count check above, so the single-pass
+ // loop will terminate before picking up this new batch.
+ aggregatedBatches.addLast(mergedBatch)
+ batchesToConcat.clear()
+ concatSize = 0
}
- override def close(): Unit = {
- aggPartitions.foreach(_.safeClose())
- deferredAggPartitions.foreach(_.safeClose())
- }
+ wasBatchMerged
}
+ private lazy val concatAndMergeHelper =
+ new AggHelper(inputAttributes, groupingExpressions, aggregateExpressions,
+ forceMerge = true, useTieredProject = useTieredProject)
- /** Build an iterator that uses a sort-based approach to merge aggregated batches together. */
- private def buildRepartitionFallbackIterator(): Iterator[ColumnarBatch] = {
- logInfo(s"Falling back to repartition-based aggregation with " +
- s"${aggregatedBatches.size} batches")
- metrics.numTasksFallBacked += 1
-
- val groupingAttributes = groupingExpressions.map(_.toAttribute)
- val aggBufferAttributes = groupingAttributes ++
- aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
-
- val hashKeys: Seq[GpuExpression] =
- GpuBindReferences.bindGpuReferences(groupingAttributes, aggBufferAttributes.toSeq)
-
-
- repartitionIter = Some(RepartitionAggregateIterator(
- aggregatedBatches,
- hashKeys,
- targetMergeBatchSize,
- opTime = metrics.opTime,
- repartitionTime = metrics.repartitionTime))
- repartitionIter.get
+ /**
+ * Concatenate batches together and perform a merge aggregation on the result. The input batches
+ * will be closed as part of this operation.
+ * @param batches batches to concatenate and merge aggregate
+ * @return lazy spillable batch which has NOT been marked spillable
+ */
+ private def concatenateAndMerge(
+ batches: mutable.ArrayBuffer[SpillableColumnarBatch]): SpillableColumnarBatch = {
+ // TODO: concatenateAndMerge (and calling code) could output a sequence
+ // of batches for the partial aggregate case. This would be done in case
+ // a retry failed a certain number of times.
+ val concatBatch = withResource(batches) { _ =>
+ val concatSpillable = concatenateBatches(metrics, batches.toSeq)
+ withResource(concatSpillable) { _.getColumnarBatch() }
+ }
+ computeAggregateAndClose(metrics, concatBatch, concatAndMergeHelper)
}
/** Build an iterator that uses a sort-based approach to merge aggregated batches together. */
private def buildSortFallbackIterator(): Iterator[ColumnarBatch] = {
- logInfo(s"Falling back to sort-based aggregation with ${aggregatedBatches.size} batches")
+ logInfo(s"Falling back to sort-based aggregation with ${aggregatedBatches.size()} batches")
metrics.numTasksFallBacked += 1
- val aggregatedBatchIter = cbIteratorStealingFromBuffer(aggregatedBatches)
+ val aggregatedBatchIter = new Iterator[ColumnarBatch] {
+ override def hasNext: Boolean = !aggregatedBatches.isEmpty
+
+ override def next(): ColumnarBatch = {
+ withResource(aggregatedBatches.removeFirst()) { spillable =>
+ spillable.getColumnarBatch()
+ }
+ }
+ }
if (isReductionOnly) {
// Normally this should never happen because `tryMergeAggregatedBatches` should have done
@@ -1498,8 +1332,7 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan](
conf.forceSinglePassPartialSortAgg,
allowSinglePassAgg,
allowNonFullyAggregatedOutput,
- conf.skipAggPassReductionRatio,
- conf.aggFallbackAlgorithm)
+ conf.skipAggPassReductionRatio)
}
}
@@ -1587,8 +1420,7 @@ abstract class GpuTypedImperativeSupportedAggregateExecMeta[INPUT <: BaseAggrega
false,
false,
false,
- 1,
- conf.aggFallbackAlgorithm)
+ 1)
} else {
super.convertToGpu()
}
@@ -1941,8 +1773,6 @@ object GpuHashAggregateExecBase {
* (can omit non fully aggregated data for non-final
* stage of aggregation)
* @param skipAggPassReductionRatio skip if the ratio of rows after a pass is bigger than this value
- * @param aggFallbackAlgorithm use sort-based fallback or repartition-based fallback for
- * oversize agg
*/
case class GpuHashAggregateExec(
requiredChildDistributionExpressions: Option[Seq[Expression]],
@@ -1957,8 +1787,7 @@ case class GpuHashAggregateExec(
forceSinglePassAgg: Boolean,
allowSinglePassAgg: Boolean,
allowNonFullyAggregatedOutput: Boolean,
- skipAggPassReductionRatio: Double,
- aggFallbackAlgorithm: String
+ skipAggPassReductionRatio: Double
) extends ShimUnaryExecNode with GpuExec {
// lifted directly from `BaseAggregateExec.inputAttributes`, edited comment.
@@ -1980,7 +1809,6 @@ case class GpuHashAggregateExec(
AGG_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_AGG_TIME),
CONCAT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_CONCAT_TIME),
SORT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_SORT_TIME),
- REPARTITION_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_REPARTITION_TIME),
"NUM_AGGS" -> createMetric(DEBUG_LEVEL, "num agg operations"),
"NUM_PRE_SPLITS" -> createMetric(DEBUG_LEVEL, "num pre splits"),
"NUM_TASKS_SINGLE_PASS" -> createMetric(MODERATE_LEVEL, "number of single pass tasks"),
@@ -2011,7 +1839,6 @@ case class GpuHashAggregateExec(
computeAggTime = gpuLongMetric(AGG_TIME),
concatTime = gpuLongMetric(CONCAT_TIME),
sortTime = gpuLongMetric(SORT_TIME),
- repartitionTime = gpuLongMetric(REPARTITION_TIME),
numAggOps = gpuLongMetric("NUM_AGGS"),
numPreSplits = gpuLongMetric("NUM_PRE_SPLITS"),
singlePassTasks = gpuLongMetric("NUM_TASKS_SINGLE_PASS"),
@@ -2046,8 +1873,7 @@ case class GpuHashAggregateExec(
boundGroupExprs, aggregateExprs, aggregateAttrs, resultExprs, modeInfo,
localEstimatedPreProcessGrowth, alreadySorted, expectedOrdering,
postBoundReferences, targetBatchSize, aggMetrics, useTieredProject,
- localForcePre, localAllowPre, allowNonFullyAggregatedOutput, skipAggPassReductionRatio,
- aggFallbackAlgorithm)
+ localForcePre, localAllowPre, allowNonFullyAggregatedOutput, skipAggPassReductionRatio)
}
}
@@ -2165,8 +1991,7 @@ class DynamicGpuPartialSortAggregateIterator(
forceSinglePassAgg: Boolean,
allowSinglePassAgg: Boolean,
allowNonFullyAggregatedOutput: Boolean,
- skipAggPassReductionRatio: Double,
- aggFallbackAlgorithm: String
+ skipAggPassReductionRatio: Double
) extends Iterator[ColumnarBatch] {
private var aggIter: Option[Iterator[ColumnarBatch]] = None
private[this] val isReductionOnly = boundGroupExprs.outputTypes.isEmpty
@@ -2267,7 +2092,6 @@ class DynamicGpuPartialSortAggregateIterator(
useTiered,
allowNonFullyAggregatedOutput,
skipAggPassReductionRatio,
- aggFallbackAlgorithm,
localInputRowsMetrics)
GpuAggFinalPassIterator.makeIter(mergeIter, postBoundReferences, metrics)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala
index 1cbf899c04d..d83f20113b2 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala
@@ -61,7 +61,6 @@ object GpuMetric extends Logging {
val COLLECT_TIME = "collectTime"
val CONCAT_TIME = "concatTime"
val SORT_TIME = "sortTime"
- val REPARTITION_TIME = "repartitionTime"
val AGG_TIME = "computeAggTime"
val JOIN_TIME = "joinTime"
val FILTER_TIME = "filterTime"
@@ -96,7 +95,6 @@ object GpuMetric extends Logging {
val DESCRIPTION_COLLECT_TIME = "collect batch time"
val DESCRIPTION_CONCAT_TIME = "concat batch time"
val DESCRIPTION_SORT_TIME = "sort time"
- val DESCRIPTION_REPARTITION_TIME = "repartition time spent in agg"
val DESCRIPTION_AGG_TIME = "aggregation time"
val DESCRIPTION_JOIN_TIME = "join time"
val DESCRIPTION_FILTER_TIME = "filter time"
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
index 46c2806140e..aad4f05b334 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
@@ -1517,13 +1517,6 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.checkValue(v => v >= 0 && v <= 1, "The ratio value must be in [0, 1].")
.createWithDefault(1.0)
- val FALLBACK_ALGORITHM_FOR_OVERSIZE_AGG = conf("spark.rapids.sql.agg.fallbackAlgorithm")
- .doc("When agg cannot be done in a single pass, use sort-based fallback or " +
- "repartition-based fallback.")
- .stringConf
- .checkValues(Set("sort", "repartition"))
- .createWithDefault("sort")
-
val FORCE_SINGLE_PASS_PARTIAL_SORT_AGG: ConfEntryWithDefault[Boolean] =
conf("spark.rapids.sql.agg.forceSinglePassPartialSort")
.doc("Force a single pass partial sort agg to happen in all cases that it could, " +
@@ -3086,8 +3079,6 @@ class RapidsConf(conf: Map[String, String]) extends Logging {
lazy val skipAggPassReductionRatio: Double = get(SKIP_AGG_PASS_REDUCTION_RATIO)
- lazy val aggFallbackAlgorithm: String = get(FALLBACK_ALGORITHM_FOR_OVERSIZE_AGG)
-
lazy val isRegExpEnabled: Boolean = get(ENABLE_REGEXP)
lazy val maxRegExpStateMemory: Long = {