ctranslate2: add CUDA/cuDNN support

Enabled by opting into `config.cudaSupport`.
This commit is contained in:
Martin Weinelt 2023-10-29 20:26:48 +01:00
parent 4f971ebf3c
commit 479138acc5
No known key found for this signature in database
GPG key ID: 87C1E9888F856759
2 changed files with 18 additions and 1 deletions

View file

@ -5,6 +5,9 @@
, darwin # Accelerate
, llvmPackages # openmp
, withMkl ? false, mkl
, withCUDA ? false
, withCuDNN ? false
, cudaPackages
# Enabling both withOneDNN and withOpenblas is broken
# https://github.com/OpenNMT/CTranslate2/issues/1294
, withOneDNN ? false, oneDNN
@ -33,6 +36,8 @@ stdenv.mkDerivation rec {
nativeBuildInputs = [
cmake
] ++ lib.optionals withCUDA [
cudaPackages.cuda_nvcc
];
cmakeFlags = [
@ -40,6 +45,8 @@ stdenv.mkDerivation rec {
# https://github.com/OpenNMT/CTranslate2/blob/54810350e662ebdb01ecbf8e4a746f02aeff1dd7/python/tools/prepare_build_environment_linux.sh#L53
# https://github.com/OpenNMT/CTranslate2/blob/59d223abcc7e636c1c2956e62482bc3299cc7766/python/tools/prepare_build_environment_macos.sh#L12
"-DOPENMP_RUNTIME=COMP"
"-DWITH_CUDA=${cmakeBool withCUDA}"
"-DWITH_CUDNN=${cmakeBool withCuDNN}"
"-DWITH_DNNL=${cmakeBool withOneDNN}"
"-DWITH_OPENBLAS=${cmakeBool withOpenblas}"
"-DWITH_RUY=${cmakeBool withRuy}"
@ -49,6 +56,12 @@ stdenv.mkDerivation rec {
buildInputs = lib.optionals withMkl [
mkl
] ++ lib.optionals withCUDA [
cudaPackages.cuda_cudart
cudaPackages.libcublas
cudaPackages.libcurand
] ++ lib.optionals withCuDNN [
cudaPackages.cudnn
] ++ lib.optionals withOneDNN [
oneDNN
] ++ lib.optionals withOpenblas [

View file

@ -20949,7 +20949,11 @@ with pkgs;
cpp-jwt = callPackage ../development/libraries/cpp-jwt { };
ctranslate2 = callPackage ../development/libraries/ctranslate2 { };
ctranslate2 = callPackage ../development/libraries/ctranslate2 {
stdenv = if pkgs.config.cudaSupport then gcc11Stdenv else stdenv;
withCUDA = pkgs.config.cudaSupport;
withCuDNN = pkgs.config.cudaSupport;
};
ubus = callPackage ../development/libraries/ubus { };