diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala index e3489904466..74edcf54f99 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala @@ -32,8 +32,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, ExprId} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, UnaryLike} -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.rapids.GpuTaskMetrics import org.apache.spark.sql.vectorized.ColumnarBatch @@ -406,7 +406,7 @@ trait GpuExec extends SparkPlan { if (loreDumpOperator.exists(o => o.equals(className)) || loreDumpLOREIds.split(',').contains(myLoreId) ) { - val childAsGpuExec = this.asInstanceOf[UnaryLike[SparkPlan]].child.asInstanceOf[GpuExec] + val childAsGpuExec = this.asInstanceOf[UnaryExecNode].child.asInstanceOf[GpuExec] childAsGpuExec.shouldDumpOutput = true childAsGpuExec.dumpForLOREId = myLoreId val childPlanId = childAsGpuExec.id diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/profiling/DumpedExecReplayer.scala b/tests/src/test/scala/com/nvidia/spark/rapids/profiling/DumpedExecReplayer.scala index 4a756608bd9..750bff91bd5 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/profiling/DumpedExecReplayer.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/profiling/DumpedExecReplayer.scala @@ -23,7 +23,7 @@ import com.nvidia.spark.rapids.GpuExec import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.execution.UnaryExecNode object DumpedExecReplayer extends Logging { @@ -56,9 +56,9 @@ object DumpedExecReplayer extends Logging { // restore SparkPlan val restoredExec = deserializeObject[GpuExec](planMetaPath) - if (!restoredExec.isInstanceOf[UnaryLike[_]]) throw new IllegalStateException( + if (!restoredExec.isInstanceOf[UnaryExecNode]) throw new IllegalStateException( s"For now, restored exec only supports UnaryLike: ${restoredExec.getClass}") - val unaryLike = restoredExec.asInstanceOf[UnaryLike[_]] + val unaryLike = restoredExec.asInstanceOf[UnaryExecNode] if (!unaryLike.child.isInstanceOf[GpuExec]) throw new IllegalStateException( s"For now, restored exec's child only supports GpuExec: " +