diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/DataWritingCommandExecParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/DataWritingCommandExecParser.scala index 86e553ac1..4c4439eb9 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/DataWritingCommandExecParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/DataWritingCommandExecParser.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2022-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ package com.nvidia.spark.rapids.tool.planparser import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker import org.apache.spark.sql.execution.ui.SparkPlanGraphNode +import org.apache.spark.sql.rapids.tool.store.{WriteOperationMetaBuilder, WriteOperationMetadataTrait} +import org.apache.spark.sql.rapids.tool.util.StringUtils case class DataWritingCommandExecParser( node: SparkPlanGraphNode, @@ -106,13 +108,13 @@ object DataWritingCommandExecParser { // This used for expressions that do not show the format as part of the description. private val specialWriteFormatMap = Map[String, String]( // if appendDataExecV1 is not deltaLakeProvider, then we want to mark it as unsupported - appendDataExecV1 -> "unknown", + appendDataExecV1 -> StringUtils.UNKNOWN_EXTRACT, // if overwriteByExprExecV1 is not deltaLakeProvider, then we want to mark it as unsupported - overwriteByExprExecV1 -> "unknown", + overwriteByExprExecV1 -> StringUtils.UNKNOWN_EXTRACT, // if atomicReplaceTableExec is not deltaLakeProvider, then we want to mark it as unsupported - atomicReplaceTableExec -> "unknown", + atomicReplaceTableExec -> StringUtils.UNKNOWN_EXTRACT, // if atomicCreateTableExec is not deltaLakeProvider, then we want to mark it as unsupported - atomicCreateTableExec -> "unknown" + atomicCreateTableExec -> StringUtils.UNKNOWN_EXTRACT ) // Checks whether a node is a write CMD Exec @@ -175,4 +177,180 @@ object DataWritingCommandExecParser { parsedString.split(",")(0) // return third parameter from the input string } } + + /** + * Extracts metadata information from a write operation node description. + * This method is specifically designed to parse the description of + * `InsertIntoHadoopFsRelationCommand` nodes and extract relevant details + * such as the output path, data format, write mode, catalog information, + * and output columns. + * An example of the pattern is: + * Execute InsertIntoHadoopFsRelationCommand /path/to/warehouse/database/table, false, format, + * [key1=value1, key2=value2], Append, `SparkCatalog`.`database`.`table`, ClassName, + * [outputColumns] + * + * The method performs the following steps: + * — Extracts the output path and data format from the node description. + * — Determines the write mode (e.g., Append, Overwrite) based on specific keywords in the + * description. + * — Extracts catalog information (database and table name) from the output path. + * — Extracts the output columns if available in the description. + * — Builds and returns a `WriteOperationMetadataTrait` object encapsulating the extracted + * metadata. + * + * This method includes error handling to ensure graceful fallback to default values + * (e.g., `UNKNOWN_EXTRACT`) in case of unexpected input or parsing errors. + * + * @param execName The name of the execution command (e.g., `InsertIntoHadoopFsRelationCommand`). + * @param nodeDescr The description of the node, typically containing details about the write + * operation. + * @return A `WriteOperationMetadataTrait` object containing the extracted metadata. + */ + private def extractWriteOpRecord( + execName: String, nodeDescr: String): WriteOperationMetadataTrait = { + // Helper function to extract catalog information (database and table name) from the output + // path. + def extractCatalog(path: String): (String, String) = { + try { + // The location path contains the database and the table as the last 2 entries. + // Example: gs:///path/to/warehouse/database/table + // Split the URI into parts by "/" + val pathParts = path.split("/").filter(_.nonEmpty) + if (pathParts.length >= 2) { + // Extract the last two parts as database and table name + val database = pathParts(pathParts.length - 2) + val tableName = pathParts.last + (database, tableName) + } else { + // If not enough parts, return UNKNOWN_EXTRACT + (StringUtils.UNKNOWN_EXTRACT, StringUtils.UNKNOWN_EXTRACT) + } + } catch { + // Handle any unexpected errors gracefully + case _: Exception => (StringUtils.UNKNOWN_EXTRACT, StringUtils.UNKNOWN_EXTRACT) + } + } + + // Helper function to extract the output path and data format from the node description. + def extractPathAndFormat(args: Array[String]): (String, String) = { + // This method expects the arguments to be nodeDescr.split(",", 3) + // `Execute cmd path/to/warehouse/db/table, false, parquet, [write options],.*`. + // — 1st arg is always the cmd followed by the path. + // — 2nd arg is boolean argument that we do not care about. + // — 3rd arg is either the format, or the list of write options. + + // Extract the path from the first argument + val path = + args.headOption.map(_.split("\\s+").last.trim).getOrElse(StringUtils.UNKNOWN_EXTRACT) + // Extract the data format from the third argument + val thirdArg = args.lift(2).getOrElse("").trim + val format = if (thirdArg.startsWith("[")) { + // Optional parameter is present in the eventlog. Get the fourth parameter by skipping the + // optional parameter string. + thirdArg.split("(?<=],)") + .map(_.trim).lift(1).getOrElse("").split(",").headOption.getOrElse("").trim + } else { + thirdArg.split(",").headOption.getOrElse("").trim + } + (path, format) + } + + // Helper function to determine the write mode (e.g., Append, Overwrite) from the description. + def extractWriteMode(description: String): String = { + val modes = Map( + ", Append," -> "Append", + ", Overwrite," -> "Overwrite", + ", ErrorIfExists," -> "ErrorIfExists", + ", Ignore," -> "Ignore" + ) + // Match the description against known write modes + modes.collectFirst { case (key, mode) if description.contains(key) => mode } + .getOrElse(StringUtils.UNKNOWN_EXTRACT) + } + + // Helper function to extract output columns from the node description. + def extractOutputColumns(description: String): Option[String] = { + // The output columns is found as the last sequence inside a bracket. This method, uses a + // regex to match on string values inside a bracket. Then it picks the last one. + // Use a regular expression to find column definitions enclosed in square brackets. + val columnsRegex = """\[(.*?)\]""".r + columnsRegex.findAllMatchIn(description).map(_.group(1)).toList.lastOption + .map(_.replaceAll(",\\s+", ";")) // Replace commas with semicolons for better readability + } + + // Parse the node description into arguments + val splitArgs = nodeDescr.split(",", 3) + + // Extract the output path and data format + val (path, format) = extractPathAndFormat(splitArgs) + + // Extract the write mode (e.g., Append, Overwrite) + val writeMode = extractWriteMode(nodeDescr) + + // Extract catalog information (database and table name) from the output path + val (catalogDB, catalogTable) = extractCatalog(path) + + // Extract the output columns, if available + val outColumns = extractOutputColumns(nodeDescr) + + // Build and return the metadata object encapsulating all extracted information + WriteOperationMetaBuilder.build( + execName = execName, + dataFormat = format, + outputPath = Option(path), + outputColumns = outColumns, + writeMode = writeMode, + tableName = catalogTable, + dataBaseName = catalogDB, + fullDescr = Some(nodeDescr) + ) + } + + /** + * Extracts metadata information from a given SparkPlanGraphNode representing a write operation. + * + * This method determines the type of write operation (e.g., Delta Lake or other supported + * commands) and extracts relevant metadata such as execution name, data format, output path, + * write mode, and catalog information. It uses helper methods to parse the node description + * and build a metadata object encapsulating the extracted details. + * + * The method performs the following steps: + * 1. Retrieves the node description from the provided SparkPlanGraphNode. + * 2. Checks if the node is a Delta Lake write operation using DeltaLakeHelper. + * 3. Retrieves the appropriate command wrapper (logical or physical) for the node. + * 4. If the command is `InsertIntoHadoopFsRelationCommand`, it invokes a specialized method + * `extractWriteOpRecord` to extract detailed metadata. + * 5. For other commands, it builds a metadata object using the command wrapper's information. + * 6. If no command wrapper is found, it falls back to building a metadata object with minimal + * information. + * + * @param node The SparkPlanGraphNode representing the write operation. + * @return A WriteOperationMetadataTrait object containing the extracted metadata. + */ + def getWriteOpMetaFromNode(node: SparkPlanGraphNode): WriteOperationMetadataTrait = { + // Determine the appropriate command wrapper based on whether the node is a Delta Lake write + // operation. + val cmdWrapper = if (DeltaLakeHelper.acceptsWriteOp(node)) { + DeltaLakeHelper.getWriteCMDWrapper(node) + } else { + getWriteCMDWrapper(node) + } + // Process the command wrapper to extract metadata + cmdWrapper match { + case Some(cmdWrapper) => + // If the command is InsertIntoHadoopFsRelationCommand, extract detailed metadata. + if (cmdWrapper.execName == DataWritingCommandExecParser.insertIntoHadoopCMD) { + extractWriteOpRecord(cmdWrapper.execName, node.desc) + } else { + // For other commands, build metadata using the command wrapper information. + WriteOperationMetaBuilder.build( + execName = cmdWrapper.execName, + dataFormat = cmdWrapper.dataFormat, + fullDescr = Some(node.desc)) + } + case _ => + // No command wrapper is found, build metadata with minimal information. + WriteOperationMetaBuilder.buildNoMeta(Some(node.desc)) + } + } } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/DeltaLakeHelper.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/DeltaLakeHelper.scala index 3822706ca..2edf0b550 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/DeltaLakeHelper.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/DeltaLakeHelper.scala @@ -20,6 +20,7 @@ import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker import org.apache.spark.sql.execution.ui.{SparkPlanGraphCluster, SparkPlanGraphNode} import org.apache.spark.sql.rapids.tool.SqlPlanInfoGraphEntry +import org.apache.spark.sql.rapids.tool.util.StringUtils // A class used to handle the DL writeOps such as: // - AppendDataExecV1 @@ -135,8 +136,6 @@ object DeltaLakeHelper { def parseNode(node: SparkPlanGraphNode, checker: PluginTypeChecker, sqlID: Long): ExecInfo = { - val opExec = new DLWriteWithFormatAndSchemaParser(node, checker, sqlID) - opExec.parse node match { case n if acceptsWriteOp(n) => val opExec = new DLWriteWithFormatAndSchemaParser(node, checker, sqlID) @@ -145,6 +144,25 @@ object DeltaLakeHelper { } } + /** + * Get the write command wrapper for the given node deltaLake exec node. + * This method should be called only if the node passes the `acceptsWriteOp` check. + * @param node the deltaLake write exec + * @return the write command wrapper + */ + def getWriteCMDWrapper(node: SparkPlanGraphNode): Option[DataWritingCmdWrapper] = { + val wcmd = exclusiveDeltaExecs.find(node.name.contains(_)) match { + case Some(cmd) => cmd + case _ => + deltaExecsFromSpark.find(node.name.contains(_)) match { + case Some(cmd) => cmd + case _ => StringUtils.UNKNOWN_EXTRACT + } + } + // The format must be delta + Some(DataWritingCmdWrapper(wcmd, DataWritingCommandExecParser.dataWriteCMD, getWriteFormat)) + } + // Kept for future use if we find that SerDe library can be used to deduce any information to // reflect on the support of the Op def getSerdeLibrary(nodeDesc: String): Option[String] = { diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HiveParseHelper.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HiveParseHelper.scala index 8384e39e6..1c95cceee 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HiveParseHelper.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/HiveParseHelper.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.tool.planparser import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.ui.SparkPlanGraphNode -import org.apache.spark.sql.rapids.tool.util.EventUtils +import org.apache.spark.sql.rapids.tool.util.{EventUtils, StringUtils} // A wrapper class to map between case class HiveScanSerdeClasses(className: String, format: String) extends Logging { @@ -68,7 +68,7 @@ object HiveParseHelper extends Logging { } def getHiveFormatFromSimpleStr(str: String): String = { - LOADED_SERDE_CLASSES.find(_.accepts(str)).map(_.format).getOrElse("unknown") + LOADED_SERDE_CLASSES.find(_.accepts(str)).map(_.format).getOrElse(StringUtils.UNKNOWN_EXTRACT) } // Given a "scan hive" NodeGraph, construct the MetaData based on the SerDe class. diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ReadParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ReadParser.scala index 0efb36a45..d3ab0865a 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ReadParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ReadParser.scala @@ -22,6 +22,7 @@ import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.ui.SparkPlanGraphNode +import org.apache.spark.sql.rapids.tool.util.StringUtils case class ReadMetaData(schema: String, location: String, format: String, tags: Map[String, String] = ReadParser.DEFAULT_METAFIELD_MAP) { @@ -60,7 +61,7 @@ object ReadParser extends Logging { val METAFIELD_TAG_FORMAT = "Format" val METAFIELD_TAG_LOCATION = "Location" - val UNKNOWN_METAFIELD: String = "unknown" + val UNKNOWN_METAFIELD: String = StringUtils.UNKNOWN_EXTRACT val DEFAULT_METAFIELD_MAP: Map[String, String] = collection.immutable.Map( METAFIELD_TAG_DATA_FILTERS -> UNKNOWN_METAFIELD, METAFIELD_TAG_PUSHED_FILTERS -> UNKNOWN_METAFIELD, diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorRefBase.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorRefBase.scala index 04bd9a190..fb1d25bf3 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorRefBase.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/ops/OperatorRefBase.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,8 +31,8 @@ import org.apache.spark.sql.rapids.tool.util.StringUtils class OperatorRefBase(val value: String, val opType: OpTypes.OpType) extends OperatorRefTrait { // Preformatted values for CSV output to avoid reformatting multiple times. - val csvValue: String = StringUtils.reformatCSVString(value) - val csvOpType: String = StringUtils.reformatCSVString(opType.toString) + lazy val csvValue: String = StringUtils.reformatCSVString(value) + lazy val csvOpType: String = StringUtils.reformatCSVString(opType.toString) override def getOpName: String = value override def getOpNameCSV: String = csvValue diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationSummaryInfo.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationSummaryInfo.scala index 93d17f981..d8c4332eb 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationSummaryInfo.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationSummaryInfo.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids.tool.profiling import com.nvidia.spark.rapids.SparkRapidsBuildInfoEvent import com.nvidia.spark.rapids.tool.AppSummaryInfoBaseProvider +import com.nvidia.spark.rapids.tool.views.WriteOpProfileResult case class ApplicationSummaryInfo( appInfo: Seq[AppInfoProfileResults], @@ -47,7 +48,8 @@ case class ApplicationSummaryInfo( ioMetrics: Seq[IOAnalysisProfileResult], sysProps: Seq[RapidsPropertyProfileResult], sqlCleanedAlignedIds: Seq[SQLCleanAndAlignIdsProfileResult], - sparkRapidsBuildInfo: Seq[SparkRapidsBuildInfoEvent]) + sparkRapidsBuildInfo: Seq[SparkRapidsBuildInfoEvent], + writeOpsInfo: Seq[WriteOpProfileResult]) trait AppInfoPropertyGetter { // returns all the properties (i.e., spark) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CollectInformation.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CollectInformation.scala index 6d677b072..86ef34bdb 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CollectInformation.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CollectInformation.scala @@ -56,6 +56,11 @@ class CollectInformation(apps: Seq[ApplicationInfo]) extends Logging { ProfDataSourceView.getRawView(apps, cachedSqlAccum) } + // get the write records information + def getWriteOperationInfo: Seq[WriteOpProfileResult] = { + ProfWriteOpsView.getRawView(apps) + } + // get executor related information def getExecutorInfo: Seq[ExecutorInfoProfileResult] = { ProfExecutorView.getRawView(apps) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala index 620509804..53e7573bc 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala @@ -399,13 +399,15 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea val endTime = System.currentTimeMillis() logInfo(s"Took ${endTime - startTime}ms to Process [${appInfo.head.appId}]") (ApplicationSummaryInfo(appInfo, dsInfo, - collect.getExecutorInfo, collect.getJobInfo, rapidsProps, - rapidsJar, sqlMetrics, stageMetrics, analysis.jobAggs, analysis.stageAggs, - analysis.sqlAggs, analysis.sqlDurAggs, analysis.taskShuffleSkew, - failedTasks, failedStages, failedJobs, removedBMs, removedExecutors, - unsupportedOps, sparkProps, collect.getSQLToStage, wholeStage, maxTaskInputInfo, - appLogPath, analysis.ioAggs, systemProps, sqlIdAlign, sparkRapidsBuildInfo), - compareRes, DiagnosticSummaryInfo(analysis.stageDiagnostics, collect.getIODiagnosticMetrics)) + collect.getExecutorInfo, collect.getJobInfo, rapidsProps, + rapidsJar, sqlMetrics, stageMetrics, analysis.jobAggs, analysis.stageAggs, + analysis.sqlAggs, analysis.sqlDurAggs, analysis.taskShuffleSkew, + failedTasks, failedStages, failedJobs, removedBMs, removedExecutors, + unsupportedOps, sparkProps, collect.getSQLToStage, wholeStage, maxTaskInputInfo, + appLogPath, analysis.ioAggs, systemProps, sqlIdAlign, sparkRapidsBuildInfo, + collect.getWriteOperationInfo), + compareRes, + DiagnosticSummaryInfo(analysis.stageDiagnostics, collect.getIODiagnosticMetrics)) } /** @@ -502,7 +504,8 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea appsSum.flatMap(_.ioMetrics).sortBy(_.appIndex), combineProps("system", appsSum).sortBy(_.key), appsSum.flatMap(_.sqlCleanedAlignedIds).sortBy(_.appIndex), - appsSum.flatMap(_.sparkRapidsBuildInfo) + appsSum.flatMap(_.sparkRapidsBuildInfo), + appsSum.flatMap(_.writeOpsInfo).sortBy(_.appIndex) ) Seq(reduced) } else { @@ -546,6 +549,8 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea Some(AGG_DESCRIPTION(SQL_AGG_LABEL))) profileOutputWriter.write(IO_LABEL, app.ioMetrics) profileOutputWriter.write(SQL_DUR_LABEL, app.durAndCpuMet) + // writeOps are generated in only CSV format + profileOutputWriter.writeCSVTable(ProfWriteOpsView.getLabel, app.writeOpsInfo) val skewHeader = TASK_SHUFFLE_SKEW val skewTableDesc = AGG_DESCRIPTION(TASK_SHUFFLE_SKEW) profileOutputWriter.write(skewHeader, app.skewInfo, tableDesc = Some(skewTableDesc)) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualSQLPlanAnalyzer.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualSQLPlanAnalyzer.scala index cbaff4443..5873c05d7 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualSQLPlanAnalyzer.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualSQLPlanAnalyzer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,7 @@ package com.nvidia.spark.rapids.tool.qualification import com.nvidia.spark.rapids.tool.analysis.AppSQLPlanAnalyzer -import com.nvidia.spark.rapids.tool.planparser.DataWritingCommandExecParser -import org.apache.spark.sql.execution.ui.SparkPlanGraphNode import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo /** @@ -35,14 +33,4 @@ import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo class QualSQLPlanAnalyzer( app: QualificationAppInfo, appIndex: Integer) extends AppSQLPlanAnalyzer(app, appIndex) { - override def visitNode(visitor: SQLPlanVisitorContext, - node: SparkPlanGraphNode): Unit = { - super.visitNode(visitor, node) - // Get the write data format - if (!app.perSqlOnly) { - DataWritingCommandExecParser.getWriteCMDWrapper(node).map { wWrapper => - app.writeDataFormat += wWrapper.dataFormat - } - } - } } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/views/JobView.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/views/JobView.scala index 9edbcb45d..7d7ca1160 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/views/JobView.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/views/JobView.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ trait AppFailedJobsViewTrait extends ViewableTrait[FailedJobsProfileResults] { } jobsFailed.map { case (id, jc) => val failureStr = jc.failedReason.getOrElse("") - FailedJobsProfileResults(index, id, jc.jobResult.getOrElse("Unknown"), + FailedJobsProfileResults(index, id, jc.jobResult.getOrElse(StringUtils.UNKNOWN_EXTRACT), StringUtils.renderStr(failureStr, doEscapeMetaCharacters = false, maxLength = 0)) }.toSeq } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/views/QualRawReportGenerator.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/views/QualRawReportGenerator.scala index 0fd30525a..d331ee85b 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/views/QualRawReportGenerator.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/views/QualRawReportGenerator.scala @@ -109,6 +109,8 @@ object QualRawReportGenerator extends Logging { pWriter.write(QualAppFailedJobView.getLabel, QualAppFailedJobView.getRawView(Seq(app))) pWriter.write(QualRemovedBLKMgrView.getLabel, QualRemovedBLKMgrView.getRawView(Seq(app))) pWriter.write(QualRemovedExecutorView.getLabel, QualRemovedExecutorView.getRawView(Seq(app))) + // we only need to write the CSV report of the WriteOps + pWriter.writeCSVTable(QualWriteOpsView.getLabel, QualWriteOpsView.getRawView(Seq(app))) } catch { case e: Exception => logError(s"Error generating raw metrics for ${app.appId}: ${e.getMessage}") diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/views/WriteOpsView.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/views/WriteOpsView.scala new file mode 100644 index 000000000..6bea4fdf2 --- /dev/null +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/views/WriteOpsView.scala @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.tool.views + +import scala.collection.breakOut + +import com.nvidia.spark.rapids.tool.analysis.{ProfAppIndexMapperTrait, QualAppIndexMapperTrait} +import com.nvidia.spark.rapids.tool.profiling.ProfileResult + +import org.apache.spark.sql.rapids.tool.AppBase +import org.apache.spark.sql.rapids.tool.store.WriteOperationRecord +import org.apache.spark.sql.rapids.tool.util.StringUtils + +/** + * Represents a single write operation profiling result. + * This case class implements the `ProfileResult` trait and provides methods + * to convert the result into sequences of strings for display or CSV export. + * + * @param appIndex The index of the application this result belongs to. + * @param record The write operation record containing metadata and details. + */ +case class WriteOpProfileResult( + appIndex: Int, + record: WriteOperationRecord) extends ProfileResult { + + /** + * Defines the headers for the output display. + */ + override val outputHeaders: Seq[String] = { + Seq("appIndex", "sqlID", "sqlPlanVersion", "nodeId", "fromFinalPlan", "execName", "format", + "location", "tableName", "dataBase", "outputColumns", "writeMode", "fullDescription") + } + + /** + * Converts the profiling result into a sequence of strings for display. + * Escapes special characters in the description and truncates long strings. + */ + override def convertToSeq: Seq[String] = { + Seq(appIndex.toString, + record.sqlID.toString, + record.version.toString, + record.nodeId.toString, + record.fromFinalPlan.toString, + // Extract metadata information + record.operationMeta.execName(), + record.operationMeta.dataFormat(), + record.operationMeta.outputPath(), + record.operationMeta.table(), + record.operationMeta.dataBase(), + record.operationMeta.outputColumns(), + record.operationMeta.writeMode(), + // Escape special characters in the description + StringUtils.renderStr(record.operationMeta.fullDescr(), doEscapeMetaCharacters = true, + maxLength = 500, showEllipses = true)) + } + + /** + * Converts the profiling result into a sequence of strings formatted for CSV output. + * Escapes special characters and truncates long descriptions to a maximum length. + */ + override def convertToCSVSeq: Seq[String] = { + Seq(appIndex.toString, + record.sqlID.toString, + record.version.toString, + record.nodeId.toString, + record.fromFinalPlan.toString, + // Extract metadata information for CSV + record.operationMeta.execNameCSV, + record.operationMeta.formatCSV, + StringUtils.reformatCSVString(record.operationMeta.outputPath()), + StringUtils.reformatCSVString(record.operationMeta.table()), + StringUtils.reformatCSVString(record.operationMeta.dataBase()), + StringUtils.reformatCSVString(record.operationMeta.outputColumns()), + record.operationMeta.writeMode(), + StringUtils.reformatCSVString( + // Escape special characters in the description and trim at 500 characters. + StringUtils.renderStr(record.operationMeta.fullDescr(), doEscapeMetaCharacters = true, + maxLength = 500, showEllipses = true))) + } +} + +/** + * A trait for creating views of write operation profiling results. + * This trait provides methods to extract raw results, sort them, and label the view. + */ +trait WriteOpsViewTrait extends ViewableTrait[WriteOpProfileResult] { + + /** + * Returns the label for the view. + */ + override def getLabel: String = "Write Operations" + + /** + * Extracts raw write operation records from the given application and maps them + * to `WriteOpProfileResult` instances. + * + * @param app The application containing write operation records. + * @param index The index of the application. + * @return A sequence of `WriteOpProfileResult` instances. + */ + def getRawView(app: AppBase, index: Int): Seq[WriteOpProfileResult] = { + app.getWriteOperationRecords().map { w => + WriteOpProfileResult(index, w) + }(breakOut) + } + + /** + * Sorts the write operation profiling results by application index, SQL ID, + * plan version, and node ID. + * + * @param rows The sequence of profiling results to sort. + * @return A sorted sequence of profiling results. + */ + override def sortView(rows: Seq[WriteOpProfileResult]): Seq[WriteOpProfileResult] = { + rows.sortBy(cols => (cols.appIndex, cols.record.sqlID, cols.record.version, + cols.record.nodeId)) + } +} + +/** + * A view for write operation profiling results specific to Qualification workflows. + * Extends `WriteOpsViewTrait` and implements `QualAppIndexMapperTrait` for customization. + */ +object QualWriteOpsView extends WriteOpsViewTrait with QualAppIndexMapperTrait { + // Placeholder for future customization specific to Qualification workflows. +} + +/** + * A view for write operation profiling results specific to Profiling workflows. + * Extends `WriteOpsViewTrait` and implements `ProfAppIndexMapperTrait` for customization. + */ +object ProfWriteOpsView extends WriteOpsViewTrait with ProfAppIndexMapperTrait { + // Placeholder for future customization specific to Profiling workflows. +} diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala index bcc4d22ac..3e0f3929e 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala @@ -37,8 +37,8 @@ import org.apache.spark.rapids.tool.benchmarks.RuntimeInjector import org.apache.spark.scheduler.{SparkListenerEvent, StageInfo} import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraphNode -import org.apache.spark.sql.rapids.tool.store.{AccumManager, DataSourceRecord, SQLPlanModelManager, StageModel, StageModelManager, TaskModelManager} -import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph, UTF8Source} +import org.apache.spark.sql.rapids.tool.store.{AccumManager, DataSourceRecord, SQLPlanModelManager, StageModel, StageModelManager, TaskModelManager, WriteOperationRecord} +import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, StringUtils, ToolsPlanGraph, UTF8Source} import org.apache.spark.util.Utils abstract class AppBase( @@ -145,6 +145,14 @@ abstract class AppBase( appMetaData.flatMap(_.duration) } + def getWriteOperationRecords(): Iterable[WriteOperationRecord] = { + sqlManager.getWriteOperationRecords() + } + + def getWriteDataFormats(): Set[String] = { + sqlManager.getWriteFormats() + } + // Returns a boolean true/false. This is used to check whether processing an eventlog was // successful. def isAppMetaDefined: Boolean = appMetaData.isDefined @@ -634,16 +642,16 @@ object AppBase { def handleException(e: Exception, path: EventLogInfo): FailureApp = { val (status, message): (String, String) = e match { case incorrectStatusEx: IncorrectAppStatusException => - ("unknown", incorrectStatusEx.getMessage) + (StringUtils.UNKNOWN_EXTRACT, incorrectStatusEx.getMessage) case skippedEx: AppEventlogProcessException => ("skipped", skippedEx.getMessage) case _: com.fasterxml.jackson.core.JsonParseException => - ("unknown", s"Error parsing JSON: ${path.eventLog.toString}") + (StringUtils.UNKNOWN_EXTRACT, s"Error parsing JSON: ${path.eventLog.toString}") case _: IllegalArgumentException => - ("unknown", s"Error parsing file: ${path.eventLog.toString}") + (StringUtils.UNKNOWN_EXTRACT, s"Error parsing file: ${path.eventLog.toString}") case ue: Exception => // catch all exceptions and skip that file - ("unknown", s"Got unexpected exception processing file:" + + (StringUtils.UNKNOWN_EXTRACT, s"Got unexpected exception processing file:" + s"${path.eventLog.toString}. ${ue.getMessage} ") } diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala index 5d0f3074c..6b68c15c5 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.rapids.tool.qualification import java.util.concurrent.TimeUnit.NANOSECONDS import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.collection.mutable import com.nvidia.spark.rapids.tool.{EventLogInfo, Platform} import com.nvidia.spark.rapids.tool.planparser.{ExecInfo, PlanInfo, SQLPlanParser} @@ -47,9 +46,6 @@ class QualificationAppInfo( var lastJobEndTime: Option[Long] = None var lastSQLEndTime: Option[Long] = None - // Keeps track of the WriteDataFormats used in the WriteExecs - // Use LinkedHashSet to preserve Order of insertion and avoid duplicates - val writeDataFormat: mutable.AbstractSet[String] = mutable.LinkedHashSet[String]() val sqlIDToTaskEndSum: HashMap[Long, StageTaskQualificationSummary] = HashMap.empty[Long, StageTaskQualificationSummary] @@ -505,7 +501,7 @@ class QualificationAppInfo( val typeString = types.mkString(":").replace(",", ":") s"${format}[$typeString]" }.toSeq - val writeFormat = writeFormatNotSupported(writeDataFormat) + val writeFormat = writeFormatNotSupported(getWriteDataFormats()) val (allComplexTypes, nestedComplexTypes) = reportComplexTypes val problems = getPotentialProblemsForDf @@ -827,7 +823,7 @@ class QualificationAppInfo( } } - private def writeFormatNotSupported(writeFormat: mutable.AbstractSet[String]): Seq[String] = { + private def writeFormatNotSupported(writeFormat: Set[String]): Seq[String] = { // Filter unsupported write data format val unSupportedWriteFormat = pluginTypeChecker.getUnsupportedWriteFormat(writeFormat) unSupportedWriteFormat.toSeq diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModel.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModel.scala index d6c67ec7b..6fb9dbe42 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModel.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModel.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,6 +78,8 @@ class SQLPlanModel(val id: Long) { // After adding a new version, reset the previous plan if necessary protected def resetPreviousPlan(): Unit = { plan.resetFinalFlag() + // call any cleanup code necessary for the plan + plan.cleanUpPlan() } /** diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModelManager.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModelManager.scala index ff2e01011..398ee76f9 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModelManager.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModelManager.scala @@ -16,8 +16,7 @@ package org.apache.spark.sql.rapids.tool.store -import scala.collection.immutable -import scala.collection.mutable +import scala.collection.{breakOut, immutable, mutable} import org.apache.spark.sql.execution.SparkPlanInfo @@ -126,4 +125,21 @@ class SQLPlanModelManager { def getPlanInfos: immutable.Map[Long, SparkPlanInfo] = { immutable.SortedMap[Long, SparkPlanInfo]() ++ sqlPlans.mapValues(_.planInfo) } + + /** + * Gets all the writeRecords of of the final plan of the SQL + * @return Iterable of WriteOperationRecord representing the write operations. + */ + def getWriteOperationRecords(): Iterable[WriteOperationRecord] = { + sqlPlans.values.flatMap(_.plan.writeRecords) + } + + /** + * Converts the writeOperations into a String set to represent the format of the writeOps. + * This only pulls the information from the final plan of the SQL. + * @return a set of write formats + */ + def getWriteFormats(): Set[String] = { + sqlPlans.values.flatMap(_.plan.getWriteDataFormats)(breakOut) + } } diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModelWithDSCaching.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModelWithDSCaching.scala index 7b21f8a19..b172f9938 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModelWithDSCaching.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanModelWithDSCaching.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,6 +36,8 @@ class SQLPlanModelWithDSCaching(sqlId: Long) extends SQLPlanModel(sqlId) { plan.resetFinalFlag() // cache the datasource records from previous plan if any cachedDataSources ++= plan.getAllReadDS + // call any cleanup code necessary for the plan + plan.cleanUpPlan() } /** diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanVersion.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanVersion.scala index e175a8c82..2f845e42e 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanVersion.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/SQLPlanVersion.scala @@ -16,7 +16,9 @@ package org.apache.spark.sql.rapids.tool.store -import com.nvidia.spark.rapids.tool.planparser.ReadParser +import scala.collection.breakOut + +import com.nvidia.spark.rapids.tool.planparser.{DataWritingCommandExecParser, ReadParser} import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph @@ -44,6 +46,57 @@ class SQLPlanVersion( val physicalPlanDescription: String, var isFinal: Boolean = true) { + // Used to cache the Spark graph for that plan to avoid creating a plan. + // This graph can be used and then cleaned up at the end of the execution. + // This has to be accessed through the getPlanGraph() which synchronizes on this object to avoid + // races between threads. + private var sparkGraph: Option[SparkPlanGraph] = None + + private def getPlanGraph(): SparkPlanGraph = { + this.synchronized { + if (sparkGraph.isEmpty) { + sparkGraph = Some(ToolsPlanGraph(planInfo)) + } + sparkGraph.get + } + } + + /** + * Builds the list of write records for this plan. + * This works by looping on all the nodes and filtering write execs. + * @return the list of write records for this plan if any. + */ + private def initWriteOperationRecords(): Iterable[WriteOperationRecord] = { + getPlanGraph().allNodes + // pick only nodes that are DataWritingCommandExec + .filter(node => DataWritingCommandExecParser.isWritingCmdExec(node.name.stripSuffix("$"))) + .map { n => + // extract the meta data and convert it to store record. + val opMeta = DataWritingCommandExecParser.getWriteOpMetaFromNode(n) + WriteOperationRecord(sqlId, version, n.id, operationMeta = opMeta) + } + } + + // Captures the write operations for this plan. This is lazy because we do not need + // to construct this until we need it. + lazy val writeRecords: Iterable[WriteOperationRecord] = { + initWriteOperationRecords() + } + + // Converts the writeRecords into a write formats. + def getWriteDataFormats: Set[String] = { + writeRecords.map(_.operationMeta.dataFormat())(breakOut) + } + + /** + * Reset any data structure that has been used to free memory. + */ + def cleanUpPlan(): Unit = { + this.synchronized { + sparkGraph = None + } + } + def resetFinalFlag(): Unit = { // This flag depends on the AQE events sequence. // It does not set that field using the substring of the physicalPlanDescription @@ -68,7 +121,7 @@ class SQLPlanVersion( * @return all the read datasources V1 recursively that are read by this plan including. */ def getReadDSV1(planGraph: Option[SparkPlanGraph] = None): Iterable[DataSourceRecord] = { - val graph = planGraph.getOrElse(ToolsPlanGraph(planInfo)) + val graph = planGraph.getOrElse(getPlanGraph()) getPlansWithSchema.flatMap { plan => val meta = plan.metadata // TODO: Improve the extraction of ReaSchema using RegEx (ReadSchema):\s(.*?)(\.\.\.|,\s|$) @@ -103,7 +156,7 @@ class SQLPlanVersion( * @return List of DataSourceRecord for all the V2 DataSources read by this plan. */ def getReadDSV2(planGraph: Option[SparkPlanGraph] = None): Iterable[DataSourceRecord] = { - val graph = planGraph.getOrElse(ToolsPlanGraph(planInfo)) + val graph = planGraph.getOrElse(getPlanGraph()) graph.allNodes.filter(ReadParser.isDataSourceV2Node).map { node => val res = ReadParser.parseReadNode(node) DataSourceRecord( @@ -125,7 +178,7 @@ class SQLPlanVersion( * @return Iterable of DataSourceRecord */ def getAllReadDS: Iterable[DataSourceRecord] = { - val planGraph = Option(ToolsPlanGraph(planInfo)) + val planGraph = Option(getPlanGraph()) getReadDSV1(planGraph) ++ getReadDSV2(planGraph) } } diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/WriteOperationStore.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/WriteOperationStore.scala new file mode 100644 index 000000000..023c274ae --- /dev/null +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/WriteOperationStore.scala @@ -0,0 +1,235 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.tool.store + + +import java.util.concurrent.ConcurrentHashMap + +import scala.util.control.NonFatal + +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.rapids.tool.util.StringUtils + +/** + * Represents a unique reference for a name, with a reformatted CSV value. + * @param value The original name value. + */ +case class UniqueNameRef(value: String) { + // Lazily reformats the name value into a CSV-compatible string. + lazy val csvValue: String = StringUtils.reformatCSVString(value) +} + +/** + * Trait defining metadata for write operations. + * This trait provides default implementations for metadata fields + * related to write operations, which can be overridden by subclasses. + */ +trait WriteOperationMetadataTrait { + def execName(): String = StringUtils.UNKNOWN_EXTRACT // Name of the execution + def dataFormat(): String = StringUtils.UNKNOWN_EXTRACT // Data format (e.g., CSV, Parquet) + def outputPath(): String = StringUtils.UNKNOWN_EXTRACT // Output path for the write operation + def outputColumns(): String = StringUtils.UNKNOWN_EXTRACT // Output columns involved + def writeMode(): String = StringUtils.UNKNOWN_EXTRACT // Save mode (e.g., Overwrite, Append) + def table(): String = StringUtils.UNKNOWN_EXTRACT // Table name (if applicable) + def dataBase(): String = StringUtils.UNKNOWN_EXTRACT // Database name (if applicable) + def fullDescr(): String = "..." // Full description of the operation + def execNameCSV: String // CSV-compatible execution name + def formatCSV: String // CSV-compatible data format +} + +/** + * Metadata implementation for write operations with a specific format. + * @param writeExecName The execution name reference. + * @param format The data format reference. + * @param descr Optional description of the operation. + */ +class WriteOperationMetaWithFormat( + writeExecName: UniqueNameRef, + format: UniqueNameRef, + descr: Option[String]) extends WriteOperationMetadataTrait { + override def dataFormat(): String = format.value + override def fullDescr(): String = descr.getOrElse("") + override def execName(): String = writeExecName.value + override def execNameCSV: String = writeExecName.csvValue + override def formatCSV: String = format.csvValue +} + +/** + * Metadata implementation for write operations with additional details. + * @param writeExecName The execution name reference. + * @param format The data format reference. + * @param outputPathValue Optional output path. + * @param outputColumnsValue Optional output columns. + * @param saveMode Optional save mode. + * @param tableName Table name (if applicable). + * @param dataBaseName Database name (if applicable). + * @param descr Optional description of the operation. + */ +case class WriteOperationMeta( + writeExecName: UniqueNameRef, + format: UniqueNameRef, + outputPathValue: Option[String], + outputColumnsValue: Option[String], + saveMode: Option[SaveMode], + tableName: String, + dataBaseName: String, + descr: Option[String]) extends WriteOperationMetaWithFormat( + writeExecName, format, descr) { + override def writeMode(): String = { + saveMode match { + case Some(w) => w.toString + case _ => StringUtils.UNKNOWN_EXTRACT + } + } + override def outputPath(): String = outputPathValue.getOrElse(StringUtils.UNKNOWN_EXTRACT) + override def outputColumns(): String = outputColumnsValue.getOrElse(StringUtils.UNKNOWN_EXTRACT) + override def table(): String = tableName + override def dataBase(): String = dataBaseName +} + +/** + * Represents a record of a write operation. + * @param sqlID The SQL ID associated with the operation. + * @param version The version of the operation. + * @param nodeId The node ID in the execution plan. + * @param operationMeta Metadata for the write operation. + * @param fromFinalPlan Indicates if the metadata is from the final execution plan. + */ +case class WriteOperationRecord( + sqlID: Long, + version: Int, + nodeId: Long, + operationMeta: WriteOperationMetadataTrait, + fromFinalPlan: Boolean = true) + +/** + * Builder object for creating instances of WriteOperationMetadataTrait. + * Provides utility methods to construct metadata objects with various levels of detail. + */ +object WriteOperationMetaBuilder { + // Default unknown name reference + private val UNKNOWN_NAME_REF = UniqueNameRef(StringUtils.UNKNOWN_EXTRACT) + + // Default unknown metadata + private val UNKNOWN_WRITE_META = + new WriteOperationMetaWithFormat(UNKNOWN_NAME_REF, UNKNOWN_NAME_REF, None) + + // Concurrent hash map to store data format references + private val DATA_FORMAT_TABLE: ConcurrentHashMap[String, UniqueNameRef] = { + val initMap = new ConcurrentHashMap[String, UniqueNameRef]() + initMap.put(StringUtils.UNKNOWN_EXTRACT, UNKNOWN_NAME_REF) + initMap + } + + // Concurrent hash map to store execution name references + private val WRITE_EXEC_TABLE: ConcurrentHashMap[String, UniqueNameRef] = { + val initMap = new ConcurrentHashMap[String, UniqueNameRef]() + initMap.put(StringUtils.UNKNOWN_EXTRACT, UNKNOWN_NAME_REF) + initMap + } + + /** + * Returns a default value if the input string is null or empty. + * @param value The input string. + * @return The default value or the input string. + */ + private def defaultIfUnknown(value: String): String = { + if (value == null || value.isEmpty) StringUtils.UNKNOWN_EXTRACT else value + } + + /** + * Retrieves or creates a UniqueNameRef for the given data format. + * @param name The data format name. + * @return A UniqueNameRef for the data format. + */ + private def getOrCreateFormatRef(name: String): UniqueNameRef = { + DATA_FORMAT_TABLE.computeIfAbsent(defaultIfUnknown(name), k => UniqueNameRef(k)) + } + + /** + * Retrieves or creates a UniqueNameRef for the given execution name. + * @param name The execution name. + * @return A UniqueNameRef for the execution name. + */ + private def getOrCreateExecRef(name: String): UniqueNameRef = { + WRITE_EXEC_TABLE.computeIfAbsent(defaultIfUnknown(name), k => UniqueNameRef(k)) + } + + /** + * Converts a string to a SaveMode, if possible. + * @param name The string representation of the save mode. + * @return An Option containing the SaveMode, or None if conversion fails. + */ + private def getSaveModeFromString(name: String): Option[SaveMode] = { + val str = defaultIfUnknown(name) + try { + Some(SaveMode.valueOf(str)) + } catch { // Failed to convert the string to SaveMode. + case NonFatal(_) => None + } + } + + /** + * Builds a WriteOperationMetadataTrait with detailed metadata. + * @param execName The execution name. + * @param dataFormat The data format. + * @param outputPath Optional output path. + * @param outputColumns Optional output columns. + * @param writeMode The save mode. + * @param tableName The table name. + * @param dataBaseName The database name. + * @param fullDescr Optional full description. + * @return A WriteOperationMetadataTrait instance. + */ + def build(execName: String, dataFormat: String, outputPath: Option[String], + outputColumns: Option[String], + writeMode: String, + tableName: String, + dataBaseName: String, + fullDescr: Option[String]): WriteOperationMetadataTrait = { + WriteOperationMeta(getOrCreateExecRef(execName), getOrCreateFormatRef(dataFormat), + outputPath, outputColumns, getSaveModeFromString(writeMode), + defaultIfUnknown(tableName), defaultIfUnknown(dataBaseName), + fullDescr) + } + + /** + * Builds a WriteOperationMetadataTrait with minimal metadata. + * @param execName The execution name. + * @param dataFormat The data format. + * @param fullDescr Optional full description. + * @return A WriteOperationMetadataTrait instance. + */ + def build(execName: String, dataFormat: String, + fullDescr: Option[String]): WriteOperationMetadataTrait = { + new WriteOperationMetaWithFormat(getOrCreateExecRef(execName), + getOrCreateFormatRef(dataFormat), fullDescr) + } + + /** + * Builds a WriteOperationMetadataTrait with no metadata. + * @param fullDescr Optional full description. + * @return A WriteOperationMetadataTrait instance with unknown metadata. + */ + def buildNoMeta(fullDescr: Option[String]): WriteOperationMetadataTrait = { + if (fullDescr.isDefined) { + new WriteOperationMetaWithFormat(UNKNOWN_NAME_REF, UNKNOWN_NAME_REF, fullDescr) + } else { + UNKNOWN_WRITE_META + } + } +} diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/StringUtils.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/StringUtils.scala index f3c99eb49..277858ed7 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/StringUtils.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/StringUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ import org.apache.spark.internal.Logging * strings. */ object StringUtils extends Logging { + // Constant used to replace the unknown values + val UNKNOWN_EXTRACT: String = "unknown" // Regular expression for duration-format 'H+:MM:SS.FFF' // Note: this is not time-of-day. Hours can be larger than 12. private val regExDurationFormat = "^(\\d+):([0-5]\\d):([0-5]\\d\\.\\d+)$" diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/WriteOperationParserSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/WriteOperationParserSuite.scala new file mode 100644 index 000000000..e87c4644c --- /dev/null +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/WriteOperationParserSuite.scala @@ -0,0 +1,227 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.tool.planparser + +import org.scalatest.FunSuite + +import org.apache.spark.sql.execution.ui +import org.apache.spark.sql.execution.ui.SparkPlanGraphNode +import org.apache.spark.sql.rapids.tool.store.WriteOperationMetadataTrait +import org.apache.spark.sql.rapids.tool.util.StringUtils + + +class WriteOperationParserSuite extends FunSuite { + + /** + * Helper method to test `getWriteOpMetaFromNode`. + * + * @param node The input `SparkPlanGraphNode`. + * @param expectedExecName The expected execution name. + * @param expectedDataFormat The expected data format. + * @param expectedOutputPath The expected output path (optional). + * @param expectedOutputColumns The expected output columns (optional). + * @param expectedWriteMode The expected write mode. + * @param expectedTableName The expected table name. + * @param expectedDatabaseName The expected database name. + */ + private def testGetWriteOpMetaFromNode( + node: SparkPlanGraphNode, + expectedExecName: String, + expectedDataFormat: String, + expectedOutputPath: String, + expectedOutputColumns: String, + expectedWriteMode: String, + expectedTableName: String, + expectedDatabaseName: String): Unit = { + + val metadata: WriteOperationMetadataTrait = + DataWritingCommandExecParser.getWriteOpMetaFromNode(node) + + assert(metadata.execName() == expectedExecName, "execName") + assert(metadata.dataFormat() == expectedDataFormat, "dataFormat") + assert(metadata.outputPath() == expectedOutputPath, "outputPath") + assert(metadata.outputColumns() == expectedOutputColumns, "outputColumns") + assert(metadata.writeMode() == expectedWriteMode, "writeMode") + assert(metadata.table() == expectedTableName, "tableName") + assert(metadata.dataBase() == expectedDatabaseName, "databaseName") + } + // scalastyle:off line.size.limit + test("InsertIntoHadoopFsRelationCommand - Common case") { + val node = new SparkPlanGraphNode( + id = 1, + name = "Execute InsertIntoHadoopFsRelationCommand", + desc = "Execute InsertIntoHadoopFsRelationCommand gs://path/to/database/table1, " + + "false, Parquet, " + + "[serialization.format=1, mergeschema=false, __hive_compatible_bucketed_table_insertion__=true], " + + "Append, `spark_catalog`.`database`.`table`, org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe, " + + "org.apache.spark.sql.execution.datasources.InMemoryFileIndex(gs://path/to/database/table1), " + + "[col01, col02, col03]", + Seq.empty + ) + testGetWriteOpMetaFromNode( + node, + expectedExecName = "InsertIntoHadoopFsRelationCommand", + expectedDataFormat = "Parquet", + expectedOutputPath = "gs://path/to/database/table1", + expectedOutputColumns = "col01;col02;col03", + expectedWriteMode = "Append", + expectedTableName = "table1", + expectedDatabaseName = "database" + ) + } + + test("getWriteOpMetaFromNode - Unknown command") { + val node = new SparkPlanGraphNode( + id = 2, + name = "UnknownWrite", + desc = "Some random description", + Seq.empty + ) + testGetWriteOpMetaFromNode( + node, + expectedExecName = StringUtils.UNKNOWN_EXTRACT, + expectedDataFormat = StringUtils.UNKNOWN_EXTRACT, + expectedOutputPath = StringUtils.UNKNOWN_EXTRACT, + expectedOutputColumns = StringUtils.UNKNOWN_EXTRACT, + expectedWriteMode = StringUtils.UNKNOWN_EXTRACT, + expectedTableName = StringUtils.UNKNOWN_EXTRACT, + expectedDatabaseName = StringUtils.UNKNOWN_EXTRACT + ) + } + + test("AppendDataExecV1 - delta format") { + val node = new SparkPlanGraphNode( + id = 3, + name = "AppendDataExecV1", + // the description should include Delta keywords; otherwise it would be considered a spark Op. + desc = + s"""|AppendDataExecV1 [num_affected_rows#18560L, num_inserted_rows#18561L], DeltaTableV2(org.apache.spark.sql.SparkSession@5aa5327e,abfss://abfs_path,Some(CatalogTable( + |Catalog: spark_catalog + |Database: database + |Table: tableName + |Owner: root + |Created Time: Wed Sep 15 16:47:47 UTC 2021 + |Last Access: UNKNOWN + |Created By: Spark 3.1.1 + |Type: EXTERNAL + |Provider: delta + |Table Properties: [bucketing_version=2, delta.lastCommitTimestamp=1631724453000, delta.lastUpdateVersion=0, delta.minReaderVersion=1, delta.minWriterVersion=2] + |Location: abfss://abfs_path + |Serde Library: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe + |InputFormat: org.apache.hadoop.mapred.SequenceFileInputFormat + |OutputFormat: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat + |Partition Provider: Catalog + |Schema: root + | |-- field_00: string (nullable = true) + | |-- field_01: string (nullable = true) + | |-- field_02: string (nullable = true) + | |-- field_03: string (nullable = true) + | |-- field_04: string (nullable = true) + | |-- field_05: string (nullable = true) + | |-- field_06: string (nullable = true) + | |-- field_07: string (nullable = true) + | |-- field_08: string (nullable = true) + |)),Some(spark_catalog.adl.tableName),None,Map()), Project [from_unixtime(unix_timestamp(current_timestamp(), yyyy-MM-dd HH:mm:ss, Some(Etc/UTC), false), yyyy-MM-dd HH:mm:ss, Some(Etc/UTC)) AS field_00#15200, 20240112 AS field_01#15201, load_func00 AS field_02#15202, completed AS field_03#15203, from_unixtime(unix_timestamp(current_timestamp(), yyyy-MM-dd HH:mm:ss, Some(Etc/UTC), false), yyyy-MM-dd HH:mm:ss, Some(Etc/UTC)) AS field_04#15204, from_unixtime(unix_timestamp(current_timestamp(), yyyy-MM-dd HH:mm:ss, Some(Etc/UTC), false), yyyy-MM-dd HH:mm:ss, Some(Etc/UTC)) AS field_05#15205, ddsdmsp AS field_06#15206, rename_01 AS field_07#15207, from_unixtime(unix_timestamp(current_timestamp(), yyyy-MM-dd HH:mm:ss, Some(Etc/UTC), false), yyyyMMdd, Some(Etc/UTC)) AS field_08#15208], org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy$$$$Lambda$$12118/1719387317@5a111bfa, com.databricks.sql.transaction.tahoe.catalog.WriteIntoDeltaBuilder$$$$anon$$1@24257336 + |""".stripMargin, + Seq.empty + ) + testGetWriteOpMetaFromNode( + node, + expectedExecName = "AppendDataExecV1", + expectedDataFormat = DeltaLakeHelper.getWriteFormat, // Special handling for DeltaLake + expectedOutputPath = StringUtils.UNKNOWN_EXTRACT, + expectedOutputColumns = StringUtils.UNKNOWN_EXTRACT, + expectedWriteMode = StringUtils.UNKNOWN_EXTRACT, + expectedTableName = StringUtils.UNKNOWN_EXTRACT, + expectedDatabaseName = StringUtils.UNKNOWN_EXTRACT + ) + } + + test("InsertIntoHadoopFsRelationCommand - Empty output columns") { + val node = new SparkPlanGraphNode( + id = 5, + name = "Execute InsertIntoHadoopFsRelationCommand", + desc = "Execute InsertIntoHadoopFsRelationCommand gs://path/to/database/table1, " + + "false, Parquet, " + + "[serialization.format=1, mergeschema=false, __hive_compatible_bucketed_table_insertion__=true], " + + "Append, `spark_catalog`.`database`.`table`, org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe, " + + "org.apache.spark.sql.execution.datasources.InMemoryFileIndex(gs://path/to/database/table1), " + + "[]", + Seq.empty + ) + testGetWriteOpMetaFromNode( + node, + expectedExecName = "InsertIntoHadoopFsRelationCommand", + expectedDataFormat = "Parquet", + expectedOutputPath = "gs://path/to/database/table1", + expectedOutputColumns = "", + expectedWriteMode = "Append", + expectedTableName = "table1", + expectedDatabaseName = "database" + ) + } + + test("InsertIntoHadoopFsRelationCommand - Format is 4th element") { + val node = new SparkPlanGraphNode( + id = 5, + name = "Execute InsertIntoHadoopFsRelationCommand", + desc = "Execute InsertIntoHadoopFsRelationCommand gs://path/to/database/table1, " + + "false, [paths=(path)], Parquet, " + + "[serialization.format=1, mergeschema=false, __hive_compatible_bucketed_table_insertion__=true], " + + "Append, `spark_catalog`.`database`.`table`, org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe, " + + "org.apache.spark.sql.execution.datasources.InMemoryFileIndex(gs://path/to/database/table1), " + + "[col01]", + Seq.empty + ) + testGetWriteOpMetaFromNode( + node, + expectedExecName = "InsertIntoHadoopFsRelationCommand", + expectedDataFormat = "Parquet", + expectedOutputPath = "gs://path/to/database/table1", + expectedOutputColumns = "col01", + expectedWriteMode = "Append", + expectedTableName = "table1", + expectedDatabaseName = "database" + ) + } + + test("InsertIntoHadoopFsRelationCommand - Long schema") { + // Long schema will show up as ellipses in the description + val node = new ui.SparkPlanGraphNode( + id = 5, + name = "Execute InsertIntoHadoopFsRelationCommand", + desc = "Execute InsertIntoHadoopFsRelationCommand gs://path/to/database/table1, " + + "false, [paths=(path)], Parquet, " + + "[serialization.format=1, mergeschema=false, __hive_compatible_bucketed_table_insertion__=true], " + + "Append, spark_catalog`.`database`.`table`, org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe, " + + "org.apache.spark.sql.execution.datasources.InMemoryFileIndex(gs://path/to/database/table1), " + + "[col01, col02, col03, ... 4 more fields]", + Seq.empty + ) + testGetWriteOpMetaFromNode( + node, + expectedExecName = "InsertIntoHadoopFsRelationCommand", + expectedDataFormat = "Parquet", + expectedOutputPath = "gs://path/to/database/table1", + expectedOutputColumns = "col01;col02;col03;... 4 more fields", + expectedWriteMode = "Append", + expectedTableName = "table1", + expectedDatabaseName = "database" + ) + } + // scalastyle:on line.size.limit +}