From dcab336a6bd885d007b6ca65902d074b189b9a24 Mon Sep 17 00:00:00 2001 From: Philippe Hamel Date: Mon, 3 Mar 2025 07:26:48 -0800 Subject: [PATCH] Add install instruction for fusion_transport_surrogates (QLKNN_7_11). Fixing a few commands in the README that were not quite right. Also add tests for qlknn_model_wrapper when fusion_transport_surrogates is installed PiperOrigin-RevId: 732915526 --- README.md | 29 ++++++- docs/installation.rst | 35 ++++++++ .../tests/qlknn_model_wrapper.py | 81 +++++++++++++++++++ 3 files changed, 141 insertions(+), 4 deletions(-) create mode 100644 torax/transport_model/tests/qlknn_model_wrapper.py diff --git a/README.md b/README.md index 06b616a5..baf5fcef 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,8 @@ sudo apt-get install python3-tk ### How to install +#### Prepare a virtual environment + Install virtualenv (if not already installed): ```shell @@ -102,7 +104,9 @@ Activate the virtual env: source toraxvenv/bin/activate ``` -Download QLKNN dependencies: +#### Install QLKNN-hyper + +Download QLKNN-hyper dependencies: ```shell git clone https://gitlab.com/qualikiz-group/qlknn-hyper.git @@ -120,6 +124,24 @@ echo export TORAX_QLKNN_MODEL_PATH="$PWD"/qlknn-hyper >> ~/.bashrc ``` The above command only needs to be run once on a given system. +#### (Optional) Install QLKNN_7_11 + +Optionally, you can instead use QLKNN_7_11, a more recent surrogate model: + +```shell +git clone https://github.com/google-deepmind/fusion_transport_surrogates.git +pip install -e ./fusion_transport_surrogates +export TORAX_QLKNN_MODEL_PATH="$PWD"/fusion_transport_surrogates/fusion_transport_surrogates/models/qlknn_7_11.qlknn +``` + +We recommend automating the variable export. If using bash, run: + +```shell +echo export TORAX_QLKNN_MODEL_PATH="$PWD"/fusion_transport_surrogates/fusion_transport_surrogates/models/qlknn_7_11.qlknn >> ~/.bashrc +``` + +#### Install TORAX + Download and install the TORAX codebase via http: ```shell @@ -177,8 +199,7 @@ completed. To run more involved, ITER-inspired simulations, run: ```shell -python3 run_simulation_main.py - --config='torax.examples.iterhybrid_rampup' +python3 run_simulation_main.py --config='torax.examples.iterhybrid_rampup' ``` and @@ -192,7 +213,7 @@ run command, and environment variables. For example, for increased output verbosity, can run with the `--log_progress` flag. ```shell -python3 run_simulation_main.py +python3 run_simulation_main.py \ --config='torax.examples.iterhybrid_rampup' --log_progress ``` diff --git a/docs/installation.rst b/docs/installation.rst index 11582834..612235ed 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -15,9 +15,15 @@ Make sure that tkinter is installed: sudo apt-get install python3-tk .. _how_to_install: + How to install ============== +.. _prepare_virtualenv: + +Prepare a virtual environment +----------------------------- + Install virtualenv (if not already installed): .. code-block:: console @@ -48,6 +54,11 @@ Activate the virtual env: It is convenient to set up an alias for the above command. +.. _install_qlknn_hyper: + +Install QLKNN-hyper +------------------- + Download QLKNN dependencies: .. code-block:: console @@ -66,6 +77,30 @@ It is recommended to automate the environment variable export. For example, if u The above command only needs to be run once on a given system. +.. _install_qlknn_7_11: + +(Optional) Install QLKNN_7_11 +----------------------------- + +Optionally, you can instead use QLKNN_7_11, a more recent surrogate model: + +.. code-block:: console + + git clone https://github.com/google-deepmind/fusion_transport_surrogates.git + pip install -e ./fusion_transport_surrogates + export TORAX_QLKNN_MODEL_PATH="$PWD"/fusion_transport_surrogates/fusion_transport_surrogates/models/qlknn_7_11.qlknn + +We recommend automating the variable export. If using bash, run: + +.. code-block:: console + + echo export TORAX_QLKNN_MODEL_PATH="$PWD"/fusion_transport_surrogates/fusion_transport_surrogates/models/qlknn_7_11.qlknn >> ~/.bashrc + +.. install_torax: + +Install TORAX +------------- + The following may optionally be added to ~/.bashrc and will cause jax to store compiled programs to the filesystem, avoiding recompilation in some cases: diff --git a/torax/transport_model/tests/qlknn_model_wrapper.py b/torax/transport_model/tests/qlknn_model_wrapper.py new file mode 100644 index 00000000..9181ed45 --- /dev/null +++ b/torax/transport_model/tests/qlknn_model_wrapper.py @@ -0,0 +1,81 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for qlknn_model_wrapper.""" + +import tempfile + +from absl.testing import absltest +from absl.testing import parameterized +import jax.numpy as jnp +from torax.transport_model import qlknn_model_wrapper +# pylint: disable=g-import-not-at-top +try: + from fusion_transport_surrogates import qlknn_model_test_utils +except ImportError: + qlknn_model_test_utils = None +# pylint: enable=g-import-not-at-top + + +def _get_test_flux_name_map(): + if qlknn_model_test_utils is None: + return {} + return dict( + (flux_name, f'torax_{flux_name}') + for flux_name in qlknn_model_test_utils.get_test_flux_map().keys() + ) + + +class QlknnModelWrapperTest(parameterized.TestCase): + """Tests for qlknn_model_wrapper.""" + + def setUp(self): + super().setUp() + if qlknn_model_test_utils is None: + self.skipTest('fusion_transport_surrogates is not available.') + # Create a test model on disk to be loaded by the wrapper. + self._config = qlknn_model_test_utils.get_test_model_config() + self._batch_dim = 10 + batch_dims = (1, self._batch_dim) + model = qlknn_model_test_utils.init_model(self._config, batch_dims) + self._model_file = tempfile.NamedTemporaryFile( + 'wb', suffix='.pkl', delete=False + ) + self._flux_name_map = _get_test_flux_name_map() + model.export_model(self._model_file.name) + self._qlknn_model_wrapper = qlknn_model_wrapper.QLKNNModelWrapper( + path=self._model_file.name, + flux_name_map=self._flux_name_map + ) + + def test_predict_shape(self): + """Tests model output shape.""" + inputs = jnp.empty((self._batch_dim, len(self._config.input_names))) + outputs = self._qlknn_model_wrapper.predict(inputs) + self.assertLen(outputs, len(self._flux_name_map)) + for output in outputs.values(): + self.assertEqual(output.shape, (self._batch_dim, 1)) + + def test_predict_names(self): + """Tests model output names are the TORAX flux names.""" + inputs = jnp.empty((self._batch_dim, len(self._config.input_names))) + outputs = self._qlknn_model_wrapper.predict(inputs) + for flux_name in self._flux_name_map.values(): + self.assertIn(flux_name, outputs) + + # TODO(b/381134347): Add tests for get_model_inputs_from_qualikiz_inputs + # and inputs_and_ranges. + + +if __name__ == '__main__': + absltest.main()