Skip to content

Commit

Permalink
[BC] ExplorationWithPolicy (#378)
Browse files Browse the repository at this point in the history
The class ```ExplorationWithPolicy```  selects which states to explore while compiling a module. This is part of collecting trajectories for the imitation learning algorithm.
  • Loading branch information
tvmarino authored Oct 8, 2024
1 parent 6efc0a8 commit d687b9b
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 0 deletions.
104 changes: 104 additions & 0 deletions compiler_opt/rl/generate_bc_trajectories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for running compilation and collect data for behavior cloning."""

from typing import Callable, Dict, List, Optional

import numpy as np
import tensorflow as tf
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step


class ExplorationWithPolicy:
"""Policy which selects states for exploration.
Exploration is facilitated in the following way. First the policy plays
all actions from the replay_prefix. At the following state the policy computes
a gap which is difference between the most likely action and the second most
likely action according to the randomized exploration policy (distr).
If the current gap is smaller than previously maintained gap, the gap is
updated and the exploration state is set to the current state.
The trajectory is completed by following following the policy from the
constructor.
Attributes:
replay_prefix: a replay buffer of actions
policy: policy to follow after exhausting the replay buffer
explore_policy: randomized policy which is used to compute the gap
curr_step: current step of the trajectory
explore_step: current candidate for exploration step
gap: current difference at explore step between probability of most likely
action according to explore_policy and second most likely action
explore_on_features: dict of feature names and functions which specify
when to explore on the respective feature
"""

def __init__(
self,
replay_prefix: List[np.ndarray],
policy: Callable[[time_step.TimeStep], np.ndarray],
explore_policy: Callable[[time_step.TimeStep], policy_step.PolicyStep],
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor],
bool]]] = None,
):
self.replay_prefix = replay_prefix
self.policy = policy
self.explore_policy = explore_policy
self.curr_step = 0
self.explore_step = 0
self.gap = np.inf
self.explore_on_features = explore_on_features
self._stop_exploration = False

def _compute_gap(self, distr: np.ndarray) -> np.float32:
if distr.shape[0] < 2:
return np.inf
sorted_distr = np.sort(distr)
return sorted_distr[-1] - sorted_distr[-2]

def get_advice(self, state: time_step.TimeStep) -> np.ndarray:
"""Action function for the policy.
Args:
state: current state in the trajectory
Returns:
policy_action: action to take at the current state.
"""
if self.curr_step < len(self.replay_prefix):
self.curr_step += 1
return np.array(self.replay_prefix[self.curr_step - 1])
policy_action = self.policy(state)
# explore_policy(state) should play at least one action per state and so
# self.explore_policy(state).action.logits should have at least one entry
distr = tf.nn.softmax(self.explore_policy(state).action.logits).numpy()[0]
curr_gap = self._compute_gap(distr)
# selecting explore_step is done based on smallest encountered gap in the
# play of self.policy. This logic can be changed to have different type
# of exploration.
if (not self._stop_exploration and distr.shape[0] > 1 and
self.gap > curr_gap):
self.gap = curr_gap
self.explore_step = self.curr_step
if not self._stop_exploration and self.explore_on_features is not None:
for feature_name, explore_on_feature in self.explore_on_features.items():
if explore_on_feature(state.observation[feature_name]):
self.explore_step = self.curr_step
self._stop_exploration = True
break
self.curr_step += 1
return policy_action
154 changes: 154 additions & 0 deletions compiler_opt/rl/generate_bc_trajectories_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for compiler_opt.rl.generate_bc_trajectories."""

from typing import List

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step

from compiler_opt.rl import generate_bc_trajectories

_eps = 1e-5


def _get_state_list() -> List[time_step.TimeStep]:

state_0 = time_step.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'feature_1': tf.constant(np.array([0]), dtype=tf.int64),
'feature_2': tf.constant(np.array([50]), dtype=tf.int64),
'feature_3': tf.constant(np.array([0]), dtype=tf.int64),
},
reward=tf.constant(np.array([0]), dtype=tf.float32),
step_type=tf.constant(np.array([0]), dtype=tf.int32))
state_1 = time_step.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'feature_1': tf.constant(np.array([1]), dtype=tf.int64),
'feature_2': tf.constant(np.array([25]), dtype=tf.int64),
'feature_3': tf.constant(np.array([0]), dtype=tf.int64),
},
reward=tf.constant(np.array([0]), dtype=tf.float32),
step_type=tf.constant(np.array([0]), dtype=tf.int32))
state_2 = time_step.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'feature_1': tf.constant(np.array([0]), dtype=tf.int64),
'feature_2': tf.constant(np.array([25]), dtype=tf.int64),
'feature_3': tf.constant(np.array([1]), dtype=tf.int64),
},
reward=tf.constant(np.array([0]), dtype=tf.float32),
step_type=tf.constant(np.array([0]), dtype=tf.int32))
state_3 = time_step.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'feature_1': tf.constant(np.array([0]), dtype=tf.int64),
'feature_2': tf.constant(np.array([25]), dtype=tf.int64),
'feature_3': tf.constant(np.array([0]), dtype=tf.int64),
},
reward=tf.constant(np.array([0]), dtype=tf.float32),
step_type=tf.constant(np.array([0]), dtype=tf.int32))

