Compare commits

...

4 Commits

5 changed files with 89 additions and 79 deletions

View File

@ -7,7 +7,7 @@ if TYPE_CHECKING:
from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC from invokeai.app.services.images import ImageServiceABC
from invokeai.backend import ModelManager from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices from invokeai.app.services.restoration_services import RestorationServices
@ -22,46 +22,47 @@ class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""
# TODO: Just forward-declared everything due to circular dependencies. Fix structure. # TODO: Just forward-declared everything due to circular dependencies. Fix structure.
events: "EventServiceBase"
latents: "LatentsStorageBase"
queue: "InvocationQueueABC"
model_manager: "ModelManager"
restoration: "RestorationServices"
configuration: "InvokeAISettings"
images: "ImageServiceABC"
boards: "BoardServiceABC"
board_images: "BoardImagesServiceABC" board_images: "BoardImagesServiceABC"
graph_library: "ItemStorageABC"["LibraryGraph"] boards: "BoardServiceABC"
configuration: "InvokeAISettings"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"]
images: "ImageServiceABC"
latents: "LatentsStorageBase"
logger: "Logger"
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
queue: "InvocationQueueABC"
restoration: "RestorationServices"
def __init__( def __init__(
self, self,
model_manager: "ModelManager",
events: "EventServiceBase",
logger: "Logger",
latents: "LatentsStorageBase",
images: "ImageServiceABC",
boards: "BoardServiceABC",
board_images: "BoardImagesServiceABC", board_images: "BoardImagesServiceABC",
queue: "InvocationQueueABC", boards: "BoardServiceABC",
graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: "RestorationServices",
configuration: "InvokeAISettings", configuration: "InvokeAISettings",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
graph_library: "ItemStorageABC"["LibraryGraph"],
images: "ImageServiceABC",
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
queue: "InvocationQueueABC",
restoration: "RestorationServices",
): ):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.boards = boards
self.board_images = board_images self.board_images = board_images
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration
self.boards = boards self.boards = boards
self.boards = boards
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
self.graph_library = graph_library
self.images = images
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.processor = processor
self.queue = queue
self.restoration = restoration

View File

@ -1,5 +1,5 @@
from .test_invoker import create_edge from .test_invoker import create_edge
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation from .test_nodes import TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
@ -15,7 +15,7 @@ import pytest
def simple_graph(): def simple_graph():
g = Graph() g = Graph()
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
g.add_node(ImageTestInvocation(id = "2")) g.add_node(TextToImageTestInvocation(id = "2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt")) g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g return g

View File

