Skip to content

Commit

Permalink
Merge pull request #84 from meom-group/tools-doc
Browse files Browse the repository at this point in the history
Improve documentation and usage instructions
  • Loading branch information
vadmbertr authored Sep 6, 2024
2 parents 94d6097 + bd75e2c commit 694278b
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 40 deletions.
57 changes: 29 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,62 +5,63 @@
![Tests](https://github.com/meom-group/jaxparrow/actions/workflows/python-package.yml/badge.svg)
[![Docs](https://github.com/meom-group/jaxparrow/actions/workflows/python-documentation.yml/badge.svg)](https://jaxparrow.readthedocs.io/)

***jaxparrow*** implements a novel approach based on a variational formulation to compute the inversion of the cyclogeostrophic balance.
`jaxparrow` implements a novel approach based on a variational formulation to compute the inversion of the cyclogeostrophic balance.

It leverages the power of [JAX](https://jax.readthedocs.io/en/latest/), to efficiently solve the inversion as an optimization problem.
It leverages the power of [`JAX`](https://jax.readthedocs.io/en/latest/), to efficiently solve the inversion as an optimization problem.
Given the Sea Surface Height (SSH) field of an ocean system, **jaxparrow** estimates the velocity field that best satisfies the cyclogeostrophic balance.

See the full [documentation](https://jaxparrow.readthedocs.io/en/latest/)!

## Installation

The package is Pip-installable:
`jaxparrow` is Pip-installable:
```shell
pip install jaxparrow
```

**<ins>However</ins>**, users with access to GPUs or TPUs should first install JAX separately in order to fully benefit from its high-performance computing capacities.
**<ins>However</ins>**, users with access to GPUs or TPUs should first install `JAX` separately in order to fully benefit from its high-performance computing capacities.
See [JAX instructions](https://jax.readthedocs.io/en/latest/installation.html). \
By default, **jaxparrow** will install a CPU-only version of JAX if no other version is already present in the Python environment.
By default, `jaxparrow` will install a CPU-only version of JAX if no other version is already present in the Python environment.

## Usage

### As a package

Two functions are directly available from `jaxparrow`:
The function you are most probably looking for is `cyclogeostrophy`.
It computes the cyclogeostrophic velocity field (returned as two `2darray`) from:
- a SSH field (a `2darray`),
- the latitude and longitude grids at the T points (two `2darray`).

- `geostrophy` computes the geostrophic velocity field (returns two `2darray`) from:
- a SSH field (a `2darray`),
- the latitude and longitude at the T points (two `2darray`),
- an optional mask grid (one `2darray`).
- `cyclogeostrophy` computes the cyclogeostrophic velocity field (returns two `2darray`) from:
- a SSH field (a `2darray`),
- the latitude and longitude at the T points (two `2darray`),
- an optional mask grid (one `2darray`).
In a Python script, assuming that the input grids have already been initialised / imported, estimating the cyclogeostrophic velocities for a single timestamp would resort to:

*Because **jaxparrow** uses [C-grids](https://xgcm.readthedocs.io/en/latest/grids.html) the velocity fields are represented on two grids (U and V), and the SSH on one grid (T).*
```python
from jaxparrow import cyclogeostrophy

u_cyclo_2d, v_cyclo_2d = cyclogeostrophy(ssh_2d, lat_2d, lon_2d)
```

In a Python script, assuming that the input grids have already been initialised / imported, it would resort to:
*Because `jaxparrow` uses [C-grids](https://xgcm.readthedocs.io/en/latest/grids.html) the velocity fields are represented on two grids (U and V), and the tracer fields (such as SSH) on one grid (T).* \
We provide functions computing some kinematics (such as velocities magnitude, normalized relative vorticity, or kinematic energy) accounting for these gridding system:

```python
from jaxparrow import cyclogeostrophy, geostrophy

u_geos, v_geos = geostrophy(ssh_t=ssh,
lat_t=lat, lon_t=lon,
mask=mask)
u_cyclo, v_cyclo = cyclogeostrophy(ssh_t=ssh,
lat_t=lat, lon_t=lon,
mask=mask)
from jaxparrow.tools.kinematics import magnitude

uv_cyclo_2d, v_cyclo_2d = magnitude(u_cyclo_2d, v_cyclo_2d, interpolate=True)
```

To vectorise the application of the `geostrophy` and `cyclogeostrophy` functions across an added time dimension, one aims to utilize `vmap`.
However, this necessitates avoiding the use of `np.ma.masked_array`.
Hence, our functions accommodate mask `array` as parameter to effectively consider masked regions.
To vectorise the estimation of the cyclogeostrophy across a first time dimension, one aims to use `jax.vmap`.

```python
import jax

vmap_cyclogeostrophy = jax.vmap(cyclogeostrophy, in_axes=(0, None, None))
u_cyclo_3d, v_cyclo_3d = vmap_cyclogeostrophy(ssh_3d, lat_2d, lon_2d)
```

By default, the `cyclogeostrophy` function relies on our variational method.
Its `method` argument provides the ability to use an iterative method instead, either the one described by [Penven *et al.*](https://doi.org/10.1016/j.dsr2.2013.10.015), or the one by [Ioannou *et al.*](https://doi.org/10.1029/2019JC015031).
Additional arguments also give a finer control over the three approaches hyperparameters. \
See **jaxparrow** [API documentation](https://jaxparrow.readthedocs.io/en/latest/api.html) for more details.
See `jaxparrow` [API documentation](https://jaxparrow.readthedocs.io/en/latest/api.html) for more details.

[Notebooks](https://jaxparrow.readthedocs.io/en/latest/examples.html) are available as step-by-step examples.

Expand Down
46 changes: 38 additions & 8 deletions jaxparrow/tools/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,47 @@ def cyclogeostrophic_imbalance(
coriolis_factor_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"]
) -> [Float[Array, "lat lon"], Float[Array, "lat lon"]]:
"""
Computes the cyclogeostrophic imbalance of a 2d velocity field, on a C-grid (following NEMO convention [1]_).
Parameters
----------
u_geos_u : Float[Array, "lat lon"]
U component of the geostrophic velocity field (on the U grid)
v_geos_v : Float[Array, "lat lon"]
V component of the geostrophic velocity field (on the V grid)
u_cyclo_u : Float[Array, "lat lon"]
U component of the cyclogeostrophic velocity field (on the U grid)
v_cyclo_v : Float[Array, "lat lon"]
V component of the cyclogeostrophic velocity field (on the V grid)
dx_u : Float[Array, "lat lon"]
Spatial steps in meters along `x` (on the U grid)
dx_v : Float[Array, "lat lon"]
Spatial steps in meters along `x` (on the V grid)
dy_u : Float[Array, "lat lon"]
Spatial steps in meters along `y` (on the U grid)
dy_v : Float[Array, "lat lon"]
Spatial steps in meters along `y` (on the V grid)
coriolis_factor_u : Float[Array, "lat lon"]
Coriolis factor (on the U grid)
coriolis_factor_v : Float[Array, "lat lon"]
Coriolis factor (on the V grid)
mask : Float[Array, "lat lon"], optional
Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)
Returns
-------
u_imbalance_u : Float[Array, "lat lon"]
U component of the cyclogeostrophic imbalance, on the U grid
v_imbalance_v : Float[Array, "lat lon"]
V component of the cyclogeostrophic imbalance, on the V grid
"""
u_adv_v, v_adv_u = advection(u_cyclo_u, v_cyclo_v, dx_u, dx_v, dy_u, dy_v, mask)

u_imbalance = u_cyclo_u + v_adv_u / coriolis_factor_u - u_geos_u
v_imbalance = v_cyclo_v - u_adv_v / coriolis_factor_v - v_geos_v
u_imbalance_u = u_cyclo_u + v_adv_u / coriolis_factor_u - u_geos_u
v_imbalance_v = v_cyclo_v - u_adv_v / coriolis_factor_v - v_geos_v

return u_imbalance, v_imbalance
return u_imbalance_u, v_imbalance_v


def magnitude(
Expand Down Expand Up @@ -211,11 +246,6 @@ def normalized_relative_vorticity(
dx_v, _ = compute_spatial_step(lat_v, lon_v)
f_u = compute_coriolis_factor(lat_u)

# Handle spurious data and apply mask
# dy_u = sanitize_data(dy_u, jnp.nan, mask)
# dx_v = sanitize_data(dx_v, jnp.nan, mask)
# f_u = sanitize_data(f_u, jnp.nan, mask)

# Compute the normalized relative vorticity
du_dy_f = derivative(u, dy_u, mask, axis=0, padding="right") # (U(j), U(j+1)) -> F(j)
dv_dx_f = derivative(v, dx_v, mask, axis=1, padding="right") # (V(i), V(i+1)) -> F(i)
Expand Down
8 changes: 4 additions & 4 deletions jaxparrow/tools/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def axis0(pad_left):

arr = lax.cond(
pad_left,
lambda: field.at[1:, :].set(midpoint_values),
lambda: field.at[:-1, :].set(midpoint_values)
lambda: jnp.pad(midpoint_values, pad_width=((1, 0), (0, 0)), mode="edge"),
lambda: jnp.pad(midpoint_values, pad_width=((0, 1), (0, 0)), mode="edge")
)

return arr
Expand All @@ -64,8 +64,8 @@ def axis1(pad_left):

arr = lax.cond(
pad_left,
lambda: field.at[:, 1:].set(midpoint_values),
lambda: field.at[:, :-1].set(midpoint_values)
lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (1, 0)), mode="edge"),
lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (0, 1)), mode="edge")
)

return arr
Expand Down

0 comments on commit 694278b

Please sign in to comment.