Skip to content

Commit

Permalink
python312Packages.jax: 0.4.38 -> 0.5.0 (NixOS#374810)
Browse files Browse the repository at this point in the history
  • Loading branch information
GaetanLepage authored Jan 24, 2025
2 parents d7f55a7 + 01f1d45 commit 2dc232a
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 49 deletions.
14 changes: 5 additions & 9 deletions pkgs/development/python-modules/flax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,15 @@ buildPythonPackage rec {
"flax/nnx/examples/*"
# See https://github.com/google/flax/issues/3232.
"tests/jax_utils_test.py"
# Too old version of tensorflow:
# ModuleNotFoundError: No module named 'keras.api._v2'
"tests/tensorboard_test.py"
];

disabledTests =
[
# ValueError: Checkpoint path should be absolute
"test_overwrite_checkpoints0"
# Fixed in more recent versions of jax: https://github.com/google/flax/issues/4211
# TODO: Re-enable when jax>0.4.28 will be available in nixpkgs
"test_vmap_and_cond_passthrough" # ValueError: vmap has mapped output but out_axes is None
"test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ...
# Failing with AssertionError since the jax update to 0.5.0
"test_basic_demo_single"
"test_batch_norm_multi_init"
"test_multimetric"
"test_split_merge"
]
++ lib.optionals stdenv.hostPlatform.isDarwin [
# SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ let
srcs = {
"x86_64-linux" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl";
hash = "sha256-g75MWfvPMAd6YAhdmOfVncc4sckeDWKOSsF3n94VrCs=";
hash = "sha256-0jgzwbiF2WwnZAAOlQUvK1gnx31JLqaPZ+kDoTJlbbs=";
};
"aarch64-linux" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl";
Expand Down
8 changes: 4 additions & 4 deletions pkgs/development/python-modules/jax-cuda12-plugin/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ let
"3.10-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp310";
hash = "sha256-nULpmc1k3VZ8FJ7Wj3k5K6iGRDZCGLtjbNzvoBl8kv4=";
hash = "sha256-D0Q6azcpjt+weW/NvR+GzoWksIS2vT8fUKT7/Wfe2Gs=";
};
"3.10-aarch64-linux" = getSrcFromPypi {
platform = "manylinux2014_aarch64";
Expand All @@ -49,7 +49,7 @@ let
"3.11-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp311";
hash = "sha256-cEZUOG8OYAoCgdquqViCqmekfttoOTthsbFzx+jKdKg=";
hash = "sha256-qYE1oCIwZLj1xoU+It3BpOOGIVLTf7aF8Nve/+DIASI=";
};
"3.11-aarch64-linux" = getSrcFromPypi {
platform = "manylinux2014_aarch64";
Expand All @@ -59,7 +59,7 @@ let
"3.12-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp312";
hash = "sha256-Ufas/3Ew63LrsCU039NYGg9eoGlx3lLX68Ia1Nh/5x4=";
hash = "sha256-QwWN/FZdjJ2mn0fNTkuVxJXxaG8onvRYTCtygD5vFgc=";
};
"3.12-aarch64-linux" = getSrcFromPypi {
platform = "manylinux2014_aarch64";
Expand All @@ -69,7 +69,7 @@ let
"3.13-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp313";
hash = "sha256-CSKKTCtEO3aozZqOwikGAInEzINuBiSWh1ptb9xm0x8=";
hash = "sha256-3zbEsXbi01qCqfOM13zDadJx5gBR43GgqO9FFD+PWLY=";
};
"3.13-aarch64-linux" = getSrcFromPypi {
platform = "manylinux2014_aarch64";
Expand Down
5 changes: 3 additions & 2 deletions pkgs/development/python-modules/jax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ let
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.38";
version = "0.5.0";
pyproject = true;

src = fetchFromGitHub {
owner = "google";
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jax tags!
tag = "jax-v${version}";
hash = "sha256-H8I9Mkz6Ia1RxJmnuJOSevLGHN2J8ey59ZTlFx8YfnA=";
hash = "sha256-D6n9Z34nrCbBd9IS8YW6uio5Yi9GLCo9PViO3YYbkQ8=";
};

build-system = [ setuptools ];
Expand Down Expand Up @@ -154,6 +154,7 @@ buildPythonPackage rec {
"testInAxesPyTreePrefixMismatchErrorKwargs"
"testOutAxesPyTreePrefixMismatchError"
"test_tree_map"
"test_tree_prefix_error"
"test_vjp_rule_inconsistent_pytree_structures_error"
"test_vmap_in_axes_tree_prefix_error"
"test_vmap_mismatched_axis_sizes_error_message_issue_705"
Expand Down
46 changes: 13 additions & 33 deletions pkgs/development/python-modules/jaxlib/bin.nix
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
}:

let
version = "0.4.38";
version = "0.5.0";
inherit (python) pythonVersion;

# As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
Expand Down Expand Up @@ -49,85 +49,65 @@ let
"3.10-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp310";
hash = "sha256-Ya7MuaJ8Z/24RQ9jVyQAGc1FEcudYqROR2R1bThIU60=";
hash = "sha256-dEQLYyEHM2QA1Pl6Fkgddn8T6pFMU7oU5UTG/aVIGbM=";
};
"3.10-aarch64-linux" = getSrcFromPypi {
platform = "manylinux2014_aarch64";
dist = "cp310";
hash = "sha256-7hnBY6j98IOdTBi4il+/tOcxunxDdBbT5Ug+Vwu3ZOQ=";
hash = "sha256-Wy7+Pf6/GKhMRR04A6yITuJCAhwRE7J5wT9LvDeMPcA=";
};
"3.10-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp310";
hash = "sha256-MLL1LLUNdHNK8vR3wlM6elg+O7eyyKzes2Hud9lAV3o=";
};
"3.10-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp310";
hash = "sha256-VcGbnT8zpvxZ9kSqWiH7oCY5zN13bLSptVJmJfV4Of8=";
hash = "sha256-G4psQ0XxN/OHZQ3i28SIwgJRt0ErVd1kjhpPE7z1B/s=";
};

"3.11-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp311";
hash = "sha256-J1H/cDfWqZfQvg53zEvjgcWp+buLMU7bdVwTpv2Wn0U=";
hash = "sha256-CRE+8Vgro018vEQP7bMY9IVbWbd2cRqKuiRzyXJ9MCU=";
};
"3.11-aarch64-linux" = getSrcFromPypi {
platform = "manylinux2014_aarch64";
dist = "cp311";
hash = "sha256-Q9tYxMQnYnKWNmpWwQMY4fAPUDaQ4X+Uu0NEKT4ZleA=";
hash = "sha256-YwiNv6qFu1bNUhqSWjRy/XMosY7JPC2P+oWvMxCVyZU=";
};
"3.11-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp311";
hash = "sha256-P7DqrnNpFXr+y+rVCq8p5z/936d6IzXXIb2XlPPFEOQ=";
};
"3.11-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp311";
hash = "sha256-tn/eq9bf7Qi3do873/tSEWAIX4MFZpvRl77vYdCN4Is=";
hash = "sha256-bNdi7RYjEySZ+nAcQgNEYQLgqcgsojGUuHKI90bRKik=";
};

"3.12-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp312";
hash = "sha256-2tbAqWVnwG0IPARp/sQPIBIQsJk2W9aYvjGm0uyI/Vk=";
hash = "sha256-+YDHM+mMmYqNqHyajMYbZybQvmZ6WL1mTB1xe0tOrnU=";
};
"3.12-aarch64-linux" = getSrcFromPypi {
platform = "manylinux2014_aarch64";
dist = "cp312";
hash = "sha256-SW9FsOABojQTCc0MdK8LZwU33O15wWjLIwz8x3PwqoY";
hash = "sha256-S0sBr7Dd7JbAg1a/8rtoXdvpf9/+Ttbi2DSzCrqXLyI=";
};
"3.12-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp312";
hash = "sha256-8zvK/jLJelYuz2iU18QWdMgMCs3t+lQj1Jr1EUcUmHQ=";
};
"3.12-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp312";
hash = "sha256-P+/qmF8EFYFvO7r9PwOkNwUCde+brJpywTFOFkSsV8E=";
hash = "sha256-c+M1cVdgxW5jUQnWFCZDWl1/RvM2OhFdrqCUJ9XNDv0=";
};

"3.13-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp313";
hash = "sha256-LOd7qM2pJZpLypevwcci5CkabEY6Y/jTcsbtyFEX1iU=";
hash = "sha256-Ee7wHTfA8cUwYmW3byB/EALRNIDe0uMf1j7HaRLJPKI=";
};
"3.13-aarch64-linux" = getSrcFromPypi {
platform = "manylinux2014_aarch64";
dist = "cp313";
hash = "sha256-JIzKN3Hr8ksHD0lwE2TOraM+YTlEWwbHgsylrFrZK/Q=";
hash = "sha256-fZsXp+oZNV1F7Nsv8NtdcHqG8MWoYtlLibRWjWxFMRo=";
};
"3.13-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp313";
hash = "sha256-b+MmuK82Y4fdR8zzElg7Kxf+0ScSybdKZIsYoTy9ur8=";
};
"3.13-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp313";
hash = "sha256-QeVa5YGKiC5XiehI9vFmh6wTK8+7Wl+hFKXRi3jQXy0=";
hash = "sha256-7RjqcWHQOqj9TRtVSUiC8hQg79/qaOXymMSuvPKsPzQ=";
};
};
in
Expand Down
24 changes: 24 additions & 0 deletions pkgs/development/python-modules/numpyro/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,30 @@ buildPythonPackage rec {
"test_kl_dirichlet_dirichlet"
"test_kl_univariate"
"test_mean_var"
# since jax update to 0.5.0
"test_analytic_kl_2"
"test_analytic_kl_3"
"test_apply_kernel"
"test_beta_bernoulli"
"test_biject_to"
"test_bijective_transforms"
"test_change_point_x64"
"test_cholesky_update"
"test_dais_vae"
"test_discrete_gibbs_multiple_sites_chain"
"test_entropy_categorical"
"test_gaussian_model"
"test_get_proposal_loc_and_scale"
"test_guide_plate_contraction"
"test_kernel_forward"
"test_laplace_approximation_warning"
"test_log_prob_gradient"
"test_logistic_regression"
"test_logistic_regression_x64"
"test_scale"
"test_scan_svi"
"test_stein_particle_loss"
"test_weight_convergence"

# Tests want to download data
"data_load"
Expand Down

0 comments on commit 2dc232a

Please sign in to comment.