diff --git a/pkgs/development/python-modules/orbax-checkpoint/default.nix b/pkgs/development/python-modules/orbax-checkpoint/default.nix new file mode 100644 index 000000000000..0f9d467335ce --- /dev/null +++ b/pkgs/development/python-modules/orbax-checkpoint/default.nix @@ -0,0 +1,78 @@ +{ lib +, absl-py +, buildPythonPackage +, cached-property +, etils +, fetchPypi +, flit-core +, importlib-resources +, jax +, jaxlib +, msgpack +, nest-asyncio +, numpy +, protobuf +, pytest-xdist +, pytestCheckHook +, pythonOlder +, pyyaml +, tensorstore +, typing-extensions +}: + +buildPythonPackage rec { + pname = "orbax-checkpoint"; + version = "0.5.3"; + pyproject = true; + + disabled = pythonOlder "3.9"; + + src = fetchPypi { + pname = "orbax_checkpoint"; + inherit version; + hash = "sha256-FXKQTLv+hROSfg2A+AtzDg7y9oAzLTwoENhENTKTi0U="; + }; + + nativeBuildInputs = [ + flit-core + ]; + + propagatedBuildInputs = [ + absl-py + cached-property + etils + importlib-resources + jax + jaxlib + msgpack + nest-asyncio + numpy + protobuf + pyyaml + tensorstore + typing-extensions + ]; + + nativeCheckInputs = [ + pytest-xdist + pytestCheckHook + ]; + + pythonImportsCheck = [ + "orbax" + ]; + + disabledTestPaths = [ + # Circular dependency flax + "orbax/checkpoint/transform_utils_test.py" + "orbax/checkpoint/utils_test.py" + ]; + + meta = with lib; { + description = "Orbax provides common utility libraries for JAX users"; + homepage = "https://github.com/google/orbax/tree/main/checkpoint"; + changelog = "https://github.com/google/orbax/blob/${version}/CHANGELOG.md"; + license = licenses.asl20; + maintainers = with maintainers; [fab ]; + }; +} diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index 214fd82a327e..3bd7e9174ad6 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -8952,6 +8952,8 @@ self: super: with self; { oras = callPackage ../development/python-modules/oras { }; + orbax-checkpoint = callPackage ../development/python-modules/orbax-checkpoint { }; + orderedmultidict = callPackage ../development/python-modules/orderedmultidict { }; ordered-set = callPackage ../development/python-modules/ordered-set { };