diff --git a/pkgs/development/python-modules/py-deprecate/default.nix b/pkgs/development/python-modules/py-deprecate/default.nix index 12ea2b7ce9e..ff921789bbf 100644 --- a/pkgs/development/python-modules/py-deprecate/default.nix +++ b/pkgs/development/python-modules/py-deprecate/default.nix @@ -28,7 +28,7 @@ buildPythonPackage { pythonImportsCheck = [ "deprecate" ]; meta = with lib; { - description = "A python module for marking deprecated functions or classes and re-routing to the new successors' instance. Used by torchmetrics"; + description = "A module for marking deprecated functions or classes and re-routing to the new successors' instance. Used by torchmetrics"; homepage = "https://borda.github.io/pyDeprecate/"; license = licenses.asl20; maintainers = with maintainers; [ diff --git a/pkgs/development/python-modules/torchmetrics/default.nix b/pkgs/development/python-modules/torchmetrics/default.nix new file mode 100644 index 00000000000..fc2f6cf7535 --- /dev/null +++ b/pkgs/development/python-modules/torchmetrics/default.nix @@ -0,0 +1,76 @@ +{ lib +, buildPythonPackage +, fetchFromGitHub +, cloudpickle +, scikit-learn +, scikitimage +, packaging +, psutil +, py-deprecate +, pytorch +, pytestCheckHook +, torchmetrics +}: + +let + pname = "torchmetrics"; + version = "0.8.1"; +in +buildPythonPackage { + inherit pname version; + + src = fetchFromGitHub { + owner = "PyTorchLightning"; + repo = "metrics"; + rev = "v${version}"; + hash = "sha256-AryEhYAeC97dO2pgHoz0Y9F//DVdX6RfCa80gI56iz4="; + }; + + propagatedBuildInputs = [ + packaging + py-deprecate + ]; + + # Let the user bring their own instance + buildInputs = [ + pytorch + ]; + + checkInputs = [ + scikit-learn + scikitimage + cloudpickle + psutil + pytestCheckHook + ]; + + # A cyclic dependency in: integrations/test_lightning.py + doCheck = false; + passthru.tests.check = torchmetrics.overridePythonAttrs (_: { + doCheck = true; + }); + + disabledTestPaths = [ + # These require too many "leftpad-level" dependencies + "tests/text" + "tests/audio" + "tests/image" + + # A few non-deterministic things like test_check_compute_groups_is_faster + "tests/bases/test_collections.py" + ]; + + pythonImportsCheck = [ + "torchmetrics" + ]; + + meta = with lib; { + description = "Machine learning metrics for distributed, scalable PyTorch applications (used in pytorch-lightning)"; + homepage = "https://torchmetrics.readthedocs.io"; + license = licenses.asl20; + maintainers = with maintainers; [ + SomeoneSerge + ]; + }; +} + diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index a5607542ac6..e96811a0207 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -10362,6 +10362,8 @@ in { torchgpipe = callPackage ../development/python-modules/torchgpipe { }; + torchmetrics = callPackage ../development/python-modules/torchmetrics { }; + torchinfo = callPackage ../development/python-modules/torchinfo { }; torchvision = callPackage ../development/python-modules/torchvision { };