python3Packages.{jaxlibWithCuda, jaxlib-bin}: add ptxas to $out/bin
This commit is contained in:
parent
73ad5f9e14
commit
f9588d6966
2 changed files with 15 additions and 3 deletions
|
@ -120,9 +120,15 @@ buildPythonPackage rec {
|
||||||
done
|
done
|
||||||
'';
|
'';
|
||||||
|
|
||||||
# pip dependencies and optionally cudatoolkit. Note that cudatoolkit is
|
propagatedBuildInputs = [ absl-py flatbuffers scipy ];
|
||||||
# necessary since jaxlib looks for "ptxas" in $PATH.
|
|
||||||
propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11;
|
# Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
|
||||||
|
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
|
||||||
|
# more info.
|
||||||
|
postInstall = lib.optional cudaSupport ''
|
||||||
|
mkdir -p $out/bin
|
||||||
|
ln -s ${cudatoolkit_11}/bin/ptxas $out/bin/ptxas
|
||||||
|
'';
|
||||||
|
|
||||||
pythonImportsCheck = [ "jaxlib" ];
|
pythonImportsCheck = [ "jaxlib" ];
|
||||||
|
|
||||||
|
|
|
@ -259,7 +259,13 @@ buildPythonPackage {
|
||||||
|
|
||||||
src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl";
|
src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl";
|
||||||
|
|
||||||
|
# Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
|
||||||
|
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
|
||||||
|
# more info.
|
||||||
postInstall = lib.optionalString cudaSupport ''
|
postInstall = lib.optionalString cudaSupport ''
|
||||||
|
mkdir -p $out/bin
|
||||||
|
ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
|
||||||
|
|
||||||
find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
|
find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
|
||||||
addOpenGLRunpath "$lib"
|
addOpenGLRunpath "$lib"
|
||||||
patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
|
patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
|
||||||
|
|
Loading…
Reference in a new issue