ctranslate2: add CUDA/cuDNN support
Enabled by opting into `config.cudaSupport`.
This commit is contained in:
parent
4f971ebf3c
commit
479138acc5
2 changed files with 18 additions and 1 deletions
|
@ -5,6 +5,9 @@
|
||||||
, darwin # Accelerate
|
, darwin # Accelerate
|
||||||
, llvmPackages # openmp
|
, llvmPackages # openmp
|
||||||
, withMkl ? false, mkl
|
, withMkl ? false, mkl
|
||||||
|
, withCUDA ? false
|
||||||
|
, withCuDNN ? false
|
||||||
|
, cudaPackages
|
||||||
# Enabling both withOneDNN and withOpenblas is broken
|
# Enabling both withOneDNN and withOpenblas is broken
|
||||||
# https://github.com/OpenNMT/CTranslate2/issues/1294
|
# https://github.com/OpenNMT/CTranslate2/issues/1294
|
||||||
, withOneDNN ? false, oneDNN
|
, withOneDNN ? false, oneDNN
|
||||||
|
@ -33,6 +36,8 @@ stdenv.mkDerivation rec {
|
||||||
|
|
||||||
nativeBuildInputs = [
|
nativeBuildInputs = [
|
||||||
cmake
|
cmake
|
||||||
|
] ++ lib.optionals withCUDA [
|
||||||
|
cudaPackages.cuda_nvcc
|
||||||
];
|
];
|
||||||
|
|
||||||
cmakeFlags = [
|
cmakeFlags = [
|
||||||
|
@ -40,6 +45,8 @@ stdenv.mkDerivation rec {
|
||||||
# https://github.com/OpenNMT/CTranslate2/blob/54810350e662ebdb01ecbf8e4a746f02aeff1dd7/python/tools/prepare_build_environment_linux.sh#L53
|
# https://github.com/OpenNMT/CTranslate2/blob/54810350e662ebdb01ecbf8e4a746f02aeff1dd7/python/tools/prepare_build_environment_linux.sh#L53
|
||||||
# https://github.com/OpenNMT/CTranslate2/blob/59d223abcc7e636c1c2956e62482bc3299cc7766/python/tools/prepare_build_environment_macos.sh#L12
|
# https://github.com/OpenNMT/CTranslate2/blob/59d223abcc7e636c1c2956e62482bc3299cc7766/python/tools/prepare_build_environment_macos.sh#L12
|
||||||
"-DOPENMP_RUNTIME=COMP"
|
"-DOPENMP_RUNTIME=COMP"
|
||||||
|
"-DWITH_CUDA=${cmakeBool withCUDA}"
|
||||||
|
"-DWITH_CUDNN=${cmakeBool withCuDNN}"
|
||||||
"-DWITH_DNNL=${cmakeBool withOneDNN}"
|
"-DWITH_DNNL=${cmakeBool withOneDNN}"
|
||||||
"-DWITH_OPENBLAS=${cmakeBool withOpenblas}"
|
"-DWITH_OPENBLAS=${cmakeBool withOpenblas}"
|
||||||
"-DWITH_RUY=${cmakeBool withRuy}"
|
"-DWITH_RUY=${cmakeBool withRuy}"
|
||||||
|
@ -49,6 +56,12 @@ stdenv.mkDerivation rec {
|
||||||
|
|
||||||
buildInputs = lib.optionals withMkl [
|
buildInputs = lib.optionals withMkl [
|
||||||
mkl
|
mkl
|
||||||
|
] ++ lib.optionals withCUDA [
|
||||||
|
cudaPackages.cuda_cudart
|
||||||
|
cudaPackages.libcublas
|
||||||
|
cudaPackages.libcurand
|
||||||
|
] ++ lib.optionals withCuDNN [
|
||||||
|
cudaPackages.cudnn
|
||||||
] ++ lib.optionals withOneDNN [
|
] ++ lib.optionals withOneDNN [
|
||||||
oneDNN
|
oneDNN
|
||||||
] ++ lib.optionals withOpenblas [
|
] ++ lib.optionals withOpenblas [
|
||||||
|
|
|
@ -20949,7 +20949,11 @@ with pkgs;
|
||||||
|
|
||||||
cpp-jwt = callPackage ../development/libraries/cpp-jwt { };
|
cpp-jwt = callPackage ../development/libraries/cpp-jwt { };
|
||||||
|
|
||||||
ctranslate2 = callPackage ../development/libraries/ctranslate2 { };
|
ctranslate2 = callPackage ../development/libraries/ctranslate2 {
|
||||||
|
stdenv = if pkgs.config.cudaSupport then gcc11Stdenv else stdenv;
|
||||||
|
withCUDA = pkgs.config.cudaSupport;
|
||||||
|
withCuDNN = pkgs.config.cudaSupport;
|
||||||
|
};
|
||||||
|
|
||||||
ubus = callPackage ../development/libraries/ubus { };
|
ubus = callPackage ../development/libraries/ubus { };
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue