feat(tests): add tests for node versions

This commit is contained in:
psychedelicious 2023-09-04 19:16:44 +10:00
parent d6317bc53f
commit 3dbb0e1bfb
2 changed files with 39 additions and 6 deletions

View File

@ -32,6 +32,10 @@ if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
class InvalidVersionError(ValueError):
pass
class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps"
@ -605,7 +609,10 @@ def invocation(
if category is not None:
cls.UIConfig.category = category
if version is not None:
semver.Version.parse(version) # raises ValueError if invalid semver
try:
semver.Version.parse(version)
except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
cls.UIConfig.version = version
# Add the invocation type to the pydantic model of the invocation

View File

@ -1,4 +1,10 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvalidVersionError,
invocation,
invocation_output,
)
from .test_nodes import (
ImageToImageTestInvocation,
TextToImageTestInvocation,
@ -616,18 +622,38 @@ def test_invocation_decorator():
title = "Test Invocation"
tags = ["first", "second", "third"]
category = "category"
version = "1.2.3"
@invocation(invocation_type, title=title, tags=tags, category=category)
class Test(BaseInvocation):
@invocation(invocation_type, title=title, tags=tags, category=category, version=version)
class TestInvocation(BaseInvocation):
def invoke(self):
pass
schema = Test.schema()
schema = TestInvocation.schema()
assert schema.get("title") == title
assert schema.get("tags") == tags
assert schema.get("category") == category
assert Test(id="1").type == invocation_type # type: ignore (type is dynamically added)
assert schema.get("version") == version
assert TestInvocation(id="1").type == invocation_type # type: ignore (type is dynamically added)
def test_invocation_version_must_be_semver():
invocation_type = "test_invocation"
valid_version = "1.0.0"
invalid_version = "not_semver"
@invocation(invocation_type, version=valid_version)
class ValidVersionInvocation(BaseInvocation):
def invoke(self):
pass
with pytest.raises(InvalidVersionError):
@invocation(invocation_type, version=invalid_version)
class InvalidVersionInvocation(BaseInvocation):
def invoke(self):
pass
def test_invocation_output_decorator():