diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 8d76b4ab15..83fc8d9e11 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -31,6 +31,7 @@ from ..services.processor import DefaultInvocationProcessor from ..services.sqlite import SqliteItemStorage from ..services.model_manager_service import ModelManagerService from ..services.batch_manager import BatchManager +from ..services.batch_manager_storage import SqliteBatchProcessStorage from .events import FastAPIEventService @@ -117,7 +118,8 @@ class ApiDependencies: ) ) - batch_manager = BatchManager() + batch_manager_storage = SqliteBatchProcessStorage(db_location) + batch_manager = BatchManager(batch_manager_storage) services = InvocationServices( model_manager=ModelManagerService(config, logger), diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index 5f6228a1e3..f2acc388c3 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -51,8 +51,9 @@ async def create_batch( batches: list[Batch] = Body(description="Batch config to apply to the given graph"), ) -> BatchProcess: """Creates and starts a new new batch process""" - session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph) - return session + batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph) + ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id) + return {"batch_id":batch_id} @session_router.delete( diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index 8acd5f9a7b..eba5a8d676 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -1,45 +1,22 @@ import networkx as nx -import uuid import copy from abc import ABC, abstractmethod +from itertools import product from pydantic import BaseModel, Field from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event -from typing import ( - Optional, - Union, -) -from invokeai.app.invocations.baseinvocation import ( - BaseInvocation, -) from invokeai.app.services.events import EventServiceBase from invokeai.app.services.graph import Graph, GraphExecutionState from invokeai.app.services.invoker import Invoker - - -InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore - - -class Batch(BaseModel): - data: list[InvocationsUnion] = Field(description="Mapping of ") - node_id: str = Field(description="ID of the node to batch") - - -class BatchProcess(BaseModel): - batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch") - sessions: list[str] = Field( - description="Tracker for which batch is currently being processed", default_factory=list - ) - batches: list[Batch] = Field( - description="List of batch configs to apply to this session", - default_factory=list, - ) - batch_indices: list[int] = Field( - description="Tracker for which batch is currently being processed", default_factory=list - ) - graph: Graph = Field(description="The graph being executed") +from invokeai.app.services.batch_manager_storage import ( + BatchProcessStorageBase, + Batch, + BatchProcess, + BatchSession, + BatchSessionChanges, +) class BatchManagerBase(ABC): @@ -48,7 +25,11 @@ class BatchManagerBase(ABC): pass @abstractmethod - def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess: + def create_batch_process(self, batches: list[Batch], graph: Graph) -> str: + pass + + @abstractmethod + def run_batch_process(self, batch_id: str): pass @abstractmethod @@ -61,8 +42,13 @@ class BatchManager(BatchManagerBase): __invoker: Invoker __batches: list[BatchProcess] + __batch_process_storage: BatchProcessStorageBase - def start(self, invoker) -> None: + def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None: + super().__init__() + self.__batch_process_storage = batch_process_storage + + def start(self, invoker: Invoker) -> None: # if we do want multithreading at some point, we could make this configurable self.__invoker = invoker self.__batches = list() @@ -73,34 +59,28 @@ class BatchManager(BatchManagerBase): match event_name: case "graph_execution_state_complete": - await self.process(event) + await self.process(event, False) case "invocation_error": - await self.process(event) + await self.process(event, True) return event - async def process(self, event: Event): + async def process(self, event: Event, err: bool): data = event[1]["data"] - batchTarget = None - for batch in self.__batches: - if data["graph_execution_state_id"] in batch.sessions: - batchTarget = batch - break - - if batchTarget == None: + batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"]) + if not batch_session: return + updateSession = BatchSessionChanges( + state='error' if err else 'completed' + ) + batch_session = self.__batch_process_storage.update_session_state( + batch_session.batch_id, + batch_session.session_id, + updateSession, + ) + self.run_batch_process(batch_session.batch_id) - if sum(batchTarget.batch_indices) == 0: - self.__batches = [batch for batch in self.__batches if batch != batchTarget] - return - - batchTarget.batch_indices = self._next_batch_index(batchTarget) - ges = self._next_batch_session(batchTarget) - batchTarget.sessions.append(ges.id) - self.__invoker.services.graph_execution_manager.set(ges) - self.__invoker.invoke(ges, invoke_all=True) - - def _next_batch_session(self, batch_process: BatchProcess) -> GraphExecutionState: + def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState: graph = copy.deepcopy(batch_process.graph) batches = batch_process.batches g = graph.nx_graph_flat() @@ -109,36 +89,47 @@ class BatchManager(BatchManagerBase): node = graph.get_node(npath) (index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None)) if batch: - batch_index = batch_process.batch_indices[index] + batch_index = batch_indices[index] datum = batch.data[batch_index] - datum.id = node.id - graph.update_node(npath, datum) + for key in datum: + node.__dict__[key] = datum[key] + graph.update_node(npath, node) return GraphExecutionState(graph=graph) - def _next_batch_index(self, batch_process: BatchProcess): - batch_indicies = batch_process.batch_indices.copy() - for index in range(len(batch_indicies)): - if batch_indicies[index] > 0: - batch_indicies[index] -= 1 - break - return batch_indicies - - def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess: - batch_indices = list() - for batch in batches: - batch_indices.append(len(batch.data) - 1) + def run_batch_process(self, batch_id: str): + created_session = self.__batch_process_storage.get_created_session(batch_id) + ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id) + self.__invoker.invoke(ges, invoke_all=True) + + def _valid_batch_config(self, batch_process: BatchProcess) -> bool: + return True + + def create_batch_process(self, batches: list[Batch], graph: Graph) -> str: batch_process = BatchProcess( batches=batches, - batch_indices=batch_indices, graph=graph, ) - ges = self._next_batch_session(batch_process) - batch_process.sessions.append(ges.id) - self.__batches.append(batch_process) - self.__invoker.services.graph_execution_manager.set(ges) - self.__invoker.invoke(ges, invoke_all=True) - return batch_process + if not self._valid_batch_config(batch_process): + return None + batch_process = self.__batch_process_storage.save(batch_process) + self._create_sessions(batch_process) + return batch_process.batch_id + + def _create_sessions(self, batch_process: BatchProcess): + batch_indices = list() + for batch in batch_process.batches: + batch_indices.append(list(range(len(batch.data)))) + all_batch_indices = product(*batch_indices) + for bi in all_batch_indices: + ges = self._create_batch_session(batch_process, bi) + self.__invoker.services.graph_execution_manager.set(ges) + batch_session = BatchSession( + batch_id=batch_process.batch_id, + session_id=ges.id, + state="created" + ) + self.__batch_process_storage.create_session(batch_session) def cancel_batch_process(self, batch_process_id: str): self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id] diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py new file mode 100644 index 0000000000..44fb5d928f --- /dev/null +++ b/invokeai/app/services/batch_manager_storage.py @@ -0,0 +1,500 @@ +from abc import ABC, abstractmethod +from typing import cast +import uuid +import sqlite3 +import threading +from typing import ( + Any, + List, + Literal, + Optional, + Union, +) +import json + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, +) +from invokeai.app.services.graph import Graph, GraphExecutionState +from invokeai.app.models.image import ImageField +from invokeai.app.services.image_record_storage import OffsetPaginatedResults + +from pydantic import BaseModel, Field, Extra, parse_raw_as + +invocations = BaseInvocation.get_invocations() +InvocationsUnion = Union[invocations] # type: ignore + +BatchDataType = Union[str, int, float, ImageField] + +class Batch(BaseModel): + data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value") + node_id: str = Field(description="ID of the node to batch") + + +class BatchSession(BaseModel): + batch_id: str = Field(description="Identifier for which batch this Index belongs to") + session_id: str = Field(description="Session ID Created for this Batch Index") + state: Literal["created", "completed", "inprogress", "error"] = Field( + description="Is this session created, completed, in progress, or errored?" + ) + + +class BatchProcess(BaseModel): + batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch") + batches: List[Batch] = Field( + description="List of batch configs to apply to this session", + default_factory=list, + ) + graph: Graph = Field(description="The graph being executed") + + +class BatchSessionChanges(BaseModel, extra=Extra.forbid): + state: Literal["created", "completed", "inprogress", "error"] = Field( + description="Is this session created, completed, in progress, or errored?" + ) + + +class BatchProcessNotFoundException(Exception): + """Raised when an Batch Process record is not found.""" + + def __init__(self, message="BatchProcess record not found"): + super().__init__(message) + + +class BatchProcessSaveException(Exception): + """Raised when an Batch Process record cannot be saved.""" + + def __init__(self, message="BatchProcess record not saved"): + super().__init__(message) + + +class BatchProcessDeleteException(Exception): + """Raised when an Batch Process record cannot be deleted.""" + + def __init__(self, message="BatchProcess record not deleted"): + super().__init__(message) + + +class BatchSessionNotFoundException(Exception): + """Raised when an Batch Session record is not found.""" + + def __init__(self, message="BatchSession record not found"): + super().__init__(message) + + +class BatchSessionSaveException(Exception): + """Raised when an Batch Session record cannot be saved.""" + + def __init__(self, message="BatchSession record not saved"): + super().__init__(message) + + +class BatchSessionDeleteException(Exception): + """Raised when an Batch Session record cannot be deleted.""" + + def __init__(self, message="BatchSession record not deleted"): + super().__init__(message) + + +class BatchProcessStorageBase(ABC): + """Low-level service responsible for interfacing with the Batch Process record store.""" + + @abstractmethod + def delete(self, batch_id: str) -> None: + """Deletes a Batch Process record.""" + pass + + @abstractmethod + def save( + self, + batch_process: BatchProcess, + ) -> BatchProcess: + """Saves a Batch Process record.""" + pass + + @abstractmethod + def get( + self, + batch_id: str, + ) -> BatchProcess: + """Gets a Batch Process record.""" + pass + + @abstractmethod + def create_session( + self, + session: BatchSession, + ) -> BatchSession: + """Creates a Batch Session attached to a Batch Process.""" + pass + + @abstractmethod + def get_session( + self, + session_id: str + ) -> BatchSession: + """Gets session by session_id""" + pass + + @abstractmethod + def get_created_session( + self, + batch_id: str + ) -> BatchSession: + """Gets all created Batch Sessions for a given Batch Process id.""" + pass + + @abstractmethod + def get_created_sessions( + self, + batch_id: str + ) -> List[BatchSession]: + """Gets all created Batch Sessions for a given Batch Process id.""" + pass + + @abstractmethod + def update_session_state( + self, + batch_id: str, + session_id: str, + changes: BatchSessionChanges, + ) -> BatchSession: + """Updates the state of a Batch Session record.""" + pass + + +class SqliteBatchProcessStorage(BatchProcessStorageBase): + _filename: str + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _lock: threading.Lock + + def __init__(self, filename: str) -> None: + super().__init__() + self._filename = filename + self._conn = sqlite3.connect(filename, check_same_thread=False) + # Enable row factory to get rows as dictionaries (must be done before making the cursor!) + self._conn.row_factory = sqlite3.Row + self._cursor = self._conn.cursor() + self._lock = threading.Lock() + + try: + self._lock.acquire() + # Enable foreign keys + self._conn.execute("PRAGMA foreign_keys = ON;") + self._create_tables() + self._conn.commit() + finally: + self._lock.release() + + def _create_tables(self) -> None: + """Creates the `batch_process` table and `batch_session` junction table.""" + + # Create the `batch_process` table. + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS batch_process ( + batch_id TEXT NOT NULL PRIMARY KEY, + batches TEXT NOT NULL, + graph TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Soft delete, currently unused + deleted_at DATETIME + ); + """ + ) + + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at); + """ + ) + + # Add trigger for `updated_at`. + self._cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at + AFTER UPDATE + ON batch_process FOR EACH ROW + BEGIN + UPDATE batch_process SET updated_at = current_timestamp + WHERE batch_id = old.batch_id; + END; + """ + ) + + # Create the `batch_session` junction table. + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS batch_session ( + batch_id TEXT NOT NULL, + session_id TEXT NOT NULL, + state TEXT NOT NULL DEFAULT('created'), + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Soft delete, currently unused + deleted_at DATETIME, + -- enforce one-to-many relationship between batch_process and batch_session using PK + -- (we can extend this to many-to-many later) + PRIMARY KEY (batch_id,session_id), + FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE + ); + """ + ) + + # Add index for batch id + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id); + """ + ) + + # Add index for batch id, sorted by created_at + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at); + """ + ) + + # Add trigger for `updated_at`. + self._cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at + AFTER UPDATE + ON batch_session FOR EACH ROW + BEGIN + UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE batch_id = old.batch_id AND image_name = old.image_name; + END; + """ + ) + + def delete(self, batch_id: str) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + DELETE FROM batch_process + WHERE batch_id = ?; + """, + (batch_id,), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BatchProcessDeleteException from e + except Exception as e: + self._conn.rollback() + raise BatchProcessDeleteException from e + finally: + self._lock.release() + + def save( + self, + batch_process: BatchProcess, + ) -> BatchProcess: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO batch_process (batch_id, batches, graph) + VALUES (?, ?, ?); + """, + (batch_process.batch_id, json.dumps([batch.json() for batch in batch_process.batches]), batch_process.graph.json()), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BatchProcessSaveException from e + finally: + self._lock.release() + return self.get(batch_process.batch_id) + + def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess: + """Deserializes a batch session.""" + + # Retrieve all the values, setting "reasonable" defaults if they are not present. + + batch_id = session_dict.get("batch_id", "unknown") + batches_raw = session_dict.get("batches", "unknown") + graph_raw = session_dict.get("graph", "unknown") + + return BatchProcess( + batch_id=batch_id, + batches=[parse_raw_as(Batch, batch) for batch in json.loads(batches_raw)], + graph=parse_raw_as(Graph, graph_raw), + ) + + def get( + self, + batch_id: str, + ) -> BatchProcess: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT * + FROM batch_process + WHERE batch_id = ?; + """, + (batch_id,) + ) + + result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) + except sqlite3.Error as e: + self._conn.rollback() + raise BatchProcessNotFoundException from e + finally: + self._lock.release() + if result is None: + raise BatchProcessNotFoundException + return self._deserialize_batch_process(dict(result)) + + def create_session( + self, + session: BatchSession, + ) -> BatchSession: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO batch_session (batch_id, session_id, state) + VALUES (?, ?, ?); + """, + (session.batch_id, session.session_id, session.state), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BatchSessionSaveException from e + finally: + self._lock.release() + return self.get_session(session.session_id) + + + def get_session( + self, + session_id: str + ) -> BatchSession: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT * + FROM batch_session + WHERE session_id= ?; + """, + (session_id,), + ) + + result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) + except sqlite3.Error as e: + self._conn.rollback() + raise BatchSessionNotFoundException from e + finally: + self._lock.release() + if result is None: + raise BatchSessionNotFoundException + return BatchSession(**dict(result)) + + def _deserialize_batch_session(self, session_dict: dict) -> BatchSession: + """Deserializes a batch session.""" + + # Retrieve all the values, setting "reasonable" defaults if they are not present. + + batch_id = session_dict.get("batch_id", "unknown") + session_id = session_dict.get("session_id", "unknown") + state = session_dict.get("state", "unknown") + + return BatchSession( + batch_id=batch_id, + session_id=session_id, + state=state, + ) + + + def get_created_session( + self, + batch_id: str + ) -> BatchSession: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT * + FROM batch_session + WHERE batch_id = ? AND state = 'created'; + """, + (batch_id,), + ) + + result = cast(list[sqlite3.Row], self._cursor.fetchone()) + except sqlite3.Error as e: + self._conn.rollback() + raise BatchSessionNotFoundException from e + finally: + self._lock.release() + if result is None: + raise BatchSessionNotFoundException + session = self._deserialize_batch_session(dict(result)) + return session + + + def get_created_sessions( + self, + batch_id: str + ) -> List[BatchSession]: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT * + FROM batch_session + WHERE batch_id = ? AND state = created; + """, + (batch_id,), + ) + + + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + except sqlite3.Error as e: + self._conn.rollback() + raise BatchSessionNotFoundException from e + finally: + self._lock.release() + if result is None: + raise BatchSessionNotFoundException + sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result)) + return sessions + + + def update_session_state( + self, + batch_id: str, + session_id: str, + changes: BatchSessionChanges, + ) -> BatchSession: + try: + self._lock.acquire() + + # Change the state of a batch session + if changes.state is not None: + self._cursor.execute( + f"""--sql + UPDATE batch_session + SET state = ? + WHERE batch_id = ? AND session_id = ?; + """, + (changes.state, batch_id, session_id), + ) + + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BatchSessionSaveException from e + finally: + self._lock.release() + return self.get_session(session_id) \ No newline at end of file diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index 855f3f1939..4f63ffb368 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -9,6 +9,7 @@ from .item_storage import ItemStorageABC, PaginatedResults T = TypeVar("T", bound=BaseModel) sqlite_memory = ":memory:" +import traceback class SqliteItemStorage(ItemStorageABC, Generic[T]): @@ -21,7 +22,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def __init__(self, filename: str, table_name: str, id_field: str = "id"): super().__init__() - self._filename = filename self._table_name = table_name self._id_field = id_field # TODO: validate that T has this field @@ -29,6 +29,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._conn = sqlite3.connect( self._filename, check_same_thread=False ) # TODO: figure out a better threading solution + self._conn.set_trace_callback(print) self._cursor = self._conn.cursor() self._create_table() @@ -54,11 +55,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def set(self, item: T): try: self._lock.acquire() + json = item.json() + print('-----------------locking db-----------------') + traceback.print_stack(limit=2) + self._cursor.execute( f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", - (item.json(),), + (json,), ) self._conn.commit() + self._cursor.close() + self._cursor = self._conn.cursor() + print('-----------------unlocking db-----------------') + except Exception as e: + print("Exception!") + print(e) finally: self._lock.release() self._on_changed(item) @@ -66,8 +77,12 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def get(self, id: str) -> Optional[T]: try: self._lock.acquire() + print('-----------------locking db-----------------') self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)) result = self._cursor.fetchone() + self._cursor.close() + self._cursor = self._conn.cursor() + print('-----------------unlocking db-----------------') finally: self._lock.release()