diff --git a/.github/container/Dockerfile.pax.amd64 b/.github/container/Dockerfile.pax.amd64 index f6072f53a..e8ae18291 100644 --- a/.github/container/Dockerfile.pax.amd64 +++ b/.github/container/Dockerfile.pax.amd64 @@ -15,14 +15,10 @@ ARG REPO_PAXML=https://github.com/google/paxml.git ARG REPO_PRAXIS=https://github.com/google/praxis.git ARG REF_PAXML=main ARG REF_PRAXIS=main -ARG REPO_TE=https://github.com/NVIDIA/TransformerEngine.git -# TODO: This is a temporary pinning of TE as the API in TE no longer matches the TE patch -# This should be reverted to main ASAP -ARG REF_TE=7976bd003fcf084dd068069b92a9a79b1743316a RUN <<"EOF" bash -ex install-pax.sh --defer --from_paxml ${REPO_PAXML} --from_praxis ${REPO_PRAXIS} --ref_paxml ${REF_PAXML} --ref_praxis ${REF_PRAXIS} install-flax.sh --defer -install-te.sh --defer --from ${REPO_TE} --ref ${REF_TE} +install-te.sh --defer if [[ -f /opt/requirements-defer.txt ]]; then # SKIP_HEAD_INSTALLS avoids having to install jax from Github source so that diff --git a/.github/container/Dockerfile.t5x b/.github/container/Dockerfile.t5x index 657459706..0bf63b291 100644 --- a/.github/container/Dockerfile.t5x +++ b/.github/container/Dockerfile.t5x @@ -14,9 +14,7 @@ ENV NVTE_FRAMEWORK=jax ARG REPO_T5X=https://github.com/google-research/t5x.git ARG REF_T5X=main ARG REPO_TE=https://github.com/NVIDIA/TransformerEngine.git -# TODO: This is a temporary pinning of TE as the API in TE no longer matches the TE patch -# This should be reverted to main ASAP -ARG REF_TE=7976bd003fcf084dd068069b92a9a79b1743316a +ARG REF_TE=main RUN <<"EOF" bash -ex install-t5x.sh --defer --from ${REPO_T5X} --ref ${REF_T5X} install-te.sh --defer --from ${REPO_TE} --ref ${REF_TE} diff --git a/.github/workflows/_build_t5x.yaml b/.github/workflows/_build_t5x.yaml index fffbb7731..ed0c1a628 100644 --- a/.github/workflows/_build_t5x.yaml +++ b/.github/workflows/_build_t5x.yaml @@ -32,9 +32,7 @@ on: type: string description: Git commit, tag, or branch for TE required: false - # TODO: This is a temporary pinning of TE as the API in TE no longer matches the TE patch - # This should be reverted to main ASAP - default: 7976bd003fcf084dd068069b92a9a79b1743316a + default: main outputs: DOCKER_TAGS: description: "Tags of the image built" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d128a17de..ef0a54432 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,9 +25,7 @@ on: description: 'TE source: #' type: string required: true - # TODO: This is a temporary pinning of TE as the API in TE no longer matches the TE patch - # This should be reverted to main ASAP - default: 'https://github.com/NVIDIA/TransformerEngine.git#7976bd003fcf084dd068069b92a9a79b1743316a' + default: 'https://github.com/NVIDIA/TransformerEngine.git#main' SRC_T5X: description: 'T5X source: #' type: string @@ -96,9 +94,7 @@ jobs: # default values are for `pull_request`` event types parse_git_src JAX "${{ inputs.SRC_JAX }}" "https://github.com/google/jax.git#main" parse_git_src XLA "${{ inputs.SRC_XLA }}" "https://github.com/openxla/xla.git#main" - # TODO: This is a temporary pinning of TE as the API in TE no longer matches the TE patch - # This should be reverted to main ASAP - parse_git_src TE "${{ inputs.SRC_TE }}" "https://github.com/NVIDIA/TransformerEngine.git#7976bd003fcf084dd068069b92a9a79b1743316a" + parse_git_src TE "${{ inputs.SRC_TE }}" "https://github.com/NVIDIA/TransformerEngine.git#main" parse_git_src T5X "${{ inputs.SRC_T5X }}" "https://github.com/google-research/t5x.git#main" parse_git_src PAXML "${{ inputs.SRC_PAXML }}" "https://github.com/google/paxml.git#main" parse_git_src PRAXIS "${{ inputs.SRC_PRAXIS }}" "https://github.com/google/praxis.git#main"