Merge pull request #285037 from breakds/PR/breakds/fix_jaxlib_bin_and_cuda
jaxlib-bin: use correct cuda releases
This commit is contained in:
commit
e8a69d497b
1 changed files with 27 additions and 8 deletions
|
@ -33,7 +33,7 @@
|
|||
}:
|
||||
|
||||
let
|
||||
inherit (cudaPackagesGoogle) cudatoolkit cudnn;
|
||||
inherit (cudaPackagesGoogle) cudatoolkit cudnn cudaVersion;
|
||||
|
||||
version = "0.4.23";
|
||||
|
||||
|
@ -118,26 +118,44 @@ let
|
|||
};
|
||||
};
|
||||
|
||||
# Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
|
||||
# Note that the prebuilt jaxlib binary requires specific version of CUDA to
|
||||
# work. The cuda12 jaxlib binaries only works with CUDA 12.2, and cuda11
|
||||
# jaxlib binaries only works with CUDA 11.8. This is why we need to find a
|
||||
# binary that matches the provided cudaVersion.
|
||||
gpuSrcVersionString = "cuda${cudaVersion}-${pythonVersion}";
|
||||
|
||||
# Find new releases at https://storage.googleapis.com/jax-releases
|
||||
# When upgrading, you can get these hashes from prefetch.sh. See
|
||||
# https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
|
||||
gpuSrcs = {
|
||||
"3.9" = fetchurl {
|
||||
"cuda12.2-3.9" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-our2mSwHPdjVoDAZP+9aNUkJ+vxv1Tq7G5UqA9HvhNI=";
|
||||
};
|
||||
"3.10" = fetchurl {
|
||||
"cuda12.2-3.10" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-jkIABnJZnn7A6n9VGs/MldzdDiKwWh0fEvl7Vqn85Kg=";
|
||||
};
|
||||
"3.11" = fetchurl {
|
||||
"cuda12.2-3.11" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-dMUcRnHjl8NyUeO3P1x7CNgF0iAHFKIzUtHh+/CNkow=";
|
||||
};
|
||||
"3.12" = fetchurl {
|
||||
"cuda12.2-3.12" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-kXJ6bUwX+QybqYPV9Kpwv+lhdoGEFRr4+1T0vfXoWRo=";
|
||||
};
|
||||
"cuda11.8-3.9" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-m2Y5p12gF3OaADu+aGw5RjcKFrj9RB8xzNWnKNpSz60=";
|
||||
};
|
||||
"cuda11.8-3.10" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
|
||||
hash = "osha256-aQ7iX3o0kQ4liPexv7dkBVWVTUpaty83L083MybGkf0=";
|
||||
};
|
||||
"cuda11.8-3.11" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-uIEyjEmv0HBaiYVl5PuICTI9XnH4zAfQ1l9tjALRcP4=";
|
||||
};
|
||||
};
|
||||
|
||||
in
|
||||
|
@ -154,7 +172,7 @@ buildPythonPackage {
|
|||
(
|
||||
cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}"
|
||||
or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
|
||||
) else gpuSrcs."${pythonVersion}";
|
||||
) else gpuSrcs."${gpuSrcVersionString}";
|
||||
|
||||
# Prebuilt wheels are dynamically linked against things that nix can't find.
|
||||
# Run `autoPatchelfHook` to automagically fix them.
|
||||
|
@ -212,6 +230,7 @@ buildPythonPackage {
|
|||
broken =
|
||||
!(cudaSupport -> (cudaPackagesGoogle ? cudatoolkit) && lib.versionAtLeast cudatoolkit.version "11.1")
|
||||
|| !(cudaSupport -> (cudaPackagesGoogle ? cudnn) && lib.versionAtLeast cudnn.version "8.2")
|
||||
|| !(cudaSupport -> stdenv.isLinux);
|
||||
|| !(cudaSupport -> stdenv.isLinux)
|
||||
|| !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"));
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue