Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in Geometry where z_magnetic_axis isn't accessible under JIT. #502

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions torax/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class Geometry:
rho_hires: chex.Array
vpr_hires: chex.Array
Phibdot: chex.Array
_z_magnetic_axis: chex.Array
z_magnetic_axis: chex.Array

@property
def rho_norm(self) -> chex.Array:
Expand Down Expand Up @@ -247,18 +247,6 @@ def g1_over_vpr2_face(self) -> jax.Array:
self.g1_face[1:] / self.vpr_face[1:] ** 2, # avoid div by zero on-axis
))

@property
def z_magnetic_axis(self) -> chex.Numeric:
if self.geometry_type in [
GeometryType.CHEASE.value,
GeometryType.CIRCULAR.value,
]:
logging.warning(
'z_magnetic_axis is not defined for CHEASE or CIRCULAR geometry type.'
' Returning 0.',
)
return self._z_magnetic_axis


@chex.dataclass(frozen=True)
class GeometryProvider:
Expand Down Expand Up @@ -306,7 +294,7 @@ class GeometryProvider:
rho_hires_norm: interpolated_param.InterpolatedVarSingleAxis
rho_hires: interpolated_param.InterpolatedVarSingleAxis
vpr_hires: interpolated_param.InterpolatedVarSingleAxis
_z_magnetic_axis: interpolated_param.InterpolatedVarSingleAxis
z_magnetic_axis: interpolated_param.InterpolatedVarSingleAxis

@classmethod
def create_provider(
Expand Down Expand Up @@ -645,7 +633,7 @@ def build_circular_geometry(
# and geo_t_plus_dt are provided, and set to be the same for geo_t and
# geo_t_plus_dt for each given time interval.
Phibdot=np.asarray(0.0),
_z_magnetic_axis=np.asarray(0.0),
z_magnetic_axis=np.asarray(0.0),
)


Expand Down Expand Up @@ -830,6 +818,10 @@ def from_chease(
rhon = np.sqrt(Phi / Phi[-1])
vpr = 4 * np.pi * Phi[-1] * rhon / (F * flux_surf_avg_1_over_R2)

logging.warning(
'z_magnetic_axis field is not present in CHEASE data, setting to 0.0'
)

return cls(
geometry_type=GeometryType.CHEASE,
Ip_from_parameters=Ip_from_parameters,
Expand Down Expand Up @@ -1525,7 +1517,7 @@ def build_standard_geometry(
# and geo_t_plus_dt are provided, and set to be the same for geo_t and
# geo_t_plus_dt for each given time interval.
Phibdot=np.asarray(0.0),
_z_magnetic_axis=intermediate.z_magnetic_axis,
z_magnetic_axis=intermediate.z_magnetic_axis,
)


Expand Down
12 changes: 12 additions & 0 deletions torax/tests/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,18 @@ def test_build_geometry_from_eqdsk(self):
intermediate = geometry.StandardGeometryIntermediates.from_eqdsk()
geometry.build_standard_geometry(intermediate)

def test_geometry_objects_can_be_used_in_jax_jitted_functions(self):
"""Test public API of geometry objects can be used in jitted functions."""
geo = geometry.build_circular_geometry()

@jax.jit
def f(geo: geometry.Geometry):
for field in dir(geo):
if not field.startswith('_'):
getattr(geo, field)

f(geo)


def face_to_cell(n_rho, face):
cell = np.zeros(n_rho)
Expand Down