Skip to content

Commit

Permalink
Add TraceBlackboxEvaluator
Browse files Browse the repository at this point in the history
This patch adds TraceBlackboxEvaluator, an evaluator designed for trace
based cost modelling. It implements the BlackboxEvaluator class, special
casing everything that is needed.

Reviewers: mtrofin

Reviewed By: mtrofin

Pull Request: #419
  • Loading branch information
boomanaiden154 authored Jan 28, 2025
1 parent 96614ea commit 3a4a297
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 3 deletions.
64 changes: 62 additions & 2 deletions compiler_opt/es/blackbox_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_results(
raise NotImplementedError()

@abc.abstractmethod
def set_baseline(self) -> None:
def set_baseline(self, pool: FixedWorkerPool) -> None:
raise NotImplementedError()

def get_rewards(
Expand Down Expand Up @@ -101,5 +101,65 @@ def get_results(

return futures

def set_baseline(self) -> None:
def set_baseline(self, pool: FixedWorkerPool) -> None:
del pool # Unused.
pass


@gin.configurable
class TraceBlackboxEvaluator(BlackboxEvaluator):
"""A blackbox evaluator that utilizes trace based cost modelling."""

def __init__(self, train_corpus: corpus.Corpus,
est_type: blackbox_optimizers.EstimatorType, bb_trace_path: str,
function_index_path: str):
self._train_corpus = train_corpus
self._est_type = est_type
self._bb_trace_path = bb_trace_path
self._function_index_path = function_index_path

self._baseline: Optional[float] = None

def get_results(
self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy]
) -> List[concurrent.futures.Future]:
job_args = []
for perturbation in perturbations:
job_args.append({
'modules': self._train_corpus.module_specs,
'function_index_path': self._function_index_path,
'bb_trace_path': self._bb_trace_path,
'tflite_policy': perturbation
})

_, futures = buffered_scheduler.schedule_on_worker_pool(
action=lambda w, args: w.compile_corpus_and_evaluate(**args),
jobs=job_args,
worker_pool=pool)
concurrent.futures.wait(
futures, return_when=concurrent.futures.ALL_COMPLETED)
return futures

def set_baseline(self, pool: FixedWorkerPool) -> None:
if self._baseline is not None:
raise RuntimeError('The baseline has already been set.')

job_args = [{
'modules': self._train_corpus.module_specs,
'function_index_path': self._function_index_path,
'bb_trace_path': self._bb_trace_path,
'tflite_policy': None,
}]

_, futures = buffered_scheduler.schedule_on_worker_pool(
action=lambda w, args: w.compile_corpus_and_evaluate(**args),
jobs=job_args,
worker_pool=pool)

concurrent.futures.wait(
futures, return_when=concurrent.futures.ALL_COMPLETED)
if len(futures) != 1:
raise ValueError('Expected to have one result for setting the baseline,'
f' got {len(futures)}')

self._baseline = futures[0].result()
25 changes: 25 additions & 0 deletions compiler_opt/es/blackbox_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,28 @@ def test_get_rewards(self):
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None)
rewards = evaluator.get_rewards(results)
self.assertEqual(rewards, [None, 2])

def test_trace_get_results(self):
with local_worker_manager.LocalWorkerPoolManager(
blackbox_test_utils.ESTraceWorker, count=3, arg='', kwarg='') as pool:
perturbations = [b'00', b'01', b'10']
test_corpus = corpus.create_corpus_for_testing(
location=self.create_tempdir(),
elements=[corpus.ModuleSpec(name='name1', size=1)])
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
test_corpus, 5, 'fake_bb_trace_path', 'fake_function_index_path')
results = evaluator.get_results(pool, perturbations)
self.assertSequenceAlmostEqual([result.result() for result in results],
[1.0, 1.0, 1.0])

def test_trace_set_baseline(self):
with local_worker_manager.LocalWorkerPoolManager(
blackbox_test_utils.ESTraceWorker, count=1, arg='', kwarg='') as pool:
test_corpus = corpus.create_corpus_for_testing(
location=self.create_tempdir(),
elements=[corpus.ModuleSpec(name='name1', size=1)])
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
test_corpus, 5, 'fake_bb_trace_path', 'fake_function_index_path')
evaluator.set_baseline(pool)
# pylint: disable=protected-access
self.assertAlmostEqual(evaluator._baseline, 10)
25 changes: 24 additions & 1 deletion compiler_opt/es/blackbox_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Test facilities for Blackbox classes."""

from typing import List
from typing import List, Collection, Optional

import gin

Expand All @@ -41,3 +41,26 @@ def compile(self, policy: policy_saver.Policy,
return self.function_value
else:
return 0.0


class ESTraceWorker(worker.Worker):
"""Temporary placeholder worker.
This is a test worker for TraceBlackboxEvaluator that expects a slightly
different interface than other workers.
"""

def __init__(self, arg, *, kwarg):
del arg # Unused.
del kwarg # Unused.
self._function_value = 0.0

def compile_corpus_and_evaluate(
self, modules: Collection[corpus.ModuleSpec], function_index_path: str,
bb_trace_path: str,
tflite_policy: Optional[policy_saver.Policy]) -> float:
if modules and function_index_path and bb_trace_path and tflite_policy:
self._function_value += 1
return self._function_value
else:
return 10

0 comments on commit 3a4a297

Please sign in to comment.