return [state_0, state_1, state_2, state_3]


def _policy(state: time_step.TimeStep) -> np.ndarray:
feature_sum = np.array([0])
for feature in state.observation.values():
feature_sum += feature.numpy()
return np.mod(feature_sum, 5)


def _explore_policy(state: time_step.TimeStep) -> policy_step.PolicyStep:
probs = [
0.5 * float(state.observation['feature_3'].numpy()),
1 - 0.5 * float(state.observation['feature_3'].numpy())
]
logits = [[0.0, tf.math.log(probs[1] / (1.0 - probs[1] + _eps))]]
return policy_step.PolicyStep(
action=tfp.distributions.Categorical(logits=logits))


class ExplorationWithPolicyTest(tf.test.TestCase):

def test_explore_policy(self):
prob = 1.
state = _get_state_list()[3]
logits = [[0.0, tf.math.log(prob / (1.0 - prob + _eps))]]
action = tfp.distributions.Categorical(logits=logits)
self.assertAllClose(action.logits, _explore_policy(state).action.logits)

def test_explore_with_gap(self):
explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy(
replay_prefix=[np.array([1])],
policy=_policy,
explore_policy=_explore_policy,
)
for state in _get_state_list():
_ = explore_with_policy.get_advice(state)[0]

self.assertAllClose(0, explore_with_policy.gap, atol=2 * _eps)
self.assertEqual(2, explore_with_policy.explore_step)

explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy(
replay_prefix=[np.array([1]),
np.array([1]),
np.array([1])],
policy=_policy,
explore_policy=_explore_policy,
)
for state in _get_state_list():
_ = explore_with_policy.get_advice(state)[0]

self.assertAllClose(1, explore_with_policy.gap, atol=2 * _eps)
self.assertEqual(3, explore_with_policy.explore_step)

def test_explore_with_feature(self):

def explore_on_feature_1_val(feature_val):
return feature_val.numpy()[0] > 0

def explore_on_feature_2_val(feature_val):
return feature_val.numpy()[0] > 25

explore_on_features = {
'feature_1': explore_on_feature_1_val,
'feature_2': explore_on_feature_2_val
}

explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy(
replay_prefix=[],
policy=_policy,
explore_policy=_explore_policy,
explore_on_features=explore_on_features)
for state in _get_state_list():
_ = explore_with_policy.get_advice(state)[0]
self.assertEqual(0, explore_with_policy.explore_step)

explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy(
replay_prefix=[np.array([1])],
policy=_policy,
explore_policy=_explore_policy,
explore_on_features=explore_on_features,
)

for state in _get_state_list():
_ = explore_with_policy.get_advice(state)[0]
self.assertEqual(1, explore_with_policy.explore_step)

0 comments on commit d687b9b

Please sign in to comment.