onnxruntime: add CUDA support

This commit is contained in:
Michal Koutenský 2024-02-10 17:30:19 +01:00
parent 5a7fbfb3ce
commit 5902d04c38
2 changed files with 85 additions and 12 deletions

View file

@ -1,7 +1,7 @@
{ stdenv
{ config
, stdenv
, lib
, fetchFromGitHub
, fetchFromGitLab
, Foundation
, abseil-cpp
, cmake
@ -18,10 +18,22 @@
, iconv
, protobuf_21
, pythonSupport ? true
}:
, cudaSupport ? config.cudaSupport
, cudaPackages ? {}
}@inputs:
let
version = "1.16.3";
stdenv = throw "Use effectiveStdenv instead";
effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
cudaCapabilities = cudaPackages.cudaFlags.cudaCapabilities;
# E.g. [ "80" "86" "90" ]
cudaArchitectures = (builtins.map cudaPackages.cudaFlags.dropDot cudaCapabilities);
cudaArchitecturesString = lib.strings.concatStringsSep ";" cudaArchitectures;
howard-hinnant-date = fetchFromGitHub {
owner = "HowardHinnant";
repo = "date";
@ -74,10 +86,17 @@ let
rev = "refs/tags/v1.14.1";
hash = "sha256-ZVSdk6LeAiZpQrrzLxphMbc1b3rNUMpcxcXPP8s/5tE=";
};
cutlass = fetchFromGitHub {
owner = "NVIDIA";
repo = "cutlass";
rev = "v3.0.0";
sha256 = "sha256-YPD5Sy6SvByjIcGtgeGH80TEKg2BtqJWSg46RvnJChY=";
};
in
stdenv.mkDerivation rec {
effectiveStdenv.mkDerivation rec {
pname = "onnxruntime";
version = "1.16.3";
inherit version;
src = fetchFromGitHub {
owner = "microsoft";
@ -96,6 +115,10 @@ stdenv.mkDerivation rec {
# - use MakeAvailable instead of the low-level Populate,
# - use Eigen3::Eigen as the target name (as declared by libeigen/eigen).
./0001-eigen-allow-dependency-injection.patch
] ++ lib.optionals cudaSupport [
# We apply the referenced 1064.patch ourselves to our nix dependency.
# FIND_PACKAGE_ARGS for CUDA was added in https://github.com/microsoft/onnxruntime/commit/87744e5 so it might be possible to delete this patch after upgrading to 1.17.0
./nvcc-gsl.patch
];
nativeBuildInputs = [
@ -109,7 +132,9 @@ stdenv.mkDerivation rec {
pythonOutputDistHook
setuptools
wheel
]);
]) ++ lib.optionals cudaSupport [
cudaPackages.cuda_nvcc
];
buildInputs = [
eigen
@ -118,16 +143,24 @@ stdenv.mkDerivation rec {
nlohmann_json
microsoft-gsl
] ++ lib.optionals pythonSupport (with python3Packages; [
gtest'
numpy
pybind11
packaging
]) ++ lib.optionals stdenv.isDarwin [
]) ++ lib.optionals effectiveStdenv.isDarwin [
Foundation
iconv
];
] ++ lib.optionals cudaSupport (with cudaPackages; [
cuda_cccl # cub/cub.cuh
libcublas # cublas_v2.h
libcurand # curand.h
libcusparse # cusparse.h
libcufft # cufft.h
cudnn # cudnn.h
cuda_cudart
]);
nativeCheckInputs = lib.optionals pythonSupport (with python3Packages; [
gtest'
pytest
sympy
onnx
@ -159,23 +192,31 @@ stdenv.mkDerivation rec {
"-Donnxruntime_BUILD_UNIT_TESTS=ON"
"-Donnxruntime_ENABLE_LTO=ON"
"-Donnxruntime_USE_FULL_PROTOBUF=OFF"
(lib.cmakeBool "onnxruntime_USE_CUDA" cudaSupport)
(lib.cmakeBool "onnxruntime_USE_NCCL" cudaSupport)
] ++ lib.optionals pythonSupport [
"-Donnxruntime_ENABLE_PYTHON=ON"
] ++ lib.optionals cudaSupport [
(lib.cmakeFeature "FETCHCONTENT_SOURCE_DIR_CUTLASS" cutlass)
(lib.cmakeFeature "onnxruntime_CUDNN_HOME" cudaPackages.cudnn)
(lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" cudaArchitecturesString)
];
env = lib.optionalAttrs stdenv.cc.isClang {
env = lib.optionalAttrs effectiveStdenv.cc.isClang {
NIX_CFLAGS_COMPILE = toString [
"-Wno-error=deprecated-declarations"
"-Wno-error=unused-but-set-variable"
];
};
doCheck = true;
doCheck = !cudaSupport;
requiredSystemFeatures = lib.optionals cudaSupport [ "big-parallel" ];
postPatch = ''
substituteInPlace cmake/libonnxruntime.pc.cmake.in \
--replace-fail '$'{prefix}/@CMAKE_INSTALL_ @CMAKE_INSTALL_
'' + lib.optionalString (stdenv.hostPlatform.system == "aarch64-linux") ''
'' + lib.optionalString (effectiveStdenv.hostPlatform.system == "aarch64-linux") ''
# https://github.com/NixOS/nixpkgs/pull/226734#issuecomment-1663028691
rm -v onnxruntime/test/optimizer/nhwc_transformer_test.cc
'';

View file

@ -0,0 +1,32 @@
diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake
index 9effd1a2db..faff5e8de7 100644
--- a/cmake/external/onnxruntime_external_deps.cmake
+++ b/cmake/external/onnxruntime_external_deps.cmake
@@ -280,21 +280,12 @@ if (NOT WIN32)
endif()
endif()
-if(onnxruntime_USE_CUDA)
- FetchContent_Declare(
- GSL
- URL ${DEP_URL_microsoft_gsl}
- URL_HASH SHA1=${DEP_SHA1_microsoft_gsl}
- PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/gsl/1064.patch
- )
-else()
- FetchContent_Declare(
- GSL
- URL ${DEP_URL_microsoft_gsl}
- URL_HASH SHA1=${DEP_SHA1_microsoft_gsl}
- FIND_PACKAGE_ARGS 4.0 NAMES Microsoft.GSL
- )
-endif()
+FetchContent_Declare(
+ GSL
+ URL ${DEP_URL_microsoft_gsl}
+ URL_HASH SHA1=${DEP_SHA1_microsoft_gsl}
+ FIND_PACKAGE_ARGS 4.0 NAMES Microsoft.GSL
+)
FetchContent_Declare(
safeint