Testing sqlite issues with batch_manager

This commit is contained in:
Brandon Rising 2023-08-10 11:38:28 -04:00
parent 835d76af45
commit e26e4740b3
5 changed files with 590 additions and 81 deletions

View File

@ -31,6 +31,7 @@ from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService from ..services.model_manager_service import ModelManagerService
from ..services.batch_manager import BatchManager from ..services.batch_manager import BatchManager
from ..services.batch_manager_storage import SqliteBatchProcessStorage
from .events import FastAPIEventService from .events import FastAPIEventService
@ -117,7 +118,8 @@ class ApiDependencies:
) )
) )
batch_manager = BatchManager() batch_manager_storage = SqliteBatchProcessStorage(db_location)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices( services = InvocationServices(
model_manager=ModelManagerService(config, logger), model_manager=ModelManagerService(config, logger),

View File

@ -51,8 +51,9 @@ async def create_batch(
batches: list[Batch] = Body(description="Batch config to apply to the given graph"), batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
) -> BatchProcess: ) -> BatchProcess:
"""Creates and starts a new new batch process""" """Creates and starts a new new batch process"""
session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph) batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
return session ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id)
return {"batch_id":batch_id}
@session_router.delete( @session_router.delete(

View File

@ -1,45 +1,22 @@
import networkx as nx import networkx as nx
import uuid
import copy import copy
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import product
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event from fastapi_events.typing import Event
from typing import (
Optional,
Union,
)
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
)
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph, GraphExecutionState from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.batch_manager_storage import (
BatchProcessStorageBase,
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore Batch,
BatchProcess,
BatchSession,
class Batch(BaseModel): BatchSessionChanges,
data: list[InvocationsUnion] = Field(description="Mapping of ") )
node_id: str = Field(description="ID of the node to batch")
class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
sessions: list[str] = Field(
description="Tracker for which batch is currently being processed", default_factory=list
)
batches: list[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
batch_indices: list[int] = Field(
description="Tracker for which batch is currently being processed", default_factory=list
)
graph: Graph = Field(description="The graph being executed")
class BatchManagerBase(ABC): class BatchManagerBase(ABC):
@ -48,7 +25,11 @@ class BatchManagerBase(ABC):
pass pass
@abstractmethod @abstractmethod
def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess: def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
pass
@abstractmethod
def run_batch_process(self, batch_id: str):
pass pass
@abstractmethod @abstractmethod
@ -61,8 +42,13 @@ class BatchManager(BatchManagerBase):
__invoker: Invoker __invoker: Invoker
__batches: list[BatchProcess] __batches: list[BatchProcess]
__batch_process_storage: BatchProcessStorageBase
def start(self, invoker) -> None: def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
super().__init__()
self.__batch_process_storage = batch_process_storage
def start(self, invoker: Invoker) -> None:
# if we do want multithreading at some point, we could make this configurable # if we do want multithreading at some point, we could make this configurable
self.__invoker = invoker self.__invoker = invoker
self.__batches = list() self.__batches = list()
@ -73,34 +59,28 @@ class BatchManager(BatchManagerBase):
match event_name: match event_name:
case "graph_execution_state_complete": case "graph_execution_state_complete":
await self.process(event) await self.process(event, False)
case "invocation_error": case "invocation_error":
await self.process(event) await self.process(event, True)
return event return event
async def process(self, event: Event): async def process(self, event: Event, err: bool):
data = event[1]["data"] data = event[1]["data"]
batchTarget = None batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
for batch in self.__batches: if not batch_session:
if data["graph_execution_state_id"] in batch.sessions:
batchTarget = batch
break
if batchTarget == None:
return return
updateSession = BatchSessionChanges(
state='error' if err else 'completed'
)
batch_session = self.__batch_process_storage.update_session_state(
batch_session.batch_id,
batch_session.session_id,
updateSession,
)
self.run_batch_process(batch_session.batch_id)
if sum(batchTarget.batch_indices) == 0: def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
self.__batches = [batch for batch in self.__batches if batch != batchTarget]
return
batchTarget.batch_indices = self._next_batch_index(batchTarget)
ges = self._next_batch_session(batchTarget)
batchTarget.sessions.append(ges.id)
self.__invoker.services.graph_execution_manager.set(ges)
self.__invoker.invoke(ges, invoke_all=True)
def _next_batch_session(self, batch_process: BatchProcess) -> GraphExecutionState:
graph = copy.deepcopy(batch_process.graph) graph = copy.deepcopy(batch_process.graph)
batches = batch_process.batches batches = batch_process.batches
g = graph.nx_graph_flat() g = graph.nx_graph_flat()
@ -109,36 +89,47 @@ class BatchManager(BatchManagerBase):
node = graph.get_node(npath) node = graph.get_node(npath)
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None)) (index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
if batch: if batch:
batch_index = batch_process.batch_indices[index] batch_index = batch_indices[index]
datum = batch.data[batch_index] datum = batch.data[batch_index]
datum.id = node.id for key in datum:
graph.update_node(npath, datum) node.__dict__[key] = datum[key]
graph.update_node(npath, node)
return GraphExecutionState(graph=graph) return GraphExecutionState(graph=graph)
def _next_batch_index(self, batch_process: BatchProcess): def run_batch_process(self, batch_id: str):
batch_indicies = batch_process.batch_indices.copy() created_session = self.__batch_process_storage.get_created_session(batch_id)
for index in range(len(batch_indicies)): ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
if batch_indicies[index] > 0: self.__invoker.invoke(ges, invoke_all=True)
batch_indicies[index] -= 1
break
return batch_indicies
def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess: def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
batch_indices = list() return True
for batch in batches:
batch_indices.append(len(batch.data) - 1) def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
batch_process = BatchProcess( batch_process = BatchProcess(
batches=batches, batches=batches,
batch_indices=batch_indices,
graph=graph, graph=graph,
) )
ges = self._next_batch_session(batch_process) if not self._valid_batch_config(batch_process):
batch_process.sessions.append(ges.id) return None
self.__batches.append(batch_process) batch_process = self.__batch_process_storage.save(batch_process)
self.__invoker.services.graph_execution_manager.set(ges) self._create_sessions(batch_process)
self.__invoker.invoke(ges, invoke_all=True) return batch_process.batch_id
return batch_process
def _create_sessions(self, batch_process: BatchProcess):
batch_indices = list()
for batch in batch_process.batches:
batch_indices.append(list(range(len(batch.data))))
all_batch_indices = product(*batch_indices)
for bi in all_batch_indices:
ges = self._create_batch_session(batch_process, bi)
self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession(
batch_id=batch_process.batch_id,
session_id=ges.id,
state="created"
)
self.__batch_process_storage.create_session(batch_session)
def cancel_batch_process(self, batch_process_id: str): def cancel_batch_process(self, batch_process_id: str):
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id] self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]

