Merge pull request #276800 from SomeoneSerge/fix/cuda-no-throw

cudaPackages: eliminate exceptions
This commit is contained in:
Someone 2023-12-26 04:54:10 +00:00 committed by GitHub
commit 86b7775ff3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 47 additions and 35 deletions

View file

@ -87,7 +87,7 @@ attrsets.filterAttrs (attr: _: (builtins.hasAttr attr prev)) {
cuda_nvcc = prev.cuda_nvcc.overrideAttrs ( cuda_nvcc = prev.cuda_nvcc.overrideAttrs (
oldAttrs: { oldAttrs: {
outputs = oldAttrs.outputs ++ [ "lib" ]; outputs = oldAttrs.outputs ++ lists.optionals (!(builtins.elem "lib" oldAttrs.outputs)) [ "lib" ];
# Patch the nvcc.profile. # Patch the nvcc.profile.
# Syntax: # Syntax:

View file

@ -143,7 +143,7 @@ let
else if nixSystem == "x86_64-windows" then else if nixSystem == "x86_64-windows" then
"windows-x86_64" "windows-x86_64"
else else
builtins.throw "Unsupported Nix system: ${nixSystem}"; "unsupported";
# Maps NVIDIA redist arch to Nix system. # Maps NVIDIA redist arch to Nix system.
# It is imperative that we include the boolean condition based on jetsonTargets to ensure # It is imperative that we include the boolean condition based on jetsonTargets to ensure
@ -163,7 +163,7 @@ let
else if redistArch == "windows-x86_64" then else if redistArch == "windows-x86_64" then
"x86_64-windows" "x86_64-windows"
else else
builtins.throw "Unsupported NVIDIA redist arch: ${redistArch}"; "unsupported-${redistArch}";
formatCapabilities = formatCapabilities =
{ {
@ -175,9 +175,10 @@ let
# archNames :: List String # archNames :: List String
# E.g. [ "Turing" "Ampere" ] # E.g. [ "Turing" "Ampere" ]
#
# Unknown architectures are rendered as sm_XX gencode flags.
archNames = lists.unique ( archNames = lists.unique (
lists.map (cap: cudaComputeCapabilityToName.${cap} or (throw "missing cuda compute capability")) lists.map (cap: cudaComputeCapabilityToName.${cap} or "sm_${dropDot cap}") cudaCapabilities
cudaCapabilities
); );
# realArches :: List String # realArches :: List String

View file

@ -77,7 +77,7 @@ backendStdenv.mkDerivation (
false false
featureRelease; featureRelease;
# Order is important here so we use a list. # Order is important here so we use a list.
additionalOutputs = builtins.filter hasOutput [ possibleOutputs = [
"bin" "bin"
"lib" "lib"
"static" "static"
@ -86,8 +86,10 @@ backendStdenv.mkDerivation (
"sample" "sample"
"python" "python"
]; ];
additionalOutputs =
if redistArch == "unsupported" then possibleOutputs else builtins.filter hasOutput possibleOutputs;
# The out output is special -- it's the default output and we always include it. # The out output is special -- it's the default output and we always include it.
outputs = ["out"] ++ additionalOutputs; outputs = [ "out" ] ++ additionalOutputs;
in in
outputs; outputs;
@ -115,10 +117,14 @@ backendStdenv.mkDerivation (
brokenConditions = {}; brokenConditions = {};
src = fetchurl { src = fetchurl {
url = "https://developer.download.nvidia.com/compute/${redistName}/redist/${ url =
redistribRelease.${redistArch}.relative_path if (builtins.hasAttr redistArch redistribRelease) then
}"; "https://developer.download.nvidia.com/compute/${redistName}/redist/${
inherit (redistribRelease.${redistArch}) sha256; redistribRelease.${redistArch}.relative_path
}"
else
"cannot-construct-an-url-for-the-${redistArch}-platform";
sha256 = redistribRelease.${redistArch}.sha256 or lib.fakeHash;
}; };
postPatch = '' postPatch = ''
@ -283,9 +289,9 @@ backendStdenv.mkDerivation (
( (
redistArch: redistArch:
let let
nixSystem = builtins.tryEval (flags.getNixSystem redistArch); nixSystem = flags.getNixSystem redistArch;
in in
if nixSystem.success then [nixSystem.value] else [] lists.optionals (!(strings.hasPrefix "unsupported-" nixSystem)) [ nixSystem ]
) )
supportedRedistArchs; supportedRedistArchs;
broken = lists.any trivial.id (attrsets.attrValues finalAttrs.brokenConditions); broken = lists.any trivial.id (attrsets.attrValues finalAttrs.brokenConditions);

View file

@ -59,9 +59,12 @@ let
# - Releases: ../modules/${pname}/releases/releases.nix # - Releases: ../modules/${pname}/releases/releases.nix
# - Package: ../modules/${pname}/releases/package.nix # - Package: ../modules/${pname}/releases/package.nix
# FIXME: do this at the module system level
propagatePlatforms = lib.mapAttrs (platform: subset: map (r: r // { inherit platform; }) subset);
# All releases across all platforms # All releases across all platforms
# See ../modules/${pname}/releases/releases.nix # See ../modules/${pname}/releases/releases.nix
allReleases = evaluatedModules.config.${pname}.releases; releaseSets = propagatePlatforms evaluatedModules.config.${pname}.releases;
# Compute versioned attribute name to be used in this package set # Compute versioned attribute name to be used in this package set
# Patch version changes should not break the build, so we only use major and minor # Patch version changes should not break the build, so we only use major and minor
@ -72,20 +75,22 @@ let
# isSupported :: Package -> Bool # isSupported :: Package -> Bool
isSupported = isSupported =
package: package:
strings.versionAtLeast cudaVersion package.minCudaVersion !(strings.hasPrefix "unsupported" package.platform)
&& strings.versionAtLeast cudaVersion package.minCudaVersion
&& strings.versionAtLeast package.maxCudaVersion cudaVersion; && strings.versionAtLeast package.maxCudaVersion cudaVersion;
# Get all of the packages for our given platform. # Get all of the packages for our given platform.
redistArch = flags.getRedistArch hostPlatform.system; redistArch = flags.getRedistArch hostPlatform.system;
allReleases = builtins.concatMap (xs: xs) (builtins.attrValues releaseSets);
# All the supported packages we can build for our platform. # All the supported packages we can build for our platform.
# supportedPackages :: List (AttrSet Packages) # perSystemReleases :: List Package
supportedPackages = builtins.filter isSupported (allReleases.${redistArch} or []); perSystemReleases = releaseSets.${redistArch} or [ ];
# newestToOldestSupportedPackage :: List (AttrSet Packages) preferable =
newestToOldestSupportedPackage = lists.reverseList supportedPackages; p1: p2: (isSupported p2 -> isSupported p1) && (strings.versionAtLeast p1.version p2.version);
newest = builtins.head (builtins.sort preferable allReleases);
nameOfNewest = computeName (builtins.head newestToOldestSupportedPackage);
# A function which takes the `final` overlay and the `package` being built and returns # A function which takes the `final` overlay and the `package` being built and returns
# a function to be consumed via `overrideAttrs`. # a function to be consumed via `overrideAttrs`.
@ -120,11 +125,9 @@ let
attrsets.nameValuePair name fixedDrv; attrsets.nameValuePair name fixedDrv;
# versionedDerivations :: AttrSet Derivation # versionedDerivations :: AttrSet Derivation
versionedDerivations = builtins.listToAttrs (lists.map buildPackage newestToOldestSupportedPackage); versionedDerivations = builtins.listToAttrs (lists.map buildPackage perSystemReleases);
defaultDerivation = attrsets.optionalAttrs (versionedDerivations != {}) { defaultDerivation = { ${pname} = (buildPackage newest).value; };
${pname} = versionedDerivations.${nameOfNewest};
};
in in
versionedDerivations // defaultDerivation; versionedDerivations // defaultDerivation;
in in

View file

@ -16,6 +16,13 @@ let
strings strings
versions versions
; ;
targetArch =
if hostPlatform.isx86_64 then
"x86_64-linux-gnu"
else if hostPlatform.isAarch64 then
"aarch64-linux-gnu"
else
"unsupported";
in in
finalAttrs: prevAttrs: { finalAttrs: prevAttrs: {
# Useful for inspecting why something went wrong. # Useful for inspecting why something went wrong.
@ -58,18 +65,9 @@ finalAttrs: prevAttrs: {
# We need to look inside the extracted output to get the files we need. # We need to look inside the extracted output to get the files we need.
sourceRoot = "TensorRT-${finalAttrs.version}"; sourceRoot = "TensorRT-${finalAttrs.version}";
buildInputs = prevAttrs.buildInputs ++ [finalAttrs.passthru.cudnn.lib]; buildInputs = prevAttrs.buildInputs ++ [ finalAttrs.passthru.cudnn.lib ];
preInstall = preInstall =
let
targetArch =
if hostPlatform.isx86_64 then
"x86_64-linux-gnu"
else if hostPlatform.isAarch64 then
"aarch64-linux-gnu"
else
throw "Unsupported architecture";
in
(prevAttrs.preInstall or "") (prevAttrs.preInstall or "")
+ '' + ''
# Replace symlinks to bin and lib with the actual directories from targets. # Replace symlinks to bin and lib with the actual directories from targets.
@ -107,6 +105,9 @@ finalAttrs: prevAttrs: {
}; };
meta = prevAttrs.meta // { meta = prevAttrs.meta // {
badPlatforms =
prevAttrs.meta.badPlatforms or [ ]
++ lib.optionals (targetArch == "unsupported") [ hostPlatform.system ];
homepage = "https://developer.nvidia.com/tensorrt"; homepage = "https://developer.nvidia.com/tensorrt";
maintainers = prevAttrs.meta.maintainers ++ [maintainers.aidalgol]; maintainers = prevAttrs.meta.maintainers ++ [maintainers.aidalgol];
}; };

View file

@ -480,6 +480,7 @@ let
}; };
meta = with lib; { meta = with lib; {
badPlatforms = lib.optionals cudaSupport lib.platforms.darwin;
changelog = "https://github.com/tensorflow/tensorflow/releases/tag/v${version}"; changelog = "https://github.com/tensorflow/tensorflow/releases/tag/v${version}";
description = "Computation using data flow graphs for scalable machine learning"; description = "Computation using data flow graphs for scalable machine learning";
homepage = "http://tensorflow.org"; homepage = "http://tensorflow.org";