Skip to content

Commit

Permalink
refine basic usage
Browse files Browse the repository at this point in the history
  • Loading branch information
vadmbertr committed Sep 6, 2024
1 parent a331bcc commit bd75e2c
Showing 1 changed file with 29 additions and 28 deletions.
57 changes: 29 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,62 +5,63 @@
![Tests](https://github.com/meom-group/jaxparrow/actions/workflows/python-package.yml/badge.svg)
[![Docs](https://github.com/meom-group/jaxparrow/actions/workflows/python-documentation.yml/badge.svg)](https://jaxparrow.readthedocs.io/)

***jaxparrow*** implements a novel approach based on a variational formulation to compute the inversion of the cyclogeostrophic balance.
`jaxparrow` implements a novel approach based on a variational formulation to compute the inversion of the cyclogeostrophic balance.

It leverages the power of [JAX](https://jax.readthedocs.io/en/latest/), to efficiently solve the inversion as an optimization problem.
It leverages the power of [`JAX`](https://jax.readthedocs.io/en/latest/), to efficiently solve the inversion as an optimization problem.
Given the Sea Surface Height (SSH) field of an ocean system, **jaxparrow** estimates the velocity field that best satisfies the cyclogeostrophic balance.

See the full [documentation](https://jaxparrow.readthedocs.io/en/latest/)!

## Installation

The package is Pip-installable:
`jaxparrow` is Pip-installable:
```shell
pip install jaxparrow
```

**<ins>However</ins>**, users with access to GPUs or TPUs should first install JAX separately in order to fully benefit from its high-performance computing capacities.
**<ins>However</ins>**, users with access to GPUs or TPUs should first install `JAX` separately in order to fully benefit from its high-performance computing capacities.
See [JAX instructions](https://jax.readthedocs.io/en/latest/installation.html). \
By default, **jaxparrow** will install a CPU-only version of JAX if no other version is already present in the Python environment.
By default, `jaxparrow` will install a CPU-only version of JAX if no other version is already present in the Python environment.

## Usage

### As a package

Two functions are directly available from `jaxparrow`:
The function you are most probably looking for is `cyclogeostrophy`.
It computes the cyclogeostrophic velocity field (returned as two `2darray`) from:
- a SSH field (a `2darray`),
- the latitude and longitude grids at the T points (two `2darray`).

- `geostrophy` computes the geostrophic velocity field (returns two `2darray`) from:
- a SSH field (a `2darray`),
- the latitude and longitude at the T points (two `2darray`),
- an optional mask grid (one `2darray`).
- `cyclogeostrophy` computes the cyclogeostrophic velocity field (returns two `2darray`) from:
- a SSH field (a `2darray`),
- the latitude and longitude at the T points (two `2darray`),
- an optional mask grid (one `2darray`).
In a Python script, assuming that the input grids have already been initialised / imported, estimating the cyclogeostrophic velocities for a single timestamp would resort to:

*Because **jaxparrow** uses [C-grids](https://xgcm.readthedocs.io/en/latest/grids.html) the velocity fields are represented on two grids (U and V), and the SSH on one grid (T).*
```python
from jaxparrow import cyclogeostrophy

u_cyclo_2d, v_cyclo_2d = cyclogeostrophy(ssh_2d, lat_2d, lon_2d)
```

In a Python script, assuming that the input grids have already been initialised / imported, it would resort to:
*Because `jaxparrow` uses [C-grids](https://xgcm.readthedocs.io/en/latest/grids.html) the velocity fields are represented on two grids (U and V), and the tracer fields (such as SSH) on one grid (T).* \
We provide functions computing some kinematics (such as velocities magnitude, normalized relative vorticity, or kinematic energy) accounting for these gridding system:

```python
from jaxparrow import cyclogeostrophy, geostrophy

u_geos, v_geos = geostrophy(ssh_t=ssh,
lat_t=lat, lon_t=lon,
mask=mask)
u_cyclo, v_cyclo = cyclogeostrophy(ssh_t=ssh,
lat_t=lat, lon_t=lon,
mask=mask)
from jaxparrow.tools.kinematics import magnitude

uv_cyclo_2d, v_cyclo_2d = magnitude(u_cyclo_2d, v_cyclo_2d, interpolate=True)
```

To vectorise the application of the `geostrophy` and `cyclogeostrophy` functions across an added time dimension, one aims to utilize `vmap`.
However, this necessitates avoiding the use of `np.ma.masked_array`.
Hence, our functions accommodate mask `array` as parameter to effectively consider masked regions.
To vectorise the estimation of the cyclogeostrophy across a first time dimension, one aims to use `jax.vmap`.

```python
import jax

vmap_cyclogeostrophy = jax.vmap(cyclogeostrophy, in_axes=(0, None, None))
u_cyclo_3d, v_cyclo_3d = vmap_cyclogeostrophy(ssh_3d, lat_2d, lon_2d)
```

By default, the `cyclogeostrophy` function relies on our variational method.
Its `method` argument provides the ability to use an iterative method instead, either the one described by [Penven *et al.*](https://doi.org/10.1016/j.dsr2.2013.10.015), or the one by [Ioannou *et al.*](https://doi.org/10.1029/2019JC015031).
Additional arguments also give a finer control over the three approaches hyperparameters. \
See **jaxparrow** [API documentation](https://jaxparrow.readthedocs.io/en/latest/api.html) for more details.
See `jaxparrow` [API documentation](https://jaxparrow.readthedocs.io/en/latest/api.html) for more details.

[Notebooks](https://jaxparrow.readthedocs.io/en/latest/examples.html) are available as step-by-step examples.

Expand Down

0 comments on commit bd75e2c

Please sign in to comment.