View File

@ -0,0 +1,500 @@
from abc import ABC, abstractmethod
from typing import cast
import uuid
import sqlite3
import threading
from typing import (
Any,
List,
Literal,
Optional,
Union,
)
import json
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
)
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.models.image import ImageField
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from pydantic import BaseModel, Field, Extra, parse_raw_as
invocations = BaseInvocation.get_invocations()
InvocationsUnion = Union[invocations] # type: ignore
BatchDataType = Union[str, int, float, ImageField]
class Batch(BaseModel):
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
node_id: str = Field(description="ID of the node to batch")
class BatchSession(BaseModel):
batch_id: str = Field(description="Identifier for which batch this Index belongs to")
session_id: str = Field(description="Session ID Created for this Batch Index")
state: Literal["created", "completed", "inprogress", "error"] = Field(
description="Is this session created, completed, in progress, or errored?"
)
class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
batches: List[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
graph: Graph = Field(description="The graph being executed")
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
state: Literal["created", "completed", "inprogress", "error"] = Field(
description="Is this session created, completed, in progress, or errored?"
)
class BatchProcessNotFoundException(Exception):
"""Raised when an Batch Process record is not found."""
def __init__(self, message="BatchProcess record not found"):
super().__init__(message)
class BatchProcessSaveException(Exception):
"""Raised when an Batch Process record cannot be saved."""
def __init__(self, message="BatchProcess record not saved"):
super().__init__(message)
class BatchProcessDeleteException(Exception):
"""Raised when an Batch Process record cannot be deleted."""
def __init__(self, message="BatchProcess record not deleted"):
super().__init__(message)
class BatchSessionNotFoundException(Exception):
"""Raised when an Batch Session record is not found."""
def __init__(self, message="BatchSession record not found"):
super().__init__(message)
class BatchSessionSaveException(Exception):
"""Raised when an Batch Session record cannot be saved."""
def __init__(self, message="BatchSession record not saved"):
super().__init__(message)
class BatchSessionDeleteException(Exception):
"""Raised when an Batch Session record cannot be deleted."""
def __init__(self, message="BatchSession record not deleted"):
super().__init__(message)
class BatchProcessStorageBase(ABC):
"""Low-level service responsible for interfacing with the Batch Process record store."""
@abstractmethod
def delete(self, batch_id: str) -> None:
"""Deletes a Batch Process record."""
pass
@abstractmethod
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
"""Saves a Batch Process record."""
pass
@abstractmethod
def get(
self,
batch_id: str,
) -> BatchProcess:
"""Gets a Batch Process record."""
pass
@abstractmethod
def create_session(
self,
session: BatchSession,
) -> BatchSession:
"""Creates a Batch Session attached to a Batch Process."""
pass
@abstractmethod
def get_session(
self,
session_id: str
) -> BatchSession:
"""Gets session by session_id"""
pass
@abstractmethod
def get_created_session(
self,
batch_id: str
) -> BatchSession:
"""Gets all created Batch Sessions for a given Batch Process id."""
pass
@abstractmethod
def get_created_sessions(
self,
batch_id: str
) -> List[BatchSession]:
"""Gets all created Batch Sessions for a given Batch Process id."""
pass
@abstractmethod
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
"""Updates the state of a Batch Session record."""
pass
class SqliteBatchProcessStorage(BatchProcessStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `batch_process` table and `batch_session` junction table."""
# Create the `batch_process` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_process (
batch_id TEXT NOT NULL PRIMARY KEY,
batches TEXT NOT NULL,
graph TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
AFTER UPDATE
ON batch_process FOR EACH ROW
BEGIN
UPDATE batch_process SET updated_at = current_timestamp
WHERE batch_id = old.batch_id;
END;
"""
)
# Create the `batch_session` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_session (
batch_id TEXT NOT NULL,
session_id TEXT NOT NULL,
state TEXT NOT NULL DEFAULT('created'),
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between batch_process and batch_session using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (batch_id,session_id),
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
);
"""
)
# Add index for batch id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
"""
)
# Add index for batch id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
AFTER UPDATE
ON batch_session FOR EACH ROW
BEGIN
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE batch_id = old.batch_id AND image_name = old.image_name;
END;
"""
)
def delete(self, batch_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
except Exception as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
finally:
self._lock.release()
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
VALUES (?, ?, ?);
""",
(batch_process.batch_id, json.dumps([batch.json() for batch in batch_process.batches]), batch_process.graph.json()),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessSaveException from e
finally:
self._lock.release()
return self.get(batch_process.batch_id)
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
"""Deserializes a batch session."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
batches_raw = session_dict.get("batches", "unknown")
graph_raw = session_dict.get("graph", "unknown")
return BatchProcess(
batch_id=batch_id,
batches=[parse_raw_as(Batch, batch) for batch in json.loads(batches_raw)],
graph=parse_raw_as(Graph, graph_raw),
)
def get(
self,
batch_id: str,
) -> BatchProcess:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,)
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchProcessNotFoundException
return self._deserialize_batch_process(dict(result))
def create_session(
self,
session: BatchSession,
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
VALUES (?, ?, ?);
""",
(session.batch_id, session.session_id, session.state),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session.session_id)
def get_session(
self,
session_id: str
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE session_id= ?;
""",
(session_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
return BatchSession(**dict(result))
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
"""Deserializes a batch session."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
session_id = session_dict.get("session_id", "unknown")
state = session_dict.get("state", "unknown")
return BatchSession(
batch_id=batch_id,
session_id=session_id,
state=state,
)
def get_created_session(
self,
batch_id: str
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = 'created';
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
session = self._deserialize_batch_session(dict(result))
return session
def get_created_sessions(
self,
batch_id: str
) -> List[BatchSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = created;
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
try:
self._lock.acquire()
# Change the state of a batch session
if changes.state is not None:
self._cursor.execute(
f"""--sql
UPDATE batch_session
SET state = ?
WHERE batch_id = ? AND session_id = ?;
""",
(changes.state, batch_id, session_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session_id)

View File

@ -9,6 +9,7 @@ from .item_storage import ItemStorageABC, PaginatedResults
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
sqlite_memory = ":memory:" sqlite_memory = ":memory:"
import traceback
class SqliteItemStorage(ItemStorageABC, Generic[T]): class SqliteItemStorage(ItemStorageABC, Generic[T]):
@ -21,7 +22,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def __init__(self, filename: str, table_name: str, id_field: str = "id"): def __init__(self, filename: str, table_name: str, id_field: str = "id"):
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._table_name = table_name self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field self._id_field = id_field # TODO: validate that T has this field
@ -29,6 +29,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._conn = sqlite3.connect( self._conn = sqlite3.connect(
self._filename, check_same_thread=False self._filename, check_same_thread=False
) # TODO: figure out a better threading solution ) # TODO: figure out a better threading solution
self._conn.set_trace_callback(print)
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._create_table() self._create_table()
@ -54,11 +55,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def set(self, item: T): def set(self, item: T):
try: try:
self._lock.acquire() self._lock.acquire()
json = item.json()
print('-----------------locking db-----------------')
traceback.print_stack(limit=2)
self._cursor.execute( self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),), (json,),
) )
self._conn.commit() self._conn.commit()
self._cursor.close()
self._cursor = self._conn.cursor()
print('-----------------unlocking db-----------------')
except Exception as e:
print("Exception!")
print(e)
finally: finally:
self._lock.release() self._lock.release()
self._on_changed(item) self._on_changed(item)
@ -66,8 +77,12 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def get(self, id: str) -> Optional[T]: def get(self, id: str) -> Optional[T]:
try: try:
self._lock.acquire() self._lock.acquire()
print('-----------------locking db-----------------')
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)) self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
result = self._cursor.fetchone() result = self._cursor.fetchone()
self._cursor.close()
self._cursor = self._conn.cursor()
print('-----------------unlocking db-----------------')
finally: finally:
self._lock.release() self._lock.release()