diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index e2d435e621..171cdfdb6f 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -76,6 +76,7 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + workflow_image_records=None, # type: ignore ) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 9774f07fdd..25b02955b0 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -81,6 +81,7 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + workflow_image_records=None, # type: ignore ) diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py index 7807a56879..1d7f2e4194 100644 --- a/tests/nodes/test_nodes.py +++ b/tests/nodes/test_nodes.py @@ -1,11 +1,12 @@ from typing import Any, Callable, Union -from pydantic import Field from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + InputField, InvocationContext, + OutputField, invocation, invocation_output, ) @@ -15,12 +16,12 @@ from invokeai.app.invocations.image import ImageField # Define test invocations before importing anything that uses invocations @invocation_output("test_list_output") class ListPassThroughInvocationOutput(BaseInvocationOutput): - collection: list[ImageField] = Field(default_factory=list) + collection: list[ImageField] = OutputField(default_factory=list) @invocation("test_list") class ListPassThroughInvocation(BaseInvocation): - collection: list[ImageField] = Field(default_factory=list) + collection: list[ImageField] = InputField(default_factory=list) def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: return ListPassThroughInvocationOutput(collection=self.collection) @@ -28,12 +29,12 @@ class ListPassThroughInvocation(BaseInvocation): @invocation_output("test_prompt_output") class PromptTestInvocationOutput(BaseInvocationOutput): - prompt: str = Field(default="") + prompt: str = OutputField(default="") @invocation("test_prompt") class PromptTestInvocation(BaseInvocation): - prompt: str = Field(default="") + prompt: str = InputField(default="") def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: return PromptTestInvocationOutput(prompt=self.prompt) @@ -47,13 +48,13 @@ class ErrorInvocation(BaseInvocation): @invocation_output("test_image_output") class ImageTestInvocationOutput(BaseInvocationOutput): - image: ImageField = Field() + image: ImageField = OutputField() @invocation("test_text_to_image") class TextToImageTestInvocation(BaseInvocation): - prompt: str = Field(default="") - prompt2: str = Field(default="") + prompt: str = InputField(default="") + prompt2: str = InputField(default="") def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -61,8 +62,8 @@ class TextToImageTestInvocation(BaseInvocation): @invocation("test_image_to_image") class ImageToImageTestInvocation(BaseInvocation): - prompt: str = Field(default="") - image: Union[ImageField, None] = Field(default=None) + prompt: str = InputField(default="") + image: Union[ImageField, None] = InputField(default=None) def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -70,12 +71,12 @@ class ImageToImageTestInvocation(BaseInvocation): @invocation_output("test_prompt_collection_output") class PromptCollectionTestInvocationOutput(BaseInvocationOutput): - collection: list[str] = Field(default_factory=list) + collection: list[str] = OutputField(default_factory=list) @invocation("test_prompt_collection") class PromptCollectionTestInvocation(BaseInvocation): - collection: list[str] = Field() + collection: list[str] = InputField() def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) @@ -83,12 +84,12 @@ class PromptCollectionTestInvocation(BaseInvocation): @invocation_output("test_any_output") class AnyTypeTestInvocationOutput(BaseInvocationOutput): - value: Any = Field() + value: Any = OutputField() @invocation("test_any") class AnyTypeTestInvocation(BaseInvocation): - value: Any = Field(default=None) + value: Any = InputField(default=None) def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: return AnyTypeTestInvocationOutput(value=self.value) @@ -96,7 +97,7 @@ class AnyTypeTestInvocation(BaseInvocation): @invocation("test_polymorphic") class PolymorphicStringTestInvocation(BaseInvocation): - value: Union[str, list[str]] = Field(default="") + value: Union[str, list[str]] = InputField(default="") def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: if isinstance(self.value, str):