Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate task metric aggregates on-the-fly to reduce memory usage #1543

Merged
merged 16 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster,
import org.apache.spark.sql.rapids.tool.{AppBase, RDDCheckHelper, SqlPlanInfoGraphBuffer, SqlPlanInfoGraphEntry}
import org.apache.spark.sql.rapids.tool.profiling.ApplicationInfo
import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo
import org.apache.spark.sql.rapids.tool.store.{AccumInfo, DataSourceRecord}
import org.apache.spark.sql.rapids.tool.store.DataSourceRecord
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph

/**
Expand Down Expand Up @@ -108,32 +108,6 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
IODiagnosticMetricsMap(key) += accum
}

/**
* Retrieves the task IDs associated with a specific stage.
*
* @param stageId The ID of the stage.
* @return A seq of task IDs corresponding to the given stage ID.
*/
private def getStageTaskIds(stageId: Int): Seq[Long] = {
app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId)(breakOut).distinct
}

/**
* Retrieves task update values from the accumulator info for the specified stage ID.
*
* @param accumInfo AccumInfo object containing the task updates map.
* @param stageId The stage ID for which task updates need to be retrived.
* @return An array of task update values (`Long`) corresponding to the tasks
* in the specified stage.
*/
private def filterAccumTaskUpdatesForStage(accumInfo: AccumInfo, stageTaskIds: Seq[Long])
: Array[Long] = {
stageTaskIds.collect {
case taskId if accumInfo.taskUpdatesMap.contains(taskId) =>
accumInfo.taskUpdatesMap(taskId)
}(breakOut)
}

