python3Packages.jaxlib-bin: 0.3.0 -> 0.3.22

* Drop support for python versions <3.10
* Drop support for cuDNN <8.2
* Add support for aarch64-darwin
This commit is contained in:
Samuel Ainsworth 2022-10-20 14:17:13 -07:00
parent 9ad4161823
commit 55b3f2ad0b

View file

@ -3,7 +3,7 @@
# https://storage.googleapis.com/jax-releases/libtpu_releases.html.
# For future reference, the easiest way to test the GPU backend is to run
# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }"
# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib-bin.override { cudaSupport = true; }"
# export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
# python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'"
# python -c "from jax import random; random.PRNGKey(0)"
@ -35,46 +35,32 @@ let
inherit (cudaPackages) cudatoolkit cudnn;
in
# There are no jaxlib wheels targeting cudnn <8.0.5, and although there are
# wheels for cudatoolkit <11.1, we don't support them.
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5";
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";
let
version = "0.3.0";
version = "0.3.22";
pythonVersion = python.pythonVersion;
# Find new releases at https://storage.googleapis.com/jax-releases. When
# upgrading, you can get these hashes from prefetch.sh.
# Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
# When upgrading, you can get these hashes from prefetch.sh. See
# https://github.com/google/jax/issues/12879 as to why this specific URL is
# the correct index.
cpuSrcs = {
"3.9" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q=";
"x86_64-linux" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-w2wo0jk+1BdEkNwfSZRQbebdI4Ac8Kgn0MB0cIMcWU4=";
};
"3.10" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl";
hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ=";
"aarch64-darwin" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
hash = "sha256-7Ir55ZhBkccqfoa56WVBF8QwFAC2ws4KFHDkfVw6zm0=";
};
};
gpuSrcs = {
"3.9-805" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl";
hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw=";
};
"3.9-82" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl";
hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4=";
};
"3.10-805" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl";
hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U=";
};
"3.10-82" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl";
hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go=";
};
gpuSrc = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-rabU62p4fF7Tu/6t8LNYZdf6YO06jGry/JtyFZeamCs=";
};
in
buildPythonPackage rec {
@ -82,23 +68,16 @@ buildPythonPackage rec {
inherit version;
format = "wheel";
# At the time of writing (2022-03-03), there are releases for <=3.10.
# Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs
# python3 version, and 3.10.
disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10");
# At the time of writing (2022-10-19), there are releases for <=3.10.
# Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs
# python version.
disabled = !(pythonVersion == "3.10");
src =
if !cudaSupport then cpuSrcs."${pythonVersion}" else
let
# jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and
# 8.2. Try to use 8.2 whenever possible.
cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805";
in
gpuSrcs."${pythonVersion}-${cudnnVersion}";
src = if !cudaSupport then cpuSrcs."${stdenv.hostPlatform.system}" else gpuSrc;
# Prebuilt wheels are dynamically linked against things that nix can't find.
# Run `autoPatchelfHook` to automagically fix them.
nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath;
nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ];
# Dynamic link dependencies
buildInputs = [ stdenv.cc.cc ];
@ -142,6 +121,6 @@ buildPythonPackage rec {
sourceProvenance = with sourceTypes; [ binaryNativeCode ];
license = licenses.asl20;
maintainers = with maintainers; [ samuela ];
platforms = [ "x86_64-linux" ];
platforms = [ "aarch64-darwin" "x86_64-linux" ];
};
}