Skip to content

Commit

Permalink
Updates to add TF now that 2.16.0rc0 supports Python 3.12
Browse files Browse the repository at this point in the history
  • Loading branch information
lanctot committed Feb 24, 2024
1 parent e2b6eab commit 62f0baa
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ jobs:
OPEN_SPIEL_BUILD_WITH_ORTOOLS: "OFF"
OPEN_SPIEL_BUILD_WITH_ORTOOLS_DOWNLOAD_URL: ""
# Disable due to not yet available for Python 3.12
# Missing deps: tensorflow, qdldl (needed by cvxpy and osqp)
# Still must change versions everything in python_extra_deps.sh
# Missing deps: qdldl (needed by cvxpy and osqp)
# Note also some tests in python/CMakeLists.txt disabled for this reason.
OPEN_SPIEL_BUILD_WITH_TENSORFLOW: "OFF"
OPEN_SPIEL_BUILD_WITH_MISC: "OFF"
Expand Down
6 changes: 2 additions & 4 deletions open_spiel/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,7 @@ set(PYTHON_TESTS ${PYTHON_TESTS}
# Add Jax tests if it is enabled.
if (OPEN_SPIEL_ENABLE_JAX)
set (PYTHON_TESTS ${PYTHON_TESTS}
# Temporarily disabled: depends on TF and TF doesn't run on Python 3.12.
# jax/deep_cfr_jax_test.py
jax/deep_cfr_jax_test.py
jax/dqn_jax_test.py
jax/nfsp_jax_test.py
jax/opponent_shaping_jax_test.py
Expand All @@ -295,8 +294,7 @@ if (OPEN_SPIEL_ENABLE_PYTORCH)
pytorch/deep_cfr_pytorch_test.py
pytorch/eva_pytorch_test.py
pytorch/losses/rl_losses_pytorch_test.py
# Temporarily disabled: depends on TF and TF doesn't run on Python 3.12.
# pytorch/policy_gradient_pytorch_test.py
pytorch/policy_gradient_pytorch_test.py
pytorch/ppo_pytorch_test.py
pytorch/neurd_pytorch_test.py
)
Expand Down
15 changes: 8 additions & 7 deletions open_spiel/python/algorithms/rcfr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import numpy as np
# Note: this import needs to come before Tensorflow to fix a malloc error.
import pyspiel # pylint: disable=g-bad-import-order
import tensorflow.compat.v1 as tf
# import tensorflow.compat.v1 as tf
import tensorflow as tf

from open_spiel.python.algorithms import rcfr

# Temporarily disable TF2 behavior while the code is not updated.
tf.disable_v2_behavior()
#tf.disable_v2_behavior()

tf.enable_eager_execution()
#tf.enable_eager_execution()

_GAME = pyspiel.load_game('kuhn_poker')
_BOOLEANS = [False, True]
Expand All @@ -46,7 +47,7 @@ class RcfrTest(parameterized.TestCase, tf.test.TestCase):

def setUp(self):
super(RcfrTest, self).setUp()
tf.random.set_random_seed(42)
tf.random.set_seed(42)

def test_with_one_hot_action_features_single_state_vector(self):
information_state_features = [1., 2., 3.]
Expand Down Expand Up @@ -476,7 +477,7 @@ def test_rcfr_functions(self):
data = data.batch(12)
data = data.repeat(num_epochs)

optimizer = tf.keras.optimizers.Adam(lr=0.005, amsgrad=True)
optimizer = tf.optimizers.Adam(learning_rate=0.005, amsgrad=True)

for x, y in data:
optimizer.minimize(
Expand Down Expand Up @@ -504,7 +505,7 @@ def _train(model, data):
data = data.batch(12)
data = data.repeat(num_epochs)

optimizer = tf.keras.optimizers.Adam(lr=0.005, amsgrad=True)
optimizer = tf.optimizers.Adam(learning_rate=0.005, amsgrad=True)

for x, y in data:
optimizer.minimize(
Expand Down Expand Up @@ -565,7 +566,7 @@ def _train(model, data):
data = data.batch(12)
data = data.repeat(num_epochs)

optimizer = tf.keras.optimizers.Adam(lr=0.005, amsgrad=True)
optimizer = tf.optimizers.Adam(learning_rate=0.005, amsgrad=True)

for x, y in data:
optimizer.minimize(
Expand Down
1 change: 0 additions & 1 deletion open_spiel/scripts/find_tensorflow.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

read -r -d '' TESTSCRIPT << EOT
import tensorflow as tf
import tensorflow_probability
print(tf.__version__)
EOT

Expand Down

0 comments on commit 62f0baa

Please sign in to comment.