mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: fix tests
This commit is contained in:
parent
0cda7943fa
commit
23fa2e560a
@ -76,6 +76,7 @@ def mock_services() -> InvocationServices:
|
|||||||
session_queue=None, # type: ignore
|
session_queue=None, # type: ignore
|
||||||
urls=None, # type: ignore
|
urls=None, # type: ignore
|
||||||
workflow_records=None, # type: ignore
|
workflow_records=None, # type: ignore
|
||||||
|
workflow_image_records=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,6 +81,7 @@ def mock_services() -> InvocationServices:
|
|||||||
session_queue=None, # type: ignore
|
session_queue=None, # type: ignore
|
||||||
urls=None, # type: ignore
|
urls=None, # type: ignore
|
||||||
workflow_records=None, # type: ignore
|
workflow_records=None, # type: ignore
|
||||||
|
workflow_image_records=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from typing import Any, Callable, Union
|
from typing import Any, Callable, Union
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -15,12 +16,12 @@ from invokeai.app.invocations.image import ImageField
|
|||||||
# Define test invocations before importing anything that uses invocations
|
# Define test invocations before importing anything that uses invocations
|
||||||
@invocation_output("test_list_output")
|
@invocation_output("test_list_output")
|
||||||
class ListPassThroughInvocationOutput(BaseInvocationOutput):
|
class ListPassThroughInvocationOutput(BaseInvocationOutput):
|
||||||
collection: list[ImageField] = Field(default_factory=list)
|
collection: list[ImageField] = OutputField(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@invocation("test_list")
|
@invocation("test_list")
|
||||||
class ListPassThroughInvocation(BaseInvocation):
|
class ListPassThroughInvocation(BaseInvocation):
|
||||||
collection: list[ImageField] = Field(default_factory=list)
|
collection: list[ImageField] = InputField(default_factory=list)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
|
||||||
return ListPassThroughInvocationOutput(collection=self.collection)
|
return ListPassThroughInvocationOutput(collection=self.collection)
|
||||||
@ -28,12 +29,12 @@ class ListPassThroughInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@invocation_output("test_prompt_output")
|
@invocation_output("test_prompt_output")
|
||||||
class PromptTestInvocationOutput(BaseInvocationOutput):
|
class PromptTestInvocationOutput(BaseInvocationOutput):
|
||||||
prompt: str = Field(default="")
|
prompt: str = OutputField(default="")
|
||||||
|
|
||||||
|
|
||||||
@invocation("test_prompt")
|
@invocation("test_prompt")
|
||||||
class PromptTestInvocation(BaseInvocation):
|
class PromptTestInvocation(BaseInvocation):
|
||||||
prompt: str = Field(default="")
|
prompt: str = InputField(default="")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||||
return PromptTestInvocationOutput(prompt=self.prompt)
|
return PromptTestInvocationOutput(prompt=self.prompt)
|
||||||
@ -47,13 +48,13 @@ class ErrorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@invocation_output("test_image_output")
|
@invocation_output("test_image_output")
|
||||||
class ImageTestInvocationOutput(BaseInvocationOutput):
|
class ImageTestInvocationOutput(BaseInvocationOutput):
|
||||||
image: ImageField = Field()
|
image: ImageField = OutputField()
|
||||||
|
|
||||||
|
|
||||||
@invocation("test_text_to_image")
|
@invocation("test_text_to_image")
|
||||||
class TextToImageTestInvocation(BaseInvocation):
|
class TextToImageTestInvocation(BaseInvocation):
|
||||||
prompt: str = Field(default="")
|
prompt: str = InputField(default="")
|
||||||
prompt2: str = Field(default="")
|
prompt2: str = InputField(default="")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
@ -61,8 +62,8 @@ class TextToImageTestInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@invocation("test_image_to_image")
|
@invocation("test_image_to_image")
|
||||||
class ImageToImageTestInvocation(BaseInvocation):
|
class ImageToImageTestInvocation(BaseInvocation):
|
||||||
prompt: str = Field(default="")
|
prompt: str = InputField(default="")
|
||||||
image: Union[ImageField, None] = Field(default=None)
|
image: Union[ImageField, None] = InputField(default=None)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
@ -70,12 +71,12 @@ class ImageToImageTestInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@invocation_output("test_prompt_collection_output")
|
@invocation_output("test_prompt_collection_output")
|
||||||
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
||||||
collection: list[str] = Field(default_factory=list)
|
collection: list[str] = OutputField(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@invocation("test_prompt_collection")
|
@invocation("test_prompt_collection")
|
||||||
class PromptCollectionTestInvocation(BaseInvocation):
|
class PromptCollectionTestInvocation(BaseInvocation):
|
||||||
collection: list[str] = Field()
|
collection: list[str] = InputField()
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||||
@ -83,12 +84,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@invocation_output("test_any_output")
|
@invocation_output("test_any_output")
|
||||||
class AnyTypeTestInvocationOutput(BaseInvocationOutput):
|
class AnyTypeTestInvocationOutput(BaseInvocationOutput):
|
||||||
value: Any = Field()
|
value: Any = OutputField()
|
||||||
|
|
||||||
|
|
||||||
@invocation("test_any")
|
@invocation("test_any")
|
||||||
class AnyTypeTestInvocation(BaseInvocation):
|
class AnyTypeTestInvocation(BaseInvocation):
|
||||||
value: Any = Field(default=None)
|
value: Any = InputField(default=None)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
|
||||||
return AnyTypeTestInvocationOutput(value=self.value)
|
return AnyTypeTestInvocationOutput(value=self.value)
|
||||||
@ -96,7 +97,7 @@ class AnyTypeTestInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@invocation("test_polymorphic")
|
@invocation("test_polymorphic")
|
||||||
class PolymorphicStringTestInvocation(BaseInvocation):
|
class PolymorphicStringTestInvocation(BaseInvocation):
|
||||||
value: Union[str, list[str]] = Field(default="")
|
value: Union[str, list[str]] = InputField(default="")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||||
if isinstance(self.value, str):
|
if isinstance(self.value, str):
|
||||||
|
Loading…
Reference in New Issue
Block a user