diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index a041edbd2cc6..1d3fb481b338 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -3,7 +3,7 @@ # https://storage.googleapis.com/jax-releases/libtpu_releases.html. # For future reference, the easiest way to test the GPU backend is to run -# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }" +# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib-bin.override { cudaSupport = true; }" # export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 # python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'" # python -c "from jax import random; random.PRNGKey(0)" @@ -35,46 +35,32 @@ let inherit (cudaPackages) cudatoolkit cudnn; in -# There are no jaxlib wheels targeting cudnn <8.0.5, and although there are -# wheels for cudatoolkit <11.1, we don't support them. assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; -assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5"; +assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2"; let - version = "0.3.0"; + version = "0.3.22"; pythonVersion = python.pythonVersion; - # Find new releases at https://storage.googleapis.com/jax-releases. When - # upgrading, you can get these hashes from prefetch.sh. + # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. + # When upgrading, you can get these hashes from prefetch.sh. See + # https://github.com/google/jax/issues/12879 as to why this specific URL is + # the correct index. cpuSrcs = { - "3.9" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; - hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q="; + "x86_64-linux" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl"; + hash = "sha256-w2wo0jk+1BdEkNwfSZRQbebdI4Ac8Kgn0MB0cIMcWU4="; }; - "3.10" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl"; - hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ="; + "aarch64-darwin" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl"; + hash = "sha256-7Ir55ZhBkccqfoa56WVBF8QwFAC2ws4KFHDkfVw6zm0="; }; }; - gpuSrcs = { - "3.9-805" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; - hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw="; - }; - "3.9-82" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; - hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4="; - }; - "3.10-805" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl"; - hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U="; - }; - "3.10-82" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl"; - hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go="; - }; + gpuSrc = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"; + hash = "sha256-rabU62p4fF7Tu/6t8LNYZdf6YO06jGry/JtyFZeamCs="; }; in buildPythonPackage rec { @@ -82,23 +68,16 @@ buildPythonPackage rec { inherit version; format = "wheel"; - # At the time of writing (2022-03-03), there are releases for <=3.10. - # Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs - # python3 version, and 3.10. - disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10"); + # At the time of writing (2022-10-19), there are releases for <=3.10. + # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs + # python version. + disabled = !(pythonVersion == "3.10"); - src = - if !cudaSupport then cpuSrcs."${pythonVersion}" else - let - # jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and - # 8.2. Try to use 8.2 whenever possible. - cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805"; - in - gpuSrcs."${pythonVersion}-${cudnnVersion}"; + src = if !cudaSupport then cpuSrcs."${stdenv.hostPlatform.system}" else gpuSrc; # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. - nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath; + nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ]; # Dynamic link dependencies buildInputs = [ stdenv.cc.cc ]; @@ -142,6 +121,6 @@ buildPythonPackage rec { sourceProvenance = with sourceTypes; [ binaryNativeCode ]; license = licenses.asl20; maintainers = with maintainers; [ samuela ]; - platforms = [ "x86_64-linux" ]; + platforms = [ "aarch64-darwin" "x86_64-linux" ]; }; }