Skip to content

Commit

Permalink
Merge pull request #1243 from plamentotev:deepcfr
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 656457757
Change-Id: If1559e628ba14ed4b0cc3a7807ada6134bd3a82f
  • Loading branch information
lanctot committed Jul 30, 2024
2 parents 380bc11 + c749781 commit 2eed8b6
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
3 changes: 2 additions & 1 deletion open_spiel/python/algorithms/deep_cfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,9 @@ def _sample_action_from_advantage(self, state, player):

return advantages, matched_regrets

def action_probabilities(self, state):
def action_probabilities(self, state, player_id=None):
"""Returns action probabilities dict for a single batch."""
del player_id # unused
cur_player = state.current_player()
legal_actions = state.legal_actions(cur_player)
info_state_vector = np.array(state.information_state_tensor())
Expand Down
3 changes: 2 additions & 1 deletion open_spiel/python/algorithms/deep_cfr_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,9 @@ def _sample_action_from_advantage(self, state, player):
info_state, legal_actions_mask, player)
return advantages.numpy(), matched_regrets.numpy()

def action_probabilities(self, state):
def action_probabilities(self, state, player_id=None):
"""Returns action probabilities dict for a single batch."""
del player_id # unused
cur_player = state.current_player()
legal_actions = state.legal_actions(cur_player)
legal_actions_mask = tf.constant(
Expand Down
3 changes: 2 additions & 1 deletion open_spiel/python/jax/deep_cfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,9 @@ def _sample_action_from_advantage(self, state, player):
info_state, legal_actions_mask, self._params_adv_network[player])
return advantages, matched_regrets

def action_probabilities(self, state):
def action_probabilities(self, state, player_id=None):
"""Returns action probabilities dict for a single batch."""
del player_id # unused
cur_player = state.current_player()
legal_actions = state.legal_actions(cur_player)
info_state_vector = jnp.array(
Expand Down
4 changes: 3 additions & 1 deletion open_spiel/python/pytorch/deep_cfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,15 +416,17 @@ def _sample_action_from_advantage(self, state, player):
matched_regrets[max(legal_actions, key=lambda a: raw_advantages[a])] = 1
return advantages, matched_regrets

def action_probabilities(self, state):
def action_probabilities(self, state, player_id=None):
"""Computes action probabilities for the current player in state.
Args:
state: (pyspiel.State) The state to compute probabilities for.
player_id: unused, but needed to implement the Policy API.
Returns:
(dict) action probabilities for a single batch.
"""
del player_id
cur_player = state.current_player()
legal_actions = state.legal_actions(cur_player)
info_state_vector = np.array(state.information_state_tensor())
Expand Down

0 comments on commit 2eed8b6

Please sign in to comment.