fix: fix tests

This commit is contained in:
psychedelicious 2023-10-18 18:27:29 +11:00
parent 0cda7943fa
commit 23fa2e560a
3 changed files with 18 additions and 15 deletions

View File

@ -76,6 +76,7 @@ def mock_services() -> InvocationServices:
session_queue=None, # type: ignore session_queue=None, # type: ignore
urls=None, # type: ignore urls=None, # type: ignore
workflow_records=None, # type: ignore workflow_records=None, # type: ignore
workflow_image_records=None, # type: ignore
) )

View File

@ -81,6 +81,7 @@ def mock_services() -> InvocationServices:
session_queue=None, # type: ignore session_queue=None, # type: ignore
urls=None, # type: ignore urls=None, # type: ignore
workflow_records=None, # type: ignore workflow_records=None, # type: ignore
workflow_image_records=None, # type: ignore
) )

View File

@ -1,11 +1,12 @@
from typing import Any, Callable, Union from typing import Any, Callable, Union
from pydantic import Field
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
InputField,
InvocationContext, InvocationContext,
OutputField,
invocation, invocation,
invocation_output, invocation_output,
) )
@ -15,12 +16,12 @@ from invokeai.app.invocations.image import ImageField
# Define test invocations before importing anything that uses invocations # Define test invocations before importing anything that uses invocations
@invocation_output("test_list_output") @invocation_output("test_list_output")
class ListPassThroughInvocationOutput(BaseInvocationOutput): class ListPassThroughInvocationOutput(BaseInvocationOutput):
collection: list[ImageField] = Field(default_factory=list) collection: list[ImageField] = OutputField(default_factory=list)
@invocation("test_list") @invocation("test_list")
class ListPassThroughInvocation(BaseInvocation): class ListPassThroughInvocation(BaseInvocation):
collection: list[ImageField] = Field(default_factory=list) collection: list[ImageField] = InputField(default_factory=list)
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
return ListPassThroughInvocationOutput(collection=self.collection) return ListPassThroughInvocationOutput(collection=self.collection)
@ -28,12 +29,12 @@ class ListPassThroughInvocation(BaseInvocation):
@invocation_output("test_prompt_output") @invocation_output("test_prompt_output")
class PromptTestInvocationOutput(BaseInvocationOutput): class PromptTestInvocationOutput(BaseInvocationOutput):
prompt: str = Field(default="") prompt: str = OutputField(default="")
@invocation("test_prompt") @invocation("test_prompt")
class PromptTestInvocation(BaseInvocation): class PromptTestInvocation(BaseInvocation):
prompt: str = Field(default="") prompt: str = InputField(default="")
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
return PromptTestInvocationOutput(prompt=self.prompt) return PromptTestInvocationOutput(prompt=self.prompt)
@ -47,13 +48,13 @@ class ErrorInvocation(BaseInvocation):
@invocation_output("test_image_output") @invocation_output("test_image_output")
class ImageTestInvocationOutput(BaseInvocationOutput): class ImageTestInvocationOutput(BaseInvocationOutput):
image: ImageField = Field() image: ImageField = OutputField()
@invocation("test_text_to_image") @invocation("test_text_to_image")
class TextToImageTestInvocation(BaseInvocation): class TextToImageTestInvocation(BaseInvocation):
prompt: str = Field(default="") prompt: str = InputField(default="")
prompt2: str = Field(default="") prompt2: str = InputField(default="")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@ -61,8 +62,8 @@ class TextToImageTestInvocation(BaseInvocation):
@invocation("test_image_to_image") @invocation("test_image_to_image")
class ImageToImageTestInvocation(BaseInvocation): class ImageToImageTestInvocation(BaseInvocation):
prompt: str = Field(default="") prompt: str = InputField(default="")
image: Union[ImageField, None] = Field(default=None) image: Union[ImageField, None] = InputField(default=None)
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@ -70,12 +71,12 @@ class ImageToImageTestInvocation(BaseInvocation):
@invocation_output("test_prompt_collection_output") @invocation_output("test_prompt_collection_output")
class PromptCollectionTestInvocationOutput(BaseInvocationOutput): class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
collection: list[str] = Field(default_factory=list) collection: list[str] = OutputField(default_factory=list)
@invocation("test_prompt_collection") @invocation("test_prompt_collection")
class PromptCollectionTestInvocation(BaseInvocation): class PromptCollectionTestInvocation(BaseInvocation):
collection: list[str] = Field() collection: list[str] = InputField()
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
@ -83,12 +84,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
@invocation_output("test_any_output") @invocation_output("test_any_output")
class AnyTypeTestInvocationOutput(BaseInvocationOutput): class AnyTypeTestInvocationOutput(BaseInvocationOutput):
value: Any = Field() value: Any = OutputField()
@invocation("test_any") @invocation("test_any")
class AnyTypeTestInvocation(BaseInvocation): class AnyTypeTestInvocation(BaseInvocation):
value: Any = Field(default=None) value: Any = InputField(default=None)
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
return AnyTypeTestInvocationOutput(value=self.value) return AnyTypeTestInvocationOutput(value=self.value)
@ -96,7 +97,7 @@ class AnyTypeTestInvocation(BaseInvocation):
@invocation("test_polymorphic") @invocation("test_polymorphic")
class PolymorphicStringTestInvocation(BaseInvocation): class PolymorphicStringTestInvocation(BaseInvocation):
value: Union[str, list[str]] = Field(default="") value: Union[str, list[str]] = InputField(default="")
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
if isinstance(self.value, str): if isinstance(self.value, str):