Skip to content

Commit

Permalink
Merge pull request #61 from meom-group/lax-jit-imp
Browse files Browse the repository at this point in the history
improve the jit graph operations
  • Loading branch information
vadmbertr authored Apr 16, 2024
2 parents cfed7ea + a6768a9 commit cc60090
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 115 deletions.
58 changes: 21 additions & 37 deletions jaxparrow/cyclogeostrophy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Callable
from functools import partial
import numbers
from typing import Literal, Union

from jax import jit, lax, value_and_grad
Expand Down Expand Up @@ -41,7 +40,6 @@ def cyclogeostrophy(
optim: Union[optax.GradientTransformation, str] = "sgd",
optim_kwargs: dict = None,
res_eps: float = RES_EPS_IT,
res_init: Union[float, Literal["same"]] = RES_INIT_IT,
use_res_filter: bool = False,
res_filter_size: int = RES_FILTER_SIZE_IT,
return_geos: bool = False,
Expand Down Expand Up @@ -90,12 +88,6 @@ def cyclogeostrophy(
When residuals are smaller, the iterative approach considers local convergence to cyclogeostrophy.
Defaults to ``RES_EPS_IT``
res_init : Union[float | Literal["same"]], optional
Residual initial value of the iterative approach.
When residuals are larger at the first iteration,
the iterative approach considers local divergence to cyclogeostrophy.
If equals to `same` (default) absolute values of the geostrophic velocities are used
use_res_filter : bool, optional
Use of a convolution filter for the iterative approach when computing the residuals [3]_ or not [2]_.
Expand Down Expand Up @@ -159,11 +151,22 @@ def cyclogeostrophy(
coriolis_factor_v = sanitize.sanitize_data(coriolis_factor_v, jnp.nan, mask)

if method == "variational":
if n_it is None:
n_it = N_IT_VAR
if isinstance(optim, str):
if optim_kwargs is None:
optim_kwargs = {"learning_rate": LR_VAR}
optim = getattr(optax, optim)(**optim_kwargs)
elif not isinstance(optim, optax.GradientTransformation):
raise TypeError("optim should be an optax.GradientTransformation optimizer, or a string referring to such "
"an optimizer.")
res = _variational(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask,
n_it, optim, optim_kwargs, return_losses)
n_it, optim, return_losses)
elif method == "iterative":
if n_it is None:
n_it = N_IT_IT
res = _iterative(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask,
n_it, res_eps, res_init, use_res_filter, res_filter_size, return_losses)
n_it, res_eps, use_res_filter, res_filter_size, return_losses)
else:
raise ValueError("method should be one of [\"variational\", \"iterative\"]")

Expand Down Expand Up @@ -217,6 +220,7 @@ def _it_step(

# compute dist to u_cyclo and v_cyclo
res_np1 = jnp.abs(u_np1 - u_cyclo) + jnp.abs(v_np1 - v_cyclo)
res_np1 = sanitize.sanitize_data(res_np1, 0., mask)
res_np1 = lax.cond(
use_res_filter, # apply filter
lambda operands: jsp.signal.convolve(operands[0], operands[1], mode="same", method="fft") / operands[2],
Expand Down Expand Up @@ -248,7 +252,7 @@ def _it_step(
return u_cyclo, v_cyclo, mask_it, res_n, losses, i


@partial(jit, static_argnames=("n_it", "res_init", "res_filter_size"))
@partial(jit, static_argnames=("n_it", "res_filter_size"))
def _iterative(
u_geos_u: Float[Array, "lat lon"],
v_geos_v: Float[Array, "lat lon"],
Expand All @@ -259,22 +263,12 @@ def _iterative(
coriolis_factor_u: Float[Array, "lat lon"],
coriolis_factor_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"],
n_it: Union[int, None],
n_it: int,
res_eps: float,
res_init: Union[float, str],
use_res_filter: bool,
res_filter_size: int,
return_losses: bool
) -> [Float[Array, "lat lon"], ...]:
if n_it is None:
n_it = N_IT_IT
if res_init == "same":
res_n = jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v))
elif isinstance(res_init, numbers.Number):
res_n = res_init * jnp.ones_like(u_geos_u)
else:
raise ValueError("res_init should be equal to \"same\" or be a number.")

# used if applying a filter when computing stopping criteria
res_filter = jnp.ones((res_filter_size, res_filter_size))
res_weights = jsp.signal.convolve(jnp.ones_like(u_geos_u), res_filter, mode="same", method="fft")
Expand All @@ -294,7 +288,8 @@ def step_fn(pytree):
u_cyclo, v_cyclo, _, _, losses, _ = lax.while_loop( # noqa
lambda args: (args[-1] < n_it) | jnp.any(args[2] != 1),
step_fn,
(u_geos_u, v_geos_v, mask.astype(int), res_n, jnp.ones(n_it) * jnp.nan, 0)
(u_geos_u, v_geos_v, mask.astype(int), jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v)),
jnp.ones(n_it) * jnp.nan, 0)
)

return u_cyclo, v_cyclo, losses
Expand Down Expand Up @@ -377,7 +372,7 @@ def step_fn(pytree):
return u_cyclo_u, v_cyclo_v, losses


@partial(jit, static_argnames=("n_it", "optim", "optim_kwargs"))
@partial(jit, static_argnames=("n_it", "optim"))
def _variational(
u_geos_u: Float[Array, "lat lon"],
v_geos_v: Float[Array, "lat lon"],
Expand All @@ -388,21 +383,10 @@ def _variational(
coriolis_factor_u: Float[Array, "lat lon"],
coriolis_factor_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"],
n_it: Union[int, None],
optim: Union[optax.GradientTransformation, str],
optim_kwargs: Union[dict, None],
n_it: int,
optim: optax.GradientTransformation,
return_losses: bool
) -> [Float[Array, "lat lon"], ...]:
if n_it is None:
n_it = N_IT_VAR
if isinstance(optim, str):
if optim_kwargs is None:
optim_kwargs = {"learning_rate": LR_VAR}
optim = getattr(optax, optim)(**optim_kwargs)
elif not isinstance(optim, optax.GradientTransformation):
raise TypeError("optim should be an optax.GradientTransformation optimizer, or a string referring to such an "
"optimizer.")

# define loss partial: freeze constant over iterations
loss_fn = partial(
_var_loss_fn,
Expand Down
11 changes: 7 additions & 4 deletions jaxparrow/tools/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .geometry import compute_spatial_step, compute_coriolis_factor
from .operators import derivative, interpolation
from .sanitize import sanitize_data
from .sanitize import init_mask, sanitize_data


def advection(
Expand Down Expand Up @@ -175,6 +175,9 @@ def normalized_relative_vorticity(
The normalised relative vorticity,
on the F grid (if ``interpolate=False``), or the T grid (if ``interpolate=True``)
"""
# Make sure the mask is initialized
mask = init_mask(u, mask)

# Compute spatial step and Coriolis factor
_, dy_u = compute_spatial_step(lat_u, lon_u)
dx_v, _ = compute_spatial_step(lat_v, lon_v)
Expand Down Expand Up @@ -202,13 +205,13 @@ def normalized_relative_vorticity(
return w


def eddy_kinetic_energy(
def kinetic_energy(
u: Float[Array, "lat lon"],
v: Float[Array, "lat lon"],
interpolate: bool = True
) -> Float[Array, "lat lon"]:
"""
Computes the Eddy Kinetic Energy (EKE) of a velocity field,
Computes the Kinetic Energy (KE) of a velocity field,
possibly on a C-grid (following NEMO convention [1]_) if ``interpolate=True``.
Parameters
Expand All @@ -227,7 +230,7 @@ def eddy_kinetic_energy(
Returns
-------
eke : Float[Array, "lat lon"]
The Eddy Kinetic Energy on the T grid
The Kinetic Energy on the T grid
"""
if interpolate:
# interpolate to the T point
Expand Down
6 changes: 3 additions & 3 deletions jaxparrow/tools/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ def derivative(
field : Float[Array, "lat lon"]
Interpolated field
"""
def do_derivate(field_b, field_f, pad_left):
def do_differentiate(field_b, field_f, pad_left):
field_b, field_f = handle_land_boundary(field_b, field_f, pad_left)
return field_f - field_b

def axis0(_field, pad_left):
field_b, field_f = _field[:-1, :], _field[1:, :]
midpoint_values = do_derivate(field_b, field_f, pad_left)
midpoint_values = do_differentiate(field_b, field_f, pad_left)

_field = lax.cond(
pad_left,
Expand All @@ -124,7 +124,7 @@ def axis0(_field, pad_left):

def axis1(_field, pad_left):
field_b, field_f = _field[:, :-1], _field[:, 1:]
midpoint_values = do_derivate(field_b, field_f, pad_left)
midpoint_values = do_differentiate(field_b, field_f, pad_left)

_field = lax.cond(
pad_left,
Expand Down
9 changes: 6 additions & 3 deletions jaxparrow/tools/sanitize.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ def handle_land_boundary(
Replaces the non-finite values of ``field1`` (``field2``) with values of ``field2`` (``field1``), element-wise.
It allows to introduce less non-finite values when applying grid operators.
In such cases, ``field1`` and ``field2`` are left and right shifted versions of a field.
In such cases, ``field1`` and ``field2`` are left and right shifted versions of a field (along one of the axes).
Parameters
----------
field1 : Float[Array, "lat lon"]
A field
field2 : Float[Array, "lat lon"]
Another field
pad_left : bool
If `True`, apply padding in the `left` direction (i.e. `West` or `South`) ;
if `False`, apply padding in the `right` direction (i.e. `East` or `North`).
Returns
-------
Expand All @@ -102,8 +105,8 @@ def sanitize_grid_np(
Sanitizes (unstructured) grids by interpolated and extrapolated `nan` or masked values to avoid spurious
(`0`, `nan`, `inf`) spatial steps and Coriolis factors.
Helper function written using ``numpy`` and ``scipy``, and as such not used internally,
because incompatible with ``jax.vmap``.
Helper function written using pure ``numpy`` and ``scipy``, and as such not used internally,
because incompatible with ``jax.vmap`` and likes.
Should be used before calling ``jaxparrow.geostrophy`` or ``jaxparrow.cyclogeostrophy``
in case of suspicious latitudes or longitudes T grids.
Expand Down
Loading

0 comments on commit cc60090

Please sign in to comment.