From c749781151d246e15bdbc48f43ed0abbf14a8103 Mon Sep 17 00:00:00 2001 From: Plamen Totev Date: Sun, 30 Jun 2024 19:21:50 +0300 Subject: [PATCH] Update Deep CFR implementations to implement `Policy` The `player_id` argument is not need, but still needs to be present in order to implement `Policy`. For example without it `PolicyBot` breaks as it passes three arguments (self, the game state and the player id). --- open_spiel/python/algorithms/deep_cfr.py | 2 +- open_spiel/python/algorithms/deep_cfr_tf2.py | 2 +- open_spiel/python/jax/deep_cfr.py | 2 +- open_spiel/python/pytorch/deep_cfr.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/open_spiel/python/algorithms/deep_cfr.py b/open_spiel/python/algorithms/deep_cfr.py index c933de773d..b901c9f546 100644 --- a/open_spiel/python/algorithms/deep_cfr.py +++ b/open_spiel/python/algorithms/deep_cfr.py @@ -360,7 +360,7 @@ def _sample_action_from_advantage(self, state, player): return advantages, matched_regrets - def action_probabilities(self, state): + def action_probabilities(self, state, player_id=None): """Returns action probabilities dict for a single batch.""" cur_player = state.current_player() legal_actions = state.legal_actions(cur_player) diff --git a/open_spiel/python/algorithms/deep_cfr_tf2.py b/open_spiel/python/algorithms/deep_cfr_tf2.py index f085511670..93203f1d84 100644 --- a/open_spiel/python/algorithms/deep_cfr_tf2.py +++ b/open_spiel/python/algorithms/deep_cfr_tf2.py @@ -631,7 +631,7 @@ def _sample_action_from_advantage(self, state, player): info_state, legal_actions_mask, player) return advantages.numpy(), matched_regrets.numpy() - def action_probabilities(self, state): + def action_probabilities(self, state, player_id=None): """Returns action probabilities dict for a single batch.""" cur_player = state.current_player() legal_actions = state.legal_actions(cur_player) diff --git a/open_spiel/python/jax/deep_cfr.py b/open_spiel/python/jax/deep_cfr.py index 4bc9dbceea..62a4668b19 100644 --- a/open_spiel/python/jax/deep_cfr.py +++ b/open_spiel/python/jax/deep_cfr.py @@ -480,7 +480,7 @@ def _sample_action_from_advantage(self, state, player): info_state, legal_actions_mask, self._params_adv_network[player]) return advantages, matched_regrets - def action_probabilities(self, state): + def action_probabilities(self, state, player_id=None): """Returns action probabilities dict for a single batch.""" cur_player = state.current_player() legal_actions = state.legal_actions(cur_player) diff --git a/open_spiel/python/pytorch/deep_cfr.py b/open_spiel/python/pytorch/deep_cfr.py index b5681f2ef4..5cd96d1a89 100644 --- a/open_spiel/python/pytorch/deep_cfr.py +++ b/open_spiel/python/pytorch/deep_cfr.py @@ -416,7 +416,7 @@ def _sample_action_from_advantage(self, state, player): matched_regrets[max(legal_actions, key=lambda a: raw_advantages[a])] = 1 return advantages, matched_regrets - def action_probabilities(self, state): + def action_probabilities(self, state, player_id=None): """Computes action probabilities for the current player in state. Args: