mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[nodes] Removed InvokerServices, simplying service model
This commit is contained in:
parent
34e3aa1f88
commit
cd98d88fe7
@ -13,7 +13,7 @@ from ...globals import Globals
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker, InvokerServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.generate_initializer import get_generate
|
||||
from .events import FastAPIEventService
|
||||
|
||||
@ -60,22 +60,19 @@ class ApiDependencies:
|
||||
|
||||
images = DiskImageStorage(output_folder)
|
||||
|
||||
services = InvocationServices(
|
||||
generate = generate,
|
||||
events = events,
|
||||
images = images
|
||||
)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, 'invokeai.db')
|
||||
|
||||
invoker_services = InvokerServices(
|
||||
services = InvocationServices(
|
||||
generate = generate,
|
||||
events = events,
|
||||
images = images,
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services, invoker_services)
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
def shutdown():
|
||||
|
@ -44,9 +44,9 @@ async def list_sessions(
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if filter == '':
|
||||
result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page)
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
else:
|
||||
result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page)
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
return result
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ async def get_session(
|
||||
session_id: str = Path(description = "The id of the session to get")
|
||||
) -> GraphExecutionState:
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
else:
|
||||
@ -80,13 +80,13 @@ async def add_node(
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.add_node(node)
|
||||
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session.id
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
@ -108,13 +108,13 @@ async def update_node(
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.update_node(node_path, node)
|
||||
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
@ -135,13 +135,13 @@ async def delete_node(
|
||||
node_path: str = Path(description = "The path to the node to delete")
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.delete_node(node_path)
|
||||
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
@ -162,13 +162,13 @@ async def add_edge(
|
||||
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
|
||||
) -> GraphExecutionState:
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.add_edge(edge)
|
||||
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
@ -193,14 +193,14 @@ async def delete_edge(
|
||||
to_field: str = Path(description = "The field of the node the edge is going to")
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes an edge from the graph"""
|
||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
@ -221,7 +221,7 @@ async def invoke_session(
|
||||
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
|
||||
) -> None:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
|
@ -20,7 +20,7 @@ from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker, InvokerServices
|
||||
from .services.invoker import Invoker
|
||||
from .invocations import *
|
||||
from ..args import Args
|
||||
from .services.events import EventServiceBase
|
||||
@ -171,28 +171,25 @@ def invoke_cli():
|
||||
|
||||
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
|
||||
|
||||
services = InvocationServices(
|
||||
generate = generate,
|
||||
events = events,
|
||||
images = DiskImageStorage(output_folder)
|
||||
)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, 'invokeai.db')
|
||||
|
||||
invoker_services = InvokerServices(
|
||||
services = InvocationServices(
|
||||
generate = generate,
|
||||
events = events,
|
||||
images = DiskImageStorage(output_folder),
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
)
|
||||
|
||||
invoker = Invoker(services, invoker_services)
|
||||
invoker = Invoker(services)
|
||||
session = invoker.create_execution_state()
|
||||
|
||||
parser = get_invocation_parser()
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(invoker_services.session_manager.list())
|
||||
# print(services.session_manager.list())
|
||||
|
||||
# Defaults storage
|
||||
defaults: Dict[str, Any] = dict()
|
||||
@ -213,7 +210,7 @@ def invoke_cli():
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
session = invoker.invoker_services.graph_execution_manager.get(session.id)
|
||||
session = invoker.services.graph_execution_manager.get(session.id)
|
||||
history = list(get_graph_execution_history(session))
|
||||
|
||||
# Split the command for piping
|
||||
@ -289,7 +286,7 @@ def invoke_cli():
|
||||
invoker.invoke(session, invoke_all = True)
|
||||
while not session.is_complete():
|
||||
# Wait some time
|
||||
session = invoker.invoker_services.graph_execution_manager.get(session.id)
|
||||
session = invoker.services.graph_execution_manager.get(session.id)
|
||||
time.sleep(0.1)
|
||||
|
||||
except InvalidArgs:
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
from .image_storage import ImageStorageBase
|
||||
from .events import EventServiceBase
|
||||
from ....generate import Generate
|
||||
@ -9,12 +11,23 @@ class InvocationServices():
|
||||
generate: Generate # TODO: wrap Generate, or split it up from model?
|
||||
events: EventServiceBase
|
||||
images: ImageStorageBase
|
||||
queue: InvocationQueueABC
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_execution_manager: ItemStorageABC['GraphExecutionState']
|
||||
processor: 'InvocationProcessorABC'
|
||||
|
||||
def __init__(self,
|
||||
generate: Generate,
|
||||
events: EventServiceBase,
|
||||
images: ImageStorageBase
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_execution_manager: ItemStorageABC['GraphExecutionState'],
|
||||
processor: 'InvocationProcessorABC'
|
||||
):
|
||||
self.generate = generate
|
||||
self.events = events
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
|
@ -9,34 +9,15 @@ from .invocation_services import InvocationServices
|
||||
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
|
||||
|
||||
|
||||
class InvokerServices:
|
||||
"""Services used by the Invoker for execution"""
|
||||
|
||||
queue: InvocationQueueABC
|
||||
graph_execution_manager: ItemStorageABC[GraphExecutionState]
|
||||
processor: 'InvocationProcessorABC'
|
||||
|
||||
def __init__(self,
|
||||
queue: InvocationQueueABC,
|
||||
graph_execution_manager: ItemStorageABC[GraphExecutionState],
|
||||
processor: 'InvocationProcessorABC'):
|
||||
self.queue = queue
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
|
||||
|
||||
class Invoker:
|
||||
"""The invoker, used to execute invocations"""
|
||||
|
||||
services: InvocationServices
|
||||
invoker_services: InvokerServices
|
||||
|
||||
def __init__(self,
|
||||
services: InvocationServices, # Services used by nodes to perform invocations
|
||||
invoker_services: InvokerServices # Services used by the invoker for orchestration
|
||||
services: InvocationServices
|
||||
):
|
||||
self.services = services
|
||||
self.invoker_services = invoker_services
|
||||
self._start()
|
||||
|
||||
|
||||
@ -49,11 +30,11 @@ class Invoker:
|
||||
return None
|
||||
|
||||
# Save the execution state
|
||||
self.invoker_services.graph_execution_manager.set(graph_execution_state)
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Queue the invocation
|
||||
print(f'queueing item {invocation.id}')
|
||||
self.invoker_services.queue.put(InvocationQueueItem(
|
||||
self.services.queue.put(InvocationQueueItem(
|
||||
#session_id = session.id,
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
@ -66,7 +47,7 @@ class Invoker:
|
||||
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph = Graph() if graph is None else graph)
|
||||
self.invoker_services.graph_execution_manager.set(new_state)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
|
||||
@ -86,8 +67,8 @@ class Invoker:
|
||||
|
||||
def _start(self) -> None:
|
||||
"""Starts the invoker. This is called automatically when the invoker is created."""
|
||||
for service in vars(self.invoker_services):
|
||||
self.__start_service(getattr(self.invoker_services, service))
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
@ -99,10 +80,10 @@ class Invoker:
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.invoker_services):
|
||||
self.__stop_service(getattr(self.invoker_services, service))
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.invoker_services.queue.put(None)
|
||||
self.services.queue.put(None)
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC):
|
||||
|
@ -28,11 +28,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get()
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
if not queue_item: # Probably stopping
|
||||
continue
|
||||
|
||||
graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id)
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id)
|
||||
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
||||
|
||||
# Send starting event
|
||||
@ -52,7 +52,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state)
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
|
@ -1,10 +1,11 @@
|
||||
from .test_invoker import create_edge
|
||||
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
||||
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
|
||||
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from ldm.invoke.app.services.invocation_services import InvocationServices
|
||||
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
||||
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
|
||||
import pytest
|
||||
|
||||
|
||||
@ -19,7 +20,14 @@ def simple_graph():
|
||||
@pytest.fixture
|
||||
def mock_services():
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
return InvocationServices(generate = None, events = None, images = None)
|
||||
return InvocationServices(
|
||||
generate = None,
|
||||
events = None,
|
||||
images = None,
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
)
|
||||
|
||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||
n = g.next()
|
||||
|
@ -2,12 +2,10 @@ from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTe
|
||||
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
|
||||
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from ldm.invoke.app.services.invoker import Invoker, InvokerServices
|
||||
from ldm.invoke.app.services.invoker import Invoker
|
||||
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from ldm.invoke.app.services.invocation_services import InvocationServices
|
||||
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
||||
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
|
||||
import pytest
|
||||
|
||||
|
||||
@ -22,21 +20,19 @@ def simple_graph():
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
return InvocationServices(generate = None, events = TestEventService(), images = None)
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_invoker_services() -> InvokerServices:
|
||||
return InvokerServices(
|
||||
return InvocationServices(
|
||||
generate = None,
|
||||
events = TestEventService(),
|
||||
images = None,
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
)
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker:
|
||||
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
||||
return Invoker(
|
||||
services = mock_services,
|
||||
invoker_services = mock_invoker_services
|
||||
services = mock_services
|
||||
)
|
||||
|
||||
def test_can_create_graph_state(mock_invoker: Invoker):
|
||||
@ -60,13 +56,13 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
||||
assert invocation_id is not None
|
||||
|
||||
def has_executed_any(g: GraphExecutionState):
|
||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
return len(g.executed) > 0
|
||||
|
||||
wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1)
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert len(g.executed) > 0
|
||||
|
||||
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||
@ -75,11 +71,11 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||
assert invocation_id is not None
|
||||
|
||||
def has_executed_all(g: GraphExecutionState):
|
||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
return g.is_complete()
|
||||
|
||||
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert g.is_complete()
|
||||
|
Loading…
Reference in New Issue
Block a user