2023-08-21 15:23:39 +00:00
|
|
|
import networkx as nx
|
|
|
|
|
2023-07-31 19:45:35 +00:00
|
|
|
from abc import ABC, abstractmethod
|
2023-08-10 15:38:28 +00:00
|
|
|
from itertools import product
|
2023-08-21 15:23:39 +00:00
|
|
|
from pydantic import BaseModel, Field
|
2023-07-31 19:45:35 +00:00
|
|
|
from fastapi_events.handlers.local import local_handler
|
|
|
|
from fastapi_events.typing import Event
|
|
|
|
|
2023-08-21 15:23:39 +00:00
|
|
|
from invokeai.app.services.events import EventServiceBase
|
|
|
|
from invokeai.app.services.graph import Graph, GraphExecutionState
|
|
|
|
from invokeai.app.services.invoker import Invoker
|
2023-08-10 15:38:28 +00:00
|
|
|
from invokeai.app.services.batch_manager_storage import (
|
2023-08-21 15:23:39 +00:00
|
|
|
BatchProcessStorageBase,
|
|
|
|
BatchSessionNotFoundException,
|
2023-08-10 15:38:28 +00:00
|
|
|
Batch,
|
|
|
|
BatchProcess,
|
|
|
|
BatchSession,
|
|
|
|
BatchSessionChanges,
|
|
|
|
)
|
2023-07-31 19:45:35 +00:00
|
|
|
|
2023-08-14 15:01:31 +00:00
|
|
|
|
2023-08-11 15:45:27 +00:00
|
|
|
class BatchProcessResponse(BaseModel):
|
|
|
|
batch_id: str = Field(description="ID for the batch")
|
|
|
|
session_ids: list[str] = Field(description="List of session IDs created for this batch")
|
|
|
|
|
2023-07-31 19:45:35 +00:00
|
|
|
|
|
|
|
class BatchManagerBase(ABC):
|
|
|
|
@abstractmethod
|
2023-08-17 02:07:20 +00:00
|
|
|
def start(self, invoker: Invoker) -> None:
|
2023-08-21 09:51:16 +00:00
|
|
|
"""Starts the BatchManager service"""
|
2023-07-31 19:45:35 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-08-16 19:21:11 +00:00
|
|
|
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
2023-08-21 09:51:16 +00:00
|
|
|
"""Creates a batch process"""
|
2023-08-10 15:38:28 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-08-17 02:07:20 +00:00
|
|
|
def run_batch_process(self, batch_id: str) -> None:
|
2023-08-21 09:51:16 +00:00
|
|
|
"""Runs a batch process"""
|
2023-07-31 19:45:35 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-08-17 02:07:20 +00:00
|
|
|
def cancel_batch_process(self, batch_process_id: str) -> None:
|
2023-08-21 09:51:16 +00:00
|
|
|
"""Cancels a batch process"""
|
2023-07-31 19:45:35 +00:00
|
|
|
pass
|
|
|
|
|
2023-08-18 19:38:16 +00:00
|
|
|
@abstractmethod
|
|
|
|
def get_batch(self, batch_id: str) -> BatchProcessResponse:
|
2023-08-21 09:51:16 +00:00
|
|
|
"""Gets a batch process"""
|
2023-08-18 19:38:16 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
2023-08-21 09:51:16 +00:00
|
|
|
"""Gets all batch processes"""
|
2023-08-18 19:38:16 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
2023-08-21 09:51:16 +00:00
|
|
|
"""Gets all incomplete batch processes"""
|
2023-08-18 19:38:16 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
2023-08-21 09:51:16 +00:00
|
|
|
"""Gets the sessions associated with a batch"""
|
2023-08-18 19:38:16 +00:00
|
|
|
pass
|
|
|
|
|
2023-07-31 19:45:35 +00:00
|
|
|
|
|
|
|
class BatchManager(BatchManagerBase):
|
|
|
|
"""Responsible for managing currently running and scheduled batch jobs"""
|
2023-08-01 20:41:40 +00:00
|
|
|
|
2023-07-31 19:45:35 +00:00
|
|
|
__invoker: Invoker
|
2023-08-10 15:38:28 +00:00
|
|
|
__batch_process_storage: BatchProcessStorageBase
|
|
|
|
|
|
|
|
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.__batch_process_storage = batch_process_storage
|
2023-07-31 19:45:35 +00:00
|
|
|
|
2023-08-10 15:38:28 +00:00
|
|
|
def start(self, invoker: Invoker) -> None:
|
2023-07-31 19:45:35 +00:00
|
|
|
self.__invoker = invoker
|
2023-08-01 20:41:40 +00:00
|
|
|
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
2023-07-31 19:45:35 +00:00
|
|
|
|
|
|
|
async def on_event(self, event: Event):
|
|
|
|
event_name = event[1]["event"]
|
|
|
|
|
|
|
|
match event_name:
|
|
|
|
case "graph_execution_state_complete":
|
2023-08-17 02:47:05 +00:00
|
|
|
await self._process(event, False)
|
2023-07-31 19:45:35 +00:00
|
|
|
case "invocation_error":
|
2023-08-17 02:47:05 +00:00
|
|
|
await self._process(event, True)
|
2023-07-31 19:45:35 +00:00
|
|
|
|
|
|
|
return event
|
2023-08-01 20:41:40 +00:00
|
|
|
|
2023-08-17 02:47:05 +00:00
|
|
|
async def _process(self, event: Event, err: bool) -> None:
|
2023-07-31 19:45:35 +00:00
|
|
|
data = event[1]["data"]
|
2023-08-17 03:58:11 +00:00
|
|
|
try:
|
2023-08-21 11:44:33 +00:00
|
|
|
batch_session = self.__batch_process_storage.get_session_by_session_id(data["graph_execution_state_id"])
|
2023-08-17 03:58:11 +00:00
|
|
|
except BatchSessionNotFoundException:
|
2023-08-17 02:07:20 +00:00
|
|
|
return None
|
2023-08-17 03:32:32 +00:00
|
|
|
changes = BatchSessionChanges(state="error" if err else "completed")
|
2023-08-10 15:38:28 +00:00
|
|
|
batch_session = self.__batch_process_storage.update_session_state(
|
|
|
|
batch_session.batch_id,
|
|
|
|
batch_session.session_id,
|
2023-08-17 03:32:32 +00:00
|
|
|
changes,
|
2023-08-10 15:38:28 +00:00
|
|
|
)
|
2023-08-11 19:52:49 +00:00
|
|
|
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
|
|
|
|
if not batch_process.canceled:
|
|
|
|
self.run_batch_process(batch_process.batch_id)
|
2023-08-01 20:41:40 +00:00
|
|
|
|
2023-08-21 15:23:39 +00:00
|
|
|
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: tuple[int]) -> GraphExecutionState:
|
2023-08-15 20:28:47 +00:00
|
|
|
graph = batch_process.graph.copy(deep=True)
|
2023-08-16 19:21:11 +00:00
|
|
|
batch = batch_process.batch
|
2023-07-31 19:45:35 +00:00
|
|
|
g = graph.nx_graph_flat()
|
|
|
|
sorted_nodes = nx.topological_sort(g)
|
|
|
|
for npath in sorted_nodes:
|
|
|
|
node = graph.get_node(npath)
|
2023-08-16 19:21:11 +00:00
|
|
|
for index, bdl in enumerate(batch.data):
|
2023-08-17 02:47:58 +00:00
|
|
|
relevant_bd = [bd for bd in bdl if bd.node_id in node.id]
|
|
|
|
if not relevant_bd:
|
2023-08-16 19:21:11 +00:00
|
|
|
continue
|
2023-08-17 02:47:58 +00:00
|
|
|
for bd in relevant_bd:
|
2023-08-16 19:21:11 +00:00
|
|
|
batch_index = batch_indices[index]
|
|
|
|
datum = bd.items[batch_index]
|
|
|
|
key = bd.field_name
|
|
|
|
node.__dict__[key] = datum
|
2023-08-10 15:38:28 +00:00
|
|
|
graph.update_node(npath, node)
|
2023-07-31 19:45:35 +00:00
|
|
|
|
2023-08-01 20:41:40 +00:00
|
|
|
return GraphExecutionState(graph=graph)
|
2023-07-31 19:45:35 +00:00
|
|
|
|
2023-08-17 02:07:20 +00:00
|
|
|
def run_batch_process(self, batch_id: str) -> None:
|
2023-08-15 20:28:47 +00:00
|
|
|
self.__batch_process_storage.start(batch_id)
|
2023-08-21 15:23:39 +00:00
|
|
|
try:
|
|
|
|
next_session = self.__batch_process_storage.get_next_session(batch_id)
|
|
|
|
except BatchSessionNotFoundException:
|
|
|
|
return
|
2023-08-21 11:44:33 +00:00
|
|
|
batch_process = self.__batch_process_storage.get(batch_id)
|
2023-08-21 15:23:39 +00:00
|
|
|
ges = self._create_batch_session(batch_process=batch_process, batch_indices=tuple(next_session.batch_index))
|
2023-08-21 11:44:33 +00:00
|
|
|
ges.id = next_session.session_id
|
|
|
|
self.__invoker.services.graph_execution_manager.set(ges)
|
|
|
|
self.__batch_process_storage.update_session_state(
|
|
|
|
batch_id=next_session.batch_id,
|
|
|
|
session_id=next_session.session_id,
|
|
|
|
changes=BatchSessionChanges(state="in_progress"),
|
|
|
|
)
|
2023-08-10 15:38:28 +00:00
|
|
|
self.__invoker.invoke(ges, invoke_all=True)
|
2023-08-14 15:01:31 +00:00
|
|
|
|
2023-08-16 19:21:11 +00:00
|
|
|
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
2023-07-31 19:45:35 +00:00
|
|
|
batch_process = BatchProcess(
|
2023-08-16 19:21:11 +00:00
|
|
|
batch=batch,
|
2023-08-01 20:41:40 +00:00
|
|
|
graph=graph,
|
2023-07-31 19:45:35 +00:00
|
|
|
)
|
2023-08-10 15:38:28 +00:00
|
|
|
batch_process = self.__batch_process_storage.save(batch_process)
|
2023-08-21 15:23:39 +00:00
|
|
|
sessions = self._create_sessions(batch_process)
|
2023-08-11 15:45:27 +00:00
|
|
|
return BatchProcessResponse(
|
|
|
|
batch_id=batch_process.batch_id,
|
2023-08-21 15:23:39 +00:00
|
|
|
session_ids=[session.session_id for session in sessions],
|
2023-08-11 15:45:27 +00:00
|
|
|
)
|
2023-08-14 15:01:31 +00:00
|
|
|
|
2023-08-21 15:23:39 +00:00
|
|
|
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
|
|
|
batch_indices = list()
|
|
|
|
sessions_to_create: list[BatchSession] = list()
|
|
|
|
for batchdata in batch_process.batch.data:
|
|
|
|
batch_indices.append(list(range(len(batchdata[0].items))))
|
|
|
|
all_batch_indices = product(*batch_indices)
|
|
|
|
for bi in all_batch_indices:
|
|
|
|
for _ in range(batch_process.batch.runs):
|
|
|
|
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
|
|
|
|
if not sessions_to_create:
|
|
|
|
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
|
|
|
|
created_sessions = self.__batch_process_storage.create_sessions(sessions_to_create)
|
|
|
|
return created_sessions
|
|
|
|
|
2023-08-18 19:38:16 +00:00
|
|
|
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
2023-08-21 11:44:33 +00:00
|
|
|
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
|
2023-08-18 19:38:16 +00:00
|
|
|
|
|
|
|
def get_batch(self, batch_id: str) -> BatchProcess:
|
|
|
|
return self.__batch_process_storage.get(batch_id)
|
|
|
|
|
|
|
|
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
|
|
|
bps = self.__batch_process_storage.get_all()
|
|
|
|
res = list()
|
|
|
|
for bp in bps:
|
2023-08-21 11:44:33 +00:00
|
|
|
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
|
2023-08-18 19:38:16 +00:00
|
|
|
res.append(
|
|
|
|
BatchProcessResponse(
|
|
|
|
batch_id=bp.batch_id,
|
|
|
|
session_ids=[session.session_id for session in sessions],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return res
|
|
|
|
|
|
|
|
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
|
|
|
bps = self.__batch_process_storage.get_incomplete()
|
|
|
|
res = list()
|
|
|
|
for bp in bps:
|
2023-08-21 11:44:33 +00:00
|
|
|
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
|
2023-08-18 19:38:16 +00:00
|
|
|
res.append(
|
|
|
|
BatchProcessResponse(
|
|
|
|
batch_id=bp.batch_id,
|
|
|
|
session_ids=[session.session_id for session in sessions],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return res
|
|
|
|
|
2023-08-17 02:07:20 +00:00
|
|
|
def cancel_batch_process(self, batch_process_id: str) -> None:
|
2023-08-11 19:52:49 +00:00
|
|
|
self.__batch_process_storage.cancel(batch_process_id)
|