Skip to content

Commit

Permalink
[𝘀𝗽𝗿] initial version
Browse files Browse the repository at this point in the history
Created using spr 1.3.4
  • Loading branch information
boomanaiden154 committed Feb 9, 2025
2 parents a8559c1 + f8e7509 commit 0cc3969
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 44 deletions.
40 changes: 20 additions & 20 deletions compiler_opt/es/es_trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Local ES trainer."""

from absl import flags, logging
import enum
import functools
import gin
import tensorflow as tf
Expand All @@ -32,22 +33,12 @@

FLAGS = flags.FLAGS

_BETA1 = flags.DEFINE_float("beta1", 0.9,
"Beta1 for ADAM gradient ascent optimizer.")
_BETA2 = flags.DEFINE_float("beta2", 0.999,
"Beta2 for ADAM gradient ascent optimizer.")
_GRAD_REG_ALPHA = flags.DEFINE_float(
"grad_reg_alpha", 0.01,
"Weight of regularization term in regression gradient.")
_GRAD_REG_TYPE = flags.DEFINE_string(
"grad_reg_type", "ridge",
"Regularization method to use with regression gradient.")
_GRADIENT_ASCENT_OPTIMIZER_TYPE = flags.DEFINE_string(
"gradient_ascent_optimizer_type", None,
"Gradient ascent optimization algorithm: 'momentum' or 'adam'")
flags.mark_flag_as_required("gradient_ascent_optimizer_type")
_MOMENTUM = flags.DEFINE_float(
"momentum", 0.0, "Momentum for momentum gradient ascent optimizer.")
_OUTPUT_PATH = flags.DEFINE_string("output_path", "",
"Path to write all output")
_PRETRAINED_POLICY_PATH = flags.DEFINE_string(
Expand All @@ -61,11 +52,22 @@
"List of paths to training corpora")


@gin.constants_from_enum(module='es_trainer_lib')
class GradientAscentOptimizerType(enum.Enum):
MOMENTUM = 1
ADAM = 2


@gin.configurable
def train(additional_compilation_flags=(),
delete_compilation_flags=(),
replace_compilation_flags=(),
worker_class=None):
worker_class=None,
beta1=0.9,
beta2=0.999,
momentum=0.0,
gradient_ascent_optimizer_type=GradientAscentOptimizerType.ADAM,
worker_manager_class=local_worker_manager.LocalWorkerPoolManager):
"""Train with ES."""

if not _TRAIN_CORPORA.value:
Expand Down Expand Up @@ -131,21 +133,20 @@ def train(additional_compilation_flags=(),
# TODO(linzinan): delete all unused parameters.

# ------------------ GRADIENT ASCENT OPTIMIZERS ------------------------------
if _GRADIENT_ASCENT_OPTIMIZER_TYPE.value == "momentum":
if (gradient_ascent_optimizer_type == GradientAscentOptimizerType.MOMENTUM):
logging.info("Running momentum gradient ascent optimizer")
# You can obtain a vanilla gradient ascent optimizer by setting momentum=0.0
# and setting step_size to the desired learning rate.
gradient_ascent_optimizer = (
gradient_ascent_optimization_algorithms.MomentumOptimizer(
learner_config.step_size, _MOMENTUM.value))
elif _GRADIENT_ASCENT_OPTIMIZER_TYPE.value == "adam":
learner_config.step_size, momentum))
elif (gradient_ascent_optimizer_type == GradientAscentOptimizerType.ADAM):
logging.info("Running Adam gradient ascent optimizer")
gradient_ascent_optimizer = (
gradient_ascent_optimization_algorithms.AdamOptimizer(
learner_config.step_size, _BETA1.value, _BETA2.value))
learner_config.step_size, beta1, beta2))
else:
logging.info("No gradient ascent \
optimizer selected. Stopping.")
logging.info("No gradient ascent optimizer selected. Stopping.")
return
# ----------------------------------------------------------------------------

Expand Down Expand Up @@ -215,9 +216,8 @@ def train(additional_compilation_flags=(),
logging.info("Ready to train: running for %d steps.",
learner_config.total_steps)

with local_worker_manager.LocalWorkerPoolManager(
worker_class, learner_config.total_num_perturbations, arg="",
kwarg="") as pool:
with worker_manager_class(worker_class,
learner_config.total_num_perturbations) as pool:
for _ in range(learner_config.total_steps):
learner.run_step(pool)

Expand Down
48 changes: 24 additions & 24 deletions compiler_opt/es/regalloc_trace/regalloc_trace_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,30 +110,30 @@ def _build_corpus(self, modules: Collection[corpus.ModuleSpec],
else:
tflite_policy_dir = None

with concurrent.futures.ThreadPoolExecutor(
max_workers=self._thread_count) as thread_pool:
compile_futures = [
thread_pool.submit(self._compile_module, module, output_directory,
tflite_policy_dir) for module in modules
]

for future in compile_futures:
if future.exception() is not None:
raise future.exception()

# Write out a corpus description. basic_block_trace_model uses a corpus
# description JSON to know which object files to load, so we need to emit
# one before performing evaluation.
corpus_description_path = os.path.join(output_directory,
"corpus_description.json")
corpus_description = {
"modules": [module_spec.name for module_spec in modules]
}

with open(
corpus_description_path, "w",
encoding="utf-8") as corpus_description_file:
json.dump(corpus_description, corpus_description_file)
with concurrent.futures.ThreadPoolExecutor(
max_workers=self._thread_count) as thread_pool:
compile_futures = [
thread_pool.submit(self._compile_module, module, output_directory,
tflite_policy_dir) for module in modules
]

for future in compile_futures:
if future.exception() is not None:
raise future.exception()

# Write out a corpus description. basic_block_trace_model uses a corpus
# description JSON to know which object files to load, so we need to emit
# one before performing evaluation.
corpus_description_path = os.path.join(output_directory,
"corpus_description.json")
corpus_description = {
"modules": [module_spec.name for module_spec in modules]
}

with open(
corpus_description_path, "w",
encoding="utf-8") as corpus_description_file:
json.dump(corpus_description, corpus_description_file)

def _evaluate_corpus(self, module_directory: str, function_index_path: str,
bb_trace_path: str):
Expand Down

0 comments on commit 0cc3969

Please sign in to comment.