mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: fix tests
This commit is contained in:
parent
0cda7943fa
commit
23fa2e560a
@ -76,6 +76,7 @@ def mock_services() -> InvocationServices:
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
workflow_image_records=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
@ -81,6 +81,7 @@ def mock_services() -> InvocationServices:
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
workflow_image_records=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -15,12 +16,12 @@ from invokeai.app.invocations.image import ImageField
|
||||
# Define test invocations before importing anything that uses invocations
|
||||
@invocation_output("test_list_output")
|
||||
class ListPassThroughInvocationOutput(BaseInvocationOutput):
|
||||
collection: list[ImageField] = Field(default_factory=list)
|
||||
collection: list[ImageField] = OutputField(default_factory=list)
|
||||
|
||||
|
||||
@invocation("test_list")
|
||||
class ListPassThroughInvocation(BaseInvocation):
|
||||
collection: list[ImageField] = Field(default_factory=list)
|
||||
collection: list[ImageField] = InputField(default_factory=list)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
|
||||
return ListPassThroughInvocationOutput(collection=self.collection)
|
||||
@ -28,12 +29,12 @@ class ListPassThroughInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("test_prompt_output")
|
||||
class PromptTestInvocationOutput(BaseInvocationOutput):
|
||||
prompt: str = Field(default="")
|
||||
prompt: str = OutputField(default="")
|
||||
|
||||
|
||||
@invocation("test_prompt")
|
||||
class PromptTestInvocation(BaseInvocation):
|
||||
prompt: str = Field(default="")
|
||||
prompt: str = InputField(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||
return PromptTestInvocationOutput(prompt=self.prompt)
|
||||
@ -47,13 +48,13 @@ class ErrorInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("test_image_output")
|
||||
class ImageTestInvocationOutput(BaseInvocationOutput):
|
||||
image: ImageField = Field()
|
||||
image: ImageField = OutputField()
|
||||
|
||||
|
||||
@invocation("test_text_to_image")
|
||||
class TextToImageTestInvocation(BaseInvocation):
|
||||
prompt: str = Field(default="")
|
||||
prompt2: str = Field(default="")
|
||||
prompt: str = InputField(default="")
|
||||
prompt2: str = InputField(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
@ -61,8 +62,8 @@ class TextToImageTestInvocation(BaseInvocation):
|
||||
|
||||
@invocation("test_image_to_image")
|
||||
class ImageToImageTestInvocation(BaseInvocation):
|
||||
prompt: str = Field(default="")
|
||||
image: Union[ImageField, None] = Field(default=None)
|
||||
prompt: str = InputField(default="")
|
||||
image: Union[ImageField, None] = InputField(default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
@ -70,12 +71,12 @@ class ImageToImageTestInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("test_prompt_collection_output")
|
||||
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
||||
collection: list[str] = Field(default_factory=list)
|
||||
collection: list[str] = OutputField(default_factory=list)
|
||||
|
||||
|
||||
@invocation("test_prompt_collection")
|
||||
class PromptCollectionTestInvocation(BaseInvocation):
|
||||
collection: list[str] = Field()
|
||||
collection: list[str] = InputField()
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||
@ -83,12 +84,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("test_any_output")
|
||||
class AnyTypeTestInvocationOutput(BaseInvocationOutput):
|
||||
value: Any = Field()
|
||||
value: Any = OutputField()
|
||||
|
||||
|
||||
@invocation("test_any")
|
||||
class AnyTypeTestInvocation(BaseInvocation):
|
||||
value: Any = Field(default=None)
|
||||
value: Any = InputField(default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
|
||||
return AnyTypeTestInvocationOutput(value=self.value)
|
||||
@ -96,7 +97,7 @@ class AnyTypeTestInvocation(BaseInvocation):
|
||||
|
||||
@invocation("test_polymorphic")
|
||||
class PolymorphicStringTestInvocation(BaseInvocation):
|
||||
value: Union[str, list[str]] = Field(default="")
|
||||
value: Union[str, list[str]] = InputField(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||
if isinstance(self.value, str):
|
||||
|
Loading…
Reference in New Issue
Block a user