-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The class ```ExplorationWithPolicy``` selects which states to explore while compiling a module. This is part of collecting trajectories for the imitation learning algorithm.
- Loading branch information
Showing
2 changed files
with
258 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Module for running compilation and collect data for behavior cloning.""" | ||
|
||
from typing import Callable, Dict, List, Optional | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from tf_agents.trajectories import policy_step | ||
from tf_agents.trajectories import time_step | ||
|
||
|
||
class ExplorationWithPolicy: | ||
"""Policy which selects states for exploration. | ||
Exploration is facilitated in the following way. First the policy plays | ||
all actions from the replay_prefix. At the following state the policy computes | ||
a gap which is difference between the most likely action and the second most | ||
likely action according to the randomized exploration policy (distr). | ||
If the current gap is smaller than previously maintained gap, the gap is | ||
updated and the exploration state is set to the current state. | ||
The trajectory is completed by following following the policy from the | ||
constructor. | ||
Attributes: | ||
replay_prefix: a replay buffer of actions | ||
policy: policy to follow after exhausting the replay buffer | ||
explore_policy: randomized policy which is used to compute the gap | ||
curr_step: current step of the trajectory | ||
explore_step: current candidate for exploration step | ||
gap: current difference at explore step between probability of most likely | ||
action according to explore_policy and second most likely action | ||
explore_on_features: dict of feature names and functions which specify | ||
when to explore on the respective feature | ||
""" | ||
|
||
def __init__( | ||
self, | ||
replay_prefix: List[np.ndarray], | ||
policy: Callable[[time_step.TimeStep], np.ndarray], | ||
explore_policy: Callable[[time_step.TimeStep], policy_step.PolicyStep], | ||
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor], | ||
bool]]] = None, | ||
): | ||
self.replay_prefix = replay_prefix | ||
self.policy = policy | ||
self.explore_policy = explore_policy | ||
self.curr_step = 0 | ||
self.explore_step = 0 | ||
self.gap = np.inf | ||
self.explore_on_features = explore_on_features | ||
self._stop_exploration = False | ||
|
||
def _compute_gap(self, distr: np.ndarray) -> np.float32: | ||
if distr.shape[0] < 2: | ||
return np.inf | ||
sorted_distr = np.sort(distr) | ||
return sorted_distr[-1] - sorted_distr[-2] | ||
|
||
def get_advice(self, state: time_step.TimeStep) -> np.ndarray: | ||
"""Action function for the policy. | ||
Args: | ||
state: current state in the trajectory | ||
Returns: | ||
policy_action: action to take at the current state. | ||
""" | ||
if self.curr_step < len(self.replay_prefix): | ||
self.curr_step += 1 | ||
return np.array(self.replay_prefix[self.curr_step - 1]) | ||
policy_action = self.policy(state) | ||
# explore_policy(state) should play at least one action per state and so | ||
# self.explore_policy(state).action.logits should have at least one entry | ||
distr = tf.nn.softmax(self.explore_policy(state).action.logits).numpy()[0] | ||
curr_gap = self._compute_gap(distr) | ||
# selecting explore_step is done based on smallest encountered gap in the | ||
# play of self.policy. This logic can be changed to have different type | ||
# of exploration. | ||
if (not self._stop_exploration and distr.shape[0] > 1 and | ||
self.gap > curr_gap): | ||
self.gap = curr_gap | ||
self.explore_step = self.curr_step | ||
if not self._stop_exploration and self.explore_on_features is not None: | ||
for feature_name, explore_on_feature in self.explore_on_features.items(): | ||
if explore_on_feature(state.observation[feature_name]): | ||
self.explore_step = self.curr_step | ||
self._stop_exploration = True | ||
break | ||
self.curr_step += 1 | ||
return policy_action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for compiler_opt.rl.generate_bc_trajectories.""" | ||
|
||
from typing import List | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_probability as tfp | ||
from tf_agents.trajectories import policy_step | ||
from tf_agents.trajectories import time_step | ||
|
||
from compiler_opt.rl import generate_bc_trajectories | ||
|
||
_eps = 1e-5 | ||
|
||
|
||
def _get_state_list() -> List[time_step.TimeStep]: | ||
|
||
state_0 = time_step.TimeStep( | ||
discount=tf.constant(np.array([0.]), dtype=tf.float32), | ||
observation={ | ||
'feature_1': tf.constant(np.array([0]), dtype=tf.int64), | ||
'feature_2': tf.constant(np.array([50]), dtype=tf.int64), | ||
'feature_3': tf.constant(np.array([0]), dtype=tf.int64), | ||
}, | ||
reward=tf.constant(np.array([0]), dtype=tf.float32), | ||
step_type=tf.constant(np.array([0]), dtype=tf.int32)) | ||
state_1 = time_step.TimeStep( | ||
discount=tf.constant(np.array([0.]), dtype=tf.float32), | ||
observation={ | ||
'feature_1': tf.constant(np.array([1]), dtype=tf.int64), | ||
'feature_2': tf.constant(np.array([25]), dtype=tf.int64), | ||
'feature_3': tf.constant(np.array([0]), dtype=tf.int64), | ||
}, | ||
reward=tf.constant(np.array([0]), dtype=tf.float32), | ||
step_type=tf.constant(np.array([0]), dtype=tf.int32)) | ||
state_2 = time_step.TimeStep( | ||
discount=tf.constant(np.array([0.]), dtype=tf.float32), | ||
observation={ | ||
'feature_1': tf.constant(np.array([0]), dtype=tf.int64), | ||
'feature_2': tf.constant(np.array([25]), dtype=tf.int64), | ||
'feature_3': tf.constant(np.array([1]), dtype=tf.int64), | ||
}, | ||
reward=tf.constant(np.array([0]), dtype=tf.float32), | ||
step_type=tf.constant(np.array([0]), dtype=tf.int32)) | ||
state_3 = time_step.TimeStep( | ||
discount=tf.constant(np.array([0.]), dtype=tf.float32), | ||
observation={ | ||
'feature_1': tf.constant(np.array([0]), dtype=tf.int64), | ||
'feature_2': tf.constant(np.array([25]), dtype=tf.int64), | ||
'feature_3': tf.constant(np.array([0]), dtype=tf.int64), | ||
}, | ||
reward=tf.constant(np.array([0]), dtype=tf.float32), | ||
step_type=tf.constant(np.array([0]), dtype=tf.int32)) | ||
|
||
return [state_0, state_1, state_2, state_3] | ||
|
||
|
||
def _policy(state: time_step.TimeStep) -> np.ndarray: | ||
feature_sum = np.array([0]) | ||
for feature in state.observation.values(): | ||
feature_sum += feature.numpy() | ||
return np.mod(feature_sum, 5) | ||
|
||
|
||
def _explore_policy(state: time_step.TimeStep) -> policy_step.PolicyStep: | ||
probs = [ | ||
0.5 * float(state.observation['feature_3'].numpy()), | ||
1 - 0.5 * float(state.observation['feature_3'].numpy()) | ||
] | ||
logits = [[0.0, tf.math.log(probs[1] / (1.0 - probs[1] + _eps))]] | ||
return policy_step.PolicyStep( | ||
action=tfp.distributions.Categorical(logits=logits)) | ||
|
||
|
||
class ExplorationWithPolicyTest(tf.test.TestCase): | ||
|
||
def test_explore_policy(self): | ||
prob = 1. | ||
state = _get_state_list()[3] | ||
logits = [[0.0, tf.math.log(prob / (1.0 - prob + _eps))]] | ||
action = tfp.distributions.Categorical(logits=logits) | ||
self.assertAllClose(action.logits, _explore_policy(state).action.logits) | ||
|
||
def test_explore_with_gap(self): | ||
explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( | ||
replay_prefix=[np.array([1])], | ||
policy=_policy, | ||
explore_policy=_explore_policy, | ||
) | ||
for state in _get_state_list(): | ||
_ = explore_with_policy.get_advice(state)[0] | ||
|
||
self.assertAllClose(0, explore_with_policy.gap, atol=2 * _eps) | ||
self.assertEqual(2, explore_with_policy.explore_step) | ||
|
||
explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( | ||
replay_prefix=[np.array([1]), | ||
np.array([1]), | ||
np.array([1])], | ||
policy=_policy, | ||
explore_policy=_explore_policy, | ||
) | ||
for state in _get_state_list(): | ||
_ = explore_with_policy.get_advice(state)[0] | ||
|
||
self.assertAllClose(1, explore_with_policy.gap, atol=2 * _eps) | ||
self.assertEqual(3, explore_with_policy.explore_step) | ||
|
||
def test_explore_with_feature(self): | ||
|
||
def explore_on_feature_1_val(feature_val): | ||
return feature_val.numpy()[0] > 0 | ||
|
||
def explore_on_feature_2_val(feature_val): | ||
return feature_val.numpy()[0] > 25 | ||
|
||
explore_on_features = { | ||
'feature_1': explore_on_feature_1_val, | ||
'feature_2': explore_on_feature_2_val | ||
} | ||
|
||
explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( | ||
replay_prefix=[], | ||
policy=_policy, | ||
explore_policy=_explore_policy, | ||
explore_on_features=explore_on_features) | ||
for state in _get_state_list(): | ||
_ = explore_with_policy.get_advice(state)[0] | ||
self.assertEqual(0, explore_with_policy.explore_step) | ||
|
||
explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy( | ||
replay_prefix=[np.array([1])], | ||
policy=_policy, | ||
explore_policy=_explore_policy, | ||
explore_on_features=explore_on_features, | ||
) | ||
|
||
for state in _get_state_list(): | ||
_ = explore_with_policy.get_advice(state)[0] | ||
self.assertEqual(1, explore_with_policy.explore_step) |