diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py index f2dd21b0bf3..75e06cfaea7 100644 --- a/integration_tests/src/main/python/conditionals_test.py +++ b/integration_tests/src/main/python/conditionals_test.py @@ -14,7 +14,7 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql from data_gen import * from spark_session import is_before_spark_320, is_jvm_charset_utf8 from pyspark.sql.types import * @@ -296,3 +296,17 @@ def test_conditional_with_side_effects_unary_minus(data_gen, ansi_enabled): 'CASE WHEN a > -32768 THEN -a ELSE null END'), conf = {'spark.sql.ansi.enabled': ansi_enabled}) +def test_case_when_all_then_values_are_scalars(): + data_gen = [ + ("a", boolean_gen), + ("b", boolean_gen), + ("c", boolean_gen), + ("d", boolean_gen), + ("e", boolean_gen)] + assert_gpu_and_cpu_are_equal_sql( + lambda spark : gen_df(spark, data_gen), + "tab", + "select case when a then 'aaa' when b then 'bbb' when c then 'ccc' " + + "when d then 'ddd' when e then 'eee' else 'unknown' end from tab", + conf = {'spark.rapids.sql.case_when.fuse': 'true'}) + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index fa091894f14..ee402948088 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2024,7 +2024,7 @@ object GpuOverrides extends Logging { } else { None } - GpuCaseWhen(branches, elseValue) + GpuCaseWhen(branches, elseValue, conf.caseWhenFuseEnabled) } }), expr[If]( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index a29c816b67b..5cb527e33cd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -2345,6 +2345,14 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .stringConf .createOptional + val CASE_WHEN_FUSE = + conf("spark.rapids.sql.case_when.fuse") + .doc("If when branches is greater than 2 and all then/else values in case when are string " + + "scalar, fuse mode improves the performance. By default this is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + private def printSectionHeader(category: String): Unit = println(s"\n### $category") @@ -3162,6 +3170,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isDeltaLowShuffleMergeEnabled: Boolean = get(ENABLE_DELTA_LOW_SHUFFLE_MERGE) + lazy val caseWhenFuseEnabled: Boolean = get(CASE_WHEN_FUSE) + private val optimizerDefaults = Map( // this is not accurate because CPU projections do have a cost due to appending values // to each row that is produced, but this needs to be a really small number because diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala index 0a5e0daa9ff..1c0ff3f01f5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala @@ -19,12 +19,14 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{BinaryOp, ColumnVector, DType, NullPolicy, Scalar, ScanAggregation, ScanType, Table, UnaryOp} import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.jni.CaseWhen import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, Expression} -import org.apache.spark.sql.types.{BooleanType, DataType, DataTypes} +import org.apache.spark.sql.types.{BooleanType, DataType, DataTypes, StringType} import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.unsafe.types.UTF8String object GpuExpressionWithSideEffectUtils { @@ -47,7 +49,7 @@ object GpuExpressionWithSideEffectUtils { /** * Used to shortcircuit predicates and filter conditions. - * + * * @param nullsAsFalse when true, null values are considered false. * @param col the input being evaluated. * @return boolean. When nullsAsFalse is set, it returns True if none of the rows is true; @@ -182,9 +184,9 @@ case class GpuIf( predicateExpr: Expression, trueExpr: Expression, falseExpr: Expression) extends GpuConditionalExpression { - + import GpuExpressionWithSideEffectUtils._ - + @transient override lazy val inputTypesForMerging: Seq[DataType] = { Seq(trueExpr.dataType, falseExpr.dataType) @@ -314,7 +316,9 @@ case class GpuIf( case class GpuCaseWhen( branches: Seq[(Expression, Expression)], - elseValue: Option[Expression] = None) extends GpuConditionalExpression with Serializable { + elseValue: Option[Expression] = None, + caseWhenFuseEnabled: Boolean = true) + extends GpuConditionalExpression with Serializable { import GpuExpressionWithSideEffectUtils._ @@ -359,15 +363,60 @@ case class GpuCaseWhen( if (branchesWithSideEffects) { columnarEvalWithSideEffects(batch) } else { - // `elseRet` will be closed in `computeIfElse`. - val elseRet = elseValue - .map(_.columnarEvalAny(batch)) - .getOrElse(GpuScalar(null, branches.last._2.dataType)) - val any = branches.foldRight[Any](elseRet) { - case ((predicateExpr, trueExpr), falseRet) => - computeIfElse(batch, predicateExpr, trueExpr, falseRet) + if (caseWhenFuseEnabled && branches.size > 2 && + inputTypesForMerging.head == StringType && + (branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[GpuLiteral]) + ) { + // when branches size > 2; + // return type is string type; + // all the then and else expressions are Scalars. + // Avoid to use multiple `computeIfElse`s which will create multiple temp columns + + // 1. select first true index from bool columns, if no true, index will be out of bound + // e.g.: + // case when bool result column 0: true, false, false + // case when bool result column 1: false, true, false + // result is: [0, 1, 2] + val whenBoolCols = branches.safeMap(_._1.columnarEval(batch).getBase).toArray + val firstTrueIndex: ColumnVector = withResource(whenBoolCols) { _ => + CaseWhen.selectFirstTrueIndex(whenBoolCols) + } + + withResource(firstTrueIndex) { _ => + val thenElseScalars = (branches.map(_._2) ++ elseValue).map(_.columnarEvalAny(batch) + .asInstanceOf[GpuScalar]) + withResource(thenElseScalars) { _ => + // 2. generate a column to store all scalars + val scalarsBytes = thenElseScalars.map(ret => ret.getValue + .asInstanceOf[UTF8String].getBytes) + val scalarCol = ColumnVector.fromUTF8Strings(scalarsBytes: _*) + withResource(scalarCol) { _ => + + val finalRet = withResource(new Table(scalarCol)) { oneColumnTable => + // 3. execute final select + // default gather OutOfBoundsPolicy is nullify, + // If index is out of bound, return null + withResource(oneColumnTable.gather(firstTrueIndex)) { resultTable => + resultTable.getColumn(0).incRefCount() + } + } + // return final column vector + GpuColumnVector.from(finalRet, dataType) + } + } + } + } else { + // execute from tail to front recursively + // `elseRet` will be closed in `computeIfElse`. + val elseRet = elseValue + .map(_.columnarEvalAny(batch)) + .getOrElse(GpuScalar(null, branches.last._2.dataType)) + val any = branches.foldRight[Any](elseRet) { + case ((predicateExpr, trueExpr), falseRet) => + computeIfElse(batch, predicateExpr, trueExpr, falseRet) + } + GpuExpressionsUtils.resolveColumnVector(any, batch.numRows()) } - GpuExpressionsUtils.resolveColumnVector(any, batch.numRows()) } }