From 3dbb0e1bfb6c88388f962c7b58587ee04d59c74c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Sep 2023 19:16:44 +1000 Subject: [PATCH] feat(tests): add tests for node versions --- invokeai/app/invocations/baseinvocation.py | 9 +++++- tests/nodes/test_node_graph.py | 36 +++++++++++++++++++--- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 540571762f..65a8734690 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -32,6 +32,10 @@ if TYPE_CHECKING: from ..services.invocation_services import InvocationServices +class InvalidVersionError(ValueError): + pass + + class FieldDescriptions: denoising_start = "When to start denoising, expressed a percentage of total steps" denoising_end = "When to stop denoising, expressed a percentage of total steps" @@ -605,7 +609,10 @@ def invocation( if category is not None: cls.UIConfig.category = category if version is not None: - semver.Version.parse(version) # raises ValueError if invalid semver + try: + semver.Version.parse(version) + except ValueError as e: + raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e cls.UIConfig.version = version # Add the invocation type to the pydantic model of the invocation diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index 56bf823d14..0e1be8f343 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -1,4 +1,10 @@ -from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InvalidVersionError, + invocation, + invocation_output, +) from .test_nodes import ( ImageToImageTestInvocation, TextToImageTestInvocation, @@ -616,18 +622,38 @@ def test_invocation_decorator(): title = "Test Invocation" tags = ["first", "second", "third"] category = "category" + version = "1.2.3" - @invocation(invocation_type, title=title, tags=tags, category=category) - class Test(BaseInvocation): + @invocation(invocation_type, title=title, tags=tags, category=category, version=version) + class TestInvocation(BaseInvocation): def invoke(self): pass - schema = Test.schema() + schema = TestInvocation.schema() assert schema.get("title") == title assert schema.get("tags") == tags assert schema.get("category") == category - assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added) + assert schema.get("version") == version + assert TestInvocation(id="1").type == invocation_type # type: ignore (type is dynamically added) + + +def test_invocation_version_must_be_semver(): + invocation_type = "test_invocation" + valid_version = "1.0.0" + invalid_version = "not_semver" + + @invocation(invocation_type, version=valid_version) + class ValidVersionInvocation(BaseInvocation): + def invoke(self): + pass + + with pytest.raises(InvalidVersionError): + + @invocation(invocation_type, version=invalid_version) + class InvalidVersionInvocation(BaseInvocation): + def invoke(self): + pass def test_invocation_output_decorator():