diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py index f2dd21b0bf3..75e06cfaea7 100644 --- a/integration_tests/src/main/python/conditionals_test.py +++ b/integration_tests/src/main/python/conditionals_test.py @@ -14,7 +14,7 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql from data_gen import * from spark_session import is_before_spark_320, is_jvm_charset_utf8 from pyspark.sql.types import * @@ -296,3 +296,17 @@ def test_conditional_with_side_effects_unary_minus(data_gen, ansi_enabled): 'CASE WHEN a > -32768 THEN -a ELSE null END'), conf = {'spark.sql.ansi.enabled': ansi_enabled}) +def test_case_when_all_then_values_are_scalars(): + data_gen = [ + ("a", boolean_gen), + ("b", boolean_gen), + ("c", boolean_gen), + ("d", boolean_gen), + ("e", boolean_gen)] + assert_gpu_and_cpu_are_equal_sql( + lambda spark : gen_df(spark, data_gen), + "tab", + "select case when a then 'aaa' when b then 'bbb' when c then 'ccc' " + + "when d then 'ddd' when e then 'eee' else 'unknown' end from tab", + conf = {'spark.rapids.sql.case_when.fuse': 'true'}) + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 09983933c6d..5cb527e33cd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -2345,9 +2345,10 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .stringConf .createOptional - val CASE_WHEN_FUSE = + val CASE_WHEN_FUSE = conf("spark.rapids.sql.case_when.fuse") - .doc("") + .doc("If when branches is greater than 2 and all then/else values in case when are string " + + "scalar, fuse mode improves the performance. By default this is enabled.") .internal() .booleanConf .createWithDefault(true) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala index 79888c4e73e..1c0ff3f01f5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala @@ -22,7 +22,6 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.CaseWhen import com.nvidia.spark.rapids.shims.ShimExpression -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, Expression} import org.apache.spark.sql.types.{BooleanType, DataType, DataTypes, StringType} @@ -318,7 +317,8 @@ case class GpuIf( case class GpuCaseWhen( branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None, - caseWhenFuseEnabled: Boolean = true) extends GpuConditionalExpression with Serializable with Logging { + caseWhenFuseEnabled: Boolean = true) + extends GpuConditionalExpression with Serializable { import GpuExpressionWithSideEffectUtils._ @@ -363,34 +363,50 @@ case class GpuCaseWhen( if (branchesWithSideEffects) { columnarEvalWithSideEffects(batch) } else { - if (caseWhenFuseEnabled && + if (caseWhenFuseEnabled && branches.size > 2 && inputTypesForMerging.head == StringType && (branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[GpuLiteral]) ) { - // return type is string type; all the then and else exprs are Scalars - // avoid to use multiple `computeIfElse`s which will create multiple temp columns - logWarning("==================== Running case with experimental =========== ") - + // when branches size > 2; + // return type is string type; + // all the then and else expressions are Scalars. + // Avoid to use multiple `computeIfElse`s which will create multiple temp columns + + // 1. select first true index from bool columns, if no true, index will be out of bound + // e.g.: + // case when bool result column 0: true, false, false + // case when bool result column 1: false, true, false + // result is: [0, 1, 2] val whenBoolCols = branches.safeMap(_._1.columnarEval(batch).getBase).toArray val firstTrueIndex: ColumnVector = withResource(whenBoolCols) { _ => CaseWhen.selectFirstTrueIndex(whenBoolCols) } withResource(firstTrueIndex) { _ => - val thenElseScalars = (branches.map(_._2) ++ elseValue).map(_.columnarEvalAny(batch).asInstanceOf[GpuScalar]) + val thenElseScalars = (branches.map(_._2) ++ elseValue).map(_.columnarEvalAny(batch) + .asInstanceOf[GpuScalar]) withResource(thenElseScalars) { _ => - // generate a column to store all scalars - val scalarsBytes = thenElseScalars.map(ret => ret.getValue.asInstanceOf[UTF8String].getBytes) + // 2. generate a column to store all scalars + val scalarsBytes = thenElseScalars.map(ret => ret.getValue + .asInstanceOf[UTF8String].getBytes) val scalarCol = ColumnVector.fromUTF8Strings(scalarsBytes: _*) withResource(scalarCol) { _ => - // execute final select - val finalRet = CaseWhen.selectFromIndex(scalarCol, firstTrueIndex) + + val finalRet = withResource(new Table(scalarCol)) { oneColumnTable => + // 3. execute final select + // default gather OutOfBoundsPolicy is nullify, + // If index is out of bound, return null + withResource(oneColumnTable.gather(firstTrueIndex)) { resultTable => + resultTable.getColumn(0).incRefCount() + } + } // return final column vector GpuColumnVector.from(finalRet, dataType) } } } } else { + // execute from tail to front recursively // `elseRet` will be closed in `computeIfElse`. val elseRet = elseValue .map(_.columnarEvalAny(batch))