From 6bff360532ceee13b268c0c8229212aa585a043a Mon Sep 17 00:00:00 2001 From: Matt Wittmann Date: Thu, 11 Aug 2022 21:43:18 -0700 Subject: [PATCH] python310Packages.jax: 0.3.6 -> 0.3.16 python310Packages.jaxlib: 0.3.0 -> 0.3.15 --- .../python-modules/jax/cache-fix.patch | 12 ------ .../python-modules/jax/default.nix | 22 ++++------- .../python-modules/jaxlib/default.nix | 39 +++++++++++-------- 3 files changed, 30 insertions(+), 43 deletions(-) delete mode 100644 pkgs/development/python-modules/jax/cache-fix.patch diff --git a/pkgs/development/python-modules/jax/cache-fix.patch b/pkgs/development/python-modules/jax/cache-fix.patch deleted file mode 100644 index 5db5319485f8..000000000000 --- a/pkgs/development/python-modules/jax/cache-fix.patch +++ /dev/null @@ -1,12 +0,0 @@ -diff --git a/jax/experimental/compilation_cache/file_system_cache.py b/jax/experimental/compilation_cache/file_system_cache.py -index b85969de..92acd523 100644 ---- a/jax/experimental/compilation_cache/file_system_cache.py -+++ b/jax/experimental/compilation_cache/file_system_cache.py -@@ -33,6 +33,7 @@ class FileSystemCache(CacheInterface): - path_to_key = os.path.join(self._path, key) - if os.path.exists(path_to_key): - with open(path_to_key, "rb") as file: -+ os.utime(file.fileno()) - return file.read() - else: - return None diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix index a302341c3141..9970783aa3bc 100644 --- a/pkgs/development/python-modules/jax/default.nix +++ b/pkgs/development/python-modules/jax/default.nix @@ -2,10 +2,11 @@ , absl-py , blas , buildPythonPackage +, etils , fetchFromGitHub -, fetchpatch , jaxlib , lapack +, matplotlib , numpy , opt-einsum , pytestCheckHook @@ -20,7 +21,7 @@ let in buildPythonPackage rec { pname = "jax"; - version = "0.3.6"; + version = "0.3.16"; format = "setuptools"; disabled = pythonOlder "3.7"; @@ -29,34 +30,25 @@ buildPythonPackage rec { owner = "google"; repo = pname; rev = "jax-v${version}"; - hash = "sha256-eGdAEZFHadNTHgciP4KMYHdwksz9g6un0Ar+A/KV5TE="; + hash = "sha256-4idh7boqBXSO9vEHxEcrzXjBIrKmmXiCf6cXh7En1/I="; }; - patches = [ - # See https://github.com/google/jax/issues/7944 - ./cache-fix.patch - - # See https://github.com/google/jax/issues/10292 - (fetchpatch { - url = "https://github.com/google/jax/commit/cadc8046d56e0c1433cf48a2f106947d5f4ecbfd.patch"; - hash = "sha256-jrpIqt4LzWAswt/Cpwtfa5d1Yn31HcXkVH3ETmaigA0="; - }) - ]; - # jaxlib is _not_ included in propagatedBuildInputs because there are # different versions of jaxlib depending on the desired target hardware. The # JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the # CPU wheel is packaged. propagatedBuildInputs = [ absl-py + etils numpy opt-einsum scipy typing-extensions - ]; + ] ++ etils.optional-dependencies.epath; checkInputs = [ jaxlib + matplotlib pytestCheckHook pytest-xdist ]; diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index eee432f71853..456c9108593e 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -9,11 +9,14 @@ , buildBazelPackage , buildPythonPackage , cctools +, curl , cython , fetchFromGitHub , git , IOKit , jsoncpp +, nsync +, openssl , pybind11 , setuptools , symlinkJoin @@ -50,7 +53,7 @@ let inherit (cudaPackages) cudatoolkit cudnn nccl; pname = "jaxlib"; - version = "0.3.0"; + version = "0.3.15"; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; @@ -93,7 +96,7 @@ let owner = "google"; repo = "jax"; rev = "${pname}-v${version}"; - sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72"; + sha256 = "sha256-pIl7zzl82w5HHnJadH2vtCT4mYFd5YmM9iHC2GoJD6s="; }; nativeBuildInputs = [ @@ -103,15 +106,19 @@ let setuptools wheel which + ] ++ lib.optionals stdenv.isDarwin [ + cctools ]; buildInputs = [ + curl double-conversion giflib grpc jsoncpp libjpeg_turbo numpy + openssl pkgs.flatbuffers pkgs.protobuf pybind11 @@ -124,6 +131,8 @@ let cudnn ] ++ lib.optionals stdenv.isDarwin [ IOKit + ] ++ lib.optionals (!stdenv.isDarwin) [ + nsync ]; postPatch = '' @@ -149,6 +158,7 @@ let build --action_env=PYENV_ROOT build --python_path="${python}/bin/python" build --distinct_host_configuration=false + build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" '' + lib.optionalString cudaSupport '' build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" build --action_env CUDNN_INSTALL_PATH="${cudnn}" @@ -163,7 +173,7 @@ let # Copy-paste from TF derivation. # Most of these are not really used in jaxlib compilation but it's simpler to keep it # 'as is' so that it's more compatible with TF derivation. - TF_SYSTEM_LIBS = lib.concatStringsSep "," [ + TF_SYSTEM_LIBS = lib.concatStringsSep "," ([ "absl_py" "astor_archive" "astunparse_archive" @@ -179,7 +189,6 @@ let "cython" "dill_archive" "double_conversion" - "enum34_archive" "flatbuffers" "functools32_archive" "gast_archive" @@ -190,11 +199,9 @@ let "libjpeg_turbo" "lmdb" "nasm" - # "nsync" # not packaged in nixpkgs "opt_einsum_archive" "org_sqlite" "pasta" - "pcre" "png" "pybind11" "six_archive" @@ -204,7 +211,9 @@ let "typing_extensions_archive" "wrapt" "zlib" - ]; + ] ++ lib.optionals (!stdenv.isDarwin) [ + "nsync" # fails to build on darwin + ]); # Make sure Bazel knows about our configuration flags during fetching so that the # relevant dependencies can be downloaded. @@ -226,9 +235,11 @@ let fetchAttrs = { sha256 = if cudaSupport then - "sha256-Ald+vplRx/DDG/7TfHAqD4Gktb1BGnf7FSCCJzSI0eo=" + "sha256-tdO4YjO985zbittb16RFWgxgUBrHYQfv5gRsA4IAkTk=" + else if stdenv.isDarwin then + "sha256-+XYxfXBCASueqDGg0Zqcmpf7zmemYM6xCE+x0rl3j34=" else - "sha256-eK5IjTAncDarkWYKnXrEo7kw7J7iOH7in2L2GabnFYo="; + "sha256-La1wC8X5aGK5mXvYy/kO8n4J+zaRZEc/DAX5zaH1D5A="; }; buildAttrs = { @@ -239,15 +250,10 @@ let # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # loading multiple extensions in the same python program due to duplicate protobuf DBs. # 3) Patch python path in the compiler driver. - # 4) Patch tensorflow sources to work with later versions of protobuf. See - # https://github.com/google/jax/issues/9534. Note that this should be - # removed on the next release after 0.3.0. preBuild = '' - for src in ./jaxlib/*.{cc,h}; do + for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do sed -i 's@include/pybind11@pybind11@g' $src done - substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \ - --replace "status.message()" "std::string{status.message()}" '' + lib.optionalString cudaSupport '' patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl '' + lib.optionalString stdenv.isDarwin '' @@ -275,7 +281,7 @@ let }; platformTag = if stdenv.targetPlatform.isLinux then - "manylinux2010_${stdenv.targetPlatform.linuxArch}" + "manylinux2014_${stdenv.targetPlatform.linuxArch}" else if stdenv.system == "x86_64-darwin" then "macosx_10_9_${stdenv.targetPlatform.linuxArch}" else if stdenv.system == "aarch64-darwin" then @@ -306,6 +312,7 @@ buildPythonPackage { propagatedBuildInputs = [ absl-py + curl double-conversion flatbuffers giflib