From 479138acc513e9d770d3e10ab45d778392017358 Mon Sep 17 00:00:00 2001 From: Martin Weinelt Date: Sun, 29 Oct 2023 20:26:48 +0100 Subject: [PATCH] ctranslate2: add CUDA/cuDNN support Enabled by opting into `config.cudaSupport`. --- pkgs/development/libraries/ctranslate2/default.nix | 13 +++++++++++++ pkgs/top-level/all-packages.nix | 6 +++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pkgs/development/libraries/ctranslate2/default.nix b/pkgs/development/libraries/ctranslate2/default.nix index f9408818e37f..04bbcf6e6b10 100644 --- a/pkgs/development/libraries/ctranslate2/default.nix +++ b/pkgs/development/libraries/ctranslate2/default.nix @@ -5,6 +5,9 @@ , darwin # Accelerate , llvmPackages # openmp , withMkl ? false, mkl +, withCUDA ? false +, withCuDNN ? false +, cudaPackages # Enabling both withOneDNN and withOpenblas is broken # https://github.com/OpenNMT/CTranslate2/issues/1294 , withOneDNN ? false, oneDNN @@ -33,6 +36,8 @@ stdenv.mkDerivation rec { nativeBuildInputs = [ cmake + ] ++ lib.optionals withCUDA [ + cudaPackages.cuda_nvcc ]; cmakeFlags = [ @@ -40,6 +45,8 @@ stdenv.mkDerivation rec { # https://github.com/OpenNMT/CTranslate2/blob/54810350e662ebdb01ecbf8e4a746f02aeff1dd7/python/tools/prepare_build_environment_linux.sh#L53 # https://github.com/OpenNMT/CTranslate2/blob/59d223abcc7e636c1c2956e62482bc3299cc7766/python/tools/prepare_build_environment_macos.sh#L12 "-DOPENMP_RUNTIME=COMP" + "-DWITH_CUDA=${cmakeBool withCUDA}" + "-DWITH_CUDNN=${cmakeBool withCuDNN}" "-DWITH_DNNL=${cmakeBool withOneDNN}" "-DWITH_OPENBLAS=${cmakeBool withOpenblas}" "-DWITH_RUY=${cmakeBool withRuy}" @@ -49,6 +56,12 @@ stdenv.mkDerivation rec { buildInputs = lib.optionals withMkl [ mkl + ] ++ lib.optionals withCUDA [ + cudaPackages.cuda_cudart + cudaPackages.libcublas + cudaPackages.libcurand + ] ++ lib.optionals withCuDNN [ + cudaPackages.cudnn ] ++ lib.optionals withOneDNN [ oneDNN ] ++ lib.optionals withOpenblas [ diff --git a/pkgs/top-level/all-packages.nix b/pkgs/top-level/all-packages.nix index b0150c6bca6d..7fc9d2f0386d 100644 --- a/pkgs/top-level/all-packages.nix +++ b/pkgs/top-level/all-packages.nix @@ -20949,7 +20949,11 @@ with pkgs; cpp-jwt = callPackage ../development/libraries/cpp-jwt { }; - ctranslate2 = callPackage ../development/libraries/ctranslate2 { }; + ctranslate2 = callPackage ../development/libraries/ctranslate2 { + stdenv = if pkgs.config.cudaSupport then gcc11Stdenv else stdenv; + withCUDA = pkgs.config.cudaSupport; + withCuDNN = pkgs.config.cudaSupport; + }; ubus = callPackage ../development/libraries/ubus { };