From 3a4a297e44b9e698b2ee0a5ce97c9df87d25bc9d Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Tue, 28 Jan 2025 12:08:23 -0800 Subject: [PATCH] Add TraceBlackboxEvaluator This patch adds TraceBlackboxEvaluator, an evaluator designed for trace based cost modelling. It implements the BlackboxEvaluator class, special casing everything that is needed. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: https://github.com/google/ml-compiler-opt/pull/419 --- compiler_opt/es/blackbox_evaluator.py | 64 +++++++++++++++++++++- compiler_opt/es/blackbox_evaluator_test.py | 25 +++++++++ compiler_opt/es/blackbox_test_utils.py | 25 ++++++++- 3 files changed, 111 insertions(+), 3 deletions(-) diff --git a/compiler_opt/es/blackbox_evaluator.py b/compiler_opt/es/blackbox_evaluator.py index 6a6cf12c..4e55cc77 100644 --- a/compiler_opt/es/blackbox_evaluator.py +++ b/compiler_opt/es/blackbox_evaluator.py @@ -42,7 +42,7 @@ def get_results( raise NotImplementedError() @abc.abstractmethod - def set_baseline(self) -> None: + def set_baseline(self, pool: FixedWorkerPool) -> None: raise NotImplementedError() def get_rewards( @@ -101,5 +101,65 @@ def get_results( return futures - def set_baseline(self) -> None: + def set_baseline(self, pool: FixedWorkerPool) -> None: + del pool # Unused. pass + + +@gin.configurable +class TraceBlackboxEvaluator(BlackboxEvaluator): + """A blackbox evaluator that utilizes trace based cost modelling.""" + + def __init__(self, train_corpus: corpus.Corpus, + est_type: blackbox_optimizers.EstimatorType, bb_trace_path: str, + function_index_path: str): + self._train_corpus = train_corpus + self._est_type = est_type + self._bb_trace_path = bb_trace_path + self._function_index_path = function_index_path + + self._baseline: Optional[float] = None + + def get_results( + self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy] + ) -> List[concurrent.futures.Future]: + job_args = [] + for perturbation in perturbations: + job_args.append({ + 'modules': self._train_corpus.module_specs, + 'function_index_path': self._function_index_path, + 'bb_trace_path': self._bb_trace_path, + 'tflite_policy': perturbation + }) + + _, futures = buffered_scheduler.schedule_on_worker_pool( + action=lambda w, args: w.compile_corpus_and_evaluate(**args), + jobs=job_args, + worker_pool=pool) + concurrent.futures.wait( + futures, return_when=concurrent.futures.ALL_COMPLETED) + return futures + + def set_baseline(self, pool: FixedWorkerPool) -> None: + if self._baseline is not None: + raise RuntimeError('The baseline has already been set.') + + job_args = [{ + 'modules': self._train_corpus.module_specs, + 'function_index_path': self._function_index_path, + 'bb_trace_path': self._bb_trace_path, + 'tflite_policy': None, + }] + + _, futures = buffered_scheduler.schedule_on_worker_pool( + action=lambda w, args: w.compile_corpus_and_evaluate(**args), + jobs=job_args, + worker_pool=pool) + + concurrent.futures.wait( + futures, return_when=concurrent.futures.ALL_COMPLETED) + if len(futures) != 1: + raise ValueError('Expected to have one result for setting the baseline,' + f' got {len(futures)}') + + self._baseline = futures[0].result() diff --git a/compiler_opt/es/blackbox_evaluator_test.py b/compiler_opt/es/blackbox_evaluator_test.py index ab188892..8eca1bf0 100644 --- a/compiler_opt/es/blackbox_evaluator_test.py +++ b/compiler_opt/es/blackbox_evaluator_test.py @@ -50,3 +50,28 @@ def test_get_rewards(self): evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None) rewards = evaluator.get_rewards(results) self.assertEqual(rewards, [None, 2]) + + def test_trace_get_results(self): + with local_worker_manager.LocalWorkerPoolManager( + blackbox_test_utils.ESTraceWorker, count=3, arg='', kwarg='') as pool: + perturbations = [b'00', b'01', b'10'] + test_corpus = corpus.create_corpus_for_testing( + location=self.create_tempdir(), + elements=[corpus.ModuleSpec(name='name1', size=1)]) + evaluator = blackbox_evaluator.TraceBlackboxEvaluator( + test_corpus, 5, 'fake_bb_trace_path', 'fake_function_index_path') + results = evaluator.get_results(pool, perturbations) + self.assertSequenceAlmostEqual([result.result() for result in results], + [1.0, 1.0, 1.0]) + + def test_trace_set_baseline(self): + with local_worker_manager.LocalWorkerPoolManager( + blackbox_test_utils.ESTraceWorker, count=1, arg='', kwarg='') as pool: + test_corpus = corpus.create_corpus_for_testing( + location=self.create_tempdir(), + elements=[corpus.ModuleSpec(name='name1', size=1)]) + evaluator = blackbox_evaluator.TraceBlackboxEvaluator( + test_corpus, 5, 'fake_bb_trace_path', 'fake_function_index_path') + evaluator.set_baseline(pool) + # pylint: disable=protected-access + self.assertAlmostEqual(evaluator._baseline, 10) diff --git a/compiler_opt/es/blackbox_test_utils.py b/compiler_opt/es/blackbox_test_utils.py index a3ee957d..3a0a2ca4 100644 --- a/compiler_opt/es/blackbox_test_utils.py +++ b/compiler_opt/es/blackbox_test_utils.py @@ -14,7 +14,7 @@ # limitations under the License. """Test facilities for Blackbox classes.""" -from typing import List +from typing import List, Collection, Optional import gin @@ -41,3 +41,26 @@ def compile(self, policy: policy_saver.Policy, return self.function_value else: return 0.0 + + +class ESTraceWorker(worker.Worker): + """Temporary placeholder worker. + + This is a test worker for TraceBlackboxEvaluator that expects a slightly + different interface than other workers. + """ + + def __init__(self, arg, *, kwarg): + del arg # Unused. + del kwarg # Unused. + self._function_value = 0.0 + + def compile_corpus_and_evaluate( + self, modules: Collection[corpus.ModuleSpec], function_index_path: str, + bb_trace_path: str, + tflite_policy: Optional[policy_saver.Policy]) -> float: + if modules and function_index_path and bb_trace_path and tflite_policy: + self._function_value += 1 + return self._function_value + else: + return 10