Skip to content

Commit

Permalink
Reorganize physics package.
Browse files Browse the repository at this point in the history
Previously, various physics calculations which were not directly associated with an existing component like sources, transport_models, were dispersed between physics.py and post_processing.py.

This refactor introduces a physics package and clearer organization of physics functions into thematic modules: collisions, charge_states, psi_calculations, and scaling_laws.

Drive-by: cleanup of unnecessary comments and docstring fixes

Upcoming PR will move "physics" calculations from post_processing into the physics package.

PiperOrigin-RevId: 732272746
  • Loading branch information
jcitrin authored and Torax team committed Mar 3, 2025
1 parent dcab336 commit 73e4f23
Show file tree
Hide file tree
Showing 37 changed files with 1,486 additions and 1,386 deletions.
2 changes: 1 addition & 1 deletion torax/config/tests/plasma_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax import charge_states
from torax import interpolated_param
from torax.config import plasma_composition
from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.physics import charge_states


class PlasmaCompositionTest(parameterized.TestCase):
Expand Down
13 changes: 11 additions & 2 deletions torax/core_profiles/formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@

_trapz = jax.scipy.integrate.trapezoid

# Using capitalized variables for physics notational conventions rather than
# Python style.
# pylint: disable=invalid-name


Expand Down Expand Up @@ -157,3 +155,14 @@ def calculate_psi_grad_constraint_from_Ip_tot(
* (16 * jnp.pi**3 * constants.CONSTANTS.mu0 * geo.Phib)
/ (geo.g2g3_over_rhon_face[-1] * geo.F_face[-1])
)


# TODO(b/377225415): generalize to arbitrary number of ions.
def get_main_ion_dilution_factor(
Zi: array_typing.ScalarFloat,
Zimp: array_typing.ArrayFloat,
Zeff: array_typing.ArrayFloat,
) -> jax.Array:
"""Calculates the main ion dilution factor based on a single assumed impurity and general main ion charge."""
return (Zimp - Zeff) / (Zi * (Zimp - Zi))

12 changes: 4 additions & 8 deletions torax/core_profiles/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,21 @@
from torax import array_typing
from torax import constants
from torax import math_utils
from torax import physics
from torax import state
from torax.config import runtime_params_slice
from torax.core_profiles import formulas
from torax.core_profiles import updaters
from torax.fvm import cell_variable
from torax.geometry import geometry
from torax.geometry import standard_geometry
from torax.physics import psi_calculations
from torax.sources import ohmic_heat_source
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib

_trapz = jax.scipy.integrate.trapezoid

# Using capitalized variables for physics notational conventions rather than
# Python style.
# pylint: disable=invalid-name


Expand Down Expand Up @@ -133,7 +131,7 @@ def initial_core_profiles(
)

