ctranslate2: add CUDA/cuDNN support

Enabled by opting into `config.cudaSupport`.
This commit is contained in:
Martin Weinelt 2023-10-29 20:26:48 +01:00
parent 4f971ebf3c
commit 479138acc5
No known key found for this signature in database
GPG key ID: 87C1E9888F856759
2 changed files with 18 additions and 1 deletions

View file

@ -5,6 +5,9 @@
, darwin # Accelerate , darwin # Accelerate
, llvmPackages # openmp , llvmPackages # openmp
, withMkl ? false, mkl , withMkl ? false, mkl
, withCUDA ? false
, withCuDNN ? false
, cudaPackages
# Enabling both withOneDNN and withOpenblas is broken # Enabling both withOneDNN and withOpenblas is broken
# https://github.com/OpenNMT/CTranslate2/issues/1294 # https://github.com/OpenNMT/CTranslate2/issues/1294
, withOneDNN ? false, oneDNN , withOneDNN ? false, oneDNN
@ -33,6 +36,8 @@ stdenv.mkDerivation rec {
nativeBuildInputs = [ nativeBuildInputs = [
cmake cmake
] ++ lib.optionals withCUDA [
cudaPackages.cuda_nvcc
]; ];
cmakeFlags = [ cmakeFlags = [
@ -40,6 +45,8 @@ stdenv.mkDerivation rec {
# https://github.com/OpenNMT/CTranslate2/blob/54810350e662ebdb01ecbf8e4a746f02aeff1dd7/python/tools/prepare_build_environment_linux.sh#L53 # https://github.com/OpenNMT/CTranslate2/blob/54810350e662ebdb01ecbf8e4a746f02aeff1dd7/python/tools/prepare_build_environment_linux.sh#L53
# https://github.com/OpenNMT/CTranslate2/blob/59d223abcc7e636c1c2956e62482bc3299cc7766/python/tools/prepare_build_environment_macos.sh#L12 # https://github.com/OpenNMT/CTranslate2/blob/59d223abcc7e636c1c2956e62482bc3299cc7766/python/tools/prepare_build_environment_macos.sh#L12
"-DOPENMP_RUNTIME=COMP" "-DOPENMP_RUNTIME=COMP"
"-DWITH_CUDA=${cmakeBool withCUDA}"
"-DWITH_CUDNN=${cmakeBool withCuDNN}"
"-DWITH_DNNL=${cmakeBool withOneDNN}" "-DWITH_DNNL=${cmakeBool withOneDNN}"
"-DWITH_OPENBLAS=${cmakeBool withOpenblas}" "-DWITH_OPENBLAS=${cmakeBool withOpenblas}"
"-DWITH_RUY=${cmakeBool withRuy}" "-DWITH_RUY=${cmakeBool withRuy}"
@ -49,6 +56,12 @@ stdenv.mkDerivation rec {
buildInputs = lib.optionals withMkl [ buildInputs = lib.optionals withMkl [
mkl mkl
] ++ lib.optionals withCUDA [
cudaPackages.cuda_cudart
cudaPackages.libcublas
cudaPackages.libcurand
] ++ lib.optionals withCuDNN [
cudaPackages.cudnn
] ++ lib.optionals withOneDNN [ ] ++ lib.optionals withOneDNN [
oneDNN oneDNN
] ++ lib.optionals withOpenblas [ ] ++ lib.optionals withOpenblas [

View file

@ -20949,7 +20949,11 @@ with pkgs;
cpp-jwt = callPackage ../development/libraries/cpp-jwt { }; cpp-jwt = callPackage ../development/libraries/cpp-jwt { };
ctranslate2 = callPackage ../development/libraries/ctranslate2 { }; ctranslate2 = callPackage ../development/libraries/ctranslate2 {
stdenv = if pkgs.config.cudaSupport then gcc11Stdenv else stdenv;
withCUDA = pkgs.config.cudaSupport;
withCuDNN = pkgs.config.cudaSupport;
};
ubus = callPackage ../development/libraries/ubus { }; ubus = callPackage ../development/libraries/ubus { };