diff --git a/compiler_opt/rl/generate_bc_trajectories.py b/compiler_opt/rl/generate_bc_trajectories.py new file mode 100644 index 00000000..79d65f2d --- /dev/null +++ b/compiler_opt/rl/generate_bc_trajectories.py @@ -0,0 +1,104 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for running compilation and collect data for behavior cloning.""" + +from typing import Callable, Dict, List, Optional + +import numpy as np +import tensorflow as tf +from tf_agents.trajectories import policy_step +from tf_agents.trajectories import time_step + + +class ExplorationWithPolicy: + """Policy which selects states for exploration. + + Exploration is facilitated in the following way. First the policy plays + all actions from the replay_prefix. At the following state the policy computes + a gap which is difference between the most likely action and the second most + likely action according to the randomized exploration policy (distr). + If the current gap is smaller than previously maintained gap, the gap is + updated and the exploration state is set to the current state. + The trajectory is completed by following following the policy from the + constructor. + + Attributes: + replay_prefix: a replay buffer of actions + policy: policy to follow after exhausting the replay buffer + explore_policy: randomized policy which is used to compute the gap + curr_step: current step of the trajectory + explore_step: current candidate for exploration step + gap: current difference at explore step between probability of most likely + action according to explore_policy and second most likely action + explore_on_features: dict of feature names and functions which specify + when to explore on the respective feature + """ + + def __init__( + self, + replay_prefix: List[np.ndarray], + policy: Callable[[time_step.TimeStep], np.ndarray], + explore_policy: Callable[[time_step.TimeStep], policy_step.PolicyStep], + explore_on_features: Optional[Dict[str, Callable[[tf.Tensor], + bool]]] = None, + ): + self.replay_prefix = replay_prefix + self.policy = policy + self.explore_policy = explore_policy + self.curr_step = 0 + self.explore_step = 0 + self.gap = np.inf + self.explore_on_features = explore_on_features + self._stop_exploration = False + + def _compute_gap(self, distr: np.ndarray) -> np.float32: + if distr.shape[0] < 2: + return np.inf + sorted_distr = np.sort(distr) + return sorted_distr[-1] - sorted_distr[-2] + + def get_advice(self, state: time_step.TimeStep) -> np.ndarray: + """Action function for the policy. + + Args: + state: current state in the trajectory + + Returns: + policy_action: action to take at the current state. + + """ + if self.curr_step < len(self.replay_prefix): + self.curr_step += 1 + return np.array(self.replay_prefix[self.curr_step - 1]) + policy_action = self.policy(state) + # explore_policy(state) should play at least one action per state and so + # self.explore_policy(state).action.logits should have at least one entry + distr = tf.nn.softmax(self.explore_policy(state).action.logits).numpy()[0] + curr_gap = self._compute_gap(distr) + # selecting explore_step is done based on smallest encountered gap in the + # play of self.policy. This logic can be changed to have different type + # of exploration. + if (not self._stop_exploration and distr.shape[0] > 1 and + self.gap > curr_gap): + self.gap = curr_gap + self.explore_step = self.curr_step + if not self._stop_exploration and self.explore_on_features is not None: + for feature_name, explore_on_feature in self.explore_on_features.items(): + if explore_on_feature(state.observation[feature_name]): + self.explore_step = self.curr_step + self._stop_exploration = True + break + self.curr_step += 1 + return policy_action diff --git a/compiler_opt/rl/generate_bc_trajectories_test.py b/compiler_opt/rl/generate_bc_trajectories_test.py new file mode 100644 index 00000000..a4ae9b6c --- /dev/null +++ b/compiler_opt/rl/generate_bc_trajectories_test.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for compiler_opt.rl.generate_bc_trajectories.""" + +from typing import List + +import numpy as np +import tensorflow as tf +import tensorflow_probability as tfp +from tf_agents.trajectories import policy_step +from tf_agents.trajectories import time_step + +from compiler_opt.rl import generate_bc_trajectories + +_eps = 1e-5 + + +def _get_state_list() -> List[time_step.TimeStep]: + + state_0 = time_step.TimeStep( + discount=tf.constant(np.array([0.]), dtype=tf.float32), + observation={ + 'feature_1': tf.constant(np.array([0]), dtype=tf.int64), + 'feature_2': tf.constant(np.array([50]), dtype=tf.int64), + 'feature_3': tf.constant(np.array([0]), dtype=tf.int64), + }, + reward=tf.constant(np.array([0]), dtype=tf.float32), + step_type=tf.constant(np.array([0]), dtype=tf.int32)) + state_1 = time_step.TimeStep( + discount=tf.constant(np.array([0.]), dtype=tf.float32), + observation={ + 'feature_1': tf.constant(np.array([1]), dtype=tf.int64), + 'feature_2': tf.constant(np.array([25]), dtype=tf.int64), + 'feature_3': tf.constant(np.array([0]), dtype=tf.int64), + }, + reward=tf.constant(np.array([0]), dtype=tf.float32), + step_type=tf.constant(np.array([0]), dtype=tf.int32)) + state_2 = time_step.TimeStep( + discount=tf.constant(np.array([0.]), dtype=tf.float32), + observation={ + 'feature_1': tf.constant(np.array([0]), dtype=tf.int64), + 'feature_2': tf.constant(np.array([25]), dtype=tf.int64), + 'feature_3': tf.constant(np.array([1]), dtype=tf.int64), + }, + reward=tf.constant(np.array([0]), dtype=tf.float32), + step_type=tf.constant(np.array([0]), dtype=tf.int32)) + state_3 = time_step.TimeStep( + discount=tf.constant(np.array([0.]), dtype=tf.float32), + observation={ + 'feature_1': tf.constant(np.array([0]), dtype=tf.int64), + 'feature_2': tf.constant(np.array([25]), dtype=tf.int64), + 'feature_3': tf.constant(np.array([0]), dtype=tf.int64), + }, + reward=tf.constant(np.array([0]), dtype=tf.float32), + step_type=tf.constant(np.array([0]), dtype=tf.int32)) + + return [state_0, state_1, state_2, state_3] + + +def _policy(state: time_step.TimeStep) -> np.ndarray: + feature_sum = np.array([0]) + for feature in state.observation.values(): + feature_sum += feature.numpy() + return np.mod(feature_sum, 5) + + +def _explore_policy(state: time_step.TimeStep) -> policy_step.PolicyStep: + probs = [ + 0.5 * float(state.observation['feature_3'].numpy()), + 1 - 0.5 * float(state.observation['feature_3'].numpy()) + ] + logits = [[0.0, tf.math.log(probs[1] / (1.0 - probs[1] + _eps))]] + return policy_step.PolicyStep( + action=tfp.distributions.Categorical(logits=logits)) + + +class ExplorationWithPolicyTest(tf.test.TestCase): + + def test_explore_policy(self): + prob = 1. + state = _get_state_list()[3] + logits = [[0.0, tf.math.log(prob / (1.0 - prob + _eps))]] + action = tfp.distributions.Categorical(logits=logits) + self.assertAllClose(action.logits, _explore_policy(state).action.logits) + + def test_explore_with_gap(self): + explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( + replay_prefix=[np.array([1])], + policy=_policy, + explore_policy=_explore_policy, + ) + for state in _get_state_list(): + _ = explore_with_policy.get_advice(state)[0] + + self.assertAllClose(0, explore_with_policy.gap, atol=2 * _eps) + self.assertEqual(2, explore_with_policy.explore_step) + + explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( + replay_prefix=[np.array([1]), + np.array([1]), + np.array([1])], + policy=_policy, + explore_policy=_explore_policy, + ) + for state in _get_state_list(): + _ = explore_with_policy.get_advice(state)[0] + + self.assertAllClose(1, explore_with_policy.gap, atol=2 * _eps) + self.assertEqual(3, explore_with_policy.explore_step) + + def test_explore_with_feature(self): + + def explore_on_feature_1_val(feature_val): + return feature_val.numpy()[0] > 0 + + def explore_on_feature_2_val(feature_val): + return feature_val.numpy()[0] > 25 + + explore_on_features = { + 'feature_1': explore_on_feature_1_val, + 'feature_2': explore_on_feature_2_val + } + + explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( + replay_prefix=[], + policy=_policy, + explore_policy=_explore_policy, + explore_on_features=explore_on_features) + for state in _get_state_list(): + _ = explore_with_policy.get_advice(state)[0] + self.assertEqual(0, explore_with_policy.explore_step) + + explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( + replay_prefix=[np.array([1])], + policy=_policy, + explore_policy=_explore_policy, + explore_on_features=explore_on_features, + ) + + for state in _get_state_list(): + _ = explore_with_policy.get_advice(state)[0] + self.assertEqual(1, explore_with_policy.explore_step)