Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Case when perf #31

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion integration_tests/src/main/python/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql
from data_gen import *
from spark_session import is_before_spark_320, is_jvm_charset_utf8
from pyspark.sql.types import *
Expand Down Expand Up @@ -296,3 +296,17 @@ def test_conditional_with_side_effects_unary_minus(data_gen, ansi_enabled):
'CASE WHEN a > -32768 THEN -a ELSE null END'),
conf = {'spark.sql.ansi.enabled': ansi_enabled})

def test_case_when_all_then_values_are_scalars():
data_gen = [
("a", boolean_gen),
("b", boolean_gen),
("c", boolean_gen),
("d", boolean_gen),
("e", boolean_gen)]
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, data_gen),
"tab",
"select case when a then 'aaa' when b then 'bbb' when c then 'ccc' " +
"when d then 'ddd' when e then 'eee' else 'unknown' end from tab",
conf = {'spark.rapids.sql.case_when.fuse': 'true'})

Original file line number Diff line number Diff line change
Expand Up @@ -2024,7 +2024,7 @@ object GpuOverrides extends Logging {
} else {
None
}
GpuCaseWhen(branches, elseValue)
GpuCaseWhen(branches, elseValue, conf.caseWhenFuseEnabled)
}
}),
expr[If](
Expand Down
10 changes: 10 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2345,6 +2345,14 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression.
.stringConf
.createOptional

val CASE_WHEN_FUSE =
conf("spark.rapids.sql.case_when.fuse")
.doc("If when branches is greater than 2 and all then/else values in case when are string " +
"scalar, fuse mode improves the performance. By default this is enabled.")
.internal()
.booleanConf
.createWithDefault(true)

private def printSectionHeader(category: String): Unit =
println(s"\n### $category")

Expand Down Expand Up @@ -3162,6 +3170,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val isDeltaLowShuffleMergeEnabled: Boolean = get(ENABLE_DELTA_LOW_SHUFFLE_MERGE)

lazy val caseWhenFuseEnabled: Boolean = get(CASE_WHEN_FUSE)

private val optimizerDefaults = Map(
// this is not accurate because CPU projections do have a cost due to appending values
// to each row that is produced, but this needs to be a really small number because
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ package com.nvidia.spark.rapids
import ai.rapids.cudf.{BinaryOp, ColumnVector, DType, NullPolicy, Scalar, ScanAggregation, ScanType, Table, UnaryOp}
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.CaseWhen
import com.nvidia.spark.rapids.shims.ShimExpression

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, Expression}
import org.apache.spark.sql.types.{BooleanType, DataType, DataTypes}
import org.apache.spark.sql.types.{BooleanType, DataType, DataTypes, StringType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String

object GpuExpressionWithSideEffectUtils {

Expand All @@ -47,7 +49,7 @@ object GpuExpressionWithSideEffectUtils {

/**
* Used to shortcircuit predicates and filter conditions.
*
*
* @param nullsAsFalse when true, null values are considered false.
* @param col the input being evaluated.
* @return boolean. When nullsAsFalse is set, it returns True if none of the rows is true;
Expand Down Expand Up @@ -182,9 +184,9 @@ case class GpuIf(
predicateExpr: Expression,
trueExpr: Expression,
falseExpr: Expression) extends GpuConditionalExpression {

import GpuExpressionWithSideEffectUtils._

@transient
override lazy val inputTypesForMerging: Seq[DataType] = {
Seq(trueExpr.dataType, falseExpr.dataType)
Expand Down Expand Up @@ -314,7 +316,9 @@ case class GpuIf(

case class GpuCaseWhen(
branches: Seq[(Expression, Expression)],
elseValue: Option[Expression] = None) extends GpuConditionalExpression with Serializable {
elseValue: Option[Expression] = None,
caseWhenFuseEnabled: Boolean = true)
extends GpuConditionalExpression with Serializable {

import GpuExpressionWithSideEffectUtils._

Expand Down Expand Up @@ -359,15 +363,60 @@ case class GpuCaseWhen(
if (branchesWithSideEffects) {
columnarEvalWithSideEffects(batch)
} else {
// `elseRet` will be closed in `computeIfElse`.
val elseRet = elseValue
.map(_.columnarEvalAny(batch))
.getOrElse(GpuScalar(null, branches.last._2.dataType))
val any = branches.foldRight[Any](elseRet) {
case ((predicateExpr, trueExpr), falseRet) =>
computeIfElse(batch, predicateExpr, trueExpr, falseRet)
if (caseWhenFuseEnabled && branches.size > 2 &&
inputTypesForMerging.head == StringType &&
(branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[GpuLiteral])
) {
// when branches size > 2;
// return type is string type;
// all the then and else expressions are Scalars.
// Avoid to use multiple `computeIfElse`s which will create multiple temp columns

// 1. select first true index from bool columns, if no true, index will be out of bound
// e.g.:
// case when bool result column 0: true, false, false
// case when bool result column 1: false, true, false
// result is: [0, 1, 2]
val whenBoolCols = branches.safeMap(_._1.columnarEval(batch).getBase).toArray
val firstTrueIndex: ColumnVector = withResource(whenBoolCols) { _ =>
CaseWhen.selectFirstTrueIndex(whenBoolCols)
}

withResource(firstTrueIndex) { _ =>
val thenElseScalars = (branches.map(_._2) ++ elseValue).map(_.columnarEvalAny(batch)
.asInstanceOf[GpuScalar])
withResource(thenElseScalars) { _ =>
// 2. generate a column to store all scalars
val scalarsBytes = thenElseScalars.map(ret => ret.getValue
.asInstanceOf[UTF8String].getBytes)
val scalarCol = ColumnVector.fromUTF8Strings(scalarsBytes: _*)
withResource(scalarCol) { _ =>

val finalRet = withResource(new Table(scalarCol)) { oneColumnTable =>
// 3. execute final select
// default gather OutOfBoundsPolicy is nullify,
// If index is out of bound, return null
withResource(oneColumnTable.gather(firstTrueIndex)) { resultTable =>
resultTable.getColumn(0).incRefCount()
}
}
// return final column vector
GpuColumnVector.from(finalRet, dataType)
}
}
}
} else {
// execute from tail to front recursively
// `elseRet` will be closed in `computeIfElse`.
val elseRet = elseValue
.map(_.columnarEvalAny(batch))
.getOrElse(GpuScalar(null, branches.last._2.dataType))
val any = branches.foldRight[Any](elseRet) {
case ((predicateExpr, trueExpr), falseRet) =>
computeIfElse(batch, predicateExpr, trueExpr, falseRet)
}
GpuExpressionsUtils.resolveColumnVector(any, batch.numRows())
}
GpuExpressionsUtils.resolveColumnVector(any, batch.numRows())
}
}

Expand Down