Skip to content

Commit

Permalink
Add test against JAX nightly
Browse files Browse the repository at this point in the history
This will catch potential incompatibilities introduced in JAX before they make it into a release.

Also: update workflow definition to reflect current best practices.
PiperOrigin-RevId: 715045067
  • Loading branch information
Jake VanderPlas authored and Orbax Authors committed Jan 13, 2025
1 parent 34d79d4 commit 5d86ade
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ on:
branches:
- main

permissions:
contents: read
actions: write # to cancel previous workflows

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

jobs:
build-checkpoint:
name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
Expand All @@ -27,14 +35,12 @@ jobs:
include:
- python-version: "3.10"
jax-version: "0.4.34" # keep in sync with minimum version in checkpoint/pyproject.toml
- python-version: "3.11"
jax-version: "nightly"
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand All @@ -46,6 +52,8 @@ jobs:
pip uninstall -y orbax
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
pip install -U jax jaxlib
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
else
pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
Expand Down Expand Up @@ -80,15 +88,11 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["newest", "0.4.34"] # keep in sync with minimum version in export/pyproject.toml
jax-version: ["nightly", "newest", "0.4.34"] # keep in sync with minimum version in export/pyproject.toml
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Extract branch name
Expand All @@ -101,6 +105,8 @@ jobs:
pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
pip install -U jax jaxlib
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
else
pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
Expand Down

0 comments on commit 5d86ade

Please sign in to comment.