Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC] ExplorationWithPolicy #378

Merged
merged 34 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a06d452
Combine two tf_agents policies with timestep spec
tvmarino Sep 3, 2024
15876f7
Merge branch 'main' of https://github.com/google/ml-compiler-opt
tvmarino Sep 3, 2024
9bc8c05
Added licence.
tvmarino Sep 6, 2024
86e4d12
yapf . -ir
tvmarino Sep 6, 2024
27dee69
yapf . -ir
tvmarino Sep 6, 2024
47f5efc
Fixed pylint errors.
tvmarino Sep 6, 2024
bf12cee
Merge branch 'policy_combiner' of https://github.com/tvmarino/ml-comp…
tvmarino Sep 6, 2024
f5b6b6f
yapf . -ir
tvmarino Sep 6, 2024
35d9e8c
Fixed super without arguments pylint error.
tvmarino Sep 6, 2024
5d6783d
yapf . -ir
tvmarino Sep 6, 2024
7997f14
Fixing pytype annotations.
tvmarino Sep 6, 2024
59d3677
Fixed pytype errors. Addressed comments.
tvmarino Sep 6, 2024
6d8c0c7
Addressed comments.
tvmarino Sep 9, 2024
3b0cefd
Resolved _distribution and common.gin comments.
tvmarino Sep 9, 2024
78460ce
Fixed Aiden's nits.
tvmarino Sep 9, 2024
5b5d67b
Merge branch 'google:main' into behavior_cloning
tvmarino Sep 30, 2024
6be7186
Patch to env.py and compilation_runner.py which adds working_dir to
tvmarino Oct 1, 2024
6342dda
Fixed comments.
tvmarino Oct 1, 2024
3082ae7
Fixed pylint.
tvmarino Oct 1, 2024
2e26243
Fixed a nit
tvmarino Oct 2, 2024
56fa72a
Added interactive only mode for env.py which
tvmarino Oct 2, 2024
463d813
Merge branch 'google:main' into behavior_cloning
tvmarino Oct 2, 2024
5568aaf
Improved _get_clang_generator documentation in env.py.
tvmarino Oct 2, 2024
0c20688
Merge branch 'behavior_cloning' of https://github.com/tvmarino/ml-com…
tvmarino Oct 2, 2024
ac307fc
Address a nit.
tvmarino Oct 2, 2024
07a77ce
Fixed pylint.
tvmarino Oct 2, 2024
49af23a
Class which defi
tvmarino Oct 3, 2024
b7e6fb2
Fix an Optional.
tvmarino Oct 3, 2024
958f3b6
yapf -ir .
tvmarino Oct 3, 2024
dad551c
Trying to fix TimeStep problem.
tvmarino Oct 3, 2024
2504a95
Fix typecheck problem
tvmarino Oct 4, 2024
6268902
Merge branch 'google:main' into behavior_cloning
tvmarino Oct 4, 2024
9a19f99
Addressing mtrofin comments.
tvmarino Oct 7, 2024
3fd273c
Addressing mtrofin comments.
tvmarino Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions compiler_opt/rl/generate_bc_trajectories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for running compilation and collect data for behavior cloning."""

from typing import Callable, Dict, List, Optional

import numpy as np
import tensorflow as tf
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step


class ExplorationWithPolicy:
"""Policy which selects states for exploration.

Exploration is facilitated in the following way. First the policy plays
all actions from the replay_prefix. At the following state the policy computes
a gap which is difference between the most likely action and the second most
likely action according to the randomized exploration policy (distr).
If the current gap is smaller than previously maintained gap, the gap is
updated and the exploration state is set to the current state.
The trajectory is completed by following following the policy from the
constructor.

Attributes:
replay_prefix: a replay buffer of actions
policy: policy to follow after exhausting the replay buffer
explore_policy: randomized policy which is used to compute the gap
curr_step: current step of the trajectory
explore_step: current candidate for exploration step
gap: current difference at explore step between probability of most likely
action according to explore_policy and second most likely action
explore_on_features: dict of feature names and functions which specify
when to explore on the respective feature
"""

def __init__(
self,
replay_prefix: List[np.ndarray],
policy: Callable[[time_step.TimeStep], np.ndarray],
explore_policy: Callable[[time_step.TimeStep], policy_step.PolicyStep],
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor],
bool]]] = None,
):
self.replay_prefix = replay_prefix
self.policy = policy
self.explore_policy = explore_policy
self.curr_step = 0
self.explore_step = 0
self.gap = np.inf
self.explore_on_features = explore_on_features
self._stop_exploration = False

def _compute_gap(self, distr: np.ndarray) -> np.float32:
if distr.shape[0] < 2:
return np.inf
sorted_distr = np.sort(distr)
return sorted_distr[-1] - sorted_distr[-2]

def get_advice(self, state: time_step.TimeStep) -> np.ndarray:
"""Action function for the policy.

Args:
state: current state in the trajectory

Returns:
policy_action: action to take at the current state.