# Set psi as source of truth and recalculate jtot, q, s
return physics.update_jtot_q_face_s_face(
return psi_calculations.update_jtot_q_face_s_face(
geo=geo,
core_profiles=core_profiles,
)
Expand All @@ -147,8 +145,6 @@ def _prescribe_currents(
) -> state.Currents:
"""Creates the initial Currents from a given bootstrap profile."""

# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot
f_bootstrap = bootstrap_profile.I_bootstrap / (Ip * 1e6)

Expand Down Expand Up @@ -201,7 +197,7 @@ def _calculate_currents_from_psi(
source_profiles: source_profiles_lib.SourceProfiles,
) -> state.Currents:
"""Creates the initial Currents using psi to calculate jtot."""
jtot, jtot_face, Ip_profile_face = physics.calc_jtot_from_psi(
jtot, jtot_face, Ip_profile_face = psi_calculations.calc_jtot(
geo,
core_profiles.psi,
)
Expand Down Expand Up @@ -474,7 +470,7 @@ def _init_psi_psidot_vloop_and_current(
currents.jtot_hires,
use_vloop_lcfs_boundary_condition=use_vloop_bc,
)
_, _, Ip_profile_face = physics.calc_jtot_from_psi(
_, _, Ip_profile_face = psi_calculations.calc_jtot(
geo,
psi,
)
Expand Down
19 changes: 19 additions & 0 deletions torax/core_profiles/tests/formulas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,25 @@ def test_ne_core_profile_setter_with_fGW(
np.all(np.isclose(ratio, ratio[0]))
self.assertNotEqual(ratio[0], 1.0)

# TODO(b/377225415): generalize to arbitrary number of ions.
# pylint: disable=invalid-name
@parameterized.parameters([
dict(Zi=1.0, Zimp=10.0, Zeff=1.0, expected=1.0),
dict(Zi=1.0, Zimp=5.0, Zeff=1.0, expected=1.0),
dict(Zi=2.0, Zimp=10.0, Zeff=2.0, expected=0.5),
dict(Zi=2.0, Zimp=5.0, Zeff=2.0, expected=0.5),
dict(Zi=1.0, Zimp=10.0, Zeff=1.9, expected=0.9),
dict(Zi=2.0, Zimp=10.0, Zeff=3.6, expected=0.4),
])
def test_get_main_ion_dilution_factor(self, Zi, Zimp, Zeff, expected):
"""Unit test of `get_main_ion_dilution_factor`."""
np.testing.assert_allclose(
formulas.get_main_ion_dilution_factor(Zi, Zimp, Zeff),
expected,
)

# pylint: enable=invalid-name


if __name__ == '__main__':
absltest.main()
45 changes: 20 additions & 25 deletions torax/core_profiles/tests/initialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from torax.config import runtime_params_slice as runtime_params_slice_lib
from torax.core_profiles import initialization
from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.geometry import standard_geometry
from torax.sources import generic_current_source
from torax.sources import source_models as source_models_lib
from torax.sources import source_profiles
Expand All @@ -41,7 +40,6 @@ def setUp(self):
self.geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry()

@parameterized.parameters([
dict(references_getter=torax_refs.circular_references),
dict(references_getter=torax_refs.chease_references_Ip_from_chease),
dict(
references_getter=torax_refs.chease_references_Ip_from_runtime_params
Expand All @@ -67,29 +65,26 @@ def test_update_psi_from_j(
},
)
)
if isinstance(geo, standard_geometry.StandardGeometry):
psi = geo.psi_from_Ip
else:
bootstrap = source_profiles.BootstrapCurrentProfile.zero_profile(geo)
external_current = generic_current_source.calculate_generic_current(
mock.ANY,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
source_name=generic_current_source.GenericCurrentSource.SOURCE_NAME,
unused_state=mock.ANY,
unused_calculated_source_profiles=mock.ANY,
)[0]
currents = initialization._prescribe_currents(
bootstrap_profile=bootstrap,
external_current=external_current,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
)
psi = initialization._update_psi_from_j(
dynamic_runtime_params_slice.profile_conditions.Ip_tot,
geo,
currents.jtot_hires,
).value
bootstrap = source_profiles.BootstrapCurrentProfile.zero_profile(geo)
external_current = generic_current_source.calculate_generic_current(
mock.ANY,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
source_name=generic_current_source.GenericCurrentSource.SOURCE_NAME,
unused_state=mock.ANY,
unused_calculated_source_profiles=mock.ANY,
)[0]
currents = initialization._prescribe_currents(
bootstrap_profile=bootstrap,
external_current=external_current,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
)
psi = initialization._update_psi_from_j(
dynamic_runtime_params_slice.profile_conditions.Ip_tot,
geo,
currents.jtot_hires,
).value
np.testing.assert_allclose(psi, references.psi.value)

@parameterized.parameters(
Expand Down
22 changes: 20 additions & 2 deletions torax/core_profiles/tests/updaters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from jax import numpy as jnp
import numpy as np
from torax import jax_utils
from torax import physics
from torax import state
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
Expand Down Expand Up @@ -91,7 +90,7 @@ def test_get_ion_density_and_charge_states(self):

Zeff = dynamic_runtime_params_slice.plasma_composition.Zeff

dilution_factor = physics.get_main_ion_dilution_factor(Zi, Zimp, Zeff)
dilution_factor = formulas.get_main_ion_dilution_factor(Zi, Zimp, Zeff)
np.testing.assert_allclose(
ni.value,
expected_value * dilution_factor,
Expand Down Expand Up @@ -355,6 +354,25 @@ def test_compute_boundary_conditions_Ti(
expected_Ti_bound_right,
)

# TODO(b/377225415): generalize to arbitrary number of ions.
# pylint: disable=invalid-name
@parameterized.parameters([
dict(Zi=1.0, Zimp=10.0, Zeff=1.0, expected=1.0),
dict(Zi=1.0, Zimp=5.0, Zeff=1.0, expected=1.0),
dict(Zi=2.0, Zimp=10.0, Zeff=2.0, expected=0.5),
dict(Zi=2.0, Zimp=5.0, Zeff=2.0, expected=0.5),
dict(Zi=1.0, Zimp=10.0, Zeff=1.9, expected=0.9),
dict(Zi=2.0, Zimp=10.0, Zeff=3.6, expected=0.4),
])
def test_get_main_ion_dilution_factor(self, Zi, Zimp, Zeff, expected):
"""Unit test of `get_main_ion_dilution_factor`."""
np.testing.assert_allclose(
formulas.get_main_ion_dilution_factor(Zi, Zimp, Zeff),
expected,
)

# pylint: enable=invalid-name


if __name__ == '__main__':
absltest.main()
11 changes: 4 additions & 7 deletions torax/core_profiles/updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,17 @@
import jax
from jax import numpy as jnp
from torax import array_typing
from torax import charge_states
from torax import jax_utils
from torax import physics
from torax import state
from torax.config import runtime_params_slice
from torax.core_profiles import formulas
from torax.fvm import cell_variable
from torax.geometry import geometry
from torax.physics import charge_states

_trapz = jax.scipy.integrate.trapezoid


# Using capitalized variables for physics notational conventions rather than
# Python style.
# pylint: disable=invalid-name
def _get_charge_states(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
Expand Down Expand Up @@ -128,8 +125,8 @@ def get_ion_density_and_charge_states(
Zeff = dynamic_runtime_params_slice.plasma_composition.Zeff
Zeff_face = dynamic_runtime_params_slice.plasma_composition.Zeff_face

dilution_factor = physics.get_main_ion_dilution_factor(Zi, Zimp, Zeff)
dilution_factor_edge = physics.get_main_ion_dilution_factor(
dilution_factor = formulas.get_main_ion_dilution_factor(Zi, Zimp, Zeff)
dilution_factor_edge = formulas.get_main_ion_dilution_factor(
Zi_face[-1], Zimp_face[-1], Zeff_face[-1]
)

Expand Down Expand Up @@ -347,7 +344,7 @@ def compute_boundary_conditions_for_t_plus_dt(
Te=Te_bound_right,
)

dilution_factor_edge = physics.get_main_ion_dilution_factor(
dilution_factor_edge = formulas.get_main_ion_dilution_factor(
Zi_edge,
Zimp_edge,
dynamic_runtime_params_slice_t_plus_dt.plasma_composition.Zeff_face[-1],
Expand Down
18 changes: 16 additions & 2 deletions torax/fvm/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import jax.numpy as jnp
from torax import constants
from torax import jax_utils
from torax import physics
from torax import state
from torax.config import config_args
from torax.config import runtime_params_slice
Expand Down Expand Up @@ -400,7 +399,7 @@ def _calc_coeffs_full(

# Boolean mask for enforcing internal temperature boundary conditions to
# model the pedestal.
mask = physics.internal_boundary(
mask = _internal_boundary(
geo,
pedestal_model_output.rho_norm_ped_top,
dynamic_runtime_params_slice.profile_conditions.set_pedestal,
Expand Down Expand Up @@ -784,3 +783,18 @@ def _calc_coeffs_reduced(
transient_in_cell=transient_in_cell,
)
return coeffs


# pylint: disable=invalid-name
def _internal_boundary(
geo: geometry.Geometry,
Ped_top: jax.Array,
set_pedestal: jax.Array,
) -> jax.Array:
# Create Boolean mask FiPy CellVariable with True where the internal boundary
# condition is
# find index closest to pedestal top.
idx = jnp.abs(geo.rho_norm - Ped_top).argmin()
mask_np = jnp.zeros(len(geo.rho), dtype=bool)
mask_np = jnp.where(set_pedestal, mask_np.at[idx].set(True), mask_np)
return mask_np
9 changes: 5 additions & 4 deletions torax/orchestration/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
import jax
import jax.numpy as jnp
from torax import jax_utils
from torax import physics
from torax import post_processing
from torax import state
from torax.config import runtime_params_slice
from torax.core_profiles import updaters
from torax.geometry import geometry
from torax.geometry import geometry_provider as geometry_provider_lib
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.physics import psi_calculations
from torax.sources import ohmic_heat_source
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
Expand Down Expand Up @@ -525,7 +525,7 @@ def finalize_output(
"""

# Update total current, q, and s profiles based on new psi
output_state.core_profiles = physics.update_jtot_q_face_s_face(
output_state.core_profiles = psi_calculations.update_jtot_q_face_s_face(
geo=geo_t_plus_dt,
core_profiles=output_state.core_profiles,
)
Expand Down Expand Up @@ -660,7 +660,8 @@ def _update_current_distribution(
bootstrap_profile = core_sources.j_bootstrap
# Needed for the case where no psi sources are present.
external_current = jnp.zeros_like(
core_profiles.currents.external_current_source)
core_profiles.currents.external_current_source
)
external_current += sum(core_sources.psi.values())

johm = (
Expand Down Expand Up @@ -699,7 +700,7 @@ def _update_psidot(
resistivity_multiplier=dynamic_runtime_params_slice.numerics.resistivity_mult,
psi=core_profiles.psi,
geo=geo,
)
),
)

new_core_profiles = dataclasses.replace(
Expand Down
4 changes: 2 additions & 2 deletions torax/pedestal_model/set_pped_tpedratio_nped.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from torax import array_typing
from torax import constants
from torax import interpolated_param
from torax import physics
from torax import state
from torax.config import runtime_params_slice
from torax.core_profiles import formulas
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model
from torax.pedestal_model import runtime_params as runtime_params_lib
Expand Down Expand Up @@ -130,7 +130,7 @@ def _call_implementation(
Zeff_ped = Zeff[ped_idx]
Zi_ped = Zi[ped_idx]
Zimp_ped = Zimp[ped_idx]
dilution_factor_ped = physics.get_main_ion_dilution_factor(
dilution_factor_ped = formulas.get_main_ion_dilution_factor(
Zi_ped, Zimp_ped, Zeff_ped
)
# Calculate ni and nimp.
Expand Down
Loading

0 comments on commit 73e4f23

Please sign in to comment.