mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[nodes] Removed InvokerServices, simplying service model
This commit is contained in:
parent
34e3aa1f88
commit
cd98d88fe7
@ -13,7 +13,7 @@ from ...globals import Globals
|
|||||||
from ..services.image_storage import DiskImageStorage
|
from ..services.image_storage import DiskImageStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from ..services.invoker import Invoker, InvokerServices
|
from ..services.invoker import Invoker
|
||||||
from ..services.generate_initializer import get_generate
|
from ..services.generate_initializer import get_generate
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
@ -60,22 +60,19 @@ class ApiDependencies:
|
|||||||
|
|
||||||
images = DiskImageStorage(output_folder)
|
images = DiskImageStorage(output_folder)
|
||||||
|
|
||||||
services = InvocationServices(
|
|
||||||
generate = generate,
|
|
||||||
events = events,
|
|
||||||
images = images
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = os.path.join(output_folder, 'invokeai.db')
|
db_location = os.path.join(output_folder, 'invokeai.db')
|
||||||
|
|
||||||
invoker_services = InvokerServices(
|
services = InvocationServices(
|
||||||
|
generate = generate,
|
||||||
|
events = events,
|
||||||
|
images = images,
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor()
|
||||||
)
|
)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services, invoker_services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shutdown():
|
def shutdown():
|
||||||
|
@ -44,9 +44,9 @@ async def list_sessions(
|
|||||||
) -> PaginatedResults[GraphExecutionState]:
|
) -> PaginatedResults[GraphExecutionState]:
|
||||||
"""Gets a list of sessions, optionally searching"""
|
"""Gets a list of sessions, optionally searching"""
|
||||||
if filter == '':
|
if filter == '':
|
||||||
result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page)
|
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||||
else:
|
else:
|
||||||
result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page)
|
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ async def get_session(
|
|||||||
session_id: str = Path(description = "The id of the session to get")
|
session_id: str = Path(description = "The id of the session to get")
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Gets a session"""
|
"""Gets a session"""
|
||||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
return Response(status_code = 404)
|
return Response(status_code = 404)
|
||||||
else:
|
else:
|
||||||
@ -80,13 +80,13 @@ async def add_node(
|
|||||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
|
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Adds a node to the graph"""
|
"""Adds a node to the graph"""
|
||||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
return Response(status_code = 404)
|
return Response(status_code = 404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.add_node(node)
|
session.add_node(node)
|
||||||
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session.id
|
return session.id
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
return Response(status_code = 400)
|
return Response(status_code = 400)
|
||||||
@ -108,13 +108,13 @@ async def update_node(
|
|||||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
|
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Updates a node in the graph and removes all linked edges"""
|
"""Updates a node in the graph and removes all linked edges"""
|
||||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
return Response(status_code = 404)
|
return Response(status_code = 404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.update_node(node_path, node)
|
session.update_node(node_path, node)
|
||||||
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
return Response(status_code = 400)
|
return Response(status_code = 400)
|
||||||
@ -135,13 +135,13 @@ async def delete_node(
|
|||||||
node_path: str = Path(description = "The path to the node to delete")
|
node_path: str = Path(description = "The path to the node to delete")
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Deletes a node in the graph and removes all linked edges"""
|
"""Deletes a node in the graph and removes all linked edges"""
|
||||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
return Response(status_code = 404)
|
return Response(status_code = 404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.delete_node(node_path)
|
session.delete_node(node_path)
|
||||||
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
return Response(status_code = 400)
|
return Response(status_code = 400)
|
||||||
@ -162,13 +162,13 @@ async def add_edge(
|
|||||||
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
|
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Adds an edge to the graph"""
|
"""Adds an edge to the graph"""
|
||||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
return Response(status_code = 404)
|
return Response(status_code = 404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.add_edge(edge)
|
session.add_edge(edge)
|
||||||
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
return Response(status_code = 400)
|
return Response(status_code = 400)
|
||||||
@ -193,14 +193,14 @@ async def delete_edge(
|
|||||||
to_field: str = Path(description = "The field of the node the edge is going to")
|
to_field: str = Path(description = "The field of the node the edge is going to")
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Deletes an edge from the graph"""
|
"""Deletes an edge from the graph"""
|
||||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
return Response(status_code = 404)
|
return Response(status_code = 404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
|
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
|
||||||
session.delete_edge(edge)
|
session.delete_edge(edge)
|
||||||
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
return Response(status_code = 400)
|
return Response(status_code = 400)
|
||||||
@ -221,7 +221,7 @@ async def invoke_session(
|
|||||||
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
|
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Invokes a session"""
|
"""Invokes a session"""
|
||||||
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
return Response(status_code = 404)
|
return Response(status_code = 404)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ from .services.image_storage import DiskImageStorage
|
|||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
from .services.invoker import Invoker, InvokerServices
|
from .services.invoker import Invoker
|
||||||
from .invocations import *
|
from .invocations import *
|
||||||
from ..args import Args
|
from ..args import Args
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
@ -171,28 +171,25 @@ def invoke_cli():
|
|||||||
|
|
||||||
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
|
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
|
||||||
|
|
||||||
services = InvocationServices(
|
|
||||||
generate = generate,
|
|
||||||
events = events,
|
|
||||||
images = DiskImageStorage(output_folder)
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = os.path.join(output_folder, 'invokeai.db')
|
db_location = os.path.join(output_folder, 'invokeai.db')
|
||||||
|
|
||||||
invoker_services = InvokerServices(
|
services = InvocationServices(
|
||||||
|
generate = generate,
|
||||||
|
events = events,
|
||||||
|
images = DiskImageStorage(output_folder),
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor()
|
||||||
)
|
)
|
||||||
|
|
||||||
invoker = Invoker(services, invoker_services)
|
invoker = Invoker(services)
|
||||||
session = invoker.create_execution_state()
|
session = invoker.create_execution_state()
|
||||||
|
|
||||||
parser = get_invocation_parser()
|
parser = get_invocation_parser()
|
||||||
|
|
||||||
# Uncomment to print out previous sessions at startup
|
# Uncomment to print out previous sessions at startup
|
||||||
# print(invoker_services.session_manager.list())
|
# print(services.session_manager.list())
|
||||||
|
|
||||||
# Defaults storage
|
# Defaults storage
|
||||||
defaults: Dict[str, Any] = dict()
|
defaults: Dict[str, Any] = dict()
|
||||||
@ -213,7 +210,7 @@ def invoke_cli():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Refresh the state of the session
|
# Refresh the state of the session
|
||||||
session = invoker.invoker_services.graph_execution_manager.get(session.id)
|
session = invoker.services.graph_execution_manager.get(session.id)
|
||||||
history = list(get_graph_execution_history(session))
|
history = list(get_graph_execution_history(session))
|
||||||
|
|
||||||
# Split the command for piping
|
# Split the command for piping
|
||||||
@ -289,7 +286,7 @@ def invoke_cli():
|
|||||||
invoker.invoke(session, invoke_all = True)
|
invoker.invoke(session, invoke_all = True)
|
||||||
while not session.is_complete():
|
while not session.is_complete():
|
||||||
# Wait some time
|
# Wait some time
|
||||||
session = invoker.invoker_services.graph_execution_manager.get(session.id)
|
session = invoker.services.graph_execution_manager.get(session.id)
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
except InvalidArgs:
|
except InvalidArgs:
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
from .invocation_queue import InvocationQueueABC
|
||||||
|
from .item_storage import ItemStorageABC
|
||||||
from .image_storage import ImageStorageBase
|
from .image_storage import ImageStorageBase
|
||||||
from .events import EventServiceBase
|
from .events import EventServiceBase
|
||||||
from ....generate import Generate
|
from ....generate import Generate
|
||||||
@ -9,12 +11,23 @@ class InvocationServices():
|
|||||||
generate: Generate # TODO: wrap Generate, or split it up from model?
|
generate: Generate # TODO: wrap Generate, or split it up from model?
|
||||||
events: EventServiceBase
|
events: EventServiceBase
|
||||||
images: ImageStorageBase
|
images: ImageStorageBase
|
||||||
|
queue: InvocationQueueABC
|
||||||
|
|
||||||
|
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||||
|
graph_execution_manager: ItemStorageABC['GraphExecutionState']
|
||||||
|
processor: 'InvocationProcessorABC'
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
generate: Generate,
|
generate: Generate,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
images: ImageStorageBase
|
images: ImageStorageBase,
|
||||||
|
queue: InvocationQueueABC,
|
||||||
|
graph_execution_manager: ItemStorageABC['GraphExecutionState'],
|
||||||
|
processor: 'InvocationProcessorABC'
|
||||||
):
|
):
|
||||||
self.generate = generate
|
self.generate = generate
|
||||||
self.events = events
|
self.events = events
|
||||||
self.images = images
|
self.images = images
|
||||||
|
self.queue = queue
|
||||||
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
self.processor = processor
|
||||||
|
@ -9,34 +9,15 @@ from .invocation_services import InvocationServices
|
|||||||
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
|
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
|
||||||
|
|
||||||
|
|
||||||
class InvokerServices:
|
|
||||||
"""Services used by the Invoker for execution"""
|
|
||||||
|
|
||||||
queue: InvocationQueueABC
|
|
||||||
graph_execution_manager: ItemStorageABC[GraphExecutionState]
|
|
||||||
processor: 'InvocationProcessorABC'
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
queue: InvocationQueueABC,
|
|
||||||
graph_execution_manager: ItemStorageABC[GraphExecutionState],
|
|
||||||
processor: 'InvocationProcessorABC'):
|
|
||||||
self.queue = queue
|
|
||||||
self.graph_execution_manager = graph_execution_manager
|
|
||||||
self.processor = processor
|
|
||||||
|
|
||||||
|
|
||||||
class Invoker:
|
class Invoker:
|
||||||
"""The invoker, used to execute invocations"""
|
"""The invoker, used to execute invocations"""
|
||||||
|
|
||||||
services: InvocationServices
|
services: InvocationServices
|
||||||
invoker_services: InvokerServices
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
services: InvocationServices, # Services used by nodes to perform invocations
|
services: InvocationServices
|
||||||
invoker_services: InvokerServices # Services used by the invoker for orchestration
|
|
||||||
):
|
):
|
||||||
self.services = services
|
self.services = services
|
||||||
self.invoker_services = invoker_services
|
|
||||||
self._start()
|
self._start()
|
||||||
|
|
||||||
|
|
||||||
@ -49,11 +30,11 @@ class Invoker:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Save the execution state
|
# Save the execution state
|
||||||
self.invoker_services.graph_execution_manager.set(graph_execution_state)
|
self.services.graph_execution_manager.set(graph_execution_state)
|
||||||
|
|
||||||
# Queue the invocation
|
# Queue the invocation
|
||||||
print(f'queueing item {invocation.id}')
|
print(f'queueing item {invocation.id}')
|
||||||
self.invoker_services.queue.put(InvocationQueueItem(
|
self.services.queue.put(InvocationQueueItem(
|
||||||
#session_id = session.id,
|
#session_id = session.id,
|
||||||
graph_execution_state_id = graph_execution_state.id,
|
graph_execution_state_id = graph_execution_state.id,
|
||||||
invocation_id = invocation.id,
|
invocation_id = invocation.id,
|
||||||
@ -66,7 +47,7 @@ class Invoker:
|
|||||||
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
|
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
|
||||||
"""Creates a new execution state for the given graph"""
|
"""Creates a new execution state for the given graph"""
|
||||||
new_state = GraphExecutionState(graph = Graph() if graph is None else graph)
|
new_state = GraphExecutionState(graph = Graph() if graph is None else graph)
|
||||||
self.invoker_services.graph_execution_manager.set(new_state)
|
self.services.graph_execution_manager.set(new_state)
|
||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
|
|
||||||
@ -86,8 +67,8 @@ class Invoker:
|
|||||||
|
|
||||||
def _start(self) -> None:
|
def _start(self) -> None:
|
||||||
"""Starts the invoker. This is called automatically when the invoker is created."""
|
"""Starts the invoker. This is called automatically when the invoker is created."""
|
||||||
for service in vars(self.invoker_services):
|
for service in vars(self.services):
|
||||||
self.__start_service(getattr(self.invoker_services, service))
|
self.__start_service(getattr(self.services, service))
|
||||||
|
|
||||||
for service in vars(self.services):
|
for service in vars(self.services):
|
||||||
self.__start_service(getattr(self.services, service))
|
self.__start_service(getattr(self.services, service))
|
||||||
@ -99,10 +80,10 @@ class Invoker:
|
|||||||
for service in vars(self.services):
|
for service in vars(self.services):
|
||||||
self.__stop_service(getattr(self.services, service))
|
self.__stop_service(getattr(self.services, service))
|
||||||
|
|
||||||
for service in vars(self.invoker_services):
|
for service in vars(self.services):
|
||||||
self.__stop_service(getattr(self.invoker_services, service))
|
self.__stop_service(getattr(self.services, service))
|
||||||
|
|
||||||
self.invoker_services.queue.put(None)
|
self.services.queue.put(None)
|
||||||
|
|
||||||
|
|
||||||
class InvocationProcessorABC(ABC):
|
class InvocationProcessorABC(ABC):
|
||||||
|
@ -28,11 +28,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
def __process(self, stop_event: Event):
|
def __process(self, stop_event: Event):
|
||||||
try:
|
try:
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get()
|
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||||
if not queue_item: # Probably stopping
|
if not queue_item: # Probably stopping
|
||||||
continue
|
continue
|
||||||
|
|
||||||
graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id)
|
graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id)
|
||||||
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
||||||
|
|
||||||
# Send starting event
|
# Send starting event
|
||||||
@ -52,7 +52,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
graph_execution_state.complete(invocation.id, outputs)
|
graph_execution_state.complete(invocation.id, outputs)
|
||||||
|
|
||||||
# Save the state changes
|
# Save the state changes
|
||||||
self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state)
|
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||||
|
|
||||||
# Send complete event
|
# Send complete event
|
||||||
self.__invoker.services.events.emit_invocation_complete(
|
self.__invoker.services.events.emit_invocation_complete(
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
from .test_invoker import create_edge
|
from .test_invoker import create_edge
|
||||||
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
||||||
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
|
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
|
||||||
|
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
|
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
|
||||||
from ldm.invoke.app.services.invocation_services import InvocationServices
|
from ldm.invoke.app.services.invocation_services import InvocationServices
|
||||||
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||||
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
|
||||||
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@ -19,7 +20,14 @@ def simple_graph():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services():
|
def mock_services():
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
return InvocationServices(generate = None, events = None, images = None)
|
return InvocationServices(
|
||||||
|
generate = None,
|
||||||
|
events = None,
|
||||||
|
images = None,
|
||||||
|
queue = MemoryInvocationQueue(),
|
||||||
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
|
processor = DefaultInvocationProcessor()
|
||||||
|
)
|
||||||
|
|
||||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||||
n = g.next()
|
n = g.next()
|
||||||
|
@ -2,12 +2,10 @@ from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTe
|
|||||||
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
|
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
|
||||||
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
|
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
|
||||||
from ldm.invoke.app.services.invoker import Invoker, InvokerServices
|
from ldm.invoke.app.services.invoker import Invoker
|
||||||
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
from ldm.invoke.app.services.invocation_services import InvocationServices
|
from ldm.invoke.app.services.invocation_services import InvocationServices
|
||||||
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||||
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
|
||||||
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@ -22,21 +20,19 @@ def simple_graph():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
return InvocationServices(generate = None, events = TestEventService(), images = None)
|
return InvocationServices(
|
||||||
|
generate = None,
|
||||||
@pytest.fixture()
|
events = TestEventService(),
|
||||||
def mock_invoker_services() -> InvokerServices:
|
images = None,
|
||||||
return InvokerServices(
|
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor()
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker:
|
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
||||||
return Invoker(
|
return Invoker(
|
||||||
services = mock_services,
|
services = mock_services
|
||||||
invoker_services = mock_invoker_services
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_can_create_graph_state(mock_invoker: Invoker):
|
def test_can_create_graph_state(mock_invoker: Invoker):
|
||||||
@ -60,13 +56,13 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
|||||||
assert invocation_id is not None
|
assert invocation_id is not None
|
||||||
|
|
||||||
def has_executed_any(g: GraphExecutionState):
|
def has_executed_any(g: GraphExecutionState):
|
||||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
return len(g.executed) > 0
|
return len(g.executed) > 0
|
||||||
|
|
||||||
wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1)
|
wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1)
|
||||||
mock_invoker.stop()
|
mock_invoker.stop()
|
||||||
|
|
||||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
assert len(g.executed) > 0
|
assert len(g.executed) > 0
|
||||||
|
|
||||||
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||||
@ -75,11 +71,11 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
|||||||
assert invocation_id is not None
|
assert invocation_id is not None
|
||||||
|
|
||||||
def has_executed_all(g: GraphExecutionState):
|
def has_executed_all(g: GraphExecutionState):
|
||||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
return g.is_complete()
|
return g.is_complete()
|
||||||
|
|
||||||
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
|
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
|
||||||
mock_invoker.stop()
|
mock_invoker.stop()
|
||||||
|
|
||||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
assert g.is_complete()
|
assert g.is_complete()
|
||||||
|
Loading…
Reference in New Issue
Block a user