mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
4 Commits
feat/batch
...
fix/nodes/
Author | SHA1 | Date | |
---|---|---|---|
fba55a3acc | |||
1758ff35c3 | |||
c60a71f0e9 | |||
ad2247c314 |
@ -7,7 +7,7 @@ if TYPE_CHECKING:
|
|||||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||||
from invokeai.app.services.boards import BoardServiceABC
|
from invokeai.app.services.boards import BoardServiceABC
|
||||||
from invokeai.app.services.images import ImageServiceABC
|
from invokeai.app.services.images import ImageServiceABC
|
||||||
from invokeai.backend import ModelManager
|
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
from invokeai.app.services.restoration_services import RestorationServices
|
from invokeai.app.services.restoration_services import RestorationServices
|
||||||
@ -22,46 +22,47 @@ class InvocationServices:
|
|||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
events: "EventServiceBase"
|
|
||||||
latents: "LatentsStorageBase"
|
|
||||||
queue: "InvocationQueueABC"
|
|
||||||
model_manager: "ModelManager"
|
|
||||||
restoration: "RestorationServices"
|
|
||||||
configuration: "InvokeAISettings"
|
|
||||||
images: "ImageServiceABC"
|
|
||||||
boards: "BoardServiceABC"
|
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
boards: "BoardServiceABC"
|
||||||
|
configuration: "InvokeAISettings"
|
||||||
|
events: "EventServiceBase"
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||||
|
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||||
|
images: "ImageServiceABC"
|
||||||
|
latents: "LatentsStorageBase"
|
||||||
|
logger: "Logger"
|
||||||
|
model_manager: "ModelManagerServiceBase"
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
|
queue: "InvocationQueueABC"
|
||||||
|
restoration: "RestorationServices"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_manager: "ModelManager",
|
|
||||||
events: "EventServiceBase",
|
|
||||||
logger: "Logger",
|
|
||||||
latents: "LatentsStorageBase",
|
|
||||||
images: "ImageServiceABC",
|
|
||||||
boards: "BoardServiceABC",
|
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
queue: "InvocationQueueABC",
|
boards: "BoardServiceABC",
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
|
||||||
processor: "InvocationProcessorABC",
|
|
||||||
restoration: "RestorationServices",
|
|
||||||
configuration: "InvokeAISettings",
|
configuration: "InvokeAISettings",
|
||||||
|
events: "EventServiceBase",
|
||||||
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||||
|
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||||
|
images: "ImageServiceABC",
|
||||||
|
latents: "LatentsStorageBase",
|
||||||
|
logger: "Logger",
|
||||||
|
model_manager: "ModelManagerServiceBase",
|
||||||
|
processor: "InvocationProcessorABC",
|
||||||
|
queue: "InvocationQueueABC",
|
||||||
|
restoration: "RestorationServices",
|
||||||
):
|
):
|
||||||
self.model_manager = model_manager
|
|
||||||
self.events = events
|
|
||||||
self.logger = logger
|
|
||||||
self.latents = latents
|
|
||||||
self.images = images
|
|
||||||
self.boards = boards
|
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.queue = queue
|
|
||||||
self.graph_library = graph_library
|
|
||||||
self.graph_execution_manager = graph_execution_manager
|
|
||||||
self.processor = processor
|
|
||||||
self.restoration = restoration
|
|
||||||
self.configuration = configuration
|
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
|
self.boards = boards
|
||||||
|
self.configuration = configuration
|
||||||
|
self.events = events
|
||||||
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
self.graph_library = graph_library
|
||||||
|
self.images = images
|
||||||
|
self.latents = latents
|
||||||
|
self.logger = logger
|
||||||
|
self.model_manager = model_manager
|
||||||
|
self.processor = processor
|
||||||
|
self.queue = queue
|
||||||
|
self.restoration = restoration
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .test_invoker import create_edge
|
from .test_invoker import create_edge
|
||||||
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
from .test_nodes import TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
from invokeai.app.invocations.collections import RangeInvocation
|
from invokeai.app.invocations.collections import RangeInvocation
|
||||||
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
||||||
@ -15,7 +15,7 @@ import pytest
|
|||||||
def simple_graph():
|
def simple_graph():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
|
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
|
||||||
g.add_node(ImageTestInvocation(id = "2"))
|
g.add_node(TextToImageTestInvocation(id = "2"))
|
||||||
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .test_nodes import ErrorInvocation, ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
|
from .test_nodes import ErrorInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
|
||||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -13,7 +13,7 @@ import pytest
|
|||||||
def simple_graph():
|
def simple_graph():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
|
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
|
||||||
g.add_node(ImageTestInvocation(id = "2"))
|
g.add_node(TextToImageTestInvocation(id = "2"))
|
||||||
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
@ -25,6 +25,8 @@ def mock_services() -> InvocationServices:
|
|||||||
events = TestEventService(),
|
events = TestEventService(),
|
||||||
logger = None, # type: ignore
|
logger = None, # type: ignore
|
||||||
images = None, # type: ignore
|
images = None, # type: ignore
|
||||||
|
boards = None, # type: ignore
|
||||||
|
board_images= None, # type: ignore
|
||||||
latents = None, # type: ignore
|
latents = None, # type: ignore
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
|
||||||
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||||
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
|
||||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
from invokeai.app.invocations.upscale import UpscaleInvocation
|
||||||
from invokeai.app.invocations.image import *
|
from invokeai.app.invocations.image import *
|
||||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||||
@ -18,7 +17,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
|
|||||||
|
|
||||||
# Tests
|
# Tests
|
||||||
def test_connections_are_compatible():
|
def test_connections_are_compatible():
|
||||||
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "image"
|
from_field = "image"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = UpscaleInvocation(id = "2")
|
||||||
to_field = "image"
|
to_field = "image"
|
||||||
@ -28,7 +27,7 @@ def test_connections_are_compatible():
|
|||||||
assert result == True
|
assert result == True
|
||||||
|
|
||||||
def test_connections_are_incompatible():
|
def test_connections_are_incompatible():
|
||||||
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "image"
|
from_field = "image"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = UpscaleInvocation(id = "2")
|
||||||
to_field = "strength"
|
to_field = "strength"
|
||||||
@ -38,7 +37,7 @@ def test_connections_are_incompatible():
|
|||||||
assert result == False
|
assert result == False
|
||||||
|
|
||||||
def test_connections_incompatible_with_invalid_fields():
|
def test_connections_incompatible_with_invalid_fields():
|
||||||
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "invalid_field"
|
from_field = "invalid_field"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = UpscaleInvocation(id = "2")
|
||||||
to_field = "image"
|
to_field = "image"
|
||||||
@ -56,28 +55,28 @@ def test_connections_incompatible_with_invalid_fields():
|
|||||||
|
|
||||||
def test_graph_can_add_node():
|
def test_graph_can_add_node():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
|
|
||||||
assert n.id in g.nodes
|
assert n.id in g.nodes
|
||||||
|
|
||||||
def test_graph_fails_to_add_node_with_duplicate_id():
|
def test_graph_fails_to_add_node_with_duplicate_id():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second")
|
n2 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi the second")
|
||||||
|
|
||||||
with pytest.raises(NodeAlreadyInGraphError):
|
with pytest.raises(NodeAlreadyInGraphError):
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
|
||||||
def test_graph_updates_node():
|
def test_graph_updates_node():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
|
n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
|
||||||
nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated")
|
nu = TextToImageTestInvocation(id = "1", prompt = "Banana sushi updated")
|
||||||
|
|
||||||
g.update_node("1", nu)
|
g.update_node("1", nu)
|
||||||
|
|
||||||
@ -85,7 +84,7 @@ def test_graph_updates_node():
|
|||||||
|
|
||||||
def test_graph_fails_to_update_node_if_type_changes():
|
def test_graph_fails_to_update_node_if_type_changes():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -97,14 +96,14 @@ def test_graph_fails_to_update_node_if_type_changes():
|
|||||||
|
|
||||||
def test_graph_allows_non_conflicting_id_change():
|
def test_graph_allows_non_conflicting_id_change():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e1 = create_edge(n.id,"image",n2.id,"image")
|
e1 = create_edge(n.id,"image",n2.id,"image")
|
||||||
g.add_edge(e1)
|
g.add_edge(e1)
|
||||||
|
|
||||||
nu = TextToImageInvocation(id = "3", prompt = "Banana sushi")
|
nu = TextToImageTestInvocation(id = "3", prompt = "Banana sushi")
|
||||||
g.update_node("1", nu)
|
g.update_node("1", nu)
|
||||||
|
|
||||||
with pytest.raises(NodeNotFoundError):
|
with pytest.raises(NodeNotFoundError):
|
||||||
@ -117,18 +116,18 @@ def test_graph_allows_non_conflicting_id_change():
|
|||||||
|
|
||||||
def test_graph_fails_to_update_node_id_if_conflict():
|
def test_graph_fails_to_update_node_id_if_conflict():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
|
n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
|
||||||
nu = TextToImageInvocation(id = "2", prompt = "Banana sushi")
|
nu = TextToImageTestInvocation(id = "2", prompt = "Banana sushi")
|
||||||
with pytest.raises(NodeAlreadyInGraphError):
|
with pytest.raises(NodeAlreadyInGraphError):
|
||||||
g.update_node("1", nu)
|
g.update_node("1", nu)
|
||||||
|
|
||||||
def test_graph_adds_edge():
|
def test_graph_adds_edge():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -148,7 +147,7 @@ def test_graph_fails_to_add_edge_with_cycle():
|
|||||||
|
|
||||||
def test_graph_fails_to_add_edge_with_long_cycle():
|
def test_graph_fails_to_add_edge_with_long_cycle():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
n3 = UpscaleInvocation(id = "3")
|
n3 = UpscaleInvocation(id = "3")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -164,7 +163,7 @@ def test_graph_fails_to_add_edge_with_long_cycle():
|
|||||||
|
|
||||||
def test_graph_fails_to_add_edge_with_missing_node_id():
|
def test_graph_fails_to_add_edge_with_missing_node_id():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -177,7 +176,7 @@ def test_graph_fails_to_add_edge_with_missing_node_id():
|
|||||||
|
|
||||||
def test_graph_fails_to_add_edge_when_destination_exists():
|
def test_graph_fails_to_add_edge_when_destination_exists():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
n3 = UpscaleInvocation(id = "3")
|
n3 = UpscaleInvocation(id = "3")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -194,7 +193,7 @@ def test_graph_fails_to_add_edge_when_destination_exists():
|
|||||||
|
|
||||||
def test_graph_fails_to_add_edge_with_mismatched_types():
|
def test_graph_fails_to_add_edge_with_mismatched_types():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -204,8 +203,8 @@ def test_graph_fails_to_add_edge_with_mismatched_types():
|
|||||||
|
|
||||||
def test_graph_connects_collector():
|
def test_graph_connects_collector():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2")
|
n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi 2")
|
||||||
n3 = CollectInvocation(id = "3")
|
n3 = CollectInvocation(id = "3")
|
||||||
n4 = ListPassThroughInvocation(id = "4")
|
n4 = ListPassThroughInvocation(id = "4")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -224,7 +223,7 @@ def test_graph_connects_collector():
|
|||||||
|
|
||||||
def test_graph_collector_invalid_with_varying_input_types():
|
def test_graph_collector_invalid_with_varying_input_types():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2")
|
n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2")
|
||||||
n3 = CollectInvocation(id = "3")
|
n3 = CollectInvocation(id = "3")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -282,7 +281,7 @@ def test_graph_connects_iterator():
|
|||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = ListPassThroughInvocation(id = "1")
|
n1 = ListPassThroughInvocation(id = "1")
|
||||||
n2 = IterateInvocation(id = "2")
|
n2 = IterateInvocation(id = "2")
|
||||||
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
|
n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
g.add_node(n3)
|
g.add_node(n3)
|
||||||
@ -298,7 +297,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
|
|||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = ListPassThroughInvocation(id = "1")
|
n1 = ListPassThroughInvocation(id = "1")
|
||||||
n2 = IterateInvocation(id = "2")
|
n2 = IterateInvocation(id = "2")
|
||||||
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
|
n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi")
|
||||||
n4 = ListPassThroughInvocation(id = "4")
|
n4 = ListPassThroughInvocation(id = "4")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -316,7 +315,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
|
|||||||
|
|
||||||
def test_graph_iterator_invalid_if_input_not_list():
|
def test_graph_iterator_invalid_if_input_not_list():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = IterateInvocation(id = "2")
|
n2 = IterateInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -344,7 +343,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different():
|
|||||||
|
|
||||||
def test_graph_validates():
|
def test_graph_validates():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -355,7 +354,7 @@ def test_graph_validates():
|
|||||||
|
|
||||||
def test_graph_invalid_if_edges_reference_missing_nodes():
|
def test_graph_invalid_if_edges_reference_missing_nodes():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.nodes[n1.id] = n1
|
g.nodes[n1.id] = n1
|
||||||
e1 = create_edge("1","image","2","image")
|
e1 = create_edge("1","image","2","image")
|
||||||
g.edges.append(e1)
|
g.edges.append(e1)
|
||||||
@ -367,7 +366,7 @@ def test_graph_invalid_if_subgraph_invalid():
|
|||||||
n1 = GraphInvocation(id = "1")
|
n1 = GraphInvocation(id = "1")
|
||||||
n1.graph = Graph()
|
n1.graph = Graph()
|
||||||
|
|
||||||
n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi")
|
n1_1 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi")
|
||||||
n1.graph.nodes[n1_1.id] = n1_1
|
n1.graph.nodes[n1_1.id] = n1_1
|
||||||
e1 = create_edge("1","image","2","image")
|
e1 = create_edge("1","image","2","image")
|
||||||
n1.graph.edges.append(e1)
|
n1.graph.edges.append(e1)
|
||||||
@ -391,7 +390,7 @@ def test_graph_invalid_if_has_cycle():
|
|||||||
|
|
||||||
def test_graph_invalid_with_invalid_connection():
|
def test_graph_invalid_with_invalid_connection():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.nodes[n1.id] = n1
|
g.nodes[n1.id] = n1
|
||||||
g.nodes[n2.id] = n2
|
g.nodes[n2.id] = n2
|
||||||
@ -408,7 +407,7 @@ def test_graph_gets_subgraph_node():
|
|||||||
n1.graph = Graph()
|
n1.graph = Graph()
|
||||||
n1.graph.add_node
|
n1.graph.add_node
|
||||||
|
|
||||||
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n1.graph.add_node(n1_1)
|
n1.graph.add_node(n1_1)
|
||||||
|
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -485,7 +484,7 @@ def test_graph_fails_to_get_missing_subgraph_node():
|
|||||||
n1.graph = Graph()
|
n1.graph = Graph()
|
||||||
n1.graph.add_node
|
n1.graph.add_node
|
||||||
|
|
||||||
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n1.graph.add_node(n1_1)
|
n1.graph.add_node(n1_1)
|
||||||
|
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -499,7 +498,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
|||||||
n1.graph = Graph()
|
n1.graph = Graph()
|
||||||
n1.graph.add_node
|
n1.graph.add_node
|
||||||
|
|
||||||
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n1.graph.add_node(n1_1)
|
n1.graph.add_node(n1_1)
|
||||||
|
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -512,7 +511,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
|||||||
|
|
||||||
def test_graph_gets_networkx_graph():
|
def test_graph_gets_networkx_graph():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -529,7 +528,7 @@ def test_graph_gets_networkx_graph():
|
|||||||
# TODO: Graph serializes and deserializes
|
# TODO: Graph serializes and deserializes
|
||||||
def test_graph_can_serialize():
|
def test_graph_can_serialize():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -541,7 +540,7 @@ def test_graph_can_serialize():
|
|||||||
|
|
||||||
def test_graph_can_deserialize():
|
def test_graph_can_deserialize():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Callable, Literal
|
from typing import Any, Callable, Literal, Union
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
from invokeai.app.invocations.image import ImageField
|
from invokeai.app.invocations.image import ImageField
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
@ -43,14 +43,23 @@ class ImageTestInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
image: ImageField = Field()
|
image: ImageField = Field()
|
||||||
|
|
||||||
class ImageTestInvocation(BaseInvocation):
|
class TextToImageTestInvocation(BaseInvocation):
|
||||||
type: Literal['test_image'] = 'test_image'
|
type: Literal['test_text_to_image'] = 'test_text_to_image'
|
||||||
|
|
||||||
prompt: str = Field(default = "")
|
prompt: str = Field(default = "")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
|
|
||||||
|
class ImageToImageTestInvocation(BaseInvocation):
|
||||||
|
type: Literal['test_image_to_image'] = 'test_image_to_image'
|
||||||
|
|
||||||
|
prompt: str = Field(default = "")
|
||||||
|
image: Union[ImageField, None] = Field(default=None)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||||
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
|
|
||||||
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
||||||
type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output'
|
type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output'
|
||||||
collection: list[str] = Field(default_factory=list)
|
collection: list[str] = Field(default_factory=list)
|
||||||
@ -62,7 +71,6 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||||
|
|
||||||
|
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.graph import Edge, EdgeConnection
|
from invokeai.app.services.graph import Edge, EdgeConnection
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user