Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor physics calculations out of post_processing.py #786

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,25 @@
import jax
from jax import numpy as jnp
from torax import array_typing
from torax import charge_states
from torax import constants
from torax import jax_utils
from torax import math_utils
from torax import physics
from torax import state
from torax.config import numerics
from torax.config import profile_conditions
from torax.config import runtime_params_slice
from torax.fvm import cell_variable
from torax.geometry import geometry
from torax.geometry import standard_geometry
from torax.physics import charge_states
from torax.physics import psi_calculations
from torax.sources import ohmic_heat_source
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib

_trapz = jax.scipy.integrate.trapezoid

# Using capitalized variables for physics notational conventions rather than
# Python style.
# pylint: disable=invalid-name


Expand Down Expand Up @@ -254,8 +252,8 @@ def get_ion_density_and_charge_states(
Zeff = dynamic_runtime_params_slice.plasma_composition.Zeff
Zeff_face = dynamic_runtime_params_slice.plasma_composition.Zeff_face

dilution_factor = physics.get_main_ion_dilution_factor(Zi, Zimp, Zeff)
dilution_factor_edge = physics.get_main_ion_dilution_factor(
dilution_factor = get_main_ion_dilution_factor(Zi, Zimp, Zeff)
dilution_factor_edge = get_main_ion_dilution_factor(
Zi_face[-1], Zimp_face[-1], Zeff_face[-1]
)

Expand Down Expand Up @@ -288,8 +286,6 @@ def _prescribe_currents(
) -> state.Currents:
"""Creates the initial Currents from a given bootstrap profile."""

# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot
f_bootstrap = bootstrap_profile.I_bootstrap / (Ip * 1e6)

Expand Down Expand Up @@ -342,7 +338,7 @@ def _calculate_currents_from_psi(
source_profiles: source_profiles_lib.SourceProfiles,
) -> state.Currents:
"""Creates the initial Currents using psi to calculate jtot."""
jtot, jtot_face, Ip_profile_face = physics.calc_jtot_from_psi(
jtot, jtot_face, Ip_profile_face = psi_calculations.calc_jtot(
geo,
core_profiles.psi,
)
Expand Down Expand Up @@ -642,7 +638,7 @@ def _init_psi_psidot_vloop_and_current(
currents.jtot_hires,
use_vloop_lcfs_boundary_condition=use_vloop_bc,
)
_, _, Ip_profile_face = physics.calc_jtot_from_psi(
_, _, Ip_profile_face = psi_calculations.calc_jtot(
geo,
psi,
)
Expand Down Expand Up @@ -780,7 +776,7 @@ def initial_core_profiles(
)

# Set psi as source of truth and recalculate jtot, q, s
return physics.update_jtot_q_face_s_face(
return psi_calculations.update_jtot_q_face_s_face(
geo=geo,
core_profiles=core_profiles,
)
Expand Down Expand Up @@ -965,7 +961,7 @@ def compute_boundary_conditions_for_t_plus_dt(
Te=Te_bound_right,
)

dilution_factor_edge = physics.get_main_ion_dilution_factor(
dilution_factor_edge = get_main_ion_dilution_factor(
Zi_edge,
Zimp_edge,
dynamic_runtime_params_slice_t_plus_dt.plasma_composition.Zeff_face[-1],
Expand Down Expand Up @@ -1063,3 +1059,13 @@ def _get_jtot_hires(
johm_hires = jformula_hires * Cohm_hires
jtot_hires = johm_hires + external_current_hires + j_bootstrap_hires
return jtot_hires


# TODO(b/377225415): generalize to arbitrary number of ions.
def get_main_ion_dilution_factor(
Zi: array_typing.ScalarFloat,
Zimp: array_typing.ArrayFloat,
Zeff: array_typing.ArrayFloat,
) -> jax.Array:
"""Calculates the main ion dilution factor based on a single assumed impurity and general main ion charge."""
return (Zimp - Zeff) / (Zi * (Zimp - Zi))
192 changes: 192 additions & 0 deletions torax/core_profiles/formulas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common functions used for working with core profiles."""

import jax
from jax import numpy as jnp
from torax import array_typing
from torax import constants
from torax import state
from torax.geometry import geometry

_trapz = jax.scipy.integrate.trapezoid


def compute_pressure(
core_profiles: state.CoreProfiles,
) -> tuple[array_typing.ArrayFloat, ...]:
"""Calculates pressure from density and temperatures on the face grid.

Args:
core_profiles: CoreProfiles object containing information on temperatures
and densities.

Returns:
pressure_thermal_el_face: Electron thermal pressure [Pa]
pressure_thermal_ion_face: Ion thermal pressure [Pa]
pressure_thermal_tot_face: Total thermal pressure [Pa]
"""
ne = core_profiles.ne.face_value()
ni = core_profiles.ni.face_value()
nimp = core_profiles.nimp.face_value()
temp_ion = core_profiles.temp_ion.face_value()
temp_el = core_profiles.temp_el.face_value()
prefactor = constants.CONSTANTS.keV2J * core_profiles.nref
pressure_thermal_el_face = ne * temp_el * prefactor
pressure_thermal_ion_face = (ni + nimp) * temp_ion * prefactor
pressure_thermal_tot_face = (
pressure_thermal_el_face + pressure_thermal_ion_face
)
return (
pressure_thermal_el_face,
pressure_thermal_ion_face,
pressure_thermal_tot_face,
)


def compute_pprime(
core_profiles: state.CoreProfiles,
) -> array_typing.ArrayFloat:
r"""Calculates total pressure gradient with respect to poloidal flux.

Args:
core_profiles: CoreProfiles object containing information on temperatures
and densities.

Returns:
pprime: Total pressure gradient :math:`\partial p / \partial \psi`
with respect to the normalized toroidal flux coordinate, on the face grid.
"""

prefactor = constants.CONSTANTS.keV2J * core_profiles.nref

ne = core_profiles.ne.face_value()
ni = core_profiles.ni.face_value()
nimp = core_profiles.nimp.face_value()
temp_ion = core_profiles.temp_ion.face_value()
temp_el = core_profiles.temp_el.face_value()
dne_drhon = core_profiles.ne.face_grad()
dni_drhon = core_profiles.ni.face_grad()
dnimp_drhon = core_profiles.nimp.face_grad()
dti_drhon = core_profiles.temp_ion.face_grad()
dte_drhon = core_profiles.temp_el.face_grad()
dpsi_drhon = core_profiles.psi.face_grad()

dptot_drhon = prefactor * (
ne * dte_drhon
+ ni * dti_drhon
+ nimp * dti_drhon
+ dne_drhon * temp_el
+ dni_drhon * temp_ion
+ dnimp_drhon * temp_ion
)
# Calculate on-axis value with L'Hôpital's rule.
pprime_face_axis = jnp.expand_dims(dptot_drhon[1] / dpsi_drhon[1], axis=0)

# Zero on-axis due to boundary conditions. Avoid division by zero.
pprime_face = jnp.concatenate(
[pprime_face_axis, dptot_drhon[1:] / dpsi_drhon[1:]]
)

return pprime_face


# pylint: disable=invalid-name
def compute_FFprime(
core_profiles: state.CoreProfiles,
geo: geometry.Geometry,
) -> array_typing.ArrayFloat:
r"""Calculates FF', an output quantity used for equilibrium coupling.

Calculation is based on the following formulation of the magnetic
equilibrium equation:
:math:`-j_{tor} = 2\pi (Rp' + \frac{1}{\mu_0 R}FF')`

And following division by R and flux surface averaging:

:math:`-\langle \frac{j_{tor}}{R} \rangle = 2\pi (p' +
\langle\frac{1}{R^2}\rangle\frac{FF'}{\mu_0})`

Args:
core_profiles: CoreProfiles object containing information on temperatures
and densities.
geo: Magnetic equilibrium.

Returns:
FFprime: F is the toroidal flux function, and F' is its derivative with
respect to the poloidal flux.
"""

mu0 = constants.CONSTANTS.mu0
pprime = compute_pprime(core_profiles)
# g3 = <1/R^2>
g3 = geo.g3_face
jtor_over_R = core_profiles.currents.jtot_face / geo.Rmaj

FFprime_face = -(jtor_over_R / (2 * jnp.pi) + pprime) * mu0 / g3
return FFprime_face


# pylint: enable=invalid-name


def compute_stored_thermal_energy(
p_el: array_typing.ArrayFloat,
p_ion: array_typing.ArrayFloat,
p_tot: array_typing.ArrayFloat,
geo: geometry.Geometry,
) -> tuple[array_typing.ScalarFloat, ...]:
"""Calculates stored thermal energy from pressures.

Args:
p_el: Electron pressure [Pa]
p_ion: Ion pressure [Pa]
p_tot: Total pressure [Pa]
geo: Geometry object

Returns:
wth_el: Electron thermal stored energy [J]
wth_ion: Ion thermal stored energy [J]
wth_tot: Total thermal stored energy [J]
"""
wth_el = _trapz(1.5 * p_el * geo.vpr_face, geo.rho_face_norm)
wth_ion = _trapz(1.5 * p_ion * geo.vpr_face, geo.rho_face_norm)
wth_tot = _trapz(1.5 * p_tot * geo.vpr_face, geo.rho_face_norm)

return wth_el, wth_ion, wth_tot


def calculate_greenwald_fraction(
ne_avg: array_typing.ScalarFloat,
core_profiles: state.CoreProfiles,
geo: geometry.Geometry,
) -> array_typing.ScalarFloat:
"""Calculates the Greenwald fraction from the averaged electron density.

Different averaging can be used, e.g. volume-averaged or line-averaged.

Args:
ne_avg: Averaged electron density [nref m^-3]
core_profiles: CoreProfiles object containing information on currents and
densities.
geo: Geometry object

Returns:
fgw: Greenwald density fraction
"""
# gw_limit is in units of 10^20 m^-3 when Ip is in MA and Rmin is in m.
gw_limit = core_profiles.currents.Ip_total * 1e-6 / (jnp.pi * geo.Rmin**2)
fgw = ne_avg * core_profiles.nref / (gw_limit * 1e20)
return fgw
Loading
Loading