/**
* Connects Operators to Stages using AccumulatorIDs.
* TODO: This function can be fused in the visitNode function to avoid the extra iteration.
Expand Down Expand Up @@ -406,7 +380,6 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
// TODO: currently if stage IDs is empty, the result is skipped
val stageIds = sqlAccums.head.stageIds
stageIds.flatMap { stageId =>
val stageTaskIds = getStageTaskIds(stageId)
val nodeName = sqlAccums.head.nodeName

// Initialize a map to store statistics for each IO metric
Expand All @@ -419,15 +392,19 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
val accumInfoOpt = app.accumManager.accumInfoMap.get(sqlAccum.accumulatorId)

val metricStats: Option[StatisticsMetrics] = accumInfoOpt.flatMap { accumInfo =>
if (!accumInfo.stageValuesMap.contains(stageId)) {
if (!accumInfo.containsStage(stageId)) {
None
} else if (stageIds.size == 1) {
// Skip computing statistics when there is only one stage
Some(StatisticsMetrics(sqlAccum.min, sqlAccum.median, sqlAccum.max, sqlAccum.total))
Some(StatisticsMetrics(
min = sqlAccum.min,
med = sqlAccum.median,
max = sqlAccum.max,
count = 0,
total = sqlAccum.total))
} else {
// Retrieve task updates which correspond to the current stage
val filteredTaskUpdates = filterAccumTaskUpdatesForStage(accumInfo, stageTaskIds)
StatisticsMetrics.createOptionalFromArr(filteredTaskUpdates)
accumInfo.calculateAccStatsForStage(stageId)
}
}

Expand Down Expand Up @@ -476,13 +453,9 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
def generateStageLevelAccums(): Seq[AccumProfileResults] = {
app.accumManager.accumInfoMap.flatMap { accumMapEntry =>
val accumInfo = accumMapEntry._2
accumInfo.stageValuesMap.keys.flatMap( stageId => {
// Retrieve task updates correspond to the current stage
val filteredTaskUpdates =
filterAccumTaskUpdatesForStage(accumInfo, getStageTaskIds(stageId))

accumInfo.getStageIds.flatMap( stageId => {
// Get the task updates that belong to that stage
StatisticsMetrics.createOptionalFromArr(filteredTaskUpdates) match {
accumInfo.calculateAccStatsForStage(stageId) match {
case Some(stat) =>
// Reuse AccumProfileResults to avoid generating allocating new objects
val accumProfileResults = AccumProfileResults(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,14 +399,14 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {

val accumHelperObj = if (app.isPhoton) { // If this a photon app, use the photonHelper
// For max peak memory, we need to look at the accumulators at the task level.
val peakMemoryValues = tasksInStage.flatMap { taskModel =>
photonPeakMemoryAccumInfos.flatMap { accumInfo =>
accumInfo.taskUpdatesMap.get(taskModel.taskId)
}
// We leverage the stage level metrics and get the max task update from it
val peakMemoryValues = photonPeakMemoryAccumInfos.flatMap { accumInfo =>
accumInfo.getMaxForStage(sm.stageInfo.stageId)
}
// For sum of shuffle write time, we need to look at the accumulators at the stage level.
// We get the values associated with all tasks for a stage
val shuffleWriteValues = photonShuffleWriteTimeAccumInfos.flatMap { accumInfo =>
accumInfo.stageValuesMap.get(sm.stageInfo.stageId)
accumInfo.getTotalForStage(sm.stageInfo.stageId)
}
new AggAccumPhotonHelper(shuffleWriteValues, peakMemoryValues)
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,12 +19,12 @@ package com.nvidia.spark.rapids.tool.analysis
import org.apache.spark.sql.rapids.tool.util.InPlaceMedianArrView.{chooseMidpointPivotInPlace, findMedianInPlace}

// Store (min, median, max, total) for a given metric
case class StatisticsMetrics(min: Long, med: Long, max: Long, total: Long)
case class StatisticsMetrics(min: Long, med: Long, max: Long, count: Long, total: Long)

object StatisticsMetrics {
// a static variable used to represent zero-statistics instead of allocating a dummy record
// on every calculation.
val ZERO_RECORD: StatisticsMetrics = StatisticsMetrics(0L, 0L, 0L, 0L)
val ZERO_RECORD: StatisticsMetrics = StatisticsMetrics(0L, 0L, 0L, 0L, 0L)

def createFromArr(arr: Array[Long]): StatisticsMetrics = {
if (arr.isEmpty) {
Expand All @@ -43,7 +43,7 @@ object StatisticsMetrics {
}
totalV += v
}
StatisticsMetrics(minV, medV, maxV, totalV)
StatisticsMetrics(minV, medV, maxV, arr.length, totalV)
}

def createOptionalFromArr(arr: Array[Long]): Option[StatisticsMetrics] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,21 +276,15 @@ object GenerateTimeline {

val semMetricsNs = semWaitIds.toList
.flatMap(app.accumManager.accumInfoMap.get)
.flatMap(_.taskUpdatesMap.values).sum
.map(_.getTotalAcrossStages).sum

val semMetricsMs = app.accumManager.accumInfoMap.flatMap {
case (_, accumInfo: AccumInfo)
if accumInfo.infoRef.name == AccumNameRef.NAMES_TABLE.get("gpuSemaphoreWait") =>
Some(accumInfo.taskUpdatesMap.values.sum)
Some(accumInfo.getTotalAcrossStages)
case _ => None
}.sum

val readMetrics = readTimeIds.toList.flatMap(app.accumManager.accumInfoMap.get)

val opMetrics = opTimeIds.toList.flatMap(app.accumManager.accumInfoMap.get)

val writeMetrics = writeTimeIds.toList.flatMap(app.accumManager.accumInfoMap.get)

app.taskManager.getAllTasks().foreach { tc =>
val host = tc.host
val execId = tc.executorId
Expand All @@ -300,11 +294,9 @@ object GenerateTimeline {
val finishTime = tc.finishTime
val duration = tc.duration
val semTimeMs = ( semMetricsNs / 1000000) + semMetricsMs
val readTimeMs = readMetrics.flatMap(_.taskUpdatesMap.get(taskId)).sum / 1000000 +
tc.sr_fetchWaitTime
val opTimeMs = opMetrics.flatMap(_.taskUpdatesMap.get(taskId)).sum / 1000000
val writeTimeMs = writeMetrics.flatMap(_.taskUpdatesMap.get(taskId)).sum / 1000000 +
tc.sw_writeTime / 1000000
val readTimeMs = tc.sr_fetchWaitTime
val opTimeMs = 0L
val writeTimeMs = tc.sw_writeTime / 1000000
val taskInfo = new TimelineTaskInfo(stageId, taskId, launchTime, finishTime, duration,
tc.executorDeserializeTime, readTimeMs, semTimeMs, opTimeMs, writeTimeMs)
val execHost = s"$execId/$host"
Expand Down
Loading