[nodes] Removed InvokerServices, simplying service model

This commit is contained in:
Kyle Schouviller 2023-02-24 20:11:28 -08:00
parent 34e3aa1f88
commit cd98d88fe7
8 changed files with 81 additions and 89 deletions

View File

@ -13,7 +13,7 @@ from ...globals import Globals
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker, InvokerServices
from ..services.invoker import Invoker
from ..services.generate_initializer import get_generate
from .events import FastAPIEventService
@ -60,22 +60,19 @@ class ApiDependencies:
images = DiskImageStorage(output_folder)
services = InvocationServices(
generate = generate,
events = events,
images = images
)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db')
invoker_services = InvokerServices(
services = InvocationServices(
generate = generate,
events = events,
images = images,
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
ApiDependencies.invoker = Invoker(services, invoker_services)
ApiDependencies.invoker = Invoker(services)
@staticmethod
def shutdown():

View File

@ -44,9 +44,9 @@ async def list_sessions(
) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching"""
if filter == '':
result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page)
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
else:
result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page)
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
return result
@ -60,7 +60,7 @@ async def get_session(
session_id: str = Path(description = "The id of the session to get")
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
else:
@ -80,13 +80,13 @@ async def add_node(
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
) -> str:
"""Adds a node to the graph"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.add_node(node)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session.id
except NodeAlreadyExecutedError:
return Response(status_code = 400)
@ -108,13 +108,13 @@ async def update_node(
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.update_node(node_path, node)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
@ -135,13 +135,13 @@ async def delete_node(
node_path: str = Path(description = "The path to the node to delete")
) -> GraphExecutionState:
"""Deletes a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.delete_node(node_path)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
@ -162,13 +162,13 @@ async def add_edge(
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
) -> GraphExecutionState:
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.add_edge(edge)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
@ -193,14 +193,14 @@ async def delete_edge(
to_field: str = Path(description = "The field of the node the edge is going to")
) -> GraphExecutionState:
"""Deletes an edge from the graph"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
session.delete_edge(edge)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
@ -221,7 +221,7 @@ async def invoke_session(
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
) -> None:
"""Invokes a session"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)

View File

@ -20,7 +20,7 @@ from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .invocations.baseinvocation import BaseInvocation
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker, InvokerServices
from .services.invoker import Invoker
from .invocations import *
from ..args import Args
from .services.events import EventServiceBase
@ -171,28 +171,25 @@ def invoke_cli():
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
services = InvocationServices(
generate = generate,
events = events,
images = DiskImageStorage(output_folder)
)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db')
invoker_services = InvokerServices(
services = InvocationServices(
generate = generate,
events = events,
images = DiskImageStorage(output_folder),
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
invoker = Invoker(services, invoker_services)
invoker = Invoker(services)
session = invoker.create_execution_state()
parser = get_invocation_parser()
# Uncomment to print out previous sessions at startup
# print(invoker_services.session_manager.list())
# print(services.session_manager.list())
# Defaults storage
defaults: Dict[str, Any] = dict()
@ -213,7 +210,7 @@ def invoke_cli():
try:
# Refresh the state of the session
session = invoker.invoker_services.graph_execution_manager.get(session.id)
session = invoker.services.graph_execution_manager.get(session.id)
history = list(get_graph_execution_history(session))
# Split the command for piping
@ -289,7 +286,7 @@ def invoke_cli():
invoker.invoke(session, invoke_all = True)
while not session.is_complete():
# Wait some time
session = invoker.invoker_services.graph_execution_manager.get(session.id)
session = invoker.services.graph_execution_manager.get(session.id)
time.sleep(0.1)
except InvalidArgs:

View File

@ -1,4 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from .image_storage import ImageStorageBase
from .events import EventServiceBase
from ....generate import Generate
@ -9,12 +11,23 @@ class InvocationServices():
generate: Generate # TODO: wrap Generate, or split it up from model?
events: EventServiceBase
images: ImageStorageBase
queue: InvocationQueueABC
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_execution_manager: ItemStorageABC['GraphExecutionState']
processor: 'InvocationProcessorABC'
def __init__(self,
generate: Generate,
events: EventServiceBase,
images: ImageStorageBase
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC['GraphExecutionState'],
processor: 'InvocationProcessorABC'
):
self.generate = generate
self.events = events
self.images = images
self.queue = queue
self.graph_execution_manager = graph_execution_manager
self.processor = processor

View File

@ -9,34 +9,15 @@ from .invocation_services import InvocationServices
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
class InvokerServices:
"""Services used by the Invoker for execution"""
queue: InvocationQueueABC
graph_execution_manager: ItemStorageABC[GraphExecutionState]
processor: 'InvocationProcessorABC'
def __init__(self,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC[GraphExecutionState],
processor: 'InvocationProcessorABC'):
self.queue = queue
self.graph_execution_manager = graph_execution_manager
self.processor = processor
class Invoker:
"""The invoker, used to execute invocations"""
services: InvocationServices
invoker_services: InvokerServices
def __init__(self,
services: InvocationServices, # Services used by nodes to perform invocations
invoker_services: InvokerServices # Services used by the invoker for orchestration
services: InvocationServices
):
self.services = services
self.invoker_services = invoker_services
self._start()
@ -49,11 +30,11 @@ class Invoker:
return None
# Save the execution state
self.invoker_services.graph_execution_manager.set(graph_execution_state)
self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation
print(f'queueing item {invocation.id}')
self.invoker_services.queue.put(InvocationQueueItem(
self.services.queue.put(InvocationQueueItem(
#session_id = session.id,
graph_execution_state_id = graph_execution_state.id,
invocation_id = invocation.id,
@ -66,7 +47,7 @@ class Invoker:
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph = Graph() if graph is None else graph)
self.invoker_services.graph_execution_manager.set(new_state)
self.services.graph_execution_manager.set(new_state)
return new_state
@ -86,8 +67,8 @@ class Invoker:
def _start(self) -> None:
"""Starts the invoker. This is called automatically when the invoker is created."""
for service in vars(self.invoker_services):
self.__start_service(getattr(self.invoker_services, service))
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
@ -99,10 +80,10 @@ class Invoker:
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
for service in vars(self.invoker_services):
self.__stop_service(getattr(self.invoker_services, service))
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
self.invoker_services.queue.put(None)
self.services.queue.put(None)
class InvocationProcessorABC(ABC):

View File

@ -28,11 +28,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event):
try:
while not stop_event.is_set():
queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get()
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
if not queue_item: # Probably stopping
continue
graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id)
graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id)
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
# Send starting event
@ -52,7 +52,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_execution_state.complete(invocation.id, outputs)
# Save the state changes
self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state)
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
# Send complete event
self.__invoker.services.events.emit_invocation_complete(

View File

@ -1,10 +1,11 @@
from .test_invoker import create_edge
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
from ldm.invoke.app.services.invocation_services import InvocationServices
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest
@ -19,7 +20,14 @@ def simple_graph():
@pytest.fixture
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(generate = None, events = None, images = None)
return InvocationServices(
generate = None,
events = None,
images = None,
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
n = g.next()

View File

@ -2,12 +2,10 @@ from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTe
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
from ldm.invoke.app.services.invoker import Invoker, InvokerServices
from ldm.invoke.app.services.invoker import Invoker
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.services.invocation_services import InvocationServices
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest
@ -22,21 +20,19 @@ def simple_graph():
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(generate = None, events = TestEventService(), images = None)
@pytest.fixture()
def mock_invoker_services() -> InvokerServices:
return InvokerServices(
return InvocationServices(
generate = None,
events = TestEventService(),
images = None,
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker:
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(
services = mock_services,
invoker_services = mock_invoker_services
services = mock_services
)
def test_can_create_graph_state(mock_invoker: Invoker):
@ -60,13 +56,13 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
assert invocation_id is not None
def has_executed_any(g: GraphExecutionState):
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
g = mock_invoker.services.graph_execution_manager.get(g.id)
return len(g.executed) > 0
wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1)
mock_invoker.stop()
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
@ -75,11 +71,11 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
assert invocation_id is not None
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
g = mock_invoker.services.graph_execution_manager.get(g.id)
return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
mock_invoker.stop()
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete()