python3Packages.jaxlib: refactor to support Nix-based builds (#151909)

* python3Packages.jaxlib: rename to `jaxlib-bin`

Refactoring `jaxlib` to have a similar structure to `tensorflow` with the 'bin' and 'build' options.

* python3Packages.jaxlib: init the 'build' variant at 0.1.75

Similar to `tensorflow-build`, now there's an option to build `jaxlib` using Nix-provided environment and dependencies.

* python3Packages.jax: 0.2.24 -> 0.2.26

* Addressed review comments.

* Fixed `cudaSupport` missing property on some arches.

* Unified the versions of CUDA-related packages with TF.

Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
This commit is contained in:
Alexander Tsvyashchenko 2021-12-28 01:19:10 +01:00 committed by GitHub
parent 8efd318b10
commit be52722509
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 401 additions and 83 deletions

View file

@ -0,0 +1,12 @@
diff --git a/jax/experimental/compilation_cache/file_system_cache.py b/jax/experimental/compilation_cache/file_system_cache.py
index b85969de..92acd523 100644
--- a/jax/experimental/compilation_cache/file_system_cache.py
+++ b/jax/experimental/compilation_cache/file_system_cache.py
@@ -33,6 +33,7 @@ class FileSystemCache(CacheInterface):
path_to_key = os.path.join(self._path, key)
if os.path.exists(path_to_key):
with open(path_to_key, "rb") as file:
+ os.utime(file.fileno())
return file.read()
else:
return None

View file

@ -13,7 +13,7 @@
buildPythonPackage rec {
pname = "jax";
version = "0.2.25";
version = "0.2.26";
format = "setuptools";
disabled = pythonOlder "3.7";
@ -21,10 +21,15 @@ buildPythonPackage rec {
src = fetchFromGitHub {
owner = "google";
repo = pname;
rev = "jax-v${version}";
sha256 = "0f32is9896g4shfhjipj3rlgpjxci5y607lp8gxlgsdzdqfpckm2";
rev = "${pname}-v${version}";
sha256 = "155hhwgq6axdrj4x4hw72322qv1wc068n4cv4z2vf5jpl05fg93g";
};
patches = [
# See https://github.com/google/jax/issues/7944
./cache-fix.patch
];
# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the

View file

@ -0,0 +1,90 @@
# For the moment we only support the CPU and GPU backends of jaxlib. The TPU
# backend will require some additional work. Those wheels are located here:
# https://storage.googleapis.com/jax-releases/libtpu_releases.html.
# For future reference, the easiest way to test the GPU backend is to run
# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }"
# export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
# python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'"
# python -c "from jax import random; random.PRNGKey(0)"
# python -c "from jax import random; x = random.normal(random.PRNGKey(0), (100, 100)); x @ x"
# There's no convenient way to test the GPU backend in the derivation since the
# nix build environment blocks access to the GPU. See also:
# * https://github.com/google/jax/issues/971#issuecomment-508216439
# * https://github.com/google/jax/issues/5723#issuecomment-913038780
{ addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config
, fetchurl, isPy39, lib, stdenv
# propagatedBuildInputs
, absl-py, flatbuffers, scipy, cudatoolkit_11
# Options:
, cudaSupport ? config.cudaSupport or false
}:
assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1";
let
device = if cudaSupport then "gpu" else "cpu";
in
buildPythonPackage rec {
pname = "jaxlib";
version = "0.1.71";
format = "wheel";
# At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting
# all of them is a pain, so we focus on 3.9, the current nixpkgs python3
# version.
disabled = !isPy39;
src = {
cpu = fetchurl {
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
sha256 = "sha256:0rqhs6qabydizlv5d3rb20dbv6612rr7dqfniy9r6h4kazdinsn6";
};
gpu = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda111/jaxlib-${version}+cuda111-cp39-none-manylinux2010_x86_64.whl";
sha256 = "sha256:065kyzjsk9m84d138p99iymdiiicm1qz8a3iwxz8rspl43rwrw89";
};
}.${device};
# Prebuilt wheels are dynamically linked against things that nix can't find.
# Run `autoPatchelfHook` to automagically fix them.
nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath;
# Dynamic link dependencies
buildInputs = [ stdenv.cc.cc ];
# jaxlib contains shared libraries that open other shared libraries via dlopen
# and these implicit dependencies are not recognized by ldd or
# autoPatchelfHook. That means we need to sneak them into rpath. This step
# must be done after autoPatchelfHook and the automatic stripping of
# artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
# patchPhase. Dependencies:
# * libcudart.so.11.0 -> cudatoolkit_11.lib
# * libcublas.so.11 -> cudatoolkit_11
# * libcuda.so.1 -> opengl driver in /run/opengl-driver/lib
preInstallCheck = lib.optional cudaSupport ''
shopt -s globstar
addOpenGLRunpath $out/**/*.so
for file in $out/**/*.so; do
rpath=$(patchelf --print-rpath $file)
# For some reason `makeLibraryPath` on `cudatoolkit_11` maps to
# <cudatoolkit_11.lib>/lib which is different from <cudatoolkit_11>/lib.
patchelf --set-rpath "$rpath:${cudatoolkit_11}/lib:${lib.makeLibraryPath [ cudatoolkit_11.lib ]}" $file
done
'';
# pip dependencies and optionally cudatoolkit.
propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11;
pythonImportsCheck = [ "jaxlib" ];
meta = with lib; {
description = "XLA library for JAX";
homepage = "https://github.com/google/jax";
license = licenses.asl20;
maintainers = with maintainers; [ samuela ];
platforms = [ "x86_64-linux" ];
};
}

View file

@ -1,90 +1,285 @@
# For the moment we only support the CPU and GPU backends of jaxlib. The TPU
# backend will require some additional work. Those wheels are located here:
# https://storage.googleapis.com/jax-releases/libtpu_releases.html.
{ lib
, pkgs
, stdenv
# For future reference, the easiest way to test the GPU backend is to run
# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }"
# export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
# python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'"
# python -c "from jax import random; random.PRNGKey(0)"
# python -c "from jax import random; x = random.normal(random.PRNGKey(0), (100, 100)); x @ x"
# There's no convenient way to test the GPU backend in the derivation since the
# nix build environment blocks access to the GPU. See also:
# * https://github.com/google/jax/issues/971#issuecomment-508216439
# * https://github.com/google/jax/issues/5723#issuecomment-913038780
# Build-time dependencies:
, addOpenGLRunpath
, bazel_4
, binutils
, buildBazelPackage
, buildPythonPackage
, cython
, fetchFromGitHub
, git
, jsoncpp
, pybind11
, setuptools
, symlinkJoin
, wheel
, which
{ addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config
, fetchurl, isPy39, lib, stdenv
# propagatedBuildInputs
, absl-py, flatbuffers, scipy, cudatoolkit_11
# Options:
, cudaSupport ? config.cudaSupport or false
# Build-time and runtime CUDA dependencies:
, cudatoolkit ? null
, cudnn ? null
, nccl ? null
# Python dependencies:
, absl-py
, flatbuffers
, numpy
, scipy
, six
# Runtime dependencies:
, double-conversion
, giflib
, grpc
, libjpeg_turbo
, python
, snappy
, zlib
# CUDA flags:
, cudaCapabilities ? [ "sm_35" "sm_50" "sm_60" "sm_70" "sm_75" "compute_80" ]
, cudaSupport ? false
# MKL:
, mklSupport ? true
}:
assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1";
let
device = if cudaSupport then "gpu" else "cpu";
in
buildPythonPackage rec {
pname = "jaxlib";
version = "0.1.71";
version = "0.1.75";
meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
homepage = "https://github.com/google/jax";
license = licenses.asl20;
maintainers = with maintainers; [ ndl ];
};
cudatoolkit_joined = symlinkJoin {
name = "${cudatoolkit.name}-merged";
paths = [
cudatoolkit.lib
cudatoolkit.out
] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [
# for some reason some of the required libs are in the targets/x86_64-linux
# directory; not sure why but this works around it
"${cudatoolkit}/targets/${stdenv.system}"
];
};
cudatoolkit_cc_joined = symlinkJoin {
name = "${cudatoolkit.cc.name}-merged";
paths = [
cudatoolkit.cc
binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
];
};
bazel-build = buildBazelPackage {
name = "bazel-build-${pname}-${version}";
bazel = bazel_4;
src = fetchFromGitHub {
owner = "google";
repo = "jax";
rev = "${pname}-v${version}";
sha256 = "01ks4djbpjsxjy2zwdwv3h00sgwi4ps3jz75swddrw2f56zjdmw4";
};
nativeBuildInputs = [
cython
pkgs.flatbuffers
git
setuptools
wheel
which
];
buildInputs = [
double-conversion
giflib
grpc
jsoncpp
libjpeg_turbo
numpy
pkgs.flatbuffers
pkgs.protobuf
pybind11
scipy
six
snappy
zlib
] ++ lib.optionals cudaSupport [
cudatoolkit
cudnn
];
postPatch = ''
rm -f .bazelversion
'';
bazelTarget = "//build:build_wheel";
removeRulesCC = false;
GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc";
preConfigure = ''
# dummy ldconfig
mkdir dummy-ldconfig
echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig
chmod +x dummy-ldconfig/ldconfig
export PATH="$PWD/dummy-ldconfig:$PATH"
cat <<CFG > ./.jax_configure.bazelrc
build --strategy=Genrule=standalone
build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
build --action_env=PYENV_ROOT
build --python_path="${python}/bin/python"
build --distinct_host_configuration=false
'' + lib.optionalString cudaSupport ''
build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}"
build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}"
build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${lib.concatStringsSep "," cudaCapabilities}"
'' + ''
CFG
'';
# Copy-paste from TF derivation.
# Most of these are not really used in jaxlib compilation but it's simpler to keep it
# 'as is' so that it's more compatible with TF derivation.
TF_SYSTEM_LIBS = lib.concatStringsSep "," [
"absl_py"
"astor_archive"
"astunparse_archive"
"boringssl"
# Not packaged in nixpkgs
# "com_github_googleapis_googleapis"
# "com_github_googlecloudplatform_google_cloud_cpp"
"com_github_grpc_grpc"
"com_google_protobuf"
# Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
# "com_googlesource_code_re2"
"curl"
"cython"
"dill_archive"
"double_conversion"
"enum34_archive"
"flatbuffers"
"functools32_archive"
"gast_archive"
"gif"
"hwloc"
"icu"
"jsoncpp_git"
"libjpeg_turbo"
"lmdb"
"nasm"
# "nsync" # not packaged in nixpkgs
"opt_einsum_archive"
"org_sqlite"
"pasta"
"pcre"
"png"
"pybind11"
"six_archive"
"snappy"
"tblib_archive"
"termcolor_archive"
"typing_extensions_archive"
"wrapt"
"zlib"
];
# Make sure Bazel knows about our configuration flags during fetching so that the
# relevant dependencies can be downloaded.
bazelFetchFlags = bazel-build.bazelBuildFlags;
bazelBuildFlags = [
"-c opt"
] ++ lib.optional (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
"--config=avx_posix"
] ++ lib.optional cudaSupport [
"--config=cuda"
] ++ lib.optional mklSupport [
"--config=mkl_open_source_only"
];
fetchAttrs = {
sha256 =
if cudaSupport then
"1lyipbflqd1y5cdj4hdml5h1inbr0wwfgp6xw5p5623qv3im16lh"
else
"09kapzpfwnlr6ghmgwac232bqf2a57mm1brz4cvfx8mlg8bbaw63";
};
buildAttrs = {
outputs = [ "out" ];
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
# 1) Fix pybind11 include paths.
# 2) Force static protobuf linkage to prevent crashes on loading multiple extensions
# in the same python program due to duplicate protobuf DBs.
# 3) Patch python path in the compiler driver.
preBuild = ''
for src in ./jaxlib/*.{cc,h}; do
sed -i 's@include/pybind11@pybind11@g' $src
done
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
'' + lib.optionalString cudaSupport ''
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
'';
installPhase = ''
./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch}
'';
};
inherit meta;
};
in
buildPythonPackage {
inherit meta pname version;
format = "wheel";
# At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting
# all of them is a pain, so we focus on 3.9, the current nixpkgs python3
# version.
disabled = !isPy39;
src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl";
src = {
cpu = fetchurl {
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
sha256 = "sha256:0rqhs6qabydizlv5d3rb20dbv6612rr7dqfniy9r6h4kazdinsn6";
};
gpu = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda111/jaxlib-${version}+cuda111-cp39-none-manylinux2010_x86_64.whl";
sha256 = "sha256:065kyzjsk9m84d138p99iymdiiicm1qz8a3iwxz8rspl43rwrw89";
};
}.${device};
# Prebuilt wheels are dynamically linked against things that nix can't find.
# Run `autoPatchelfHook` to automagically fix them.
nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath;
# Dynamic link dependencies
buildInputs = [ stdenv.cc.cc ];
# jaxlib contains shared libraries that open other shared libraries via dlopen
# and these implicit dependencies are not recognized by ldd or
# autoPatchelfHook. That means we need to sneak them into rpath. This step
# must be done after autoPatchelfHook and the automatic stripping of
# artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
# patchPhase. Dependencies:
# * libcudart.so.11.0 -> cudatoolkit_11.lib
# * libcublas.so.11 -> cudatoolkit_11
# * libcuda.so.1 -> opengl driver in /run/opengl-driver/lib
preInstallCheck = lib.optional cudaSupport ''
shopt -s globstar
addOpenGLRunpath $out/**/*.so
for file in $out/**/*.so; do
rpath=$(patchelf --print-rpath $file)
# For some reason `makeLibraryPath` on `cudatoolkit_11` maps to
# <cudatoolkit_11.lib>/lib which is different from <cudatoolkit_11>/lib.
patchelf --set-rpath "$rpath:${cudatoolkit_11}/lib:${lib.makeLibraryPath [ cudatoolkit_11.lib ]}" $file
postInstall = lib.optionalString cudaSupport ''
find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
addOpenGLRunpath "$lib"
patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
done
'';
# pip dependencies and optionally cudatoolkit.
propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11;
nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath;
propagatedBuildInputs = [
absl-py
double-conversion
flatbuffers
giflib
grpc
jsoncpp
libjpeg_turbo
numpy
scipy
six
snappy
];
pythonImportsCheck = [ "jaxlib" ];
meta = with lib; {
description = "XLA library for JAX";
homepage = "https://github.com/google/jax";
license = licenses.asl20;
maintainers = with maintainers; [ samuela ];
platforms = [ "x86_64-linux" ];
};
# Without it there are complaints about libcudart.so.11.0 not being found
# because RPATH path entries added above are stripped.
dontPatchELF = cudaSupport;
}

View file

@ -100,6 +100,12 @@ let
disabledIf = x: drv: if x then disabled drv else drv;
# CUDA-related packages that are compatible with the currently packaged version
# of TensorFlow, used to keep these versions in sync in related packages like `jaxlib`.
tensorflow_compat_cudatoolkit = pkgs.cudatoolkit_11_2;
tensorflow_compat_cudnn = pkgs.cudnn_cudatoolkit_11_2;
tensorflow_compat_nccl = pkgs.nccl_cudatoolkit_11;
in {
inherit pkgs stdenv;
@ -4053,7 +4059,17 @@ in {
jax = callPackage ../development/python-modules/jax { };
jaxlib = callPackage ../development/python-modules/jaxlib { };
jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { };
jaxlib-build = callPackage ../development/python-modules/jaxlib {
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
cudaSupport = pkgs.config.cudaSupport or false;
cudatoolkit = tensorflow_compat_cudatoolkit;
cudnn = tensorflow_compat_cudnn;
nccl = tensorflow_compat_nccl;
};
jaxlib = self.jaxlib-build;
JayDeBeApi = callPackage ../development/python-modules/JayDeBeApi { };
@ -9453,16 +9469,16 @@ in {
tensorflow-bin = callPackage ../development/python-modules/tensorflow/bin.nix {
cudaSupport = pkgs.config.cudaSupport or false;
cudatoolkit = pkgs.cudatoolkit_11_2;
cudnn = pkgs.cudnn_cudatoolkit_11_2;
cudatoolkit = tensorflow_compat_cudatoolkit;
cudnn = tensorflow_compat_cudnn;
};
tensorflow-build = callPackage ../development/python-modules/tensorflow {
inherit (pkgs.darwin) cctools;
cudaSupport = pkgs.config.cudaSupport or false;
cudatoolkit = pkgs.cudatoolkit_11_2;
cudnn = pkgs.cudnn_cudatoolkit_11_2;
nccl = pkgs.nccl_cudatoolkit_11;
cudatoolkit = tensorflow_compat_cudatoolkit;
cudnn = tensorflow_compat_cudnn;
nccl = tensorflow_compat_nccl;
inherit (pkgs.darwin.apple_sdk.frameworks) Foundation Security;
flatbuffers-core = pkgs.flatbuffers;
flatbuffers-python = self.flatbuffers;