-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add loop unroll into MLGO #180
Open
eopXD
wants to merge
2
commits into
google:main
Choose a base branch
from
eopXD:loop-unroll
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Implementation of the 'loop unroll' problem.""" | ||
|
||
import gin | ||
|
||
from compiler_opt.rl import problem_configuration | ||
from compiler_opt.rl.unroll import config | ||
from compiler_opt.rl.unroll import unroll_runner | ||
|
||
|
||
@gin.register(module='configs') | ||
class LoopUnrollConfig(problem_configuration.ProblemConfiguration): | ||
"""Expose the regalloc eviction components.""" | ||
|
||
def get_runner_type(self): | ||
return unroll_runner.LoopUnrollRunner | ||
|
||
def get_signature_spec(self): | ||
return config.get_unroll_signature_spec() | ||
|
||
def get_preprocessing_layer_creator(self): | ||
return config.get_observation_processing_layer_creator() | ||
|
||
def get_nonnormalized_features(self): | ||
return config.get_nonnormalized_features() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Loop unroll training config.""" | ||
|
||
import gin | ||
import tensorflow as tf | ||
from tf_agents.specs import tensor_spec | ||
from tf_agents.trajectories import time_step | ||
from compiler_opt.rl import feature_ops | ||
|
||
|
||
# pylint: disable=g-complex-comprehension | ||
@gin.configurable() | ||
def get_unroll_signature_spec(): | ||
"""Returns (time_step_spec, action_spec) for LLVM loop unroll.""" | ||
# LINT.IfChange | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete # LINT.IfChange |
||
observation_spec = dict( | ||
(key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key)) | ||
for key in ('loop_size', 'trip_count', 'is_innermost_loop', | ||
'preheader_blocksize', 'bb_count', 'num_of_loop_latch', | ||
'load_inst_count', 'store_inst_count', 'logical_inst_count', | ||
'cast_inst_count')) | ||
reward_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='reward') | ||
time_step_spec = time_step.time_step_spec(observation_spec, reward_spec) | ||
action_spec = tensor_spec.BoundedTensorSpec( | ||
dtype=tf.int64, shape=(), name='unroll_count') | ||
|
||
return time_step_spec, action_spec | ||
|
||
|
||
@gin.configurable | ||
def get_observation_processing_layer_creator(quantile_file_dir=None, | ||
with_sqrt=True, | ||
with_z_score_normalization=True, | ||
eps=1e-8): | ||
"""Wrapper for observation_processing_layer.""" | ||
quantile_map = feature_ops.build_quantile_map(quantile_file_dir) | ||
|
||
def observation_processing_layer(obs_spec): | ||
"""Creates the layer to process observation given obs_spec.""" | ||
|
||
# I guess we discard rewards when observation? | ||
if obs_spec.name in ('icache_pressure', 'latency'): | ||
return tf.keras.layers.Lambda(feature_ops.discard_fn) | ||
|
||
# for boolean features, use feature_ops.identity_fn | ||
if obs_spec.name in ('is_innermost_loop'): | ||
return tf.keras.layers.Lambda(feature_ops.identity_fn) | ||
|
||
# Do we need to define some layer here to normalize 'loop_size' | ||
# and instruction count features (e.g. 'load_inst_count'). | ||
# Bigger loops expect more instruction counts, and we need to | ||
# normalize this? | ||
|
||
quantile = quantile_map[obs_spec.name] | ||
return tf.keras.layers.Lambda( | ||
feature_ops.get_normalize_fn(quantile, with_sqrt, | ||
with_z_score_normalization, eps)) | ||
|
||
return observation_processing_layer | ||
|
||
|
||
def get_nonnormalized_features(): | ||
return ['reward', 'is_innermost_loop'] |
32 changes: 32 additions & 0 deletions
32
compiler_opt/rl/unroll/gin_configs/behavioral_cloning_nn_agent.gin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import gin.tf.external_configurables | ||
import compiler_opt.rl.constant | ||
import compiler_opt.rl.gin_external_configurables | ||
import compiler_opt.rl.unroll.config | ||
import tf_agents.agents.behavioral_cloning.behavioral_cloning_agent | ||
import tf_agents.networks.q_network | ||
|
||
include 'compiler_opt/rl/unroll/gin_configs/common.gin' | ||
|
||
train_eval.agent_name=%constant.AgentName.BEHAVIORAL_CLONE | ||
train_eval.num_iterations=100000 | ||
train_eval.batch_size=64 | ||
train_eval.train_sequence_length=1 | ||
|
||
unroll.config.get_observation_processing_layer_creator.with_sqrt = False | ||
unroll.config.get_observation_processing_layer_creator.with_z_score_normalization = False | ||
|
||
create_agent.policy_network = @q_network.QNetwork | ||
|
||
QNetwork.preprocessing_combiner=@tf.keras.layers.Concatenate() | ||
QNetwork.fc_layer_params=(40, 40, 20) | ||
QNetwork.dropout_layer_params=(0.2, 0.2, 0.2) | ||
QNetwork.activation_fn=@tf.keras.activations.relu | ||
|
||
tf.train.AdamOptimizer.learning_rate = 0.001 | ||
tf.train.AdamOptimizer.epsilon = 0.0003125 | ||
|
||
BehavioralCloningAgent.optimizer = @tf.train.AdamOptimizer() | ||
BehavioralCloningAgent.epsilon_greedy = 0.1 | ||
BehavioralCloningAgent.gradient_clipping = None | ||
BehavioralCloningAgent.debug_summaries = True | ||
BehavioralCloningAgent.summarize_grads_and_vars = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
config_registry.get_configuration.implementation=@configs.LoopUnrollConfig | ||
|
||
clang_path=None | ||
llvm_objcopy_path=None | ||
parse_reward_script_path=None | ||
latency_coefficient=None | ||
|
||
runners.LoopUnrollRunner.clang_path=%clang_path | ||
runners.LoopUnrollRunner.llvm_objcopy_path=%llvm_objcopy_path | ||
runners.LoopUnrollRunner.parse_reward_script_path=%parse_reward_script_path | ||
runners.LoopUnrollRunner.latency_coefficient=%latency_coefficient |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Module for collect data of loop unroll.""" | ||
|
||
import base64 | ||
import io | ||
import os | ||
import tempfile | ||
from typing import Dict, Optional, Tuple | ||
|
||
import gin | ||
import tensorflow as tf | ||
|
||
from google.protobuf import struct_pb2 # pytype: disable=pyi-error | ||
from compiler_opt.rl import compilation_runner | ||
from compiler_opt.rl import corpus | ||
|
||
|
||
@gin.configurable(module='runners') | ||
class LoopUnrollRunner(compilation_runner.CompilationRunner): | ||
"""Class for collecting data for loop partial unroll. | ||
|
||
Usage: | ||
runner = LoopUnrollRunner( | ||
clang_path, llvm_objcopy_path, parse_reward_script_path, | ||
moving_average_decay_rate) | ||
policy_reward = unroll.collect_data( | ||
ir_path, tf_policy_path, default_reward, moving_average_reward) | ||
""" | ||
|
||
def __init__(self, llvm_objcopy_path: str, parse_reward_script_path: str, | ||
latency_coefficient: str, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self._llvm_objcopy_path = llvm_objcopy_path | ||
self._parse_reward_script_path = parse_reward_script_path | ||
self._latency_coefficient = float(latency_coefficient) | ||
|
||
def compile_fn( | ||
self, command_line: corpus.FullyQualifiedCmdLine, tf_policy_path: str, | ||
reward_only: bool, cancellation_manager: Optional[ | ||
compilation_runner.WorkerCancellationManager] | ||
) -> Dict[str, Tuple[tf.train.SequenceExample, float]]: | ||
"""Run loop unroll for the given IR file under the given policy. | ||
|
||
Args: | ||
command_line: the fully qualified command line. | ||
tf_policy_path: path to TF policy directory on local disk. | ||
reward_only: whether to only return reward (icache pressure and latency) | ||
cancellation_manager: handler for early termination by killing any running | ||
processes | ||
|
||
Returns: | ||
For loop unroll, the result is in module level. IWS and Latency is | ||
already weighted by the probability to be executed, checkout | ||
parse_reward.py and code embedded under AsmPrinter.cpp for more detail). | ||
|
||
Since the reward is calculated at late stage in a compiler that is after | ||
inlining some functions may be inlined and not be found for some loops, | ||
so we sum all functions into a single float, reward_total. | ||
|
||
The function returns in the format: | ||
{ | ||
"loop1_key": (loop1_features, reward_total), | ||
"loop2_key": (loop2_features, reward_total), | ||
..., | ||
"loopN_key": (loopN_features, reward_total) | ||
} | ||
- reward_total: sum of IWS and Latency of all functions in this module | ||
|
||
Early return: | ||
The function early returns when the compiled module doesn't record any | ||
logs or the log file doesn't record any loop. This happens when | ||
`LoopUnrollPass` is not triggered or no loop triggered "partial unroll" | ||
in the pass. | ||
""" | ||
working_dir = tempfile.mkdtemp() | ||
|
||
# The compiler will log input feature (loop properties) and decision | ||
# (unroll count) into the specified log path | ||
log_path = os.path.join(working_dir, 'log') | ||
|
||
# The compilation will generate object files, and our augmentation under | ||
# AsmPrinter.cpp will create section data `llvm_block_data`. | ||
object_path = os.path.join(working_dir, 'object') | ||
# llvm-objcopy extracts the section data from object to data | ||
data_path = os.path.join(working_dir, 'data') | ||
# Reward parsing script parses data into parsed_reward | ||
parsed_reward_path = os.path.join(working_dir, 'parsed_reward') | ||
|
||
try: | ||
# Construct command to execute clang | ||
command_line = [] | ||
|
||
# parameters for MLGO unroll | ||
command_line.extend([self._clang_path] + list(command_line) + [ | ||
'-mllvm', '-mlgo-unroll-mode=training', '-mllvm', | ||
'-mlgo-unroll-training-log=' + | ||
log_path, '-mllvm', '-calc-reward', '-o', object_path | ||
]) | ||
|
||
# Under `training mode`... | ||
# If model path is provided, compiler will use ModelUnderTrainingRunner | ||
# Otherwise, compiler will use NoInferenceModelRunner | ||
if tf_policy_path: | ||
command_line.extend( | ||
['-mllvm', 'mlgo-unroll-train-model=' + tf_policy_path]) | ||
|
||
print('Command to execute clang: ', command_line) | ||
|
||
# run clang | ||
compilation_runner.start_cancellable_process(command_line, | ||
self._compilation_timeout, | ||
cancellation_manager) | ||
|
||
# A module may not generate a log if none of the loops go into the | ||
# LoopUnroll decision. Early return here if log_path cannot be found. | ||
if not os.path.exists(log_path): | ||
print('Early return, log file not found.') | ||
return {} | ||
|
||
# A log file may not have anything inside when none of the loops goes | ||
# into PartialUnroll decision. Early return a log file is created but | ||
# nothing inside. | ||
if os.path.getsize(log_path) == 0: | ||
print('Early return, log file contains nothing.') | ||
return {} | ||
|
||
# Run llvm-objcopy to get section data | ||
command_line = [ | ||
self._llvm_objcopy_path, | ||
'--dump-section=.llvm_block_data.=' + data_path, object_path | ||
] | ||
print('Command to get section data: ', command_line) | ||
compilation_runner.start_cancellable_process(command_line, | ||
self._compilation_timeout, | ||
cancellation_manager) | ||
|
||
# Run parse_reward.py to get reward | ||
command_line = [ | ||
self._parse_reward_script_path, data_path, parsed_reward_path | ||
] | ||
print('Command to parse reward: ', command_line) | ||
compilation_runner.start_cancellable_process(command_line, | ||
self._compilation_timeout, | ||
cancellation_manager) | ||
|
||
# Sum rewards of all functions into a single float | ||
reward_total = 0 | ||
with io.open(parsed_reward_path, 'r', encoding='utf-8') as reward_f: | ||
for line in reward_f.readlines(): | ||
line = line[:-1] # strip end-line | ||
items = line.split(',') | ||
assert len(items) == 3 | ||
# function_name = items[0] (commented out because currently unused) | ||
iws = float(items[1]) | ||
latency = float(items[2]) | ||
reward_total = reward_total + ( | ||
iws + latency * self._latency_coefficient) | ||
|
||
if reward_only: | ||
return {'default': (None, reward_total)} | ||
|
||
result = {} | ||
|
||
# Read training log, fill them in to result. | ||
sequence_examples = struct_pb2.Struct() | ||
with io.open(log_path, 'rb') as log_f: | ||
sequence_examples.ParseFromString(log_f.read()) | ||
|
||
for key, value in sequence_examples.fields.items(): | ||
entry = tf.train.SequenceExample() | ||
entry.ParseFromString(base64.b64decode(value.string_value)) | ||
|
||
if not entry.HasField('feature_lists'): | ||
continue | ||
|
||
result[key] = (entry, reward_total) | ||
|
||
finally: | ||
tf.io.gfile.rmtree(working_dir) | ||
|
||
return result |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you revert it back?