From 5d86adec3718d826d76afe742e8222c9b9231a34 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 13 Jan 2025 11:26:55 -0800 Subject: [PATCH] Add test against JAX nightly This will catch potential incompatibilities introduced in JAX before they make it into a release. Also: update workflow definition to reflect current best practices. PiperOrigin-RevId: 715045067 --- .github/workflows/build.yml | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d7eded3b4..516296470 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -13,6 +13,14 @@ on: branches: - main +permissions: + contents: read + actions: write # to cancel previous workflows + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + jobs: build-checkpoint: name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" @@ -27,14 +35,12 @@ jobs: include: - python-version: "3.10" jax-version: "0.4.34" # keep in sync with minimum version in checkpoint/pyproject.toml + - python-version: "3.11" + jax-version: "nightly" steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@0.8.0 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v2 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -46,6 +52,8 @@ jobs: pip uninstall -y orbax if [[ "${{ matrix.jax-version }}" == "newest" ]]; then pip install -U jax jaxlib + elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then + pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html else pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" fi @@ -80,15 +88,11 @@ jobs: strategy: matrix: python-version: ["3.10"] - jax-version: ["newest", "0.4.34"] # keep in sync with minimum version in export/pyproject.toml + jax-version: ["nightly", "newest", "0.4.34"] # keep in sync with minimum version in export/pyproject.toml steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@0.8.0 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v2 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Extract branch name @@ -101,6 +105,8 @@ jobs: pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html if [[ "${{ matrix.jax-version }}" == "newest" ]]; then pip install -U jax jaxlib + elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then + pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html else pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" fi