Skip to content

Commit

Permalink
update the code comments
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
  • Loading branch information
bangtianliu committed Jan 17, 2025
1 parent f188afe commit 859cfc6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
15 changes: 10 additions & 5 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,12 +881,15 @@ def calculate_speedup(
Speedup is defined as the ratio of the candidate's runtime to the average baseline time
for the corresponding device as:
speedup = candidate_runtime / avg_baseline_time
speedup = candidate_runtime / avg_baseline_time (or fallback_baseline)
If no valid baseline times are available, the candidate'sruntime is used directly as:
If no valid baseline times are available for a specific device, the fallback baseline is used.
The fallback baseline is calculated as the average of all valid baseline times across devices.
speedup = candidate_runtime
If no valid baseline times are available across all devices, the candidate's runtime is
used directly as:
speedup = candidate_runtime
The speedup values are sorted in ascending order to select the top-performing candidates.
"""
if not self.is_valid():
Expand Down Expand Up @@ -915,7 +918,7 @@ def calculate_speedup(
speedup_by_candidate[candidate.candidate_id] = candidate.time / baseline_avg
return speedup_by_candidate

def get_top_candidates(
def sort_candidates_with_speedup(
self,
speedup_by_candidate: dict[int, float],
) -> list[tuple[int, float]]:
Expand Down Expand Up @@ -1054,7 +1057,9 @@ def benchmark(
logging.warning("Baseline run failed.")

speedup_result = baseline_handler.calculate_speedup(candidate_results)
all_candidates_with_speedup = baseline_handler.get_top_candidates(speedup_result)
all_candidates_with_speedup = baseline_handler.sort_candidates_with_speedup(
speedup_result
)
top_candidates_with_speedup = all_candidates_with_speedup[:num_candidates]

if baseline_handler.is_valid():
Expand Down
7 changes: 3 additions & 4 deletions tuner/tuner/libtuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_baseline_result_handler_speedup():
4: 0.2 / 0.875,
}

all_candidates_with_speedup = handler.get_top_candidates(speedup)
all_candidates_with_speedup = handler.sort_candidates_with_speedup(speedup)
assert all_candidates_with_speedup == [
(4, 0.2 / 0.875),
(1, 0.4 / 0.9),
Expand All @@ -290,8 +290,7 @@ def test_baseline_result_handler_speedup():
7: 0.8 / 1.2,
}

all_candidates_with_speedup = handler.get_top_candidates(speedup)
print(all_candidates_with_speedup)
all_candidates_with_speedup = handler.sort_candidates_with_speedup(speedup)
assert all_candidates_with_speedup == [
(5, 0.6 / 0.9),
(7, 0.8 / 1.2),
Expand All @@ -307,7 +306,7 @@ def test_baseline_result_handler_speedup():
6: 0.4,
7: 0.8,
}
all_candidates_with_speedup = handler.get_top_candidates(speedup)
all_candidates_with_speedup = handler.sort_candidates_with_speedup(speedup)
assert all_candidates_with_speedup == [
(6, 0.4),
(5, 0.6),
Expand Down

0 comments on commit 859cfc6

Please sign in to comment.