mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests: fix tests for new invocation context
This commit is contained in:
parent
248176604f
commit
8dc1207790
@ -21,7 +21,6 @@ from invokeai.app.services.invocation_processor.invocation_processor_default imp
|
|||||||
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
|
||||||
from invokeai.app.services.shared.graph import (
|
from invokeai.app.services.shared.graph import (
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
Graph,
|
Graph,
|
||||||
@ -86,12 +85,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
|
|||||||
print(f"invoking {n.id}: {type(n)}")
|
print(f"invoking {n.id}: {type(n)}")
|
||||||
o = n.invoke(
|
o = n.invoke(
|
||||||
InvocationContext(
|
InvocationContext(
|
||||||
queue_batch_id="1",
|
conditioning=None, config=None, data=None, images=None, latents=None, logger=None, models=None, util=None
|
||||||
queue_item_id=1,
|
|
||||||
queue_id=DEFAULT_QUEUE_ID,
|
|
||||||
services=services,
|
|
||||||
graph_execution_state_id="1",
|
|
||||||
workflow=None,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
g.complete(n.id, o)
|
g.complete(n.id, o)
|
||||||
|
@ -3,7 +3,6 @@ from typing import Any, Callable, Union
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -21,7 +20,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput):
|
|||||||
class ListPassThroughInvocation(BaseInvocation):
|
class ListPassThroughInvocation(BaseInvocation):
|
||||||
collection: list[ImageField] = InputField(default=[])
|
collection: list[ImageField] = InputField(default=[])
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
|
def invoke(self, context) -> ListPassThroughInvocationOutput:
|
||||||
return ListPassThroughInvocationOutput(collection=self.collection)
|
return ListPassThroughInvocationOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
@ -34,13 +33,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput):
|
|||||||
class PromptTestInvocation(BaseInvocation):
|
class PromptTestInvocation(BaseInvocation):
|
||||||
prompt: str = InputField(default="")
|
prompt: str = InputField(default="")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
def invoke(self, context) -> PromptTestInvocationOutput:
|
||||||
return PromptTestInvocationOutput(prompt=self.prompt)
|
return PromptTestInvocationOutput(prompt=self.prompt)
|
||||||
|
|
||||||
|
|
||||||
@invocation("test_error", version="1.0.0")
|
@invocation("test_error", version="1.0.0")
|
||||||
class ErrorInvocation(BaseInvocation):
|
class ErrorInvocation(BaseInvocation):
|
||||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
def invoke(self, context) -> PromptTestInvocationOutput:
|
||||||
raise Exception("This invocation is supposed to fail")
|
raise Exception("This invocation is supposed to fail")
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +53,7 @@ class TextToImageTestInvocation(BaseInvocation):
|
|||||||
prompt: str = InputField(default="")
|
prompt: str = InputField(default="")
|
||||||
prompt2: str = InputField(default="")
|
prompt2: str = InputField(default="")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
def invoke(self, context) -> ImageTestInvocationOutput:
|
||||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
|
|
||||||
|
|
||||||
@ -63,7 +62,7 @@ class ImageToImageTestInvocation(BaseInvocation):
|
|||||||
prompt: str = InputField(default="")
|
prompt: str = InputField(default="")
|
||||||
image: Union[ImageField, None] = InputField(default=None)
|
image: Union[ImageField, None] = InputField(default=None)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
def invoke(self, context) -> ImageTestInvocationOutput:
|
||||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
|
|
||||||
|
|
||||||
@ -76,7 +75,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
|||||||
class PromptCollectionTestInvocation(BaseInvocation):
|
class PromptCollectionTestInvocation(BaseInvocation):
|
||||||
collection: list[str] = InputField()
|
collection: list[str] = InputField()
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
def invoke(self, context) -> PromptCollectionTestInvocationOutput:
|
||||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||||
|
|
||||||
|
|
||||||
@ -89,7 +88,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput):
|
|||||||
class AnyTypeTestInvocation(BaseInvocation):
|
class AnyTypeTestInvocation(BaseInvocation):
|
||||||
value: Any = InputField(default=None)
|
value: Any = InputField(default=None)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
|
def invoke(self, context) -> AnyTypeTestInvocationOutput:
|
||||||
return AnyTypeTestInvocationOutput(value=self.value)
|
return AnyTypeTestInvocationOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@ -97,7 +96,7 @@ class AnyTypeTestInvocation(BaseInvocation):
|
|||||||
class PolymorphicStringTestInvocation(BaseInvocation):
|
class PolymorphicStringTestInvocation(BaseInvocation):
|
||||||
value: Union[str, list[str]] = InputField(default="")
|
value: Union[str, list[str]] = InputField(default="")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
def invoke(self, context) -> PromptCollectionTestInvocationOutput:
|
||||||
if isinstance(self.value, str):
|
if isinstance(self.value, str):
|
||||||
return PromptCollectionTestInvocationOutput(collection=[self.value])
|
return PromptCollectionTestInvocationOutput(collection=[self.value])
|
||||||
return PromptCollectionTestInvocationOutput(collection=self.value)
|
return PromptCollectionTestInvocationOutput(collection=self.value)
|
||||||
|
Loading…
Reference in New Issue
Block a user