From 26ba5a906f68e787d7391a7bbdc9151df37bbd6e Mon Sep 17 00:00:00 2001 From: Jonathan Citrin Date: Fri, 28 Feb 2025 15:05:50 -0800 Subject: [PATCH] Reorganize physics package. Previously, various physics calculations which were not directly associated with an existing component like sources, transport_models, were dispersed between physics.py and post_processing.py. This refactor introduces a physics package and clearer organization of physics functions into thematic modules: collisions, charge_states, psi_calculations, and scaling_laws. Drive-by: cleanup of unnecessary comments and docstring fixes Upcoming PR will move "physics" calculations from post_processing into the physics package. PiperOrigin-RevId: 732272746 --- torax/config/tests/plasma_composition.py | 2 +- torax/core_profiles/formulas.py | 13 +- torax/core_profiles/initialization.py | 12 +- torax/core_profiles/tests/formulas_test.py | 19 + .../tests/initialization_test.py | 58 +- torax/core_profiles/tests/updaters_test.py | 22 +- torax/core_profiles/updaters.py | 11 +- torax/fvm/calc_coeffs.py | 18 +- torax/orchestration/step_function.py | 9 +- .../pedestal_model/set_pped_tpedratio_nped.py | 4 +- torax/physics.py | 712 ------------------ torax/physics/__init__.py | 15 + torax/{ => physics}/charge_states.py | 0 torax/physics/collisions.py | 234 ++++++ torax/physics/psi_calculations.py | 285 +++++++ torax/physics/scaling_laws.py | 243 ++++++ .../tests/charge_states_tests.py} | 5 +- torax/physics/tests/collisions_tests.py | 89 +++ torax/physics/tests/psi_calculations_tests.py | 168 +++++ torax/physics/tests/scaling_laws_tests.py | 216 ++++++ torax/post_processing.py | 27 +- torax/sources/bootstrap_current_source.py | 13 +- torax/sources/bremsstrahlung_heat_sink.py | 2 - .../sources/cyclotron_radiation_heat_sink.py | 2 - torax/sources/formulas.py | 2 - torax/sources/fusion_heat_source.py | 6 +- torax/sources/generic_ion_el_heat_source.py | 2 - torax/sources/ion_cyclotron_source.py | 8 +- torax/sources/ohmic_heat_source.py | 4 +- torax/sources/qei_source.py | 4 +- torax/tests/physics.py | 556 -------------- torax/tests/post_processing.py | 2 +- torax/tests/test_lib/explicit_stepper.py | 8 +- torax/tests/test_lib/torax_refs.py | 88 +++ torax/transport_model/bohm_gyrobohm.py | 2 - torax/transport_model/critical_gradient.py | 15 +- .../qualikiz_based_transport_model.py | 9 +- 37 files changed, 1488 insertions(+), 1397 deletions(-) delete mode 100644 torax/physics.py create mode 100644 torax/physics/__init__.py rename torax/{ => physics}/charge_states.py (100%) create mode 100644 torax/physics/collisions.py create mode 100644 torax/physics/psi_calculations.py create mode 100644 torax/physics/scaling_laws.py rename torax/{tests/charge_states.py => physics/tests/charge_states_tests.py} (98%) create mode 100644 torax/physics/tests/collisions_tests.py create mode 100644 torax/physics/tests/psi_calculations_tests.py create mode 100644 torax/physics/tests/scaling_laws_tests.py delete mode 100644 torax/tests/physics.py diff --git a/torax/config/tests/plasma_composition.py b/torax/config/tests/plasma_composition.py index bdf3fbba..763ab15b 100644 --- a/torax/config/tests/plasma_composition.py +++ b/torax/config/tests/plasma_composition.py @@ -17,10 +17,10 @@ from absl.testing import absltest from absl.testing import parameterized import numpy as np -from torax import charge_states from torax import interpolated_param from torax.config import plasma_composition from torax.geometry import pydantic_model as geometry_pydantic_model +from torax.physics import charge_states class PlasmaCompositionTest(parameterized.TestCase): diff --git a/torax/core_profiles/formulas.py b/torax/core_profiles/formulas.py index f24e885e..8d739cca 100644 --- a/torax/core_profiles/formulas.py +++ b/torax/core_profiles/formulas.py @@ -25,8 +25,6 @@ _trapz = jax.scipy.integrate.trapezoid -# Using capitalized variables for physics notational conventions rather than -# Python style. # pylint: disable=invalid-name @@ -157,3 +155,14 @@ def calculate_psi_grad_constraint_from_Ip_tot( * (16 * jnp.pi**3 * constants.CONSTANTS.mu0 * geo.Phib) / (geo.g2g3_over_rhon_face[-1] * geo.F_face[-1]) ) + + +# TODO(b/377225415): generalize to arbitrary number of ions. +def get_main_ion_dilution_factor( + Zi: array_typing.ScalarFloat, + Zimp: array_typing.ArrayFloat, + Zeff: array_typing.ArrayFloat, +) -> jax.Array: + """Calculates the main ion dilution factor based on a single assumed impurity and general main ion charge.""" + return (Zimp - Zeff) / (Zi * (Zimp - Zi)) + diff --git a/torax/core_profiles/initialization.py b/torax/core_profiles/initialization.py index 91b3f967..43a33b3d 100644 --- a/torax/core_profiles/initialization.py +++ b/torax/core_profiles/initialization.py @@ -20,7 +20,6 @@ from torax import array_typing from torax import constants from torax import math_utils -from torax import physics from torax import state from torax.config import runtime_params_slice from torax.core_profiles import formulas @@ -28,6 +27,7 @@ from torax.fvm import cell_variable from torax.geometry import geometry from torax.geometry import standard_geometry +from torax.physics import psi_calculations from torax.sources import ohmic_heat_source from torax.sources import source_models as source_models_lib from torax.sources import source_profile_builders @@ -35,8 +35,6 @@ _trapz = jax.scipy.integrate.trapezoid -# Using capitalized variables for physics notational conventions rather than -# Python style. # pylint: disable=invalid-name @@ -133,7 +131,7 @@ def initial_core_profiles( ) # Set psi as source of truth and recalculate jtot, q, s - return physics.update_jtot_q_face_s_face( + return psi_calculations.update_jtot_q_face_s_face( geo=geo, core_profiles=core_profiles, ) @@ -147,8 +145,6 @@ def _prescribe_currents( ) -> state.Currents: """Creates the initial Currents from a given bootstrap profile.""" - # Many variables throughout this function are capitalized based on physics - # notational conventions rather than on Google Python style Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot f_bootstrap = bootstrap_profile.I_bootstrap / (Ip * 1e6) @@ -201,7 +197,7 @@ def _calculate_currents_from_psi( source_profiles: source_profiles_lib.SourceProfiles, ) -> state.Currents: """Creates the initial Currents using psi to calculate jtot.""" - jtot, jtot_face, Ip_profile_face = physics.calc_jtot_from_psi( + jtot, jtot_face, Ip_profile_face = psi_calculations.calc_jtot( geo, core_profiles.psi, ) @@ -474,7 +470,7 @@ def _init_psi_psidot_vloop_and_current( currents.jtot_hires, use_vloop_lcfs_boundary_condition=use_vloop_bc, ) - _, _, Ip_profile_face = physics.calc_jtot_from_psi( + _, _, Ip_profile_face = psi_calculations.calc_jtot( geo, psi, ) diff --git a/torax/core_profiles/tests/formulas_test.py b/torax/core_profiles/tests/formulas_test.py index 52a2ebe2..98350294 100644 --- a/torax/core_profiles/tests/formulas_test.py +++ b/torax/core_profiles/tests/formulas_test.py @@ -279,6 +279,25 @@ def test_ne_core_profile_setter_with_fGW( np.all(np.isclose(ratio, ratio[0])) self.assertNotEqual(ratio[0], 1.0) + # TODO(b/377225415): generalize to arbitrary number of ions. + # pylint: disable=invalid-name + @parameterized.parameters([ + dict(Zi=1.0, Zimp=10.0, Zeff=1.0, expected=1.0), + dict(Zi=1.0, Zimp=5.0, Zeff=1.0, expected=1.0), + dict(Zi=2.0, Zimp=10.0, Zeff=2.0, expected=0.5), + dict(Zi=2.0, Zimp=5.0, Zeff=2.0, expected=0.5), + dict(Zi=1.0, Zimp=10.0, Zeff=1.9, expected=0.9), + dict(Zi=2.0, Zimp=10.0, Zeff=3.6, expected=0.4), + ]) + def test_get_main_ion_dilution_factor(self, Zi, Zimp, Zeff, expected): + """Unit test of `get_main_ion_dilution_factor`.""" + np.testing.assert_allclose( + formulas.get_main_ion_dilution_factor(Zi, Zimp, Zeff), + expected, + ) + + # pylint: enable=invalid-name + if __name__ == '__main__': absltest.main() diff --git a/torax/core_profiles/tests/initialization_test.py b/torax/core_profiles/tests/initialization_test.py index 478176db..a47b1786 100644 --- a/torax/core_profiles/tests/initialization_test.py +++ b/torax/core_profiles/tests/initialization_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable from unittest import mock from absl.testing import absltest @@ -24,7 +23,6 @@ from torax.config import runtime_params_slice as runtime_params_slice_lib from torax.core_profiles import initialization from torax.geometry import pydantic_model as geometry_pydantic_model -from torax.geometry import standard_geometry from torax.sources import generic_current_source from torax.sources import source_models as source_models_lib from torax.sources import source_profiles @@ -40,18 +38,9 @@ def setUp(self): jax_utils.enable_errors(True) self.geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry() - @parameterized.parameters([ - dict(references_getter=torax_refs.circular_references), - dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict( - references_getter=torax_refs.chease_references_Ip_from_runtime_params - ), - ]) - def test_update_psi_from_j( - self, references_getter: Callable[[], torax_refs.References] - ): + def test_update_psi_from_j(self): """Compare `update_psi_from_j` function to a reference implementation.""" - references = references_getter() + references = torax_refs.circular_references() runtime_params = references.runtime_params source_runtime_params = generic_current_source.RuntimeParams() @@ -67,29 +56,26 @@ def test_update_psi_from_j( }, ) ) - if isinstance(geo, standard_geometry.StandardGeometry): - psi = geo.psi_from_Ip - else: - bootstrap = source_profiles.BootstrapCurrentProfile.zero_profile(geo) - external_current = generic_current_source.calculate_generic_current( - mock.ANY, - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - geo=geo, - source_name=generic_current_source.GenericCurrentSource.SOURCE_NAME, - unused_state=mock.ANY, - unused_calculated_source_profiles=mock.ANY, - )[0] - currents = initialization._prescribe_currents( - bootstrap_profile=bootstrap, - external_current=external_current, - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - geo=geo, - ) - psi = initialization._update_psi_from_j( - dynamic_runtime_params_slice.profile_conditions.Ip_tot, - geo, - currents.jtot_hires, - ).value + bootstrap = source_profiles.BootstrapCurrentProfile.zero_profile(geo) + external_current = generic_current_source.calculate_generic_current( + mock.ANY, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + geo=geo, + source_name=generic_current_source.GenericCurrentSource.SOURCE_NAME, + unused_state=mock.ANY, + unused_calculated_source_profiles=mock.ANY, + )[0] + currents = initialization._prescribe_currents( + bootstrap_profile=bootstrap, + external_current=external_current, + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + geo=geo, + ) + psi = initialization._update_psi_from_j( + dynamic_runtime_params_slice.profile_conditions.Ip_tot, + geo, + currents.jtot_hires, + ).value np.testing.assert_allclose(psi, references.psi.value) @parameterized.parameters( diff --git a/torax/core_profiles/tests/updaters_test.py b/torax/core_profiles/tests/updaters_test.py index f0102f90..29457e92 100644 --- a/torax/core_profiles/tests/updaters_test.py +++ b/torax/core_profiles/tests/updaters_test.py @@ -18,7 +18,6 @@ from jax import numpy as jnp import numpy as np from torax import jax_utils -from torax import physics from torax import state from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params @@ -91,7 +90,7 @@ def test_get_ion_density_and_charge_states(self): Zeff = dynamic_runtime_params_slice.plasma_composition.Zeff - dilution_factor = physics.get_main_ion_dilution_factor(Zi, Zimp, Zeff) + dilution_factor = formulas.get_main_ion_dilution_factor(Zi, Zimp, Zeff) np.testing.assert_allclose( ni.value, expected_value * dilution_factor, @@ -355,6 +354,25 @@ def test_compute_boundary_conditions_Ti( expected_Ti_bound_right, ) + # TODO(b/377225415): generalize to arbitrary number of ions. + # pylint: disable=invalid-name + @parameterized.parameters([ + dict(Zi=1.0, Zimp=10.0, Zeff=1.0, expected=1.0), + dict(Zi=1.0, Zimp=5.0, Zeff=1.0, expected=1.0), + dict(Zi=2.0, Zimp=10.0, Zeff=2.0, expected=0.5), + dict(Zi=2.0, Zimp=5.0, Zeff=2.0, expected=0.5), + dict(Zi=1.0, Zimp=10.0, Zeff=1.9, expected=0.9), + dict(Zi=2.0, Zimp=10.0, Zeff=3.6, expected=0.4), + ]) + def test_get_main_ion_dilution_factor(self, Zi, Zimp, Zeff, expected): + """Unit test of `get_main_ion_dilution_factor`.""" + np.testing.assert_allclose( + formulas.get_main_ion_dilution_factor(Zi, Zimp, Zeff), + expected, + ) + + # pylint: enable=invalid-name + if __name__ == '__main__': absltest.main() diff --git a/torax/core_profiles/updaters.py b/torax/core_profiles/updaters.py index d22e2688..f5850d63 100644 --- a/torax/core_profiles/updaters.py +++ b/torax/core_profiles/updaters.py @@ -22,20 +22,17 @@ import jax from jax import numpy as jnp from torax import array_typing -from torax import charge_states from torax import jax_utils -from torax import physics from torax import state from torax.config import runtime_params_slice from torax.core_profiles import formulas from torax.fvm import cell_variable from torax.geometry import geometry +from torax.physics import charge_states _trapz = jax.scipy.integrate.trapezoid -# Using capitalized variables for physics notational conventions rather than -# Python style. # pylint: disable=invalid-name def _get_charge_states( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, @@ -128,8 +125,8 @@ def get_ion_density_and_charge_states( Zeff = dynamic_runtime_params_slice.plasma_composition.Zeff Zeff_face = dynamic_runtime_params_slice.plasma_composition.Zeff_face - dilution_factor = physics.get_main_ion_dilution_factor(Zi, Zimp, Zeff) - dilution_factor_edge = physics.get_main_ion_dilution_factor( + dilution_factor = formulas.get_main_ion_dilution_factor(Zi, Zimp, Zeff) + dilution_factor_edge = formulas.get_main_ion_dilution_factor( Zi_face[-1], Zimp_face[-1], Zeff_face[-1] ) @@ -347,7 +344,7 @@ def compute_boundary_conditions_for_t_plus_dt( Te=Te_bound_right, ) - dilution_factor_edge = physics.get_main_ion_dilution_factor( + dilution_factor_edge = formulas.get_main_ion_dilution_factor( Zi_edge, Zimp_edge, dynamic_runtime_params_slice_t_plus_dt.plasma_composition.Zeff_face[-1], diff --git a/torax/fvm/calc_coeffs.py b/torax/fvm/calc_coeffs.py index 14a36f81..07bf9b9f 100644 --- a/torax/fvm/calc_coeffs.py +++ b/torax/fvm/calc_coeffs.py @@ -23,7 +23,6 @@ import jax.numpy as jnp from torax import constants from torax import jax_utils -from torax import physics from torax import state from torax.config import config_args from torax.config import runtime_params_slice @@ -400,7 +399,7 @@ def _calc_coeffs_full( # Boolean mask for enforcing internal temperature boundary conditions to # model the pedestal. - mask = physics.internal_boundary( + mask = _internal_boundary( geo, pedestal_model_output.rho_norm_ped_top, dynamic_runtime_params_slice.profile_conditions.set_pedestal, @@ -784,3 +783,18 @@ def _calc_coeffs_reduced( transient_in_cell=transient_in_cell, ) return coeffs + + +# pylint: disable=invalid-name +def _internal_boundary( + geo: geometry.Geometry, + Ped_top: jax.Array, + set_pedestal: jax.Array, +) -> jax.Array: + # Create Boolean mask FiPy CellVariable with True where the internal boundary + # condition is + # find index closest to pedestal top. + idx = jnp.abs(geo.rho_norm - Ped_top).argmin() + mask_np = jnp.zeros(len(geo.rho), dtype=bool) + mask_np = jnp.where(set_pedestal, mask_np.at[idx].set(True), mask_np) + return mask_np diff --git a/torax/orchestration/step_function.py b/torax/orchestration/step_function.py index 7f7f0eac..98c850dd 100644 --- a/torax/orchestration/step_function.py +++ b/torax/orchestration/step_function.py @@ -22,7 +22,6 @@ import jax import jax.numpy as jnp from torax import jax_utils -from torax import physics from torax import post_processing from torax import state from torax.config import runtime_params_slice @@ -30,6 +29,7 @@ from torax.geometry import geometry from torax.geometry import geometry_provider as geometry_provider_lib from torax.pedestal_model import pedestal_model as pedestal_model_lib +from torax.physics import psi_calculations from torax.sources import ohmic_heat_source from torax.sources import source_profile_builders from torax.sources import source_profiles as source_profiles_lib @@ -525,7 +525,7 @@ def finalize_output( """ # Update total current, q, and s profiles based on new psi - output_state.core_profiles = physics.update_jtot_q_face_s_face( + output_state.core_profiles = psi_calculations.update_jtot_q_face_s_face( geo=geo_t_plus_dt, core_profiles=output_state.core_profiles, ) @@ -660,7 +660,8 @@ def _update_current_distribution( bootstrap_profile = core_sources.j_bootstrap # Needed for the case where no psi sources are present. external_current = jnp.zeros_like( - core_profiles.currents.external_current_source) + core_profiles.currents.external_current_source + ) external_current += sum(core_sources.psi.values()) johm = ( @@ -699,7 +700,7 @@ def _update_psidot( resistivity_multiplier=dynamic_runtime_params_slice.numerics.resistivity_mult, psi=core_profiles.psi, geo=geo, - ) + ), ) new_core_profiles = dataclasses.replace( diff --git a/torax/pedestal_model/set_pped_tpedratio_nped.py b/torax/pedestal_model/set_pped_tpedratio_nped.py index 6f2dbf1c..6d1ddce2 100644 --- a/torax/pedestal_model/set_pped_tpedratio_nped.py +++ b/torax/pedestal_model/set_pped_tpedratio_nped.py @@ -23,9 +23,9 @@ from torax import array_typing from torax import constants from torax import interpolated_param -from torax import physics from torax import state from torax.config import runtime_params_slice +from torax.core_profiles import formulas from torax.geometry import geometry from torax.pedestal_model import pedestal_model from torax.pedestal_model import runtime_params as runtime_params_lib @@ -130,7 +130,7 @@ def _call_implementation( Zeff_ped = Zeff[ped_idx] Zi_ped = Zi[ped_idx] Zimp_ped = Zimp[ped_idx] - dilution_factor_ped = physics.get_main_ion_dilution_factor( + dilution_factor_ped = formulas.get_main_ion_dilution_factor( Zi_ped, Zimp_ped, Zeff_ped ) # Calculate ni and nimp. diff --git a/torax/physics.py b/torax/physics.py deleted file mode 100644 index 55951538..00000000 --- a/torax/physics.py +++ /dev/null @@ -1,712 +0,0 @@ -# Copyright 2024 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Physics calculations. - -This module contains problem-specific calculations that set up e.g. -coefficients on terms in a differential equation, as opposed to more -general differential equation solver functionality. -""" -import dataclasses - -import chex -import jax -from jax import numpy as jnp -from torax import array_typing -from torax import constants -from torax import jax_utils -from torax import state -from torax.fvm import cell_variable -from torax.geometry import geometry - -_trapz = jax.scipy.integrate.trapezoid - -# Many variable names in this file use scientific or mathematical notation, so -# disable pylint complaints. -# pylint: disable=invalid-name - - -# TODO(b/377225415): generalize to arbitrary number of ions. -def get_main_ion_dilution_factor( - Zi: array_typing.ScalarFloat, - Zimp: array_typing.ArrayFloat, - Zeff: array_typing.ArrayFloat, -) -> jax.Array: - """Calculates the main ion dilution factor based on a single assumed impurity and general main ion charge.""" - return (Zimp - Zeff) / (Zi * (Zimp - Zi)) - - -@jax_utils.jit -def update_jtot_q_face_s_face( - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, -) -> state.CoreProfiles: - """Updates jtot, jtot_face, q_face, and s_face.""" - - jtot, jtot_face, Ip_profile_face = calc_jtot_from_psi( - geo, - core_profiles.psi, - ) - q_face, _ = calc_q_from_psi( - geo=geo, - psi=core_profiles.psi, - ) - s_face = calc_s_from_psi( - geo, - core_profiles.psi, - ) - currents = dataclasses.replace( - core_profiles.currents, - jtot=jtot, - jtot_face=jtot_face, - Ip_profile_face=Ip_profile_face, - ) - new_core_profiles = dataclasses.replace( - core_profiles, - currents=currents, - q_face=q_face, - s_face=s_face, - ) - return new_core_profiles - - -def coll_exchange( - core_profiles: state.CoreProfiles, - nref: float, - Qei_mult: float, -) -> jax.Array: - """Computes collisional ion-electron heat exchange coefficient. - - Args: - core_profiles: Core plasma profiles. - nref: Reference value for normalization - Qei_mult: multiplier for ion-electron heat exchange term - - Returns: - Qei_coeff: ion-electron collisional heat exchange coefficient. - """ - # Calculate Coulomb logarithm - lambda_ei = _calculate_lambda_ei( - core_profiles.temp_el.value, core_profiles.ne.value * nref - ) - # ion-electron collisionality for Zeff=1. Ion charge and multiple ion effects - # are included in the Qei_coef calculation below. - log_tau_e_Z1 = _calculate_log_tau_e_Z1( - core_profiles.temp_el.value, - core_profiles.ne.value * nref, - lambda_ei, - ) - # pylint: disable=invalid-name - - weighted_Zeff = _calculate_weighted_Zeff(core_profiles) - - log_Qei_coef = ( - jnp.log(Qei_mult * 1.5 * core_profiles.ne.value * nref) - + jnp.log(constants.CONSTANTS.keV2J / constants.CONSTANTS.mp) - + jnp.log(2 * constants.CONSTANTS.me) - + jnp.log(weighted_Zeff) - - log_tau_e_Z1 - ) - Qei_coef = jnp.exp(log_Qei_coef) - return Qei_coef - - -# TODO(b/377225415): generalize to arbitrary number of ions. -def _calculate_weighted_Zeff( - core_profiles: state.CoreProfiles, -) -> jax.Array: - """Calculates ion mass weighted Zeff. Used for collisional heat exchange.""" - return ( - core_profiles.ni.value * core_profiles.Zi**2 / core_profiles.Ai - + core_profiles.nimp.value * core_profiles.Zimp**2 / core_profiles.Aimp - ) / core_profiles.ne.value - - -def _calculate_log_tau_e_Z1( - temp_el: jax.Array, - ne: jax.Array, - lambda_ei: jax.Array, -) -> jax.Array: - """Calculates log of electron-ion collision time for Z=1 plasma. - - See Wesson 3rd edition p729. Extension to multiple ions is context dependent - and implemented in calling functions. - - Args: - temp_el: Electron temperature in keV. - ne: Electron density in m^-3. - lambda_ei: Coulomb logarithm. - - Returns: - Log of electron-ion collision time. - """ - return ( - jnp.log(12 * jnp.pi**1.5 / (ne * lambda_ei)) - - 4 * jnp.log(constants.CONSTANTS.qe) - + 0.5 * jnp.log(constants.CONSTANTS.me / 2.0) - + 2 * jnp.log(constants.CONSTANTS.epsilon0) - + 1.5 * jnp.log(temp_el * constants.CONSTANTS.keV2J) - ) - - -def internal_boundary( - geo: geometry.Geometry, - Ped_top: jax.Array, - set_pedestal: jax.Array, -) -> jax.Array: - # Create Boolean mask FiPy CellVariable with True where the internal boundary - # condition is - # find index closest to pedestal top. - idx = jnp.abs(geo.rho_norm - Ped_top).argmin() - mask_np = jnp.zeros(len(geo.rho), dtype=bool) - mask_np = jnp.where(set_pedestal, mask_np.at[idx].set(True), mask_np) - return mask_np - - -def calc_q_from_psi( - geo: geometry.Geometry, - psi: cell_variable.CellVariable, -) -> tuple[chex.Array, chex.Array]: - """Calculates the q-profile (q) given current (jtot) and poloidal flux (psi). - - We don't simply pass a `CoreProfiles` instance because this needs to be called - before the first `CoreProfiles` is constructed; the output of this function is - used to populate the `q_face` field of the first `CoreProfiles`. - - Args: - geo: Magnetic geometry. - psi: Poloidal flux. - - Returns: - q_face: q at faces. - q: q at cell centers. - """ - # We calculate iota on the face grid but omit face 0, so inv_iota[0] - # corresponds to face 1. - # iota on face 0 is unused in this function, and would need to be implemented - # as a special case. - inv_iota = jnp.abs( - (2 * geo.Phib * geo.rho_face_norm[1:]) / psi.face_grad()[1:] - ) - - # Use L'Hôpital's rule to calculate iota on-axis, with psi_face_grad()[0]=0. - inv_iota0 = jnp.expand_dims( - jnp.abs((2 * geo.Phib * geo.drho_norm) / psi.face_grad()[1]), 0 - ) - - q_face = jnp.concatenate([inv_iota0, inv_iota]) - q_face *= geo.q_correction_factor - q = geometry.face_to_cell(q_face) - - return q_face, q - - -def calc_jtot_from_psi( - geo: geometry.Geometry, - psi: cell_variable.CellVariable, -) -> tuple[chex.Array, chex.Array, chex.Array]: - """Calculates FSA toroidal current density (jtot) from poloidal flux (psi). - - Calculation based on jtot = dI/dS - - Args: - geo: Torus geometry. - psi: Poloidal flux. - - Returns: - jtot: total current density [A/m2] on cell grid - jtot_face: total current density [A/m2] on face grid - Ip_profile_face: cumulative total plasma current profile [A] on face grid - """ - - # inside flux surface on face grid - # pylint: disable=invalid-name - Ip_profile_face = ( - psi.face_grad() - * geo.g2g3_over_rhon_face - * geo.F_face - / geo.Phib - / (16 * jnp.pi**3 * constants.CONSTANTS.mu0) - ) - - dI_tot_drhon = jnp.gradient(Ip_profile_face, geo.rho_face_norm) - - jtot_face_bulk = dI_tot_drhon[1:] / geo.spr_face[1:] - - # Set on-axis jtot according to L'Hôpital's rule, noting that I[0]=S[0]=0. - jtot_face_axis = Ip_profile_face[1] / geo.area_face[1] - - jtot_face = jnp.concatenate([jnp.array([jtot_face_axis]), jtot_face_bulk]) - jtot = geometry.face_to_cell(jtot_face) - - return jtot, jtot_face, Ip_profile_face - - -def calc_s_from_psi( - geo: geometry.Geometry, psi: cell_variable.CellVariable -) -> jax.Array: - """Calculates magnetic shear (s) from poloidal flux (psi). - - Args: - geo: Torus geometry. - psi: Poloidal flux. - - Returns: - s_face: Magnetic shear, on the face grid. - """ - - # iota (1/q) should have a /2*Phib but we drop it since will cancel out in - # the s calculation. - iota_scaled = jnp.abs((psi.face_grad()[1:] / geo.rho_face_norm[1:])) - - # on-axis iota_scaled from L'Hôpital's rule = dpsi_face_grad / drho_norm - # Using expand_dims to make it compatible with jnp.concatenate - iota_scaled0 = jnp.expand_dims( - jnp.abs(psi.face_grad()[1] / geo.drho_norm), axis=0 - ) - - iota_scaled = jnp.concatenate([iota_scaled0, iota_scaled]) - - s_face = ( - -geo.rho_face_norm - * jnp.gradient(iota_scaled, geo.rho_face_norm) - / iota_scaled - ) - - return s_face - - -def calc_s_from_psi_rmid( - geo: geometry.Geometry, psi: cell_variable.CellVariable -) -> jax.Array: - """Calculates magnetic shear (s) from poloidal flux (psi). - - Version taking the derivative of iota with respect to the midplane r, - in line with expectations from circular-derived models like QuaLiKiz. - - Args: - geo: Torus geometry. - psi: Poloidal flux. - - Returns: - s_face: Magnetic shear, on the face grid. - """ - - # iota (1/q) should have a /2*Phib but we drop it since will cancel out in - # the s calculation. - iota_scaled = jnp.abs((psi.face_grad()[1:] / geo.rho_face_norm[1:])) - - # on-axis iota_scaled from L'Hôpital's rule = dpsi_face_grad / drho_norm - # Using expand_dims to make it compatible with jnp.concatenate - iota_scaled0 = jnp.expand_dims( - jnp.abs(psi.face_grad()[1] / geo.drho_norm), axis=0 - ) - - iota_scaled = jnp.concatenate([iota_scaled0, iota_scaled]) - - rmid_face = (geo.Rout_face - geo.Rin_face) * 0.5 - - s_face = -rmid_face * jnp.gradient(iota_scaled, rmid_face) / iota_scaled - - return s_face - - -def _calc_bpol2( - geo: geometry.Geometry, psi: cell_variable.CellVariable -) -> jax.Array: - r"""Calculates square of poloidal field (Bp) from poloidal flux (psi). - - An identity for the poloidal magnetic field is: - B_p = 1/R \partial \psi / \partial \rho (\nabla \rho \times e_phi) - - Where e_phi is the unit vector pointing in the toroidal direction. - - Args: - geo: Torus geometry. - psi: Poloidal flux. - - Returns: - bpol2_face: Square of poloidal magnetic field, on the face grid. - """ - bpol2_bulk = ( - (psi.face_grad()[1:] / (2 * jnp.pi)) ** 2 - * geo.g2_face[1:] - / geo.vpr_face[1:] ** 2 - ) - bpol2_axis = jnp.array([0.0]) - bpol2_face = jnp.concatenate([bpol2_axis, bpol2_bulk]) - return bpol2_face - - -def calc_Wpol( - geo: geometry.Geometry, psi: cell_variable.CellVariable -) -> jax.Array: - """Calculates total magnetic energy (Wpol) from poloidal flux (psi).""" - bpol2 = _calc_bpol2(geo, psi) - Wpol = _trapz(bpol2 * geo.vpr_face, geo.rho_face_norm) / ( - 2 * constants.CONSTANTS.mu0 - ) - return Wpol - - -def calc_li3( - Rmaj: jax.Array, - Wpol: jax.Array, - Ip_total: jax.Array, -) -> jax.Array: - """Calculates li3 based on a formulation using Wpol. - - Normalized internal inductance is defined as: - li = _V / _LCFS where <>_V is a volume average and <>_LCFS is - the average at the last closed flux surface. - - We use the ITER convention for normalized internal inductance defined as: - li3 = 2*V*_V / (mu0^2 Ip^2*Rmaj) = 4 * Wpol / (mu0 Ip^2*Rmaj) - - Ip (total plasma current) enters through the integral form of Ampere's law. - Since Wpol also corresponds to a volume integral of the poloidal field, we - can define li3 with respect to Wpol. - - Args: - Rmaj: Major radius. - Wpol: Total magnetic energy. - Ip_total: Total plasma current. - - Returns: - li3: Normalized internal inductance, ITER convention. - """ - return 4 * Wpol / (constants.CONSTANTS.mu0 * Ip_total**2 * Rmaj) - - -def calc_nu_star( - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - nref: float, - Zeff_face: jax.Array, - coll_mult: float, -) -> jax.Array: - """Calculates nu star. - - Args: - geo: Torus geometry. - core_profiles: Core plasma profiles. - nref: Reference value for normalization - Zeff_face: Effective ion charge on face grid. - coll_mult: Collisionality multiplier in QLKNN for sensitivity testing. - - Returns: - nu_star: on face grid. - """ - - # Calculate Coulomb logarithm - lambda_ei_face = _calculate_lambda_ei( - core_profiles.temp_el.face_value(), core_profiles.ne.face_value() * nref - ) - - # ion_electron collisionality - log_tau_e_Z1 = _calculate_log_tau_e_Z1( - core_profiles.temp_el.face_value(), - core_profiles.ne.face_value() * nref, - lambda_ei_face, - ) - - nu_e = 1 / jnp.exp(log_tau_e_Z1) * Zeff_face * coll_mult - - # calculate bounce time - epsilon = geo.rho_face / geo.Rmaj - # to avoid divisions by zero - epsilon = jnp.clip(epsilon, constants.CONSTANTS.eps) - tau_bounce = ( - core_profiles.q_face - * geo.Rmaj - / ( - epsilon**1.5 - * jnp.sqrt( - core_profiles.temp_el.face_value() - * constants.CONSTANTS.keV2J - / constants.CONSTANTS.me - ) - ) - ) - # due to pathological on-axis epsilon=0 term - tau_bounce = tau_bounce.at[0].set(tau_bounce[1]) - - # calculate normalized collisionality - nustar = nu_e * tau_bounce - - return nustar - - -def _calculate_lambda_ei( - temp_el: jax.Array, - ne: jax.Array, -) -> jax.Array: - """Calculates Coulomb logarithm for electron-ion collisions. - - See Wesson 3rd edition p727. - - Args: - temp_el: Electron temperature in keV. - ne: Electron density in m^-3. - - Returns: - Coulomb logarithm. - """ - return 15.2 - 0.5 * jnp.log(ne / 1e20) + jnp.log(temp_el) - - -def fast_ion_fractional_heating_formula( - birth_energy: float | array_typing.ArrayFloat, - temp_el: array_typing.ArrayFloat, - fast_ion_mass: float, -) -> array_typing.ArrayFloat: - """Returns the fraction of heating that goes to the ions. - - From eq. 5 and eq. 26 in Mikkelsen Nucl. Tech. Fusion 237 4 1983. - Note there is a typo in eq. 26 where a `2x` term is missing in the numerator - of the log. - - Args: - birth_energy: Birth energy of the fast ions in keV. - temp_el: Electron temperature. - fast_ion_mass: Mass of the fast ions in amu. - - Returns: - The fraction of heating that goes to the ions. - """ - critical_energy = 10 * fast_ion_mass * temp_el # Eq. 5. - energy_ratio = birth_energy / critical_energy - - # Eq. 26. - x_squared = energy_ratio - x = jnp.sqrt(x_squared) - frac_i = ( - 2 - * ( - (1 / 6) * jnp.log((1.0 - x + x_squared) / (1.0 + 2.0 * x + x_squared)) - + (jnp.arctan((2.0 * x - 1.0) / jnp.sqrt(3)) + jnp.pi / 6) - / jnp.sqrt(3) - ) - / x_squared - ) - return frac_i - - -def calculate_plh_scaling_factor( - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, -) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: - """Calculates the H-mode transition power scalings. - - See Y.R. Martin and Tomonori Takizuka. - "Power requirement for accessing the H-mode in ITER." - Journal of Physics: Conference Series. Vol. 123. No. 1. IOP Publishing, 2008. - - Only valid for hydrogenic isotopes and mixtures (H, D, T). - Includes a simple inverse scaling of the factor to average isotope mass. - - For an overview see U Plank, U., et al. "Overview of L-to H-mode transition - experiments at ASDEX Upgrade." - Plasma Physics and Controlled Fusion 65.1 (2022): 014001. - - Args: - geo: Torus geometry. - core_profiles: Core plasma profiles. - - Returns: - Tuple of: P_LH scaling factor for high density branch, minimum P_LH, - P_LH = max(P_LH_min, P_LH_hi_dens) for practical use, and the density - corresponding to the P_LH_min. - """ - - line_avg_ne = _calculate_line_avg_density(geo, core_profiles) - # LH transition power for deuterium, in W. Eq 3 from Martin 2008. - P_LH_hi_dens_D = ( - 2.15 - * (line_avg_ne / 1e20) ** 0.782 - * geo.B0**0.772 - * geo.Rmin**0.975 - * geo.Rmaj**0.999 - * 1e6 - ) - - # Scale to average isotope mass. - A_deuterium = constants.ION_PROPERTIES_DICT['D']['A'] - P_LH_hi_dens = P_LH_hi_dens_D * A_deuterium / core_profiles.Ai - - # Calculate density (in nref) corresponding to P_LH_min from Eq 3 Ryter 2014 - ne_min_P_LH = ( - 0.7 - * (core_profiles.currents.Ip_profile_face[-1] / 1e6) ** 0.34 - * geo.Rmin**-0.95 - * geo.B0**0.62 - * (geo.Rmaj / geo.Rmin) ** 0.4 - * 1e19 - / core_profiles.nref - ) - # Calculate P_LH_min at ne_min from Eq 4 Ryter 2014 - P_LH_min_D = ( - 0.36 - * (core_profiles.currents.Ip_profile_face[-1] / 1e6) ** 0.27 - * geo.B0**1.25 - * geo.Rmaj**1.23 - * (geo.Rmaj / geo.Rmin) ** 0.08 - * 1e6 - ) - P_LH_min = P_LH_min_D * A_deuterium / core_profiles.Ai - P_LH = jnp.maximum(P_LH_min, P_LH_hi_dens) - return P_LH_hi_dens, P_LH_min, P_LH, ne_min_P_LH - - -def calculate_scaling_law_confinement_time( - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - Ploss: jax.Array, - scaling_law: str, -) -> jax.Array: - """Calculates the thermal energy confinement time for a given empirical scaling law. - - Args: - geo: Torus geometry. - core_profiles: Core plasma profiles. - Ploss: Plasma power loss in W. - scaling_law: Scaling law to use. - - Returns: - Thermal energy confinement time in s. - """ - scaling_params = { - 'H89P': { - # From Yushmanov et al, Nuclear Fusion, vol. 30, no. 10, pp. 4-6, 1990 - 'prefactor': 0.038128, - 'Ip_exponent': 0.85, - 'B_exponent': 0.2, - 'line_avg_ne_exponent': 0.1, - 'Ploss_exponent': -0.5, - 'R_exponent': 1.5, - 'inverse_aspect_ratio_exponent': 0.3, - 'elongation_exponent': 0.5, - 'effective_mass_exponent': 0.50, - 'triangularity_exponent': 0.0, - }, - 'H98': { - # H98 empirical confinement scaling law: - # ITER Physics Expert Groups on Confinement and Transport and - # Confinement Modelling and Database, Nucl. Fusion 39 2175, 1999 - # Doyle et al, Nucl. Fusion 47 (2007) S18–S127, Eq 30 - 'prefactor': 0.0562, - 'Ip_exponent': 0.93, - 'B_exponent': 0.15, - 'line_avg_ne_exponent': 0.41, - 'Ploss_exponent': -0.69, - 'R_exponent': 1.97, - 'inverse_aspect_ratio_exponent': 0.58, - 'elongation_exponent': 0.78, - 'effective_mass_exponent': 0.19, - 'triangularity_exponent': 0.0, - }, - 'H97L': { - # From the ITER L-mode confinement database. - # S.M. Kaye et al 1997 Nucl. Fusion 37 1303, Eq 7 - 'prefactor': 0.023, - 'Ip_exponent': 0.96, - 'B_exponent': 0.03, - 'line_avg_ne_exponent': 0.4, - 'Ploss_exponent': -0.73, - 'R_exponent': 1.83, - 'inverse_aspect_ratio_exponent': -0.06, - 'elongation_exponent': 0.64, - 'effective_mass_exponent': 0.20, - 'triangularity_exponent': 0.0, - }, - 'H20': { - # Updated ITER H-mode confinement database, using full dataset. - # G. Verdoolaege et al 2021 Nucl. Fusion 61 076006, Eq 7 - 'prefactor': 0.053, - 'Ip_exponent': 0.98, - 'B_exponent': 0.22, - 'line_avg_ne_exponent': 0.24, - 'Ploss_exponent': -0.669, - 'R_exponent': 1.71, - 'inverse_aspect_ratio_exponent': 0.35, - 'elongation_exponent': 0.80, - 'effective_mass_exponent': 0.20, - 'triangularity_exponent': 0.36, # (1+delta)^exponent - }, - } - - if scaling_law not in scaling_params: - raise ValueError(f'Unknown scaling law: {scaling_law}') - - params = scaling_params[scaling_law] - - scaled_Ip = core_profiles.currents.Ip_profile_face[-1] / 1e6 # convert to MA - scaled_Ploss = Ploss / 1e6 # convert to MW - B = geo.B0 - line_avg_ne = _calculate_line_avg_density(geo, core_profiles) / 1e19 - R = geo.Rmaj - inverse_aspect_ratio = geo.Rmin / geo.Rmaj - - # Effective elongation definition. This is a different definition than - # the standard definition used in geo.elongation. - elongation = geo.area_face[-1] / (jnp.pi * geo.Rmin**2) - # TODO(b/317360834): extend when multiple ions are supported. - effective_mass = core_profiles.Ai - triangularity = geo.delta_face[-1] - - tau_scaling = ( - params['prefactor'] - * scaled_Ip ** params['Ip_exponent'] - * B ** params['B_exponent'] - * line_avg_ne ** params['line_avg_ne_exponent'] - * scaled_Ploss ** params['Ploss_exponent'] - * R ** params['R_exponent'] - * inverse_aspect_ratio ** params['inverse_aspect_ratio_exponent'] - * elongation ** params['elongation_exponent'] - * effective_mass ** params['effective_mass_exponent'] - * (1 + triangularity) ** params['triangularity_exponent'] - ) - return tau_scaling - - -def _calculate_line_avg_density( - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, -) -> jax.Array: - """Calculates line-averaged electron density. - - Line-averaged electron density is poorly defined. In general, the definition - is machine-dependent and even shot-dependent since it depends on the usage of - a specific interferometry chord. Furthermore, even if we knew the specific - chord used, its calculation would depend on magnetic geometry information - beyond what is available in StandardGeometry. In lieu of a better solution, we - use line-averaged electron density defined on the outer midplane. - - Args: - geo: Torus geometry. - core_profiles: Core plasma profiles. - - Returns: - Line-averaged electron density. - """ - Rmin_out = geo.Rout_face[-1] - geo.Rout_face[0] - line_avg_ne = ( - core_profiles.nref - * _trapz(core_profiles.ne.face_value(), geo.Rout_face) - / Rmin_out - ) - return line_avg_ne - - -# pylint: enable=invalid-name diff --git a/torax/physics/__init__.py b/torax/physics/__init__.py new file mode 100644 index 00000000..b172883d --- /dev/null +++ b/torax/physics/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This package contains functionality related to calculating physics quantities.""" diff --git a/torax/charge_states.py b/torax/physics/charge_states.py similarity index 100% rename from torax/charge_states.py rename to torax/physics/charge_states.py diff --git a/torax/physics/collisions.py b/torax/physics/collisions.py new file mode 100644 index 00000000..0ad59b5e --- /dev/null +++ b/torax/physics/collisions.py @@ -0,0 +1,234 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Physics calculations related to collisional quantities. + +Functions: + - coll_exchange: Computes the collisional ion-electron heat exchange + coefficient (equipartion). + - calc_nu_star: Calculates the nu_star parameter: the electron-ion collision + frequency normalized by bounce frequency. + - fast_ion_fractional_heating_formula: Returns the fraction of heating that + goes to the ions according to Stix 1975 analyticlal formulas. + - _calculate_lambda_ei: Calculates the Coulomb logarithm for electron-ion + collisions. + - _calculate_weighted_Zeff: Calculates ion mass weighted Zeff used in + the equipartion calculation. + - _calculate_log_tau_e_Z1: Calculates log of electron-ion collision time for + Z=1 plasma. +""" + +import jax +from jax import numpy as jnp +from torax import array_typing +from torax import constants +from torax import state +from torax.geometry import geometry + +# pylint: disable=invalid-name + + +def coll_exchange( + core_profiles: state.CoreProfiles, + nref: float, + Qei_mult: float, +) -> jax.Array: + """Computes collisional ion-electron heat exchange coefficient (equipartion). + + Args: + core_profiles: Core plasma profiles. + nref: Reference value for normalization + Qei_mult: multiplier for ion-electron heat exchange term + + Returns: + Qei_coeff: ion-electron collisional heat exchange coefficient. + """ + # Calculate Coulomb logarithm + lambda_ei = _calculate_lambda_ei( + core_profiles.temp_el.value, core_profiles.ne.value * nref + ) + # ion-electron collisionality for Zeff=1. Ion charge and multiple ion effects + # are included in the Qei_coef calculation below. + log_tau_e_Z1 = _calculate_log_tau_e_Z1( + core_profiles.temp_el.value, + core_profiles.ne.value * nref, + lambda_ei, + ) + # pylint: disable=invalid-name + + weighted_Zeff = _calculate_weighted_Zeff(core_profiles) + + log_Qei_coef = ( + jnp.log(Qei_mult * 1.5 * core_profiles.ne.value * nref) + + jnp.log(constants.CONSTANTS.keV2J / constants.CONSTANTS.mp) + + jnp.log(2 * constants.CONSTANTS.me) + + jnp.log(weighted_Zeff) + - log_tau_e_Z1 + ) + Qei_coef = jnp.exp(log_Qei_coef) + return Qei_coef + + +def calc_nu_star( + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + nref: float, + Zeff_face: jax.Array, + coll_mult: float, +) -> jax.Array: + """Calculates nustar. + + Electron-ion collision frequency normalized by bounce frequency. + + Args: + geo: Torus geometry. + core_profiles: Core plasma profiles. + nref: Reference value for normalization + Zeff_face: Effective ion charge on face grid. + coll_mult: Collisionality multiplier in QLKNN for sensitivity testing. + + Returns: + nu_star: on face grid. + """ + + # Calculate Coulomb logarithm + lambda_ei_face = _calculate_lambda_ei( + core_profiles.temp_el.face_value(), core_profiles.ne.face_value() * nref + ) + + # ion_electron collisionality + log_tau_e_Z1 = _calculate_log_tau_e_Z1( + core_profiles.temp_el.face_value(), + core_profiles.ne.face_value() * nref, + lambda_ei_face, + ) + + nu_e = 1 / jnp.exp(log_tau_e_Z1) * Zeff_face * coll_mult + + # calculate bounce time + epsilon = geo.rho_face / geo.Rmaj + # to avoid divisions by zero + epsilon = jnp.clip(epsilon, constants.CONSTANTS.eps) + tau_bounce = ( + core_profiles.q_face + * geo.Rmaj + / ( + epsilon**1.5 + * jnp.sqrt( + core_profiles.temp_el.face_value() + * constants.CONSTANTS.keV2J + / constants.CONSTANTS.me + ) + ) + ) + # due to pathological on-axis epsilon=0 term + tau_bounce = tau_bounce.at[0].set(tau_bounce[1]) + + # calculate normalized collisionality + nustar = nu_e * tau_bounce + + return nustar + + +def fast_ion_fractional_heating_formula( + birth_energy: float | array_typing.ArrayFloat, + temp_el: array_typing.ArrayFloat, + fast_ion_mass: float, +) -> array_typing.ArrayFloat: + """Returns the fraction of heating that goes to the ions. + + From eq. 5 and eq. 26 in Mikkelsen Nucl. Tech. Fusion 237 4 1983. + Note there is a typo in eq. 26 where a `2x` term is missing in the numerator + of the log. + + Args: + birth_energy: Birth energy of the fast ions in keV. + temp_el: Electron temperature. + fast_ion_mass: Mass of the fast ions in amu. + + Returns: + The fraction of heating that goes to the ions. + """ + critical_energy = 10 * fast_ion_mass * temp_el # Eq. 5. + energy_ratio = birth_energy / critical_energy + + # Eq. 26. + x_squared = energy_ratio + x = jnp.sqrt(x_squared) + frac_i = ( + 2 + * ( + (1 / 6) * jnp.log((1.0 - x + x_squared) / (1.0 + 2.0 * x + x_squared)) + + (jnp.arctan((2.0 * x - 1.0) / jnp.sqrt(3)) + jnp.pi / 6) + / jnp.sqrt(3) + ) + / x_squared + ) + return frac_i + + +def _calculate_lambda_ei( + temp_el: jax.Array, + ne: jax.Array, +) -> jax.Array: + """Calculates Coulomb logarithm for electron-ion collisions. + + See Wesson 3rd edition p727. + + Args: + temp_el: Electron temperature in keV. + ne: Electron density in m^-3. + + Returns: + Coulomb logarithm. + """ + return 15.2 - 0.5 * jnp.log(ne / 1e20) + jnp.log(temp_el) + + +# TODO(b/377225415): generalize to arbitrary number of ions. +def _calculate_weighted_Zeff( + core_profiles: state.CoreProfiles, +) -> jax.Array: + """Calculates ion mass weighted Zeff. Used for collisional heat exchange.""" + return ( + core_profiles.ni.value * core_profiles.Zi**2 / core_profiles.Ai + + core_profiles.nimp.value * core_profiles.Zimp**2 / core_profiles.Aimp + ) / core_profiles.ne.value + + +def _calculate_log_tau_e_Z1( + temp_el: jax.Array, + ne: jax.Array, + lambda_ei: jax.Array, +) -> jax.Array: + """Calculates log of electron-ion collision time for Z=1 plasma. + + See Wesson 3rd edition p729. Extension to multiple ions is context dependent + and implemented in calling functions. + + Args: + temp_el: Electron temperature in keV. + ne: Electron density in m^-3. + lambda_ei: Coulomb logarithm. + + Returns: + Log of electron-ion collision time. + """ + return ( + jnp.log(12 * jnp.pi**1.5 / (ne * lambda_ei)) + - 4 * jnp.log(constants.CONSTANTS.qe) + + 0.5 * jnp.log(constants.CONSTANTS.me / 2.0) + + 2 * jnp.log(constants.CONSTANTS.epsilon0) + + 1.5 * jnp.log(temp_el * constants.CONSTANTS.keV2J) + ) diff --git a/torax/physics/psi_calculations.py b/torax/physics/psi_calculations.py new file mode 100644 index 00000000..1542b969 --- /dev/null +++ b/torax/physics/psi_calculations.py @@ -0,0 +1,285 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calculations related to derived quantities from poloidal flux (psi). + +Functions: + - update_jtot_q_face_s_face: Updates core profiles with + psi-derived quantities. + - calc_q: Calculates the q-profile (q). + - calc_jtot: Calculate flux-surface-averaged toroidal current density. + - calc_s: Calculates magnetic shear (s). + - calc_s_rmid: Calculates magnetic shear (s), using midplane r as radial + coordinate. + - calc_Wpol: Calculates total magnetic energy (Wpol). + - calc_li3: Calculates normalized internal inductance li3 (ITER convention). + - _calc_bpol2: Calculates square of poloidal field (Bp). +""" + +import dataclasses + +import chex +import jax +from jax import numpy as jnp +from torax import constants +from torax import jax_utils +from torax import state +from torax.fvm import cell_variable +from torax.geometry import geometry + +_trapz = jax.scipy.integrate.trapezoid + +# pylint: disable=invalid-name + + +@jax_utils.jit +def update_jtot_q_face_s_face( + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, +) -> state.CoreProfiles: + """Updates core profiles with psi-derived quantities. + + Args: + geo: Geometry object. + core_profiles: Core plasma profiles. + + Returns: + Updated core profiles with new jtot, jtot_face, Ip_profile_face, q_face, + and s_face. + """ + + jtot, jtot_face, Ip_profile_face = calc_jtot(geo, core_profiles.psi) + q_face, _ = calc_q(geo=geo, psi=core_profiles.psi) + s_face = calc_s(geo, core_profiles.psi) + + currents = dataclasses.replace( + core_profiles.currents, + jtot=jtot, + jtot_face=jtot_face, + Ip_profile_face=Ip_profile_face, + ) + new_core_profiles = dataclasses.replace( + core_profiles, + currents=currents, + q_face=q_face, + s_face=s_face, + ) + return new_core_profiles + + +def calc_q( + geo: geometry.Geometry, + psi: cell_variable.CellVariable, +) -> tuple[chex.Array, chex.Array]: + """Calculates the q-profile (q) given current (jtot) and poloidal flux (psi). + + Args: + geo: Magnetic geometry. + psi: Poloidal flux. + + Returns: + q_face: q at faces. + q: q at cell centers. + """ + # iota is standard terminology for 1/q + inv_iota = jnp.abs( + (2 * geo.Phib * geo.rho_face_norm[1:]) / psi.face_grad()[1:] + ) + + # Use L'Hôpital's rule to calculate iota on-axis, with psi_face_grad()[0]=0. + inv_iota0 = jnp.expand_dims( + jnp.abs((2 * geo.Phib * geo.drho_norm) / psi.face_grad()[1]), 0 + ) + + q_face = jnp.concatenate([inv_iota0, inv_iota]) + q_face *= geo.q_correction_factor + q = geometry.face_to_cell(q_face) + + return q_face, q + + +def calc_jtot( + geo: geometry.Geometry, + psi: cell_variable.CellVariable, +) -> tuple[chex.Array, chex.Array, chex.Array]: + """Calculate flux-surface-averaged toroidal current density from poloidal flux. + + Calculation based on jtot = dI/dS + + Args: + geo: Torus geometry. + psi: Poloidal flux. + + Returns: + jtot: total current density [A/m2] on cell grid + jtot_face: total current density [A/m2] on face grid + Ip_profile_face: cumulative total plasma current profile [A] on face grid + """ + + # pylint: disable=invalid-name + Ip_profile_face = ( + psi.face_grad() + * geo.g2g3_over_rhon_face + * geo.F_face + / geo.Phib + / (16 * jnp.pi**3 * constants.CONSTANTS.mu0) + ) + + dI_tot_drhon = jnp.gradient(Ip_profile_face, geo.rho_face_norm) + + jtot_face_bulk = dI_tot_drhon[1:] / geo.spr_face[1:] + + # Set on-axis jtot according to L'Hôpital's rule, noting that I[0]=S[0]=0. + jtot_face_axis = Ip_profile_face[1] / geo.area_face[1] + + jtot_face = jnp.concatenate([jnp.array([jtot_face_axis]), jtot_face_bulk]) + jtot = geometry.face_to_cell(jtot_face) + + return jtot, jtot_face, Ip_profile_face + + +def calc_s( + geo: geometry.Geometry, psi: cell_variable.CellVariable +) -> jax.Array: + """Calculates magnetic shear (s) from poloidal flux (psi). + + Args: + geo: Torus geometry. + psi: Poloidal flux. + + Returns: + s_face: Magnetic shear, on the face grid. + """ + + # iota (1/q) should have a /2*Phib but we drop it since will cancel out in + # the s calculation. + iota_scaled = jnp.abs((psi.face_grad()[1:] / geo.rho_face_norm[1:])) + + # on-axis iota_scaled from L'Hôpital's rule = dpsi_face_grad / drho_norm + # Using expand_dims to make it compatible with jnp.concatenate + iota_scaled0 = jnp.expand_dims( + jnp.abs(psi.face_grad()[1] / geo.drho_norm), axis=0 + ) + + iota_scaled = jnp.concatenate([iota_scaled0, iota_scaled]) + + s_face = ( + -geo.rho_face_norm + * jnp.gradient(iota_scaled, geo.rho_face_norm) + / iota_scaled + ) + + return s_face + + +def calc_s_rmid( + geo: geometry.Geometry, psi: cell_variable.CellVariable +) -> jax.Array: + """Calculates magnetic shear (s) from poloidal flux (psi). + + Version taking the derivative of iota with respect to the midplane r, + in line with expectations from circular-derived models like QuaLiKiz. + + Args: + geo: Torus geometry. + psi: Poloidal flux. + + Returns: + s_face: Magnetic shear, on the face grid. + """ + + # iota (1/q) should have a /2*Phib but we drop it since will cancel out in + # the s calculation. + iota_scaled = jnp.abs((psi.face_grad()[1:] / geo.rho_face_norm[1:])) + + # on-axis iota_scaled from L'Hôpital's rule = dpsi_face_grad / drho_norm + # Using expand_dims to make it compatible with jnp.concatenate + iota_scaled0 = jnp.expand_dims( + jnp.abs(psi.face_grad()[1] / geo.drho_norm), axis=0 + ) + + iota_scaled = jnp.concatenate([iota_scaled0, iota_scaled]) + + rmid_face = (geo.Rout_face - geo.Rin_face) * 0.5 + + s_face = -rmid_face * jnp.gradient(iota_scaled, rmid_face) / iota_scaled + + return s_face + + +def _calc_bpol2( + geo: geometry.Geometry, psi: cell_variable.CellVariable +) -> jax.Array: + r"""Calculates square of poloidal field (Bp) from poloidal flux (psi). + + An identity for the poloidal magnetic field is: + :math:`B_p = 1/R \partial \psi / \partial \rho (\nabla \rho \times e_\phi)` + + Where :math:`e_\phi` is the unit vector pointing in the toroidal direction. + + Args: + geo: Torus geometry. + psi: Poloidal flux. + + Returns: + bpol2_face: Square of poloidal magnetic field, on the face grid. + """ + bpol2_bulk = ( + (psi.face_grad()[1:] / (2 * jnp.pi)) ** 2 + * geo.g2_face[1:] + / geo.vpr_face[1:] ** 2 + ) + bpol2_axis = jnp.array([0.0]) + bpol2_face = jnp.concatenate([bpol2_axis, bpol2_bulk]) + return bpol2_face + + +def calc_Wpol( + geo: geometry.Geometry, psi: cell_variable.CellVariable +) -> jax.Array: + """Calculates total magnetic energy (Wpol) from poloidal flux (psi).""" + bpol2 = _calc_bpol2(geo, psi) + Wpol = _trapz(bpol2 * geo.vpr_face, geo.rho_face_norm) / ( + 2 * constants.CONSTANTS.mu0 + ) + return Wpol + + +def calc_li3( + Rmaj: jax.Array, + Wpol: jax.Array, + Ip_total: jax.Array, +) -> jax.Array: + """Calculates li3 based on a formulation using Wpol. + + Normalized internal inductance is defined as: + li = _V / _LCFS where <>_V is a volume average and <>_LCFS is + the average at the last closed flux surface. + + We use the ITER convention for normalized internal inductance defined as: + li3 = 2*V*_V / (mu0^2 Ip^2*Rmaj) = 4 * Wpol / (mu0 Ip^2*Rmaj) + + Ip (total plasma current) enters through the integral form of Ampere's law. + Since Wpol also corresponds to a volume integral of the poloidal field, we + can define li3 with respect to Wpol. + + Args: + Rmaj: Major radius. + Wpol: Total magnetic energy. + Ip_total: Total plasma current. + + Returns: + li3: Normalized internal inductance, ITER convention. + """ + return 4 * Wpol / (constants.CONSTANTS.mu0 * Ip_total**2 * Rmaj) diff --git a/torax/physics/scaling_laws.py b/torax/physics/scaling_laws.py new file mode 100644 index 00000000..93796910 --- /dev/null +++ b/torax/physics/scaling_laws.py @@ -0,0 +1,243 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calculations related to empirical scaling laws. + +Functions: + - calculate_plh_scaling_factor: Calculates the H-mode transition power + according to Martin 2008, and the density corresponding to the P_LH_min + according to Ryter 2014. + - calculate_scaling_law_confinement_time: Calculates the predicted + thermal energy confinement time from a given empirical scaling law. + - _calculate_line_avg_density: Calculates line-averaged electron density. +""" + +import jax +from jax import numpy as jnp +from torax import constants +from torax import state +from torax.geometry import geometry + +_trapz = jax.scipy.integrate.trapezoid + +# pylint: disable=invalid-name + + +def calculate_plh_scaling_factor( + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: + """Calculates the H-mode transition power scalings. + + See Y.R. Martin and Tomonori Takizuka. + "Power requirement for accessing the H-mode in ITER." + Journal of Physics: Conference Series. Vol. 123. No. 1. IOP Publishing, 2008. + + Only valid for hydrogenic isotopes and mixtures (H, D, T). + Includes a simple inverse scaling of the factor to average isotope mass. + + For an overview see U Plank, U., et al. "Overview of L-to H-mode transition + experiments at ASDEX Upgrade." + Plasma Physics and Controlled Fusion 65.1 (2022): 014001. + + Args: + geo: Torus geometry. + core_profiles: Core plasma profiles. + + Returns: + Tuple of: P_LH scaling factor for high density branch, minimum P_LH, + P_LH = max(P_LH_min, P_LH_hi_dens) for practical use, and the density + corresponding to the P_LH_min. + """ + + line_avg_ne = _calculate_line_avg_density(geo, core_profiles) + # LH transition power for deuterium, in W. Eq 3 from Martin 2008. + P_LH_hi_dens_D = ( + 2.15 + * (line_avg_ne / 1e20) ** 0.782 + * geo.B0**0.772 + * geo.Rmin**0.975 + * geo.Rmaj**0.999 + * 1e6 + ) + + # Scale to average isotope mass. + A_deuterium = constants.ION_PROPERTIES_DICT['D']['A'] + P_LH_hi_dens = P_LH_hi_dens_D * A_deuterium / core_profiles.Ai + + # Calculate density (in nref) corresponding to P_LH_min from Eq 3 Ryter 2014 + ne_min_P_LH = ( + 0.7 + * (core_profiles.currents.Ip_total / 1e6) ** 0.34 + * geo.Rmin**-0.95 + * geo.B0**0.62 + * (geo.Rmaj / geo.Rmin) ** 0.4 + * 1e19 + / core_profiles.nref + ) + # Calculate P_LH_min at ne_min from Eq 4 Ryter 2014 + P_LH_min_D = ( + 0.36 + * (core_profiles.currents.Ip_total / 1e6) ** 0.27 + * geo.B0**1.25 + * geo.Rmaj**1.23 + * (geo.Rmaj / geo.Rmin) ** 0.08 + * 1e6 + ) + P_LH_min = P_LH_min_D * A_deuterium / core_profiles.Ai + P_LH = jnp.maximum(P_LH_min, P_LH_hi_dens) + return P_LH_hi_dens, P_LH_min, P_LH, ne_min_P_LH + + +def calculate_scaling_law_confinement_time( + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + Ploss: jax.Array, + scaling_law: str, +) -> jax.Array: + """Calculates the thermal energy confinement time for a given empirical scaling law. + + Args: + geo: Torus geometry. + core_profiles: Core plasma profiles. + Ploss: Plasma power loss in W. + scaling_law: Scaling law to use. + + Returns: + Thermal energy confinement time in s. + """ + scaling_params = { + 'H89P': { + # From Yushmanov et al, Nuclear Fusion, vol. 30, no. 10, pp. 4-6, 1990 + 'prefactor': 0.038128, + 'Ip_exponent': 0.85, + 'B_exponent': 0.2, + 'line_avg_ne_exponent': 0.1, + 'Ploss_exponent': -0.5, + 'R_exponent': 1.5, + 'inverse_aspect_ratio_exponent': 0.3, + 'elongation_exponent': 0.5, + 'effective_mass_exponent': 0.50, + 'triangularity_exponent': 0.0, + }, + 'H98': { + # H98 empirical confinement scaling law: + # ITER Physics Expert Groups on Confinement and Transport and + # Confinement Modelling and Database, Nucl. Fusion 39 2175, 1999 + # Doyle et al, Nucl. Fusion 47 (2007) S18–S127, Eq 30 + 'prefactor': 0.0562, + 'Ip_exponent': 0.93, + 'B_exponent': 0.15, + 'line_avg_ne_exponent': 0.41, + 'Ploss_exponent': -0.69, + 'R_exponent': 1.97, + 'inverse_aspect_ratio_exponent': 0.58, + 'elongation_exponent': 0.78, + 'effective_mass_exponent': 0.19, + 'triangularity_exponent': 0.0, + }, + 'H97L': { + # From the ITER L-mode confinement database. + # S.M. Kaye et al 1997 Nucl. Fusion 37 1303, Eq 7 + 'prefactor': 0.023, + 'Ip_exponent': 0.96, + 'B_exponent': 0.03, + 'line_avg_ne_exponent': 0.4, + 'Ploss_exponent': -0.73, + 'R_exponent': 1.83, + 'inverse_aspect_ratio_exponent': -0.06, + 'elongation_exponent': 0.64, + 'effective_mass_exponent': 0.20, + 'triangularity_exponent': 0.0, + }, + 'H20': { + # Updated ITER H-mode confinement database, using full dataset. + # G. Verdoolaege et al 2021 Nucl. Fusion 61 076006, Eq 7 + 'prefactor': 0.053, + 'Ip_exponent': 0.98, + 'B_exponent': 0.22, + 'line_avg_ne_exponent': 0.24, + 'Ploss_exponent': -0.669, + 'R_exponent': 1.71, + 'inverse_aspect_ratio_exponent': 0.35, + 'elongation_exponent': 0.80, + 'effective_mass_exponent': 0.20, + 'triangularity_exponent': 0.36, # (1+delta)^exponent + }, + } + + if scaling_law not in scaling_params: + raise ValueError(f'Unknown scaling law: {scaling_law}') + + params = scaling_params[scaling_law] + + scaled_Ip = core_profiles.currents.Ip_total / 1e6 # convert to MA + scaled_Ploss = Ploss / 1e6 # convert to MW + B = geo.B0 + line_avg_ne = _calculate_line_avg_density(geo, core_profiles) / 1e19 + R = geo.Rmaj + inverse_aspect_ratio = geo.Rmin / geo.Rmaj + + # Effective elongation definition. This is a different definition than + # the standard definition used in geo.elongation. + elongation = geo.area_face[-1] / (jnp.pi * geo.Rmin**2) + # TODO(b/317360834): extend when multiple ions are supported. + effective_mass = core_profiles.Ai + triangularity = geo.delta_face[-1] + + tau_scaling = ( + params['prefactor'] + * scaled_Ip ** params['Ip_exponent'] + * B ** params['B_exponent'] + * line_avg_ne ** params['line_avg_ne_exponent'] + * scaled_Ploss ** params['Ploss_exponent'] + * R ** params['R_exponent'] + * inverse_aspect_ratio ** params['inverse_aspect_ratio_exponent'] + * elongation ** params['elongation_exponent'] + * effective_mass ** params['effective_mass_exponent'] + * (1 + triangularity) ** params['triangularity_exponent'] + ) + return tau_scaling + + +def _calculate_line_avg_density( + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, +) -> jax.Array: + """Calculates line-averaged electron density. + + Line-averaged electron density is poorly defined. In general, the definition + is machine-dependent and even shot-dependent since it depends on the usage of + a specific interferometry chord. Furthermore, even if we knew the specific + chord used, its calculation would depend on magnetic geometry information + beyond what is available in StandardGeometry. In lieu of a better solution, we + use line-averaged electron density defined on the outer midplane. + + Args: + geo: Torus geometry. + core_profiles: Core plasma profiles. + + Returns: + Line-averaged electron density. + """ + Rmin_out = geo.Rout_face[-1] - geo.Rout_face[0] + line_avg_ne = ( + core_profiles.nref + * _trapz(core_profiles.ne.face_value(), geo.Rout_face) + / Rmin_out + ) + return line_avg_ne + + +# pylint: enable=invalid-name diff --git a/torax/tests/charge_states.py b/torax/physics/tests/charge_states_tests.py similarity index 98% rename from torax/tests/charge_states.py rename to torax/physics/tests/charge_states_tests.py index c68e59ec..86337fc9 100644 --- a/torax/tests/charge_states.py +++ b/torax/physics/tests/charge_states_tests.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for module torax.charge_states.""" - from absl.testing import parameterized import numpy as np -from torax import charge_states from torax import constants from torax.config import plasma_composition +from torax.physics import charge_states # pylint: disable=invalid-name class ChargeStatesTest(parameterized.TestCase): - """Tests for impurity charge states.""" @parameterized.product( ion_symbol=[ diff --git a/torax/physics/tests/collisions_tests.py b/torax/physics/tests/collisions_tests.py new file mode 100644 index 00000000..2ece5d02 --- /dev/null +++ b/torax/physics/tests/collisions_tests.py @@ -0,0 +1,89 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock +from absl.testing import absltest +from absl.testing import parameterized +from jax import numpy as jnp +import numpy as np +from torax import state +from torax.fvm import cell_variable +from torax.physics import collisions + + +class CollisionsTest(parameterized.TestCase): + + def test_fast_ion_fractional_heating_formula_ion_heating_limit(self): + # Inertial energy small compared to critical energy. + birth_energy = 1e-3 + temp_el = jnp.array(0.1, dtype=jnp.float32) + fast_ion_mass = 1 + frac_i = collisions.fast_ion_fractional_heating_formula( + birth_energy, temp_el, fast_ion_mass + ) + np.testing.assert_allclose(frac_i, 1.0, atol=1e-3) + + def test_fast_ion_fractional_heating_formula_electron_heating_limit(self): + # Inertial energy large compared to critical energy. + birth_energy = 1e10 + temp_el = jnp.array(0.1, dtype=jnp.float32) + fast_ion_mass = 1 + frac_i = collisions.fast_ion_fractional_heating_formula( + birth_energy, temp_el, fast_ion_mass + ) + np.testing.assert_allclose(frac_i, 0.0, atol=1e-9) + + # TODO(b/377225415): generalize to arbitrary number of ions. + @parameterized.parameters([ + dict(Aimp=20.0, Zimp=10.0, Zi=1.0, Ai=1.0, ni=1.0, expected=1.0), + dict(Aimp=20.0, Zimp=10.0, Zi=1.0, Ai=2.0, ni=1.0, expected=0.5), + dict(Aimp=20.0, Zimp=10.0, Zi=2.0, Ai=4.0, ni=0.5, expected=0.5), + dict(Aimp=20.0, Zimp=10.0, Zi=1.0, Ai=2.0, ni=0.9, expected=0.5), + dict(Aimp=40.0, Zimp=20.0, Zi=1.0, Ai=2.0, ni=0.92, expected=0.5), + ]) + # pylint: disable=invalid-name + def test_calculate_weighted_Zeff(self, Aimp, Zimp, Zi, Ai, ni, expected): + """Compare `_calculate_weighted_Zeff` to a reference value.""" + ne = 1.0 + nimp = (ne - ni * Zi) / Zimp + core_profiles = mock.create_autospec( + state.CoreProfiles, + instance=True, + ne=cell_variable.CellVariable( + value=jnp.array(ne), + dr=jnp.array(1.0), + ), + ni=cell_variable.CellVariable( + value=jnp.array(ni), + dr=jnp.array(1.0), + ), + nimp=cell_variable.CellVariable( + value=jnp.array(nimp), + dr=jnp.array(1.0), + ), + Zi=Zi, + Ai=Ai, + Zimp=Zimp, + Aimp=Aimp, + ) + # pylint: enable=invalid-name + # pylint: disable=protected-access + np.testing.assert_allclose( + collisions._calculate_weighted_Zeff(core_profiles), expected + ) + # pylint: enable=protected-access + + +if __name__ == '__main__': + absltest.main() diff --git a/torax/physics/tests/psi_calculations_tests.py b/torax/physics/tests/psi_calculations_tests.py new file mode 100644 index 00000000..baf887ad --- /dev/null +++ b/torax/physics/tests/psi_calculations_tests.py @@ -0,0 +1,168 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable +from absl.testing import absltest +from absl.testing import parameterized +import jax +import numpy as np +from torax import constants +from torax.core_profiles import initialization +from torax.geometry import pydantic_model as geometry_pydantic_model +from torax.geometry import standard_geometry +from torax.physics import psi_calculations +from torax.tests.test_lib import torax_refs + + +_trapz = jax.scipy.integrate.trapezoid + + +class PsiCalculationsTest(torax_refs.ReferenceValueTest): + + @parameterized.parameters([ + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), + ]) + def test_calc_q(self, references_getter: Callable[[], torax_refs.References]): + """Compare `calc_q` function to a reference implementation.""" + references = references_getter() + + runtime_params = references.runtime_params + _, geo = ( + torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry( + runtime_params, + references.geometry_provider, + ) + ) + + q_face_calculated, _ = psi_calculations.calc_q( + geo, + references.psi, + ) + + q_face_expected = references.q + + np.testing.assert_allclose(q_face_calculated, q_face_expected, rtol=1e-5) + + @parameterized.parameters([ + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), + ]) + def test_calc_jtot( + self, references_getter: Callable[[], torax_refs.References] + ): + """Compare `calc_jtot` to a reference value.""" + references = references_getter() + geo = references.geometry_provider( + references.runtime_params.numerics.t_initial + ) + # pylint: disable=invalid-name + j, _, Ip_profile_face = psi_calculations.calc_jtot( + geo, + references.psi, + ) + # pylint: enable=invalid-name + np.testing.assert_allclose(j, references.jtot, rtol=1e-5) + + if references.Ip_from_parameters: + np.testing.assert_allclose( + Ip_profile_face[-1], + references.runtime_params.profile_conditions.Ip_tot * 1e6, + ) + else: + assert isinstance(geo, standard_geometry.StandardGeometry) + np.testing.assert_allclose( + Ip_profile_face[-1], + geo.Ip_profile_face[-1], + ) + + @parameterized.parameters([ + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict( + references_getter=torax_refs.chease_references_Ip_from_runtime_params + ), + ]) + def test_calc_s(self, references_getter: Callable[[], torax_refs.References]): + """Compare `calc_s` to a reference value.""" + references = references_getter() + geo = references.geometry_provider( + references.runtime_params.numerics.t_initial + ) + + s = psi_calculations.calc_s( + geo, + references.psi, + ) + + np.testing.assert_allclose(s, references.s, rtol=1e-5) + + # pylint: disable=invalid-name + def test_calc_Wpol(self): + """Compare `calc_Wpol` to an analytical formula in circular geometry.""" + + # Small inverse aspect ratio limit of circular geometry, such that we + # approximate the simplest form of circular geometry where the analytical + # Bpol formula is applicable. + geo = geometry_pydantic_model.CircularConfig( + n_rho=25, + elongation_LCFS=1.0, + Rmaj=100.0, + Rmin=1.0, + B0=5.0, + ).build_geometry() + Ip_tot = 15 + # calculate high resolution jtot consistent with total current profile + jtot_profile = (1 - geo.rho_hires_norm**2) ** 2 + denom = _trapz(jtot_profile * geo.spr_hires, geo.rho_hires_norm) + Ctot = Ip_tot * 1e6 / denom + jtot = jtot_profile * Ctot + # pylint: disable=protected-access + psi_cell_variable = initialization._update_psi_from_j( + Ip_tot, + geo, + jtot, + ) + _, _, Ip_profile_face = psi_calculations.calc_jtot( + geo, + psi_cell_variable, + ) + + # Analytical formula for Bpol in circular geometry (Ampere's law) + Bpol_bulk = ( + constants.CONSTANTS.mu0 + * Ip_profile_face[1:] + / (2 * np.pi * geo.rho_face[1:]) + ) + Bpol = np.concatenate([np.array([0.0]), Bpol_bulk]) + + expected_Wpol = _trapz(Bpol**2 * geo.vpr_face, geo.rho_face_norm) / ( + 2 * constants.CONSTANTS.mu0 + ) + + calculated_Wpol = psi_calculations.calc_Wpol(geo, psi_cell_variable) + + # Relatively low tolerence because the analytical formula is not exact for + # our circular geometry, but approximates it at low inverse aspect ratio. + np.testing.assert_allclose(calculated_Wpol, expected_Wpol, rtol=1e-3) + + +if __name__ == '__main__': + absltest.main() diff --git a/torax/physics/tests/scaling_laws_tests.py b/torax/physics/tests/scaling_laws_tests.py new file mode 100644 index 00000000..22fdbb82 --- /dev/null +++ b/torax/physics/tests/scaling_laws_tests.py @@ -0,0 +1,216 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for torax.physics.scaling_laws.""" + +import dataclasses +from unittest import mock +from absl.testing import absltest +from absl.testing import parameterized +from jax import numpy as jnp +import numpy as np +from torax import state +from torax.fvm import cell_variable +from torax.geometry import pydantic_model as geometry_pydantic_model +from torax.physics import scaling_laws + + +class ScalingLawsTest(parameterized.TestCase): + """Unit tests for the `torax.physics.scaling_laws` module.""" + + def test_calculate_plh_scaling_factor(self): + """Compare `calculate_plh_scaling_factor` to a reference value.""" + geo = geometry_pydantic_model.CircularConfig( + n_rho=25, + elongation_LCFS=1.0, + hires_fac=4, + Rmaj=6.0, + Rmin=2.0, + B0=5.0, + ).build_geometry() + + # Using mock.ANY instead of mock.create_autospec to maintain the Ip_total + # property needed in calculate_plh_scaling_factor. + core_profiles = state.CoreProfiles( + ne=cell_variable.CellVariable( + value=jnp.ones_like(geo.rho_norm) * 2, + left_face_grad_constraint=jnp.zeros(()), + right_face_grad_constraint=None, + right_face_constraint=jnp.array(2.0), + dr=geo.drho_norm, + ), + ni=mock.ANY, + nimp=mock.ANY, + temp_ion=mock.ANY, + temp_el=mock.ANY, + psi=mock.ANY, + psidot=mock.ANY, + vloop_lcfs=mock.ANY, + currents=state.Currents.zeros(geo), + q_face=mock.ANY, + s_face=mock.ANY, + Zi=mock.ANY, + Zi_face=mock.ANY, + Ai=3.0, + Zimp=mock.ANY, + Zimp_face=mock.ANY, + Aimp=mock.ANY, + nref=1e20, + ) + + core_profiles = dataclasses.replace( + core_profiles, + currents=dataclasses.replace( + core_profiles.currents, + Ip_profile_face=jnp.ones_like(geo.rho_face_norm) * 10e6, + ), + ) + # pylint: disable=invalid-name + P_LH_hi_dens, P_LH_min, P_LH, ne_min_P_LH = ( + scaling_laws.calculate_plh_scaling_factor(geo, core_profiles) + ) + expected_PLH_hi_dens = ( + 2.15 * 2**0.782 * 5**0.772 * 2**0.975 * 6**0.999 * (2.0141 / 3) + ) + expected_PLH_min = ( + 0.36 * 10**0.27 * 5**1.25 * 6**1.23 * 3**0.08 * (2.0141 / 3) + ) + expected_ne_min_P_LH = 0.7 * 10**0.34 * 5**0.62 * 2.0**-0.95 * 3**0.4 / 10 + # pylint: enable=invalid-name + np.testing.assert_allclose(P_LH_hi_dens / 1e6, expected_PLH_hi_dens) + np.testing.assert_allclose(P_LH_min / 1e6, expected_PLH_min) + np.testing.assert_allclose(ne_min_P_LH, expected_ne_min_P_LH) + np.testing.assert_allclose(P_LH, P_LH_hi_dens) + + @parameterized.parameters([ + dict(elongation_LCFS=1.0), + dict(elongation_LCFS=1.5), + ]) + # pylint: disable=invalid-name + def test_calculate_scaling_law_confinement_time(self, elongation_LCFS): + """Compare `calculate_scaling_law_confinement_time` to reference values.""" + geo = geometry_pydantic_model.CircularConfig( + n_rho=25, + elongation_LCFS=elongation_LCFS, + hires_fac=4, + Rmaj=6.0, + Rmin=2.0, + B0=5.0, + ).build_geometry() + # Using mock.ANY instead of mock.create_autospec to maintain the Ip_total + # property needed in calculate_plh_scaling_factor. + core_profiles = state.CoreProfiles( + ne=cell_variable.CellVariable( + value=jnp.ones_like(geo.rho_norm) * 2, + left_face_grad_constraint=jnp.zeros(()), + right_face_grad_constraint=None, + right_face_constraint=jnp.array(2.0), + dr=geo.drho_norm, + ), + ni=mock.ANY, + nimp=mock.ANY, + temp_ion=mock.ANY, + temp_el=mock.ANY, + psi=mock.ANY, + psidot=mock.ANY, + vloop_lcfs=mock.ANY, + currents=state.Currents.zeros(geo), + q_face=mock.ANY, + s_face=mock.ANY, + Zi=mock.ANY, + Zi_face=mock.ANY, + Ai=3.0, + Zimp=mock.ANY, + Zimp_face=mock.ANY, + Aimp=mock.ANY, + nref=1e20, + ) + core_profiles = dataclasses.replace( + core_profiles, + currents=dataclasses.replace( + core_profiles.currents, + Ip_profile_face=jnp.ones_like(geo.rho_face_norm) * 10e6, + ), + ) + Ploss = jnp.array(50e6) + + H89P = scaling_laws.calculate_scaling_law_confinement_time( + geo, core_profiles, Ploss, 'H89P' + ) + H98 = scaling_laws.calculate_scaling_law_confinement_time( + geo, core_profiles, Ploss, 'H98' + ) + H97L = scaling_laws.calculate_scaling_law_confinement_time( + geo, core_profiles, Ploss, 'H97L' + ) + H20 = scaling_laws.calculate_scaling_law_confinement_time( + geo, core_profiles, Ploss, 'H20' + ) + + expected_H89P = ( + 0.038128 + * 10**0.85 + * 5**0.2 + * 20**0.1 + * 50**-0.5 + * 6**1.5 + * (1 / 3) ** 0.3 + * 3**0.50 + * elongation_LCFS**0.50 + ) + + expected_H98 = ( + 0.0562 + * 10**0.93 + * 5**0.15 + * 20**0.41 + * 50**-0.69 + * 6**1.97 + * (1 / 3) ** 0.58 + * 3**0.19 + * elongation_LCFS**0.78 + ) + + expected_H97L = ( + 0.023 + * 10**0.96 + * 5**0.03 + * 20**0.4 + * 50**-0.73 + * 6**1.83 + * (1 / 3) ** -0.06 + * 3**0.20 + * elongation_LCFS**0.64 + ) + + expected_H20 = ( + 0.053 + * 10**0.98 + * 5**0.22 + * 20**0.24 + * 50**-0.669 + * 6**1.71 + * (1 / 3) ** 0.35 + * 3**0.20 + * elongation_LCFS**0.80 + ) + # pylint: enable=invalid-name + np.testing.assert_allclose(H89P, expected_H89P) + np.testing.assert_allclose(H98, expected_H98) + np.testing.assert_allclose(H97L, expected_H97L) + np.testing.assert_allclose(H20, expected_H20) + + +if __name__ == '__main__': + absltest.main() diff --git a/torax/post_processing.py b/torax/post_processing.py index db7872a7..171b89cf 100644 --- a/torax/post_processing.py +++ b/torax/post_processing.py @@ -23,9 +23,10 @@ from torax import constants from torax import jax_utils from torax import math_utils -from torax import physics from torax import state from torax.geometry import geometry +from torax.physics import psi_calculations +from torax.physics import scaling_laws from torax.sources import source_profiles _trapz = jax.scipy.integrate.trapezoid @@ -325,8 +326,8 @@ def _calculate_greenwald_fraction( Args: ne_avg: Averaged electron density [nref m^-3] - core_profiles: CoreProfiles object containing information on currents - and densities. + core_profiles: CoreProfiles object containing information on currents and + densities. geo: Geometry object Returns: @@ -334,9 +335,7 @@ def _calculate_greenwald_fraction( """ # gw_limit is in units of 10^20 m^-3 when Ip is in MA and Rmin is in m. gw_limit = ( - core_profiles.currents.Ip_total - * 1e-6 - / (jnp.pi * geo.Rmin ** 2) + core_profiles.currents.Ip_total * 1e-6 / (jnp.pi * geo.Rmin**2) ) fgw = ne_avg * core_profiles.nref / (gw_limit * 1e20) return fgw @@ -395,7 +394,7 @@ def make_outputs( ) P_LH_hi_dens, P_LH_min, P_LH, ne_min_P_LH = ( - physics.calculate_plh_scaling_factor(geo, sim_state.core_profiles) + scaling_laws.calculate_plh_scaling_factor(geo, sim_state.core_profiles) ) # Thermal energy confinement time is the stored energy divided by the total @@ -411,16 +410,16 @@ def make_outputs( # TODO(b/380848256): include dW/dt term tauE = W_thermal_tot / Ploss - tauH89P = physics.calculate_scaling_law_confinement_time( + tauH89P = scaling_laws.calculate_scaling_law_confinement_time( geo, sim_state.core_profiles, Ploss, 'H89P' ) - tauH98 = physics.calculate_scaling_law_confinement_time( + tauH98 = scaling_laws.calculate_scaling_law_confinement_time( geo, sim_state.core_profiles, Ploss, 'H98' ) - tauH97L = physics.calculate_scaling_law_confinement_time( + tauH97L = scaling_laws.calculate_scaling_law_confinement_time( geo, sim_state.core_profiles, Ploss, 'H97L' ) - tauH20 = physics.calculate_scaling_law_confinement_time( + tauH20 = scaling_laws.calculate_scaling_law_confinement_time( geo, sim_state.core_profiles, Ploss, 'H20' ) @@ -500,12 +499,12 @@ def make_outputs( ) fgw_ne_volume_avg = _calculate_greenwald_fraction( ne_volume_avg, sim_state.core_profiles, geo - ) + ) fgw_ne_line_avg = _calculate_greenwald_fraction( ne_line_avg, sim_state.core_profiles, geo ) - Wpol = physics.calc_Wpol(geo, sim_state.core_profiles.psi) - li3 = physics.calc_li3( + Wpol = psi_calculations.calc_Wpol(geo, sim_state.core_profiles.psi) + li3 = psi_calculations.calc_li3( geo.Rmaj, Wpol, sim_state.core_profiles.currents.Ip_profile_face[-1] ) diff --git a/torax/sources/bootstrap_current_source.py b/torax/sources/bootstrap_current_source.py index ecbc3419..5d7c260c 100644 --- a/torax/sources/bootstrap_current_source.py +++ b/torax/sources/bootstrap_current_source.py @@ -25,11 +25,11 @@ from jax.scipy import integrate from torax import constants from torax import jax_utils -from torax import physics from torax import state from torax.config import runtime_params_slice from torax.fvm import cell_variable from torax.geometry import geometry +from torax.physics import psi_calculations from torax.sources import runtime_params as runtime_params_lib from torax.sources import source from torax.sources import source_profiles @@ -175,8 +175,6 @@ def calc_sauter_model( geo: geometry.Geometry, ) -> source_profiles.BootstrapCurrentProfile: """Calculates sigmaneo, j_bootstrap, and I_bootstrap.""" - # Many variables throughout this function are capitalized based on physics - # notational conventions rather than on Google Python style # pylint: disable=invalid-name # Formulas from Sauter PoP 1999. Future work can include Redl PoP 2021 @@ -211,8 +209,9 @@ def calc_sauter_model( # We don't store q_cell in the evolving core profiles, so we need to # recalculate it. - q_face, _ = physics.calc_q_from_psi( - geo=geo, psi=psi, + q_face, _ = psi_calculations.calc_q( + geo=geo, + psi=psi, ) nuestar = ( 6.921e-18 @@ -321,9 +320,7 @@ def calc_sauter_model( ) / (1 + 0.15 * nuistar**2 * ftrap**6) # calculate bootstrap current - prefactor = ( - -geo.F_face * bootstrap_multiplier * 2 * jnp.pi / geo.B0 - ) + prefactor = -geo.F_face * bootstrap_multiplier * 2 * jnp.pi / geo.B0 pe = true_ne_face * (temp_el.face_value()) * 1e3 * 1.6e-19 pi = true_ni_face * (temp_ion.face_value()) * 1e3 * 1.6e-19 diff --git a/torax/sources/bremsstrahlung_heat_sink.py b/torax/sources/bremsstrahlung_heat_sink.py index 674e7f6b..6ea117f1 100644 --- a/torax/sources/bremsstrahlung_heat_sink.py +++ b/torax/sources/bremsstrahlung_heat_sink.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Many variables throughout this function are capitalized based on physics -# notational conventions rather than on Google Python style # pylint: disable=invalid-name """Bremsstrahlung heat sink for electron heat equation..""" diff --git a/torax/sources/cyclotron_radiation_heat_sink.py b/torax/sources/cyclotron_radiation_heat_sink.py index 442c015f..53254a0e 100644 --- a/torax/sources/cyclotron_radiation_heat_sink.py +++ b/torax/sources/cyclotron_radiation_heat_sink.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Many variables throughout this function are capitalized based on physics -# notational conventions rather than on Google Python style # pylint: disable=invalid-name """Cyclotron radiation heat sink for electron heat equation..""" diff --git a/torax/sources/formulas.py b/torax/sources/formulas.py index c3ad194b..6b568971 100644 --- a/torax/sources/formulas.py +++ b/torax/sources/formulas.py @@ -16,8 +16,6 @@ import jax from jax import numpy as jnp from torax.geometry import geometry -# Many variables throughout this function are capitalized based on physics -# notational conventions rather than on Google Python style # pylint: disable=invalid-name diff --git a/torax/sources/fusion_heat_source.py b/torax/sources/fusion_heat_source.py index a409cb2d..98296312 100644 --- a/torax/sources/fusion_heat_source.py +++ b/torax/sources/fusion_heat_source.py @@ -23,10 +23,10 @@ import jax from jax import numpy as jnp from torax import constants -from torax import physics from torax import state from torax.config import runtime_params_slice from torax.geometry import geometry +from torax.physics import collisions from torax.sources import runtime_params as runtime_params_lib from torax.sources import source from torax.sources import source_profiles @@ -77,8 +77,6 @@ def calc_fusion( # for DT calculated with the Bosch-Hale parameterization NF 1992. # T is in keV for the formula - # Many variables throughout this function are capitalized based on physics - # notational conventions rather than on Google Python style # pylint: disable=invalid-name Efus = 17.6 * 1e3 * constants.CONSTANTS.keV2J mrc2 = 1124656 @@ -129,7 +127,7 @@ def calc_fusion( # Fractional fusion power ions/electrons. birth_energy = 3520 # Birth energy of alpha particles is 3.52MeV. alpha_mass = 4.002602 - frac_i = physics.fast_ion_fractional_heating_formula( + frac_i = collisions.fast_ion_fractional_heating_formula( birth_energy, core_profiles.temp_el.value, alpha_mass, diff --git a/torax/sources/generic_ion_el_heat_source.py b/torax/sources/generic_ion_el_heat_source.py index f1c75f46..8684632a 100644 --- a/torax/sources/generic_ion_el_heat_source.py +++ b/torax/sources/generic_ion_el_heat_source.py @@ -29,8 +29,6 @@ from torax.sources import runtime_params as runtime_params_lib from torax.sources import source from torax.sources import source_profiles -# Many variables throughout this function are capitalized based on physics -# notational conventions rather than on Google Python style # pylint: disable=invalid-name diff --git a/torax/sources/ion_cyclotron_source.py b/torax/sources/ion_cyclotron_source.py index 956df20d..31f73df7 100644 --- a/torax/sources/ion_cyclotron_source.py +++ b/torax/sources/ion_cyclotron_source.py @@ -30,10 +30,10 @@ import numpy as np from torax import array_typing from torax import interpolated_param -from torax import physics from torax import state from torax.config import runtime_params_slice from torax.geometry import geometry +from torax.physics import collisions from torax.sources import runtime_params as runtime_params_lib from torax.sources import source from torax.sources import source_profiles @@ -289,8 +289,6 @@ def __call__( # pylint: disable=invalid-name -# Several variable names below follow physics notation matching so don't adhere -# to the lint guide. @dataclasses.dataclass class RuntimeParams(runtime_params_lib.RuntimeParams): """Runtime parameters for the ion cyclotron resonance source.""" @@ -460,7 +458,7 @@ def icrh_model_func( dynamic_source_runtime_params.Ptot / 1e6, # required in MW. ) helium3_mass = 3.016 - frac_ion_heating = physics.fast_ion_fractional_heating_formula( + frac_ion_heating = collisions.fast_ion_fractional_heating_formula( helium3_birth_energy, core_profiles.temp_el.value, helium3_mass, @@ -483,6 +481,8 @@ def icrh_model_func( source_ion += power_deposition_2T * dynamic_source_runtime_params.Ptot return (source_ion, source_el) + + # pylint: enable=invalid-name diff --git a/torax/sources/ohmic_heat_source.py b/torax/sources/ohmic_heat_source.py index 28e6c0e8..07beb8ad 100644 --- a/torax/sources/ohmic_heat_source.py +++ b/torax/sources/ohmic_heat_source.py @@ -22,13 +22,13 @@ import jax import jax.numpy as jnp from torax import constants -from torax import physics from torax import state from torax.config import runtime_params_slice from torax.fvm import cell_variable from torax.fvm import convection_terms from torax.fvm import diffusion_terms from torax.geometry import geometry +from torax.physics import psi_calculations from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_operations @@ -128,7 +128,7 @@ def ohmic_model_func( ' an explicit source.' ) - jtot, _, _ = physics.calc_jtot_from_psi( + jtot, _, _ = psi_calculations.calc_jtot( geo, core_profiles.psi, ) diff --git a/torax/sources/qei_source.py b/torax/sources/qei_source.py index 128cb4da..ff4e465b 100644 --- a/torax/sources/qei_source.py +++ b/torax/sources/qei_source.py @@ -22,10 +22,10 @@ import chex import jax from jax import numpy as jnp -from torax import physics from torax import state from torax.config import runtime_params_slice from torax.geometry import geometry +from torax.physics import collisions from torax.sources import runtime_params as runtime_params_lib from torax.sources import source from torax.sources import source_profiles @@ -137,7 +137,7 @@ def _model_based_qei( """Computes Qei via the coll_exchange model.""" assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) zeros = jnp.zeros_like(geo.rho_norm) - qei_coef = physics.coll_exchange( + qei_coef = collisions.coll_exchange( core_profiles=core_profiles, nref=dynamic_runtime_params_slice.numerics.nref, Qei_mult=dynamic_source_runtime_params.Qei_mult, diff --git a/torax/tests/physics.py b/torax/tests/physics.py deleted file mode 100644 index 7ff52f5d..00000000 --- a/torax/tests/physics.py +++ /dev/null @@ -1,556 +0,0 @@ -# Copyright 2024 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for torax.physics.""" - -import dataclasses -from typing import Callable - -from absl.testing import absltest -from absl.testing import parameterized -import jax -from jax import numpy as jnp -import numpy as np -from torax import constants -from torax import physics -from torax import state -from torax.core_profiles import initialization -from torax.fvm import cell_variable -from torax.geometry import pydantic_model as geometry_pydantic_model -from torax.geometry import standard_geometry -from torax.tests.test_lib import torax_refs - - -_trapz = jax.scipy.integrate.trapezoid - - -class PhysicsTest(torax_refs.ReferenceValueTest): - """Unit tests for the `torax.physics` module.""" - - @parameterized.parameters([ - dict(references_getter=torax_refs.circular_references), - dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict( - references_getter=torax_refs.chease_references_Ip_from_runtime_params - ), - ]) - def test_calc_q_from_psi( - self, references_getter: Callable[[], torax_refs.References] - ): - """Compare `calc_q_from_psi` function to a reference implementation.""" - references = references_getter() - - runtime_params = references.runtime_params - _, geo = ( - torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry( - runtime_params, - references.geometry_provider, - ) - ) - - q_face_jax, q_cell_jax = physics.calc_q_from_psi( - geo, - references.psi, - ) - - # Make ground truth - def calc_q_from_psi(geo): - """Reference implementation from PINT.""" - iota = np.zeros(geo.torax_mesh.nx + 1) # on face grid - # We use the reference value of psi here because the original code - # for calculating psi depends on FiPy, and we don't want to install that - iota[1:] = np.abs( - references.psi_face_grad[1:] / (2 * geo.Phib * geo.rho_face_norm[1:]) - ) - iota[0] = np.abs( - references.psi_face_grad[1] / (2 * geo.Phib * geo.drho_norm) - ) - q = 1 / iota - q *= geo.q_correction_factor - - def face_to_cell(face): - cell = np.zeros(geo.torax_mesh.nx) - cell[:] = 0.5 * (face[1:] + face[:-1]) - return cell - - q_cell = face_to_cell(q) - return q, q_cell - - q_face_np, q_cell_np = calc_q_from_psi(geo) - - np.testing.assert_allclose(q_face_jax, q_face_np) - np.testing.assert_allclose(q_cell_jax, q_cell_np) - - @parameterized.parameters([ - dict(references_getter=torax_refs.circular_references), - dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict( - references_getter=torax_refs.chease_references_Ip_from_runtime_params - ), - ]) - def test_calc_jtot_from_psi( - self, references_getter: Callable[[], torax_refs.References] - ): - """Compare `calc_jtot_from_psi` to a reference value.""" - references = references_getter() - geo = references.geometry_provider( - references.runtime_params.numerics.t_initial - ) - # pylint: disable=invalid-name - j, _, Ip_profile_face = physics.calc_jtot_from_psi( - geo, - references.psi, - ) - # pylint: enable=invalid-name - np.testing.assert_allclose(j, references.jtot, rtol=1e-5) - - if references.Ip_from_parameters: - np.testing.assert_allclose( - Ip_profile_face[-1], - references.runtime_params.profile_conditions.Ip_tot * 1e6, - ) - else: - assert isinstance(geo, standard_geometry.StandardGeometry) - np.testing.assert_allclose( - Ip_profile_face[-1], - geo.Ip_profile_face[-1], - ) - - @parameterized.parameters([ - dict(references_getter=torax_refs.circular_references), - dict(references_getter=torax_refs.chease_references_Ip_from_chease), - dict( - references_getter=torax_refs.chease_references_Ip_from_runtime_params - ), - ]) - def test_calc_s_from_psi( - self, references_getter: Callable[[], torax_refs.References] - ): - """Compare `calc_s_from_psi` to a reference value.""" - references = references_getter() - geo = references.geometry_provider( - references.runtime_params.numerics.t_initial - ) - - s = physics.calc_s_from_psi( - geo, - references.psi, - ) - - np.testing.assert_allclose(s, references.s, rtol=1e-5) - - def test_fast_ion_fractional_heating_formula(self): - """Compare `ion_heat_fraction` to a reference value.""" - # Inertial energy small compared to critical energy, all energy to ions. - birth_energy = 1e-3 - temp_el = jnp.array(0.1, dtype=jnp.float32) - fast_ion_mass = 1 - frac_i = physics.fast_ion_fractional_heating_formula( - birth_energy, temp_el, fast_ion_mass - ) - np.testing.assert_allclose(frac_i, 1.0, atol=1e-3) - - # Inertial energy large compared to critical energy, all energy to e-. - birth_energy = 1e10 - frac_i = physics.fast_ion_fractional_heating_formula( - birth_energy, temp_el, fast_ion_mass - ) - np.testing.assert_allclose(frac_i, 0.0, atol=1e-9) - - # TODO(b/377225415): generalize to arbitrary number of ions. - @parameterized.parameters([ - dict(Aimp=20.0, Zimp=10.0, Zi=1.0, Ai=1.0, ni=1.0, expected=1.0), - dict(Aimp=20.0, Zimp=10.0, Zi=1.0, Ai=2.0, ni=1.0, expected=0.5), - dict(Aimp=20.0, Zimp=10.0, Zi=2.0, Ai=4.0, ni=0.5, expected=0.5), - dict(Aimp=20.0, Zimp=10.0, Zi=1.0, Ai=2.0, ni=0.9, expected=0.5), - dict(Aimp=40.0, Zimp=20.0, Zi=1.0, Ai=2.0, ni=0.92, expected=0.5), - ]) - # pylint: disable=invalid-name - def test_calculate_weighted_Zeff(self, Aimp, Zimp, Zi, Ai, ni, expected): - """Compare `_calculate_weighted_Zeff` to a reference value.""" - references = torax_refs.circular_references() - geo = references.geometry_provider( - references.runtime_params.numerics.t_initial - ) - ne = 1.0 - nimp = (ne - ni * Zi) / Zimp - core_profiles = state.CoreProfiles( - ne=cell_variable.CellVariable( - value=jnp.array(ne), - dr=jnp.array(1.0), - ), - ni=cell_variable.CellVariable( - value=jnp.array(ni), - dr=jnp.array(1.0), - ), - nimp=cell_variable.CellVariable( - value=jnp.array(nimp), - dr=jnp.array(1.0), - ), - temp_ion=cell_variable.CellVariable( - value=jnp.array(0.0), - dr=jnp.array(1.0), - ), - temp_el=cell_variable.CellVariable( - value=jnp.array(0.0), - dr=jnp.array(1.0), - ), - psi=cell_variable.CellVariable( - value=jnp.array(0.0), - dr=jnp.array(1.0), - ), - psidot=cell_variable.CellVariable( - value=jnp.array(0.0), - dr=jnp.array(1.0), - ), - vloop_lcfs=jnp.array(0.0), - currents=state.Currents.zeros(geo), - q_face=jnp.array(0.0), - s_face=jnp.array(0.0), - Zi=Zi, - Zi_face=Zi, - Ai=Ai, - Zimp=Zimp, - Zimp_face=Zimp, - Aimp=Aimp, - nref=1e20, - ) - # pylint: enable=invalid-name - # pylint: disable=protected-access - np.testing.assert_allclose( - physics._calculate_weighted_Zeff(core_profiles), expected - ) - # pylint: enable=protected-access - - # TODO(b/377225415): generalize to arbitrary number of ions. - # pylint: disable=invalid-name - @parameterized.parameters([ - dict(Zi=1.0, Zimp=10.0, Zeff=1.0, expected=1.0), - dict(Zi=1.0, Zimp=5.0, Zeff=1.0, expected=1.0), - dict(Zi=2.0, Zimp=10.0, Zeff=2.0, expected=0.5), - dict(Zi=2.0, Zimp=5.0, Zeff=2.0, expected=0.5), - dict(Zi=1.0, Zimp=10.0, Zeff=1.9, expected=0.9), - dict(Zi=2.0, Zimp=10.0, Zeff=3.6, expected=0.4), - ]) - def test_get_main_ion_dilution_factor(self, Zi, Zimp, Zeff, expected): - """Unit test of `get_main_ion_dilution_factor`.""" - np.testing.assert_allclose( - physics.get_main_ion_dilution_factor(Zi, Zimp, Zeff), expected - ) - - # pylint: enable=invalid-name - - def test_calculate_plh_scaling_factor(self): - """Compare `calculate_plh_scaling_factor` to a reference value.""" - geo = geometry_pydantic_model.CircularConfig( - n_rho=25, - elongation_LCFS=1.0, - hires_fac=4, - Rmaj=6.0, - Rmin=2.0, - B0=5.0, - ).build_geometry() - core_profiles = state.CoreProfiles( - ne=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 2, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(2.0), - dr=geo.drho_norm, - ), - ni=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 1, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(1.0), - dr=geo.drho_norm, - ), - nimp=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 0, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(0.0), - dr=geo.drho_norm, - ), - temp_ion=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 0, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(0.0), - dr=geo.drho_norm, - ), - temp_el=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 0, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(0.0), - dr=geo.drho_norm, - ), - psi=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 0, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(0.0), - dr=geo.drho_norm, - ), - psidot=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 0, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(0.0), - dr=geo.drho_norm, - ), - vloop_lcfs=jnp.array(0.0), - currents=state.Currents.zeros(geo), - q_face=jnp.array(0.0), - s_face=jnp.array(0.0), - Zi=1.0, - Zi_face=1.0, - Ai=3.0, - Zimp=20, - Zimp_face=20, - Aimp=40, - nref=1e20, - ) - core_profiles = dataclasses.replace( - core_profiles, - currents=dataclasses.replace( - core_profiles.currents, - Ip_profile_face=jnp.ones_like(geo.rho_face_norm) * 10e6, - ), - ) - # pylint: disable=invalid-name - P_LH_hi_dens, P_LH_min, P_LH, ne_min_P_LH = ( - physics.calculate_plh_scaling_factor(geo, core_profiles) - ) - expected_PLH_hi_dens = ( - 2.15 * 2**0.782 * 5**0.772 * 2**0.975 * 6**0.999 * (2.0141 / 3) - ) - expected_PLH_min = ( - 0.36 * 10**0.27 * 5**1.25 * 6**1.23 * 3**0.08 * (2.0141 / 3) - ) - expected_ne_min_P_LH = 0.7 * 10**0.34 * 5**0.62 * 2.0**-0.95 * 3**0.4 / 10 - # pylint: enable=invalid-name - np.testing.assert_allclose(P_LH_hi_dens / 1e6, expected_PLH_hi_dens) - np.testing.assert_allclose(P_LH_min / 1e6, expected_PLH_min) - np.testing.assert_allclose(ne_min_P_LH, expected_ne_min_P_LH) - np.testing.assert_allclose(P_LH, P_LH_hi_dens) - - @parameterized.parameters([ - dict(elongation_LCFS=1.0), - dict(elongation_LCFS=1.5), - ]) - # pylint: disable=invalid-name - def test_calculate_scaling_law_confinement_time(self, elongation_LCFS): - """Compare `calculate_scaling_law_confinement_time` to reference values.""" - geo = geometry_pydantic_model.CircularConfig( - n_rho=25, - elongation_LCFS=elongation_LCFS, - hires_fac=4, - Rmaj=6.0, - Rmin=2.0, - B0=5.0, - ).build_geometry() - core_profiles = state.CoreProfiles( - ne=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 2, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(2.0), - dr=geo.drho_norm, - ), - ni=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 2, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(1.0), - dr=geo.drho_norm, - ), - nimp=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 0, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(0.0), - dr=geo.drho_norm, - ), - temp_ion=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 0, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(0.0), - dr=geo.drho_norm, - ), - temp_el=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 0, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(0.0), - dr=geo.drho_norm, - ), - psi=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 0, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(0.0), - dr=geo.drho_norm, - ), - psidot=cell_variable.CellVariable( - value=jnp.ones_like(geo.rho_norm) * 0, - left_face_grad_constraint=jnp.zeros(()), - right_face_grad_constraint=None, - right_face_constraint=jnp.array(0.0), - dr=geo.drho_norm, - ), - vloop_lcfs=jnp.array(0.0), - currents=state.Currents.zeros(geo), - q_face=jnp.array(0.0), - s_face=jnp.array(0.0), - Zi=1.0, - Zi_face=1.0, - Ai=3.0, - Zimp=20.0, - Zimp_face=20.0, - Aimp=40.0, - nref=1e20, - ) - core_profiles = dataclasses.replace( - core_profiles, - currents=dataclasses.replace( - core_profiles.currents, - Ip_profile_face=jnp.ones_like(geo.rho_face_norm) * 10e6, - ), - ) - Ploss = jnp.array(50e6) - - H89P = physics.calculate_scaling_law_confinement_time( - geo, core_profiles, Ploss, 'H89P' - ) - H98 = physics.calculate_scaling_law_confinement_time( - geo, core_profiles, Ploss, 'H98' - ) - H97L = physics.calculate_scaling_law_confinement_time( - geo, core_profiles, Ploss, 'H97L' - ) - H20 = physics.calculate_scaling_law_confinement_time( - geo, core_profiles, Ploss, 'H20' - ) - - expected_H89P = ( - 0.038128 - * 10**0.85 - * 5**0.2 - * 20**0.1 - * 50**-0.5 - * 6**1.5 - * (1 / 3) ** 0.3 - * 3**0.50 - * elongation_LCFS**0.50 - ) - - expected_H98 = ( - 0.0562 - * 10**0.93 - * 5**0.15 - * 20**0.41 - * 50**-0.69 - * 6**1.97 - * (1 / 3) ** 0.58 - * 3**0.19 - * elongation_LCFS**0.78 - ) - - expected_H97L = ( - 0.023 - * 10**0.96 - * 5**0.03 - * 20**0.4 - * 50**-0.73 - * 6**1.83 - * (1 / 3) ** -0.06 - * 3**0.20 - * elongation_LCFS**0.64 - ) - - expected_H20 = ( - 0.053 - * 10**0.98 - * 5**0.22 - * 20**0.24 - * 50**-0.669 - * 6**1.71 - * (1 / 3) ** 0.35 - * 3**0.20 - * elongation_LCFS**0.80 - ) - # pylint: enable=invalid-name - np.testing.assert_allclose(H89P, expected_H89P) - np.testing.assert_allclose(H98, expected_H98) - np.testing.assert_allclose(H97L, expected_H97L) - np.testing.assert_allclose(H20, expected_H20) - - # pylint: disable=invalid-name - def test_calc_Wpol(self): - """Compare `calc_Wpol` to an analytical formula in circular geometry.""" - - # Small inverse aspect ratio limit of circular geometry, such that we - # approximate the simplest form of circular geometry where the analytical - # Bpol formula is applicable. - geo = geometry_pydantic_model.CircularConfig( - n_rho=25, - elongation_LCFS=1.0, - Rmaj=100.0, - Rmin=1.0, - B0=5.0, - ).build_geometry() - Ip_tot = 15 - # calculate high resolution jtot consistent with total current profile - jtot_profile = (1 - geo.rho_hires_norm**2) ** 2 - denom = _trapz(jtot_profile * geo.spr_hires, geo.rho_hires_norm) - Ctot = Ip_tot * 1e6 / denom - jtot = jtot_profile * Ctot - # pylint: disable=protected-access - psi_cell_variable = initialization._update_psi_from_j( - Ip_tot, - geo, - jtot, - ) - _, _, Ip_profile_face = physics.calc_jtot_from_psi( - geo, - psi_cell_variable, - ) - - # Analytical formula for Bpol in circular geometry (Ampere's law) - Bpol_bulk = ( - constants.CONSTANTS.mu0 - * Ip_profile_face[1:] - / (2 * np.pi * geo.rho_face[1:]) - ) - Bpol = np.concatenate([np.array([0.0]), Bpol_bulk]) - - expected_Wpol = _trapz(Bpol**2 * geo.vpr_face, geo.rho_face_norm) / ( - 2 * constants.CONSTANTS.mu0 - ) - - calculated_Wpol = physics.calc_Wpol(geo, psi_cell_variable) - - # Relatively low tolerence because the analytical formula is not exact for - # our circular geometry, but approximates it at low inverse aspect ratio. - np.testing.assert_allclose(calculated_Wpol, expected_Wpol, rtol=1e-3) - - # pylint: enable=invalid-name - # pylint: enable=protected-access - - -if __name__ == '__main__': - absltest.main() diff --git a/torax/tests/post_processing.py b/torax/tests/post_processing.py index f03547ef..b23cbc03 100644 --- a/torax/tests/post_processing.py +++ b/torax/tests/post_processing.py @@ -180,7 +180,7 @@ def test_compute_stored_thermal_energy(self): ) # pylint: enable=protected-access - volume = np.trapz(geo.vpr_face, geo.rho_face_norm) + volume = np.trapezoid(geo.vpr_face, geo.rho_face_norm) np.testing.assert_allclose(wth_el, 1.5 * p_el[0] * volume) np.testing.assert_allclose(wth_ion, 1.5 * p_ion[0] * volume) diff --git a/torax/tests/test_lib/explicit_stepper.py b/torax/tests/test_lib/explicit_stepper.py index 2e612c18..2a46d293 100644 --- a/torax/tests/test_lib/explicit_stepper.py +++ b/torax/tests/test_lib/explicit_stepper.py @@ -23,12 +23,12 @@ import jax from jax import numpy as jnp from torax import constants -from torax import physics from torax import state from torax.config import runtime_params_slice from torax.core_profiles import updaters from torax.fvm import diffusion_terms from torax.geometry import geometry +from torax.physics import psi_calculations from torax.sources import source_operations from torax.sources import source_profile_builders from torax.sources import source_profiles @@ -67,8 +67,6 @@ def __call__( ]: """Applies a time step update. See Stepper.__call__ docstring.""" - # Many variables throughout this function are capitalized based on physics - # notational conventions rather than on Google Python style # pylint: disable=invalid-name # The explicit method is for testing purposes and @@ -130,11 +128,11 @@ def __call__( **updated_boundary_conditions['temp_ion'], ) - q_face, _ = physics.calc_q_from_psi( + q_face, _ = psi_calculations.calc_q( geo=geo_t, psi=core_profiles_t.psi, ) - s_face = physics.calc_s_from_psi(geo_t, core_profiles_t.psi) + s_face = psi_calculations.calc_s(geo_t, core_profiles_t.psi) # error isn't used for timestep adaptation for this method. # However, too large a timestep will lead to numerical instabilities. diff --git a/torax/tests/test_lib/torax_refs.py b/torax/tests/test_lib/torax_refs.py index 7e0f89cd..f1a7a4c8 100644 --- a/torax/tests/test_lib/torax_refs.py +++ b/torax/tests/test_lib/torax_refs.py @@ -48,6 +48,7 @@ class References: psi: fvm.cell_variable.CellVariable psi_face_grad: np.ndarray jtot: np.ndarray + q: np.ndarray s: np.ndarray Ip_from_parameters: bool # pylint: disable=invalid-name @@ -185,6 +186,34 @@ def circular_references() -> References: 5.96528658619718e03, -2.16166743039443e03, ]) + q = np.array([ + 0.65136284674552, + 0.65136284674552, + 0.66549200308491, + 0.68288435974473, + 0.70403585664969, + 0.72929778158408, + 0.75816199863277, + 0.78585816811502, + 0.79880472498816, + 0.78520449625846, + 0.76017886574672, + 0.75633510231278, + 0.79147108029917, + 0.86242114812994, + 0.95771945489906, + 1.06980964076381, + 1.19762252680817, + 1.34329397040352, + 1.50983774906477, + 1.70045978288947, + 1.91831678381418, + 2.16626740192572, + 2.44657621398095, + 2.76066019839069, + 3.10910068989888, + 3.49443220339156, + ]) s = np.array([ -0.0, 0.01061557184301, @@ -219,6 +248,7 @@ def circular_references() -> References: psi=psi, psi_face_grad=psi_face_grad, jtot=jtot, + q=q, s=s, Ip_from_parameters=True, ) @@ -335,6 +365,34 @@ def chease_references_Ip_from_chease() -> References: # pylint: disable=invalid 520019.67038794764, 580635.5973376503, ]) + q = np.array([ + 1.74778489687499, + 1.74778489687499, + 1.63019017120101, + 1.55135548433531, + 1.47944703389477, + 1.34587749007678, + 1.22208416710003, + 1.14174321035495, + 1.12073724167867, + 1.15719843686828, + 1.23446022617573, + 1.33693108455852, + 1.45471199796885, + 1.58807953262815, + 1.73889524213944, + 1.90938587558384, + 2.10182786843057, + 2.31924672085511, + 2.56737489962044, + 2.85095694625504, + 3.18252465224016, + 3.5949796873733, + 4.06580137921761, + 4.41721011791634, + 4.74589637288284, + 4.98383229828587, + ]) s = np.array([ -0.0, -0.03606779373088, @@ -369,6 +427,7 @@ def chease_references_Ip_from_chease() -> References: # pylint: disable=invalid psi=psi, psi_face_grad=psi_face_grad, jtot=jtot, + q=q, s=s, Ip_from_parameters=False, ) @@ -485,6 +544,34 @@ def chease_references_Ip_from_runtime_params() -> References: # pylint: disable 662741.7385489461, 739994.0178338734, ]) + q = np.array([ + 1.37139774532955, + 1.37139774532955, + 1.27912715646002, + 1.21726959491758, + 1.16084669815229, + 1.05604148352651, + 0.95890717122028, + 0.89586771645968, + 0.87938540325705, + 0.90799464514058, + 0.9686180341204, + 1.04902169500535, + 1.14143837590391, + 1.24608508423676, + 1.36442248625569, + 1.49819779849434, + 1.64919722386333, + 1.81979471817761, + 2.01448825599212, + 2.23700063727324, + 2.4971649202036, + 2.82079736847724, + 3.19022715803668, + 3.46595968828215, + 3.72386304343085, + 3.91055923940663, + ]) s = np.array([ -0.0, -0.03606779373088, @@ -519,6 +606,7 @@ def chease_references_Ip_from_runtime_params() -> References: # pylint: disable psi=psi, psi_face_grad=psi_face_grad, jtot=jtot, + q=q, s=s, Ip_from_parameters=True, ) diff --git a/torax/transport_model/bohm_gyrobohm.py b/torax/transport_model/bohm_gyrobohm.py index 9dda8be3..0356ee2f 100644 --- a/torax/transport_model/bohm_gyrobohm.py +++ b/torax/transport_model/bohm_gyrobohm.py @@ -139,8 +139,6 @@ def _call_implementation( coeffs: The transport coefficients """ del pedestal_model_outputs # Unused. - # Many variables throughout this function are capitalized based on physics - # notational conventions rather than on Google Python style # pylint: disable=invalid-name assert isinstance( dynamic_runtime_params_slice.transport, DynamicRuntimeParams diff --git a/torax/transport_model/critical_gradient.py b/torax/transport_model/critical_gradient.py index d7606e26..31af9c14 100644 --- a/torax/transport_model/critical_gradient.py +++ b/torax/transport_model/critical_gradient.py @@ -107,7 +107,14 @@ def _call_implementation( core_profiles: state.CoreProfiles, pedestal_model_outputs: pedestal_model_lib.PedestalModelOutput, ) -> state.CoreTransport: - """Calculates transport coefficients using the Critical Gradient Model. + r"""Calculates transport coefficients using the Critical Gradient Model. + + Uses critical normalized logarithmic ion temperature gradient + :math:`R/L_{Ti}|_crit` from Guo Romanelli 1993: + :math:`\chi_i = \chi_{GB} \chi_{stiff} H(R/L_{Ti} - R/L_{Ti})` + where :math:`\chi_{GB}` is the GyroBohm diffusivity, + :math:`\chi_{stiff}` is the stiffness parameter, and + :math:`H` is the Heaviside function. Args: dynamic_runtime_params_slice: Input runtime parameters that can change @@ -120,14 +127,8 @@ def _call_implementation( coeffs: The transport coefficients """ - # Many variables throughout this function are capitalized based on physics - # notational conventions rather than on Google Python style # pylint: disable=invalid-name - # ITG critical gradient model. R/LTi_crit from Guo Romanelli 1993 - # chi_i = chiGB * chistiff * H(R/LTi - - # R/LTi_crit)*(R/LTi - R/LTi_crit)^alpha - constants = constants_module.CONSTANTS assert isinstance( dynamic_runtime_params_slice.transport, DynamicRuntimeParams diff --git a/torax/transport_model/qualikiz_based_transport_model.py b/torax/transport_model/qualikiz_based_transport_model.py index cc74df3d..5ea7f64c 100644 --- a/torax/transport_model/qualikiz_based_transport_model.py +++ b/torax/transport_model/qualikiz_based_transport_model.py @@ -16,9 +16,10 @@ import chex from jax import numpy as jnp from torax import constants as constants_module -from torax import physics from torax import state from torax.geometry import geometry +from torax.physics import collisions +from torax.physics import psi_calculations from torax.transport_model import quasilinear_transport_model from torax.transport_model import runtime_params as runtime_params_lib @@ -144,11 +145,11 @@ def _prepare_qualikiz_inputs( # Calculate q and s. # Need to recalculate since in the nonlinear solver psi has intermediate # states in the iterative solve. - q, _ = physics.calc_q_from_psi( + q, _ = psi_calculations.calc_q( geo=geo, psi=core_profiles.psi, ) - smag = physics.calc_s_from_psi_rmid( + smag = psi_calculations.calc_s_rmid( geo, core_profiles.psi, ) @@ -165,7 +166,7 @@ def _prepare_qualikiz_inputs( ) # logarithm of normalized collisionality - nu_star = physics.calc_nu_star( + nu_star = collisions.calc_nu_star( geo=geo, core_profiles=core_profiles, nref=nref,