fix: fix tests

This commit is contained in:
psychedelicious 2023-10-18 18:27:29 +11:00
parent 0cda7943fa
commit 23fa2e560a
3 changed files with 18 additions and 15 deletions

View File

@ -76,6 +76,7 @@ def mock_services() -> InvocationServices:
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
workflow_image_records=None, # type: ignore
)

View File

@ -81,6 +81,7 @@ def mock_services() -> InvocationServices:
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
workflow_image_records=None, # type: ignore
)

View File

@ -1,11 +1,12 @@
from typing import Any, Callable, Union
from pydantic import Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
invocation,
invocation_output,
)
@ -15,12 +16,12 @@ from invokeai.app.invocations.image import ImageField
# Define test invocations before importing anything that uses invocations
@invocation_output("test_list_output")
class ListPassThroughInvocationOutput(BaseInvocationOutput):
collection: list[ImageField] = Field(default_factory=list)
collection: list[ImageField] = OutputField(default_factory=list)
@invocation("test_list")
class ListPassThroughInvocation(BaseInvocation):
collection: list[ImageField] = Field(default_factory=list)
collection: list[ImageField] = InputField(default_factory=list)
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
return ListPassThroughInvocationOutput(collection=self.collection)
@ -28,12 +29,12 @@ class ListPassThroughInvocation(BaseInvocation):
@invocation_output("test_prompt_output")
class PromptTestInvocationOutput(BaseInvocationOutput):
prompt: str = Field(default="")
prompt: str = OutputField(default="")
@invocation("test_prompt")
class PromptTestInvocation(BaseInvocation):
prompt: str = Field(default="")
prompt: str = InputField(default="")
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
return PromptTestInvocationOutput(prompt=self.prompt)
@ -47,13 +48,13 @@ class ErrorInvocation(BaseInvocation):
@invocation_output("test_image_output")
class ImageTestInvocationOutput(BaseInvocationOutput):
image: ImageField = Field()
image: ImageField = OutputField()
@invocation("test_text_to_image")
class TextToImageTestInvocation(BaseInvocation):
prompt: str = Field(default="")
prompt2: str = Field(default="")
prompt: str = InputField(default="")
prompt2: str = InputField(default="")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@ -61,8 +62,8 @@ class TextToImageTestInvocation(BaseInvocation):
@invocation("test_image_to_image")
class ImageToImageTestInvocation(BaseInvocation):
prompt: str = Field(default="")
image: Union[ImageField, None] = Field(default=None)
prompt: str = InputField(default="")
image: Union[ImageField, None] = InputField(default=None)
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@ -70,12 +71,12 @@ class ImageToImageTestInvocation(BaseInvocation):
@invocation_output("test_prompt_collection_output")
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
collection: list[str] = Field(default_factory=list)
collection: list[str] = OutputField(default_factory=list)
@invocation("test_prompt_collection")
class PromptCollectionTestInvocation(BaseInvocation):
collection: list[str] = Field()
collection: list[str] = InputField()
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
@ -83,12 +84,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
@invocation_output("test_any_output")
class AnyTypeTestInvocationOutput(BaseInvocationOutput):
value: Any = Field()
value: Any = OutputField()
@invocation("test_any")
class AnyTypeTestInvocation(BaseInvocation):
value: Any = Field(default=None)
value: Any = InputField(default=None)
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
return AnyTypeTestInvocationOutput(value=self.value)
@ -96,7 +97,7 @@ class AnyTypeTestInvocation(BaseInvocation):
@invocation("test_polymorphic")
class PolymorphicStringTestInvocation(BaseInvocation):
value: Union[str, list[str]] = Field(default="")
value: Union[str, list[str]] = InputField(default="")
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
if isinstance(self.value, str):