mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests: fix tests for new invocation context
This commit is contained in:
parent
248176604f
commit
8dc1207790
@ -21,7 +21,6 @@ from invokeai.app.services.invocation_processor.invocation_processor_default imp
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import (
|
||||
CollectInvocation,
|
||||
Graph,
|
||||
@ -86,12 +85,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
|
||||
print(f"invoking {n.id}: {type(n)}")
|
||||
o = n.invoke(
|
||||
InvocationContext(
|
||||
queue_batch_id="1",
|
||||
queue_item_id=1,
|
||||
queue_id=DEFAULT_QUEUE_ID,
|
||||
services=services,
|
||||
graph_execution_state_id="1",
|
||||
workflow=None,
|
||||
conditioning=None, config=None, data=None, images=None, latents=None, logger=None, models=None, util=None
|
||||
)
|
||||
)
|
||||
g.complete(n.id, o)
|
||||
|
@ -3,7 +3,6 @@ from typing import Any, Callable, Union
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -21,7 +20,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput):
|
||||
class ListPassThroughInvocation(BaseInvocation):
|
||||
collection: list[ImageField] = InputField(default=[])
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
|
||||
def invoke(self, context) -> ListPassThroughInvocationOutput:
|
||||
return ListPassThroughInvocationOutput(collection=self.collection)
|
||||
|
||||
|
||||
@ -34,13 +33,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput):
|
||||
class PromptTestInvocation(BaseInvocation):
|
||||
prompt: str = InputField(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||
def invoke(self, context) -> PromptTestInvocationOutput:
|
||||
return PromptTestInvocationOutput(prompt=self.prompt)
|
||||
|
||||
|
||||
@invocation("test_error", version="1.0.0")
|
||||
class ErrorInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||
def invoke(self, context) -> PromptTestInvocationOutput:
|
||||
raise Exception("This invocation is supposed to fail")
|
||||
|
||||
|
||||
@ -54,7 +53,7 @@ class TextToImageTestInvocation(BaseInvocation):
|
||||
prompt: str = InputField(default="")
|
||||
prompt2: str = InputField(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
def invoke(self, context) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
|
||||
|
||||
@ -63,7 +62,7 @@ class ImageToImageTestInvocation(BaseInvocation):
|
||||
prompt: str = InputField(default="")
|
||||
image: Union[ImageField, None] = InputField(default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
def invoke(self, context) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
|
||||
|
||||
@ -76,7 +75,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
||||
class PromptCollectionTestInvocation(BaseInvocation):
|
||||
collection: list[str] = InputField()
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||
def invoke(self, context) -> PromptCollectionTestInvocationOutput:
|
||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||
|
||||
|
||||
@ -89,7 +88,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput):
|
||||
class AnyTypeTestInvocation(BaseInvocation):
|
||||
value: Any = InputField(default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
|
||||
def invoke(self, context) -> AnyTypeTestInvocationOutput:
|
||||
return AnyTypeTestInvocationOutput(value=self.value)
|
||||
|
||||
|
||||
@ -97,7 +96,7 @@ class AnyTypeTestInvocation(BaseInvocation):
|
||||
class PolymorphicStringTestInvocation(BaseInvocation):
|
||||
value: Union[str, list[str]] = InputField(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||
def invoke(self, context) -> PromptCollectionTestInvocationOutput:
|
||||
if isinstance(self.value, str):
|
||||
return PromptCollectionTestInvocationOutput(collection=[self.value])
|
||||
return PromptCollectionTestInvocationOutput(collection=self.value)
|
||||
|
Loading…
Reference in New Issue
Block a user