@ -1,4 +1,4 @@
from .test_nodes import ErrorInvocation, ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until from .test_nodes import ErrorInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_queue import MemoryInvocationQueue
@ -13,7 +13,7 @@ import pytest
def simple_graph(): def simple_graph():
g = Graph() g = Graph()
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
g.add_node(ImageTestInvocation(id = "2")) g.add_node(TextToImageTestInvocation(id = "2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt")) g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g return g
@ -25,6 +25,8 @@ def mock_services() -> InvocationServices:
events = TestEventService(), events = TestEventService(),
logger = None, # type: ignore logger = None, # type: ignore
images = None, # type: ignore images = None, # type: ignore
boards = None, # type: ignore
board_images= None, # type: ignore
latents = None, # type: ignore latents = None, # type: ignore
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph]( graph_library=SqliteItemStorage[LibraryGraph](

View File

@ -1,6 +1,5 @@
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from invokeai.app.invocations.upscale import UpscaleInvocation from invokeai.app.invocations.upscale import UpscaleInvocation
from invokeai.app.invocations.image import * from invokeai.app.invocations.image import *
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
@ -18,7 +17,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
# Tests # Tests
def test_connections_are_compatible(): def test_connections_are_compatible():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "image" from_field = "image"
to_node = UpscaleInvocation(id = "2") to_node = UpscaleInvocation(id = "2")
to_field = "image" to_field = "image"
@ -28,7 +27,7 @@ def test_connections_are_compatible():
assert result == True assert result == True
def test_connections_are_incompatible(): def test_connections_are_incompatible():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "image" from_field = "image"
to_node = UpscaleInvocation(id = "2") to_node = UpscaleInvocation(id = "2")
to_field = "strength" to_field = "strength"
@ -38,7 +37,7 @@ def test_connections_are_incompatible():
assert result == False assert result == False
def test_connections_incompatible_with_invalid_fields(): def test_connections_incompatible_with_invalid_fields():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "invalid_field" from_field = "invalid_field"
to_node = UpscaleInvocation(id = "2") to_node = UpscaleInvocation(id = "2")
to_field = "image" to_field = "image"
@ -56,28 +55,28 @@ def test_connections_incompatible_with_invalid_fields():
def test_graph_can_add_node(): def test_graph_can_add_node():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
assert n.id in g.nodes assert n.id in g.nodes
def test_graph_fails_to_add_node_with_duplicate_id(): def test_graph_fails_to_add_node_with_duplicate_id():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second") n2 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi the second")
with pytest.raises(NodeAlreadyInGraphError): with pytest.raises(NodeAlreadyInGraphError):
g.add_node(n2) g.add_node(n2)
def test_graph_updates_node(): def test_graph_updates_node():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second")
g.add_node(n2) g.add_node(n2)
nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated") nu = TextToImageTestInvocation(id = "1", prompt = "Banana sushi updated")
g.update_node("1", nu) g.update_node("1", nu)
@ -85,7 +84,7 @@ def test_graph_updates_node():
def test_graph_fails_to_update_node_if_type_changes(): def test_graph_fails_to_update_node_if_type_changes():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n2) g.add_node(n2)
@ -97,14 +96,14 @@ def test_graph_fails_to_update_node_if_type_changes():
def test_graph_allows_non_conflicting_id_change(): def test_graph_allows_non_conflicting_id_change():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n2) g.add_node(n2)
e1 = create_edge(n.id,"image",n2.id,"image") e1 = create_edge(n.id,"image",n2.id,"image")
g.add_edge(e1) g.add_edge(e1)
nu = TextToImageInvocation(id = "3", prompt = "Banana sushi") nu = TextToImageTestInvocation(id = "3", prompt = "Banana sushi")
g.update_node("1", nu) g.update_node("1", nu)
with pytest.raises(NodeNotFoundError): with pytest.raises(NodeNotFoundError):
@ -117,18 +116,18 @@ def test_graph_allows_non_conflicting_id_change():
def test_graph_fails_to_update_node_id_if_conflict(): def test_graph_fails_to_update_node_id_if_conflict():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second")
g.add_node(n2) g.add_node(n2)
nu = TextToImageInvocation(id = "2", prompt = "Banana sushi") nu = TextToImageTestInvocation(id = "2", prompt = "Banana sushi")
with pytest.raises(NodeAlreadyInGraphError): with pytest.raises(NodeAlreadyInGraphError):
g.update_node("1", nu) g.update_node("1", nu)
def test_graph_adds_edge(): def test_graph_adds_edge():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -148,7 +147,7 @@ def test_graph_fails_to_add_edge_with_cycle():
def test_graph_fails_to_add_edge_with_long_cycle(): def test_graph_fails_to_add_edge_with_long_cycle():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3") n3 = UpscaleInvocation(id = "3")
g.add_node(n1) g.add_node(n1)
@ -164,7 +163,7 @@ def test_graph_fails_to_add_edge_with_long_cycle():
def test_graph_fails_to_add_edge_with_missing_node_id(): def test_graph_fails_to_add_edge_with_missing_node_id():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -177,7 +176,7 @@ def test_graph_fails_to_add_edge_with_missing_node_id():
def test_graph_fails_to_add_edge_when_destination_exists(): def test_graph_fails_to_add_edge_when_destination_exists():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3") n3 = UpscaleInvocation(id = "3")
g.add_node(n1) g.add_node(n1)
@ -194,7 +193,7 @@ def test_graph_fails_to_add_edge_when_destination_exists():
def test_graph_fails_to_add_edge_with_mismatched_types(): def test_graph_fails_to_add_edge_with_mismatched_types():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -204,8 +203,8 @@ def test_graph_fails_to_add_edge_with_mismatched_types():
def test_graph_connects_collector(): def test_graph_connects_collector():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2") n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi 2")
n3 = CollectInvocation(id = "3") n3 = CollectInvocation(id = "3")
n4 = ListPassThroughInvocation(id = "4") n4 = ListPassThroughInvocation(id = "4")
g.add_node(n1) g.add_node(n1)
@ -224,7 +223,7 @@ def test_graph_connects_collector():
def test_graph_collector_invalid_with_varying_input_types(): def test_graph_collector_invalid_with_varying_input_types():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2") n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2")
n3 = CollectInvocation(id = "3") n3 = CollectInvocation(id = "3")
g.add_node(n1) g.add_node(n1)
@ -282,7 +281,7 @@ def test_graph_connects_iterator():
g = Graph() g = Graph()
n1 = ListPassThroughInvocation(id = "1") n1 = ListPassThroughInvocation(id = "1")
n2 = IterateInvocation(id = "2") n2 = IterateInvocation(id = "2")
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
g.add_node(n3) g.add_node(n3)
@ -298,7 +297,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
g = Graph() g = Graph()
n1 = ListPassThroughInvocation(id = "1") n1 = ListPassThroughInvocation(id = "1")
n2 = IterateInvocation(id = "2") n2 = IterateInvocation(id = "2")
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi")
n4 = ListPassThroughInvocation(id = "4") n4 = ListPassThroughInvocation(id = "4")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -316,7 +315,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
def test_graph_iterator_invalid_if_input_not_list(): def test_graph_iterator_invalid_if_input_not_list():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = IterateInvocation(id = "2") n2 = IterateInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -344,7 +343,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different():
def test_graph_validates(): def test_graph_validates():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -355,7 +354,7 @@ def test_graph_validates():
def test_graph_invalid_if_edges_reference_missing_nodes(): def test_graph_invalid_if_edges_reference_missing_nodes():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.nodes[n1.id] = n1 g.nodes[n1.id] = n1
e1 = create_edge("1","image","2","image") e1 = create_edge("1","image","2","image")
g.edges.append(e1) g.edges.append(e1)
@ -367,7 +366,7 @@ def test_graph_invalid_if_subgraph_invalid():
n1 = GraphInvocation(id = "1") n1 = GraphInvocation(id = "1")
n1.graph = Graph() n1.graph = Graph()
n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi") n1_1 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi")
n1.graph.nodes[n1_1.id] = n1_1 n1.graph.nodes[n1_1.id] = n1_1
e1 = create_edge("1","image","2","image") e1 = create_edge("1","image","2","image")
n1.graph.edges.append(e1) n1.graph.edges.append(e1)
@ -391,7 +390,7 @@ def test_graph_invalid_if_has_cycle():
def test_graph_invalid_with_invalid_connection(): def test_graph_invalid_with_invalid_connection():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.nodes[n1.id] = n1 g.nodes[n1.id] = n1
g.nodes[n2.id] = n2 g.nodes[n2.id] = n2
@ -408,7 +407,7 @@ def test_graph_gets_subgraph_node():
n1.graph = Graph() n1.graph = Graph()
n1.graph.add_node n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1) n1.graph.add_node(n1_1)
g.add_node(n1) g.add_node(n1)
@ -485,7 +484,7 @@ def test_graph_fails_to_get_missing_subgraph_node():
n1.graph = Graph() n1.graph = Graph()
n1.graph.add_node n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1) n1.graph.add_node(n1_1)
g.add_node(n1) g.add_node(n1)
@ -499,7 +498,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
n1.graph = Graph() n1.graph = Graph()
n1.graph.add_node n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1) n1.graph.add_node(n1_1)
g.add_node(n1) g.add_node(n1)
@ -512,7 +511,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
def test_graph_gets_networkx_graph(): def test_graph_gets_networkx_graph():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -529,7 +528,7 @@ def test_graph_gets_networkx_graph():
# TODO: Graph serializes and deserializes # TODO: Graph serializes and deserializes
def test_graph_can_serialize(): def test_graph_can_serialize():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -541,7 +540,7 @@ def test_graph_can_serialize():
def test_graph_can_deserialize(): def test_graph_can_deserialize():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Literal from typing import Any, Callable, Literal, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.image import ImageField from invokeai.app.invocations.image import ImageField
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
@ -43,14 +43,23 @@ class ImageTestInvocationOutput(BaseInvocationOutput):
image: ImageField = Field() image: ImageField = Field()
class ImageTestInvocation(BaseInvocation): class TextToImageTestInvocation(BaseInvocation):
type: Literal['test_image'] = 'test_image' type: Literal['test_text_to_image'] = 'test_text_to_image'
prompt: str = Field(default = "") prompt: str = Field(default = "")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
class ImageToImageTestInvocation(BaseInvocation):
type: Literal['test_image_to_image'] = 'test_image_to_image'
prompt: str = Field(default = "")
image: Union[ImageField, None] = Field(default=None)
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
class PromptCollectionTestInvocationOutput(BaseInvocationOutput): class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output' type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output'
collection: list[str] = Field(default_factory=list) collection: list[str] = Field(default_factory=list)
@ -62,7 +71,6 @@ class PromptCollectionTestInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Edge, EdgeConnection from invokeai.app.services.graph import Edge, EdgeConnection