python310Packages.jax: 0.3.6 -> 0.3.16
python310Packages.jaxlib: 0.3.0 -> 0.3.15
This commit is contained in:
parent
5e053ae4a5
commit
6bff360532
3 changed files with 30 additions and 43 deletions
|
@ -1,12 +0,0 @@
|
|||
diff --git a/jax/experimental/compilation_cache/file_system_cache.py b/jax/experimental/compilation_cache/file_system_cache.py
|
||||
index b85969de..92acd523 100644
|
||||
--- a/jax/experimental/compilation_cache/file_system_cache.py
|
||||
+++ b/jax/experimental/compilation_cache/file_system_cache.py
|
||||
@@ -33,6 +33,7 @@ class FileSystemCache(CacheInterface):
|
||||
path_to_key = os.path.join(self._path, key)
|
||||
if os.path.exists(path_to_key):
|
||||
with open(path_to_key, "rb") as file:
|
||||
+ os.utime(file.fileno())
|
||||
return file.read()
|
||||
else:
|
||||
return None
|
|
@ -2,10 +2,11 @@
|
|||
, absl-py
|
||||
, blas
|
||||
, buildPythonPackage
|
||||
, etils
|
||||
, fetchFromGitHub
|
||||
, fetchpatch
|
||||
, jaxlib
|
||||
, lapack
|
||||
, matplotlib
|
||||
, numpy
|
||||
, opt-einsum
|
||||
, pytestCheckHook
|
||||
|
@ -20,7 +21,7 @@ let
|
|||
in
|
||||
buildPythonPackage rec {
|
||||
pname = "jax";
|
||||
version = "0.3.6";
|
||||
version = "0.3.16";
|
||||
format = "setuptools";
|
||||
|
||||
disabled = pythonOlder "3.7";
|
||||
|
@ -29,34 +30,25 @@ buildPythonPackage rec {
|
|||
owner = "google";
|
||||
repo = pname;
|
||||
rev = "jax-v${version}";
|
||||
hash = "sha256-eGdAEZFHadNTHgciP4KMYHdwksz9g6un0Ar+A/KV5TE=";
|
||||
hash = "sha256-4idh7boqBXSO9vEHxEcrzXjBIrKmmXiCf6cXh7En1/I=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
# See https://github.com/google/jax/issues/7944
|
||||
./cache-fix.patch
|
||||
|
||||
# See https://github.com/google/jax/issues/10292
|
||||
(fetchpatch {
|
||||
url = "https://github.com/google/jax/commit/cadc8046d56e0c1433cf48a2f106947d5f4ecbfd.patch";
|
||||
hash = "sha256-jrpIqt4LzWAswt/Cpwtfa5d1Yn31HcXkVH3ETmaigA0=";
|
||||
})
|
||||
];
|
||||
|
||||
# jaxlib is _not_ included in propagatedBuildInputs because there are
|
||||
# different versions of jaxlib depending on the desired target hardware. The
|
||||
# JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
|
||||
# CPU wheel is packaged.
|
||||
propagatedBuildInputs = [
|
||||
absl-py
|
||||
etils
|
||||
numpy
|
||||
opt-einsum
|
||||
scipy
|
||||
typing-extensions
|
||||
];
|
||||
] ++ etils.optional-dependencies.epath;
|
||||
|
||||
checkInputs = [
|
||||
jaxlib
|
||||
matplotlib
|
||||
pytestCheckHook
|
||||
pytest-xdist
|
||||
];
|
||||
|
|
|
@ -9,11 +9,14 @@
|
|||
, buildBazelPackage
|
||||
, buildPythonPackage
|
||||
, cctools
|
||||
, curl
|
||||
, cython
|
||||
, fetchFromGitHub
|
||||
, git
|
||||
, IOKit
|
||||
, jsoncpp
|
||||
, nsync
|
||||
, openssl
|
||||
, pybind11
|
||||
, setuptools
|
||||
, symlinkJoin
|
||||
|
@ -50,7 +53,7 @@ let
|
|||
inherit (cudaPackages) cudatoolkit cudnn nccl;
|
||||
|
||||
pname = "jaxlib";
|
||||
version = "0.3.0";
|
||||
version = "0.3.15";
|
||||
|
||||
meta = with lib; {
|
||||
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
|
||||
|
@ -93,7 +96,7 @@ let
|
|||
owner = "google";
|
||||
repo = "jax";
|
||||
rev = "${pname}-v${version}";
|
||||
sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72";
|
||||
sha256 = "sha256-pIl7zzl82w5HHnJadH2vtCT4mYFd5YmM9iHC2GoJD6s=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
|
@ -103,15 +106,19 @@ let
|
|||
setuptools
|
||||
wheel
|
||||
which
|
||||
] ++ lib.optionals stdenv.isDarwin [
|
||||
cctools
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
curl
|
||||
double-conversion
|
||||
giflib
|
||||
grpc
|
||||
jsoncpp
|
||||
libjpeg_turbo
|
||||
numpy
|
||||
openssl
|
||||
pkgs.flatbuffers
|
||||
pkgs.protobuf
|
||||
pybind11
|
||||
|
@ -124,6 +131,8 @@ let
|
|||
cudnn
|
||||
] ++ lib.optionals stdenv.isDarwin [
|
||||
IOKit
|
||||
] ++ lib.optionals (!stdenv.isDarwin) [
|
||||
nsync
|
||||
];
|
||||
|
||||
postPatch = ''
|
||||
|
@ -149,6 +158,7 @@ let
|
|||
build --action_env=PYENV_ROOT
|
||||
build --python_path="${python}/bin/python"
|
||||
build --distinct_host_configuration=false
|
||||
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
|
||||
'' + lib.optionalString cudaSupport ''
|
||||
build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
|
||||
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
|
||||
|
@ -163,7 +173,7 @@ let
|
|||
# Copy-paste from TF derivation.
|
||||
# Most of these are not really used in jaxlib compilation but it's simpler to keep it
|
||||
# 'as is' so that it's more compatible with TF derivation.
|
||||
TF_SYSTEM_LIBS = lib.concatStringsSep "," [
|
||||
TF_SYSTEM_LIBS = lib.concatStringsSep "," ([
|
||||
"absl_py"
|
||||
"astor_archive"
|
||||
"astunparse_archive"
|
||||
|
@ -179,7 +189,6 @@ let
|
|||
"cython"
|
||||
"dill_archive"
|
||||
"double_conversion"
|
||||
"enum34_archive"
|
||||
"flatbuffers"
|
||||
"functools32_archive"
|
||||
"gast_archive"
|
||||
|
@ -190,11 +199,9 @@ let
|
|||
"libjpeg_turbo"
|
||||
"lmdb"
|
||||
"nasm"
|
||||
# "nsync" # not packaged in nixpkgs
|
||||
"opt_einsum_archive"
|
||||
"org_sqlite"
|
||||
"pasta"
|
||||
"pcre"
|
||||
"png"
|
||||
"pybind11"
|
||||
"six_archive"
|
||||
|
@ -204,7 +211,9 @@ let
|
|||
"typing_extensions_archive"
|
||||
"wrapt"
|
||||
"zlib"
|
||||
];
|
||||
] ++ lib.optionals (!stdenv.isDarwin) [
|
||||
"nsync" # fails to build on darwin
|
||||
]);
|
||||
|
||||
# Make sure Bazel knows about our configuration flags during fetching so that the
|
||||
# relevant dependencies can be downloaded.
|
||||
|
@ -226,9 +235,11 @@ let
|
|||
fetchAttrs = {
|
||||
sha256 =
|
||||
if cudaSupport then
|
||||
"sha256-Ald+vplRx/DDG/7TfHAqD4Gktb1BGnf7FSCCJzSI0eo="
|
||||
"sha256-tdO4YjO985zbittb16RFWgxgUBrHYQfv5gRsA4IAkTk="
|
||||
else if stdenv.isDarwin then
|
||||
"sha256-+XYxfXBCASueqDGg0Zqcmpf7zmemYM6xCE+x0rl3j34="
|
||||
else
|
||||
"sha256-eK5IjTAncDarkWYKnXrEo7kw7J7iOH7in2L2GabnFYo=";
|
||||
"sha256-La1wC8X5aGK5mXvYy/kO8n4J+zaRZEc/DAX5zaH1D5A=";
|
||||
};
|
||||
|
||||
buildAttrs = {
|
||||
|
@ -239,15 +250,10 @@ let
|
|||
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
|
||||
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
|
||||
# 3) Patch python path in the compiler driver.
|
||||
# 4) Patch tensorflow sources to work with later versions of protobuf. See
|
||||
# https://github.com/google/jax/issues/9534. Note that this should be
|
||||
# removed on the next release after 0.3.0.
|
||||
preBuild = ''
|
||||
for src in ./jaxlib/*.{cc,h}; do
|
||||
for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
|
||||
sed -i 's@include/pybind11@pybind11@g' $src
|
||||
done
|
||||
substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \
|
||||
--replace "status.message()" "std::string{status.message()}"
|
||||
'' + lib.optionalString cudaSupport ''
|
||||
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
'' + lib.optionalString stdenv.isDarwin ''
|
||||
|
@ -275,7 +281,7 @@ let
|
|||
};
|
||||
platformTag =
|
||||
if stdenv.targetPlatform.isLinux then
|
||||
"manylinux2010_${stdenv.targetPlatform.linuxArch}"
|
||||
"manylinux2014_${stdenv.targetPlatform.linuxArch}"
|
||||
else if stdenv.system == "x86_64-darwin" then
|
||||
"macosx_10_9_${stdenv.targetPlatform.linuxArch}"
|
||||
else if stdenv.system == "aarch64-darwin" then
|
||||
|
@ -306,6 +312,7 @@ buildPythonPackage {
|
|||
|
||||
propagatedBuildInputs = [
|
||||
absl-py
|
||||
curl
|
||||
double-conversion
|
||||
flatbuffers
|
||||
giflib
|
||||
|
|
Loading…
Reference in a new issue