diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index fab1fa4598..9cc30e43e1 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -21,7 +21,6 @@ from invokeai.app.services.invocation_processor.invocation_processor_default imp from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService -from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID from invokeai.app.services.shared.graph import ( CollectInvocation, Graph, @@ -86,12 +85,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B print(f"invoking {n.id}: {type(n)}") o = n.invoke( InvocationContext( - queue_batch_id="1", - queue_item_id=1, - queue_id=DEFAULT_QUEUE_ID, - services=services, - graph_execution_state_id="1", - workflow=None, + conditioning=None, config=None, data=None, images=None, latents=None, logger=None, models=None, util=None ) ) g.complete(n.id, o) diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index e71daad3f3..559457c0e1 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Union from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -21,7 +20,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput): class ListPassThroughInvocation(BaseInvocation): collection: list[ImageField] = InputField(default=[]) - def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: + def invoke(self, context) -> ListPassThroughInvocationOutput: return ListPassThroughInvocationOutput(collection=self.collection) @@ -34,13 +33,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput): class PromptTestInvocation(BaseInvocation): prompt: str = InputField(default="") - def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: + def invoke(self, context) -> PromptTestInvocationOutput: return PromptTestInvocationOutput(prompt=self.prompt) @invocation("test_error", version="1.0.0") class ErrorInvocation(BaseInvocation): - def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: + def invoke(self, context) -> PromptTestInvocationOutput: raise Exception("This invocation is supposed to fail") @@ -54,7 +53,7 @@ class TextToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") prompt2: str = InputField(default="") - def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + def invoke(self, context) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -63,7 +62,7 @@ class ImageToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") image: Union[ImageField, None] = InputField(default=None) - def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + def invoke(self, context) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -76,7 +75,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput): class PromptCollectionTestInvocation(BaseInvocation): collection: list[str] = InputField() - def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: + def invoke(self, context) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) @@ -89,7 +88,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput): class AnyTypeTestInvocation(BaseInvocation): value: Any = InputField(default=None) - def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: + def invoke(self, context) -> AnyTypeTestInvocationOutput: return AnyTypeTestInvocationOutput(value=self.value) @@ -97,7 +96,7 @@ class AnyTypeTestInvocation(BaseInvocation): class PolymorphicStringTestInvocation(BaseInvocation): value: Union[str, list[str]] = InputField(default="") - def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: + def invoke(self, context) -> PromptCollectionTestInvocationOutput: if isinstance(self.value, str): return PromptCollectionTestInvocationOutput(collection=[self.value]) return PromptCollectionTestInvocationOutput(collection=self.value)