diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index 91060f80..6c6498dd 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -30,7 +30,7 @@ import tensorflow as tf _COMPILATION_TIMEOUT = flags.DEFINE_integer( - 'compilation_timeout', 60, + 'compilation_timeout', 120, 'Max duration (in seconds) after which we cancel any compilation job.') _QUIET = flags.DEFINE_bool( 'quiet', True, 'Whether or not to compile quietly (hiding info logging)') diff --git a/compiler_opt/rl/registry.py b/compiler_opt/rl/registry.py index 85f8097a..a04d9275 100644 --- a/compiler_opt/rl/registry.py +++ b/compiler_opt/rl/registry.py @@ -33,6 +33,7 @@ # to trigger gin registration. import compiler_opt.rl.inlining # pylint: disable=unused-import import compiler_opt.rl.regalloc # pylint: disable=unused-import +import compiler_opt.rl.unroll # pylint: disable=unused-import types = tfa.typing.types diff --git a/compiler_opt/rl/unroll/__init__.py b/compiler_opt/rl/unroll/__init__.py new file mode 100644 index 00000000..b292fb9f --- /dev/null +++ b/compiler_opt/rl/unroll/__init__.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of the 'loop unroll' problem.""" + +import gin + +from compiler_opt.rl import problem_configuration +from compiler_opt.rl.unroll import config +from compiler_opt.rl.unroll import unroll_runner + + +@gin.register(module='configs') +class LoopUnrollConfig(problem_configuration.ProblemConfiguration): + """Expose the regalloc eviction components.""" + + def get_runner_type(self): + return unroll_runner.LoopUnrollRunner + + def get_signature_spec(self): + return config.get_unroll_signature_spec() + + def get_preprocessing_layer_creator(self): + return config.get_observation_processing_layer_creator() + + def get_nonnormalized_features(self): + return config.get_nonnormalized_features() diff --git a/compiler_opt/rl/unroll/config.py b/compiler_opt/rl/unroll/config.py new file mode 100644 index 00000000..203f4eaf --- /dev/null +++ b/compiler_opt/rl/unroll/config.py @@ -0,0 +1,76 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Loop unroll training config.""" + +import gin +import tensorflow as tf +from tf_agents.specs import tensor_spec +from tf_agents.trajectories import time_step +from compiler_opt.rl import feature_ops + + +# pylint: disable=g-complex-comprehension +@gin.configurable() +def get_unroll_signature_spec(): + """Returns (time_step_spec, action_spec) for LLVM loop unroll.""" + # LINT.IfChange + observation_spec = dict( + (key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key)) + for key in ('loop_size', 'trip_count', 'is_innermost_loop', + 'preheader_blocksize', 'bb_count', 'num_of_loop_latch', + 'load_inst_count', 'store_inst_count', 'logical_inst_count', + 'cast_inst_count')) + reward_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='reward') + time_step_spec = time_step.time_step_spec(observation_spec, reward_spec) + action_spec = tensor_spec.BoundedTensorSpec( + dtype=tf.int64, shape=(), name='unroll_count') + + return time_step_spec, action_spec + + +@gin.configurable +def get_observation_processing_layer_creator(quantile_file_dir=None, + with_sqrt=True, + with_z_score_normalization=True, + eps=1e-8): + """Wrapper for observation_processing_layer.""" + quantile_map = feature_ops.build_quantile_map(quantile_file_dir) + + def observation_processing_layer(obs_spec): + """Creates the layer to process observation given obs_spec.""" + + # I guess we discard rewards when observation? + if obs_spec.name in ('icache_pressure', 'latency'): + return tf.keras.layers.Lambda(feature_ops.discard_fn) + + # for boolean features, use feature_ops.identity_fn + if obs_spec.name in ('is_innermost_loop'): + return tf.keras.layers.Lambda(feature_ops.identity_fn) + + # Do we need to define some layer here to normalize 'loop_size' + # and instruction count features (e.g. 'load_inst_count'). + # Bigger loops expect more instruction counts, and we need to + # normalize this? + + quantile = quantile_map[obs_spec.name] + return tf.keras.layers.Lambda( + feature_ops.get_normalize_fn(quantile, with_sqrt, + with_z_score_normalization, eps)) + + return observation_processing_layer + + +def get_nonnormalized_features(): + return ['reward', 'is_innermost_loop'] diff --git a/compiler_opt/rl/unroll/gin_configs/behavioral_cloning_nn_agent.gin b/compiler_opt/rl/unroll/gin_configs/behavioral_cloning_nn_agent.gin new file mode 100644 index 00000000..d38f63aa --- /dev/null +++ b/compiler_opt/rl/unroll/gin_configs/behavioral_cloning_nn_agent.gin @@ -0,0 +1,32 @@ +import gin.tf.external_configurables +import compiler_opt.rl.constant +import compiler_opt.rl.gin_external_configurables +import compiler_opt.rl.unroll.config +import tf_agents.agents.behavioral_cloning.behavioral_cloning_agent +import tf_agents.networks.q_network + +include 'compiler_opt/rl/unroll/gin_configs/common.gin' + +train_eval.agent_name=%constant.AgentName.BEHAVIORAL_CLONE +train_eval.num_iterations=100000 +train_eval.batch_size=64 +train_eval.train_sequence_length=1 + +unroll.config.get_observation_processing_layer_creator.with_sqrt = False +unroll.config.get_observation_processing_layer_creator.with_z_score_normalization = False + +create_agent.policy_network = @q_network.QNetwork + +QNetwork.preprocessing_combiner=@tf.keras.layers.Concatenate() +QNetwork.fc_layer_params=(40, 40, 20) +QNetwork.dropout_layer_params=(0.2, 0.2, 0.2) +QNetwork.activation_fn=@tf.keras.activations.relu + +tf.train.AdamOptimizer.learning_rate = 0.001 +tf.train.AdamOptimizer.epsilon = 0.0003125 + +BehavioralCloningAgent.optimizer = @tf.train.AdamOptimizer() +BehavioralCloningAgent.epsilon_greedy = 0.1 +BehavioralCloningAgent.gradient_clipping = None +BehavioralCloningAgent.debug_summaries = True +BehavioralCloningAgent.summarize_grads_and_vars = True diff --git a/compiler_opt/rl/unroll/gin_configs/common.gin b/compiler_opt/rl/unroll/gin_configs/common.gin new file mode 100644 index 00000000..7d213ba8 --- /dev/null +++ b/compiler_opt/rl/unroll/gin_configs/common.gin @@ -0,0 +1,11 @@ +config_registry.get_configuration.implementation=@configs.LoopUnrollConfig + +clang_path=None +llvm_objcopy_path=None +parse_reward_script_path=None +latency_coefficient=None + +runners.LoopUnrollRunner.clang_path=%clang_path +runners.LoopUnrollRunner.llvm_objcopy_path=%llvm_objcopy_path +runners.LoopUnrollRunner.parse_reward_script_path=%parse_reward_script_path +runners.LoopUnrollRunner.latency_coefficient=%latency_coefficient diff --git a/compiler_opt/rl/unroll/unroll_runner.py b/compiler_opt/rl/unroll/unroll_runner.py new file mode 100644 index 00000000..70753614 --- /dev/null +++ b/compiler_opt/rl/unroll/unroll_runner.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for collect data of loop unroll.""" + +import base64 +import io +import os +import tempfile +from typing import Dict, Optional, Tuple + +import gin +import tensorflow as tf + +from google.protobuf import struct_pb2 # pytype: disable=pyi-error +from compiler_opt.rl import compilation_runner +from compiler_opt.rl import corpus + + +@gin.configurable(module='runners') +class LoopUnrollRunner(compilation_runner.CompilationRunner): + """Class for collecting data for loop partial unroll. + + Usage: + runner = LoopUnrollRunner( + clang_path, llvm_objcopy_path, parse_reward_script_path, + moving_average_decay_rate) + policy_reward = unroll.collect_data( + ir_path, tf_policy_path, default_reward, moving_average_reward) + """ + + def __init__(self, llvm_objcopy_path: str, parse_reward_script_path: str, + latency_coefficient: str, *args, **kwargs): + super().__init__(*args, **kwargs) + self._llvm_objcopy_path = llvm_objcopy_path + self._parse_reward_script_path = parse_reward_script_path + self._latency_coefficient = float(latency_coefficient) + + def compile_fn( + self, module_spec: corpus.ModuleSpec, tf_policy_path: str, + reward_only: bool, cancellation_manager: Optional[ + compilation_runner.WorkerCancellationManager] + ) -> Dict[str, Tuple[tf.train.SequenceExample, float]]: + """Run loop unroll for the given IR file under the given policy. + + Args: + module_spec: a ModuleSpec. + tf_policy_path: path to TF policy directory on local disk. + reward_only: whether to only return reward (icache pressure and latency) + cancellation_manager: handler for early termination by killing any running + processes + + Returns: + For loop unroll, the result is in module level. IWS and Latency is + already weighted by the probability to be executed, checkout + parse_reward.py and code embedded under AsmPrinter.cpp for more detail). + + Since the reward is calculated at late stage in a compiler that is after + inlining some functions may be inlined and not be found for some loops, + so we sum all functions into a single float, reward_total. + + The function returns in the format: + { + "loop1_key": (loop1_features, reward_total), + "loop2_key": (loop2_features, reward_total), + ..., + "loopN_key": (loopN_features, reward_total) + } + - reward_total: sum of IWS and Latency of all functions in this module + + Early return: + The function early returns when the compiled module doesn't record any + logs or the log file doesn't record any loop. This happens when + `LoopUnrollPass` is not triggered or no loop triggered "partial unroll" + in the pass. + """ + working_dir = tempfile.mkdtemp() + + # The compiler will log input feature (loop properties) and decision + # (unroll count) into the specified log path + log_path = os.path.join(working_dir, 'log') + + # The compilation will generate object files, and our augmentation under + # AsmPrinter.cpp will create section data `llvm_block_data`. + object_path = os.path.join(working_dir, 'object') + # llvm-objcopy extracts the section data from object to data + data_path = os.path.join(working_dir, 'data') + # Reward parsing script parses data into parsed_reward + parsed_reward_path = os.path.join(working_dir, 'parsed_reward') + + try: + # Construct command to execute clang + command_line = [] + + # parameters for MLGO unroll + command_line.extend([self._clang_path] + list(module_spec.exec_cmd) + [ + '-mllvm', '-mlgo-unroll-mode=training', '-mllvm', + '-mlgo-unroll-training-log=' + + log_path, '-mllvm', '-calc-reward', '-o', object_path + ]) + + # Under `training mode`... + # If model path is provided, compiler will use ModelUnderTrainingRunner + # Otherwise, compiler will use NoInferenceModelRunner + if tf_policy_path: + command_line.extend( + ['-mllvm', 'mlgo-unroll-train-model=' + tf_policy_path]) + + print('Command to execute clang: ', command_line) + + # run clang + compilation_runner.start_cancellable_process(command_line, + self._compilation_timeout, + cancellation_manager) + + # A module may not generate a log if none of the loops go into the + # LoopUnroll decision. Early return here if log_path cannot be found. + if not os.path.exists(log_path): + print('Early return, log file not found.') + return {} + + # A log file may not have anything inside when none of the loops goes + # into PartialUnroll decision. Early return a log file is created but + # nothing inside. + if os.path.getsize(log_path) == 0: + print('Early return, log file contains nothing.') + return {} + + # Run llvm-objcopy to get section data + command_line = [ + self._llvm_objcopy_path, + '--dump-section=.llvm_block_data.=' + data_path, object_path + ] + print('Command to get section data: ', command_line) + compilation_runner.start_cancellable_process(command_line, + self._compilation_timeout, + cancellation_manager) + + # Run parse_reward.py to get reward + command_line = [ + self._parse_reward_script_path, data_path, parsed_reward_path + ] + print('Command to parse reward: ', command_line) + compilation_runner.start_cancellable_process(command_line, + self._compilation_timeout, + cancellation_manager) + + # Sum rewards of all functions into a single float + reward_total = 0 + with io.open(parsed_reward_path, 'r', encoding='utf-8') as reward_f: + for line in reward_f.readlines(): + line = line[:-1] # strip end-line + items = line.split(',') + assert len(items) == 3 + # function_name = items[0] (commented out because currently unused) + iws = float(items[1]) + latency = float(items[2]) + reward_total = reward_total + ( + iws + latency * self._latency_coefficient) + + if reward_only: + return {'default': (None, reward_total)} + + result = {} + + # Read training log, fill them in to result. + sequence_examples = struct_pb2.Struct() + with io.open(log_path, 'rb') as log_f: + sequence_examples.ParseFromString(log_f.read()) + + for key, value in sequence_examples.fields.items(): + entry = tf.train.SequenceExample() + entry.ParseFromString(base64.b64decode(value.string_value)) + + if not entry.HasField('feature_lists'): + continue + + result[key] = (entry, reward_total) + + finally: + tf.io.gfile.rmtree(working_dir) + + return result diff --git a/compiler_opt/tools/sparse_bucket_generator.py b/compiler_opt/tools/sparse_bucket_generator.py index d8e66217..2bb95a04 100644 --- a/compiler_opt/tools/sparse_bucket_generator.py +++ b/compiler_opt/tools/sparse_bucket_generator.py @@ -170,7 +170,7 @@ def main(_) -> None: parser_fn = create_tfrecord_parser_fn(sequence_features) dataset = dataset.map(parser_fn, num_parallel_calls=tf.data.AUTOTUNE) data_list = np.array(list(dataset.as_numpy_iterator()), dtype=object) - data_list = np.transpose(data_list, [1, 0]) + data_list = np.transpose(data_list, [1, 0, 2]) with mp.Pool(FLAGS.parallelism) as pool: feature_names = list(sorted(sequence_features))