python310Packages.jax: 0.3.6 -> 0.3.16

python310Packages.jaxlib: 0.3.0 -> 0.3.15
This commit is contained in:
Matt Wittmann 2022-08-11 21:43:18 -07:00
parent 5e053ae4a5
commit 6bff360532
No known key found for this signature in database
GPG key ID: AEAD22D9F95347ED
3 changed files with 30 additions and 43 deletions

View file

@ -1,12 +0,0 @@
diff --git a/jax/experimental/compilation_cache/file_system_cache.py b/jax/experimental/compilation_cache/file_system_cache.py
index b85969de..92acd523 100644
--- a/jax/experimental/compilation_cache/file_system_cache.py
+++ b/jax/experimental/compilation_cache/file_system_cache.py
@@ -33,6 +33,7 @@ class FileSystemCache(CacheInterface):
path_to_key = os.path.join(self._path, key)
if os.path.exists(path_to_key):
with open(path_to_key, "rb") as file:
+ os.utime(file.fileno())
return file.read()
else:
return None

View file

@ -2,10 +2,11 @@
, absl-py
, blas
, buildPythonPackage
, etils
, fetchFromGitHub
, fetchpatch
, jaxlib
, lapack
, matplotlib
, numpy
, opt-einsum
, pytestCheckHook
@ -20,7 +21,7 @@ let
in
buildPythonPackage rec {
pname = "jax";
version = "0.3.6";
version = "0.3.16";
format = "setuptools";
disabled = pythonOlder "3.7";
@ -29,34 +30,25 @@ buildPythonPackage rec {
owner = "google";
repo = pname;
rev = "jax-v${version}";
hash = "sha256-eGdAEZFHadNTHgciP4KMYHdwksz9g6un0Ar+A/KV5TE=";
hash = "sha256-4idh7boqBXSO9vEHxEcrzXjBIrKmmXiCf6cXh7En1/I=";
};
patches = [
# See https://github.com/google/jax/issues/7944
./cache-fix.patch
# See https://github.com/google/jax/issues/10292
(fetchpatch {
url = "https://github.com/google/jax/commit/cadc8046d56e0c1433cf48a2f106947d5f4ecbfd.patch";
hash = "sha256-jrpIqt4LzWAswt/Cpwtfa5d1Yn31HcXkVH3ETmaigA0=";
})
];
# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
# CPU wheel is packaged.
propagatedBuildInputs = [
absl-py
etils
numpy
opt-einsum
scipy
typing-extensions
];
] ++ etils.optional-dependencies.epath;
checkInputs = [
jaxlib
matplotlib
pytestCheckHook
pytest-xdist
];

View file

@ -9,11 +9,14 @@
, buildBazelPackage
, buildPythonPackage
, cctools
, curl
, cython
, fetchFromGitHub
, git
, IOKit
, jsoncpp
, nsync
, openssl
, pybind11
, setuptools
, symlinkJoin
@ -50,7 +53,7 @@ let
inherit (cudaPackages) cudatoolkit cudnn nccl;
pname = "jaxlib";
version = "0.3.0";
version = "0.3.15";
meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@ -93,7 +96,7 @@ let
owner = "google";
repo = "jax";
rev = "${pname}-v${version}";
sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72";
sha256 = "sha256-pIl7zzl82w5HHnJadH2vtCT4mYFd5YmM9iHC2GoJD6s=";
};
nativeBuildInputs = [
@ -103,15 +106,19 @@ let
setuptools
wheel
which
] ++ lib.optionals stdenv.isDarwin [
cctools
];
buildInputs = [
curl
double-conversion
giflib
grpc
jsoncpp
libjpeg_turbo
numpy
openssl
pkgs.flatbuffers
pkgs.protobuf
pybind11
@ -124,6 +131,8 @@ let
cudnn
] ++ lib.optionals stdenv.isDarwin [
IOKit
] ++ lib.optionals (!stdenv.isDarwin) [
nsync
];
postPatch = ''
@ -149,6 +158,7 @@ let
build --action_env=PYENV_ROOT
build --python_path="${python}/bin/python"
build --distinct_host_configuration=false
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
'' + lib.optionalString cudaSupport ''
build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
@ -163,7 +173,7 @@ let
# Copy-paste from TF derivation.
# Most of these are not really used in jaxlib compilation but it's simpler to keep it
# 'as is' so that it's more compatible with TF derivation.
TF_SYSTEM_LIBS = lib.concatStringsSep "," [
TF_SYSTEM_LIBS = lib.concatStringsSep "," ([
"absl_py"
"astor_archive"
"astunparse_archive"
@ -179,7 +189,6 @@ let
"cython"
"dill_archive"
"double_conversion"
"enum34_archive"
"flatbuffers"
"functools32_archive"
"gast_archive"
@ -190,11 +199,9 @@ let
"libjpeg_turbo"
"lmdb"
"nasm"
# "nsync" # not packaged in nixpkgs
"opt_einsum_archive"
"org_sqlite"
"pasta"
"pcre"
"png"
"pybind11"
"six_archive"
@ -204,7 +211,9 @@ let
"typing_extensions_archive"
"wrapt"
"zlib"
];
] ++ lib.optionals (!stdenv.isDarwin) [
"nsync" # fails to build on darwin
]);
# Make sure Bazel knows about our configuration flags during fetching so that the
# relevant dependencies can be downloaded.
@ -226,9 +235,11 @@ let
fetchAttrs = {
sha256 =
if cudaSupport then
"sha256-Ald+vplRx/DDG/7TfHAqD4Gktb1BGnf7FSCCJzSI0eo="
"sha256-tdO4YjO985zbittb16RFWgxgUBrHYQfv5gRsA4IAkTk="
else if stdenv.isDarwin then
"sha256-+XYxfXBCASueqDGg0Zqcmpf7zmemYM6xCE+x0rl3j34="
else
"sha256-eK5IjTAncDarkWYKnXrEo7kw7J7iOH7in2L2GabnFYo=";
"sha256-La1wC8X5aGK5mXvYy/kO8n4J+zaRZEc/DAX5zaH1D5A=";
};
buildAttrs = {
@ -239,15 +250,10 @@ let
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
# 3) Patch python path in the compiler driver.
# 4) Patch tensorflow sources to work with later versions of protobuf. See
# https://github.com/google/jax/issues/9534. Note that this should be
# removed on the next release after 0.3.0.
preBuild = ''
for src in ./jaxlib/*.{cc,h}; do
for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
sed -i 's@include/pybind11@pybind11@g' $src
done
substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \
--replace "status.message()" "std::string{status.message()}"
'' + lib.optionalString cudaSupport ''
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
'' + lib.optionalString stdenv.isDarwin ''
@ -275,7 +281,7 @@ let
};
platformTag =
if stdenv.targetPlatform.isLinux then
"manylinux2010_${stdenv.targetPlatform.linuxArch}"
"manylinux2014_${stdenv.targetPlatform.linuxArch}"
else if stdenv.system == "x86_64-darwin" then
"macosx_10_9_${stdenv.targetPlatform.linuxArch}"
else if stdenv.system == "aarch64-darwin" then
@ -306,6 +312,7 @@ buildPythonPackage {
propagatedBuildInputs = [
absl-py
curl
double-conversion
flatbuffers
giflib