"""
if self.curr_step < len(self.replay_prefix):
self.curr_step += 1
return np.array(self.replay_prefix[self.curr_step - 1])
policy_action = self.policy(state)
# explore_policy(state) should play at least one action per state and so
# self.explore_policy(state).action.logits should have at least one entry
distr = tf.nn.softmax(self.explore_policy(state).action.logits).numpy()[0]
curr_gap = self._compute_gap(distr)
# selecting explore_step is done based on smallest encountered gap in the
# play of self.policy. This logic can be changed to have different type
# of exploration.
if (not self._stop_exploration and distr.shape[0] > 1 and
self.gap > curr_gap):
self.gap = curr_gap
self.explore_step = self.curr_step
if not self._stop_exploration and self.explore_on_features is not None:
for feature_name, explore_on_feature in self.explore_on_features.items():
if explore_on_feature(state.observation[feature_name]):
self.explore_step = self.curr_step
self._stop_exploration = True
break
self.curr_step += 1
return policy_action
154 changes: 154 additions & 0 deletions compiler_opt/rl/generate_bc_trajectories_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for compiler_opt.rl.generate_bc_trajectories."""

from typing import List

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step

from compiler_opt.rl import generate_bc_trajectories

_eps = 1e-5


def _get_state_list() -> List[time_step.TimeStep]:

state_0 = time_step.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'feature_1': tf.constant(np.array([0]), dtype=tf.int64),
'feature_2': tf.constant(np.array([50]), dtype=tf.int64),
'feature_3': tf.constant(np.array([0]), dtype=tf.int64),
},
reward=tf.constant(np.array([0]), dtype=tf.float32),
step_type=tf.constant(np.array([0]), dtype=tf.int32))
state_1 = time_step.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'feature_1': tf.constant(np.array([1]), dtype=tf.int64),
'feature_2': tf.constant(np.array([25]), dtype=tf.int64),
'feature_3': tf.constant(np.array([0]), dtype=tf.int64),
},
reward=tf.constant(np.array([0]), dtype=tf.float32),
step_type=tf.constant(np.array([0]), dtype=tf.int32))
state_2 = time_step.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'feature_1': tf.constant(np.array([0]), dtype=tf.int64),
'feature_2': tf.constant(np.array([25]), dtype=tf.int64),
'feature_3': tf.constant(np.array([1]), dtype=tf.int64),
},
reward=tf.constant(np.array([0]), dtype=tf.float32),
step_type=tf.constant(np.array([0]), dtype=tf.int32))
state_3 = time_step.TimeStep(
discount=tf.constant(np.array([0.]), dtype=tf.float32),
observation={
'feature_1': tf.constant(np.array([0]), dtype=tf.int64),
'feature_2': tf.constant(np.array([25]), dtype=tf.int64),
'feature_3': tf.constant(np.array([0]), dtype=tf.int64),
},
reward=tf.constant(np.array([0]), dtype=tf.float32),
step_type=tf.constant(np.array([0]), dtype=tf.int32))

return [state_0, state_1, state_2, state_3]


def _policy(state: time_step.TimeStep) -> np.ndarray:
feature_sum = np.array([0])
for feature in state.observation.values():
feature_sum += feature.numpy()
return np.mod(feature_sum, 5)


def _explore_policy(state: time_step.TimeStep) -> policy_step.PolicyStep:
probs = [
0.5 * float(state.observation['feature_3'].numpy()),
1 - 0.5 * float(state.observation['feature_3'].numpy())
]
logits = [[0.0, tf.math.log(probs[1] / (1.0 - probs[1] + _eps))]]
return policy_step.PolicyStep(
action=tfp.distributions.Categorical(logits=logits))


class ExplorationWithPolicyTest(tf.test.TestCase):

def test_explore_policy(self):
prob = 1.
state = _get_state_list()[3]
logits = [[0.0, tf.math.log(prob / (1.0 - prob + _eps))]]
action = tfp.distributions.Categorical(logits=logits)
self.assertAllClose(action.logits, _explore_policy(state).action.logits)

def test_explore_with_gap(self):
explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy(
replay_prefix=[np.array([1])],
policy=_policy,
explore_policy=_explore_policy,
)
for state in _get_state_list():
_ = explore_with_policy.get_advice(state)[0]

self.assertAllClose(0, explore_with_policy.gap, atol=2 * _eps)
self.assertEqual(2, explore_with_policy.explore_step)

explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy(
replay_prefix=[np.array([1]),
np.array([1]),
np.array([1])],
policy=_policy,
explore_policy=_explore_policy,
)
for state in _get_state_list():
_ = explore_with_policy.get_advice(state)[0]

self.assertAllClose(1, explore_with_policy.gap, atol=2 * _eps)
self.assertEqual(3, explore_with_policy.explore_step)

def test_explore_with_feature(self):

def explore_on_feature_1_val(feature_val):
return feature_val.numpy()[0] > 0

def explore_on_feature_2_val(feature_val):
return feature_val.numpy()[0] > 25

explore_on_features = {
'feature_1': explore_on_feature_1_val,
'feature_2': explore_on_feature_2_val
}

explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy(
replay_prefix=[],
policy=_policy,
explore_policy=_explore_policy,
explore_on_features=explore_on_features)
for state in _get_state_list():
_ = explore_with_policy.get_advice(state)[0]
self.assertEqual(0, explore_with_policy.explore_step)

explore_with_policy = generate_bc_trajectories.ExplorationWithPolicy(
replay_prefix=[np.array([1])],
policy=_policy,
explore_policy=_explore_policy,
explore_on_features=explore_on_features,
)

for state in _get_state_list():
_ = explore_with_policy.get_advice(state)[0]
self.assertEqual(1, explore_with_policy.explore_step)