mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[nodes] Removed InvokerServices, simplying service model
This commit is contained in:
@ -1,10 +1,11 @@
|
||||
from .test_invoker import create_edge
|
||||
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
||||
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
|
||||
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from ldm.invoke.app.services.invocation_services import InvocationServices
|
||||
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
||||
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
|
||||
import pytest
|
||||
|
||||
|
||||
@ -19,7 +20,14 @@ def simple_graph():
|
||||
@pytest.fixture
|
||||
def mock_services():
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
return InvocationServices(generate = None, events = None, images = None)
|
||||
return InvocationServices(
|
||||
generate = None,
|
||||
events = None,
|
||||
images = None,
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
)
|
||||
|
||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||
n = g.next()
|
||||
|
@ -2,12 +2,10 @@ from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTe
|
||||
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
|
||||
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from ldm.invoke.app.services.invoker import Invoker, InvokerServices
|
||||
from ldm.invoke.app.services.invoker import Invoker
|
||||
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from ldm.invoke.app.services.invocation_services import InvocationServices
|
||||
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
||||
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
|
||||
import pytest
|
||||
|
||||
|
||||
@ -22,21 +20,19 @@ def simple_graph():
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
return InvocationServices(generate = None, events = TestEventService(), images = None)
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_invoker_services() -> InvokerServices:
|
||||
return InvokerServices(
|
||||
return InvocationServices(
|
||||
generate = None,
|
||||
events = TestEventService(),
|
||||
images = None,
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
)
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker:
|
||||
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
||||
return Invoker(
|
||||
services = mock_services,
|
||||
invoker_services = mock_invoker_services
|
||||
services = mock_services
|
||||
)
|
||||
|
||||
def test_can_create_graph_state(mock_invoker: Invoker):
|
||||
@ -60,13 +56,13 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
||||
assert invocation_id is not None
|
||||
|
||||
def has_executed_any(g: GraphExecutionState):
|
||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
return len(g.executed) > 0
|
||||
|
||||
wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1)
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert len(g.executed) > 0
|
||||
|
||||
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||
@ -75,11 +71,11 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||
assert invocation_id is not None
|
||||
|
||||
def has_executed_all(g: GraphExecutionState):
|
||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
return g.is_complete()
|
||||
|
||||
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert g.is_complete()
|
||||
|
Reference in New Issue
Block a user