tests: fix tests for new invocation context

This commit is contained in:
psychedelicious 2024-01-13 23:23:38 +11:00 committed by Brandon Rising
parent 248176604f
commit 8dc1207790
2 changed files with 9 additions and 16 deletions

View File

@ -21,7 +21,6 @@ from invokeai.app.services.invocation_processor.invocation_processor_default imp
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import (
CollectInvocation,
Graph,
@ -86,12 +85,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
print(f"invoking {n.id}: {type(n)}")
o = n.invoke(
InvocationContext(
queue_batch_id="1",
queue_item_id=1,
queue_id=DEFAULT_QUEUE_ID,
services=services,
graph_execution_state_id="1",
workflow=None,
conditioning=None, config=None, data=None, images=None, latents=None, logger=None, models=None, util=None
)
)
g.complete(n.id, o)

View File

@ -3,7 +3,6 @@ from typing import Any, Callable, Union
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
invocation,
invocation_output,
)
@ -21,7 +20,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput):
class ListPassThroughInvocation(BaseInvocation):
collection: list[ImageField] = InputField(default=[])
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
def invoke(self, context) -> ListPassThroughInvocationOutput:
return ListPassThroughInvocationOutput(collection=self.collection)
@ -34,13 +33,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput):
class PromptTestInvocation(BaseInvocation):
prompt: str = InputField(default="")
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
def invoke(self, context) -> PromptTestInvocationOutput:
return PromptTestInvocationOutput(prompt=self.prompt)
@invocation("test_error", version="1.0.0")
class ErrorInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
def invoke(self, context) -> PromptTestInvocationOutput:
raise Exception("This invocation is supposed to fail")
@ -54,7 +53,7 @@ class TextToImageTestInvocation(BaseInvocation):
prompt: str = InputField(default="")
prompt2: str = InputField(default="")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
def invoke(self, context) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@ -63,7 +62,7 @@ class ImageToImageTestInvocation(BaseInvocation):
prompt: str = InputField(default="")
image: Union[ImageField, None] = InputField(default=None)
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
def invoke(self, context) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@ -76,7 +75,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
class PromptCollectionTestInvocation(BaseInvocation):
collection: list[str] = InputField()
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
def invoke(self, context) -> PromptCollectionTestInvocationOutput:
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
@ -89,7 +88,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput):
class AnyTypeTestInvocation(BaseInvocation):
value: Any = InputField(default=None)
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
def invoke(self, context) -> AnyTypeTestInvocationOutput:
return AnyTypeTestInvocationOutput(value=self.value)
@ -97,7 +96,7 @@ class AnyTypeTestInvocation(BaseInvocation):
class PolymorphicStringTestInvocation(BaseInvocation):
value: Union[str, list[str]] = InputField(default="")
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
def invoke(self, context) -> PromptCollectionTestInvocationOutput:
if isinstance(self.value, str):
return PromptCollectionTestInvocationOutput(collection=[self.value])
return PromptCollectionTestInvocationOutput(collection=self.value)