InvokeAI/invokeai/app/services/batch_manager_storage.py

708 lines
23 KiB
Python

import sqlite3
import threading
import uuid
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Union, cast
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.graph import Graph
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
class BatchData(BaseModel):
"""
A batch data collection.
"""
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
items: list[BatchDataType] = Field(
default_factory=list, description="The list of items to substitute into the node/field."
)
class Batch(BaseModel):
"""
A batch, consisting of a list of a list of batch data collections.
First, each inner list[BatchData] is zipped into a single batch data collection.
Then, the final batch collection is created by taking the Cartesian product of all batch data collections.
"""
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
runs: int = Field(default=1, description="Int stating how many times to iterate through all possible batch indices")
@validator("runs")
def validate_positive_runs(cls, r: int):
if r < 1:
raise ValueError("runs must be a positive integer")
return r
@validator("data")
def validate_len(cls, v: list[list[BatchData]]):
for batch_data in v:
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
raise ValueError("Zipped batch items must have all have same length")
return v
@validator("data")
def validate_types(cls, v: list[list[BatchData]]):
for batch_data in v:
for datum in batch_data:
for item in datum.items:
if not all(isinstance(item, type(i)) for i in datum.items):
raise TypeError("All items in a batch must have have same type")
return v
@validator("data")
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
paths: set[tuple[str, str]] = set()
count: int = 0
for batch_data in v:
for datum in batch_data:
paths.add((datum.node_path, datum.field_name))
count += 1
if len(paths) != count:
raise ValueError("Each batch data must have unique node_id and field_name")
return v
def uuid_string():
res = uuid.uuid4()
return str(res)
BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"]
class BatchSession(BaseModel):
batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.")
session_id: str = Field(
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
)
batch_index: int = Field(description="The index of this batch session in its parent batch process")
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
class BatchProcess(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
batch: Batch = Field(description="The Batch to apply to this session.")
current_batch_index: int = Field(default=0, description="The last executed batch index")
current_run: int = Field(default=0, description="The current run of the batch")
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession")
class BatchProcessNotFoundException(Exception):
"""Raised when an Batch Process record is not found."""
def __init__(self, message="BatchProcess record not found"):
super().__init__(message)
class BatchProcessSaveException(Exception):
"""Raised when an Batch Process record cannot be saved."""
def __init__(self, message="BatchProcess record not saved"):
super().__init__(message)
class BatchProcessDeleteException(Exception):
"""Raised when an Batch Process record cannot be deleted."""
def __init__(self, message="BatchProcess record not deleted"):
super().__init__(message)
class BatchSessionNotFoundException(Exception):
"""Raised when an Batch Session record is not found."""
def __init__(self, message="BatchSession record not found"):
super().__init__(message)
class BatchSessionSaveException(Exception):
"""Raised when an Batch Session record cannot be saved."""
def __init__(self, message="BatchSession record not saved"):
super().__init__(message)
class BatchSessionDeleteException(Exception):
"""Raised when an Batch Session record cannot be deleted."""
def __init__(self, message="BatchSession record not deleted"):
super().__init__(message)
class BatchProcessStorageBase(ABC):
"""Low-level service responsible for interfacing with the Batch Process record store."""
@abstractmethod
def delete(self, batch_id: str) -> None:
"""Deletes a BatchProcess record."""
pass
@abstractmethod
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
"""Saves a BatchProcess record."""
pass
@abstractmethod
def get(
self,
batch_id: str,
) -> BatchProcess:
"""Gets a BatchProcess record."""
pass
@abstractmethod
def get_all(
self,
) -> list[BatchProcess]:
"""Gets a BatchProcess record."""
pass
@abstractmethod
def get_incomplete(
self,
) -> list[BatchProcess]:
"""Gets a BatchProcess record."""
pass
@abstractmethod
def start(
self,
batch_id: str,
) -> None:
"""'Starts' a BatchProcess record by marking its `canceled` attribute to False."""
pass
@abstractmethod
def cancel(
self,
batch_id: str,
) -> None:
"""'Cancels' a BatchProcess record by setting its `canceled` attribute to True."""
pass
@abstractmethod
def create_session(
self,
session: BatchSession,
) -> BatchSession:
"""Creates a BatchSession attached to a BatchProcess."""
pass
@abstractmethod
def create_sessions(
self,
sessions: list[BatchSession],
) -> list[BatchSession]:
"""Creates many BatchSessions attached to a BatchProcess."""
pass
@abstractmethod
def get_session_by_session_id(self, session_id: str) -> BatchSession:
"""Gets a BatchSession by session_id"""
pass
@abstractmethod
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
"""Gets all BatchSession's for a given BatchProcess id."""
pass
@abstractmethod
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
"""Gets all BatchSession's for a given list of session ids."""
pass
@abstractmethod
def get_next_session(self, batch_id: str) -> BatchSession:
"""Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id."""
pass
@abstractmethod
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
"""Updates the state of a BatchSession record."""
pass
class SqliteBatchProcessStorage(BatchProcessStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__()
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `batch_process` table and `batch_session` junction table."""
# Create the `batch_process` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_process (
batch_id TEXT NOT NULL PRIMARY KEY,
batch TEXT NOT NULL,
graph TEXT NOT NULL,
current_batch_index NUMBER NOT NULL,
current_run NUMBER NOT NULL,
canceled BOOLEAN NOT NULL DEFAULT(0),
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
AFTER UPDATE
ON batch_process FOR EACH ROW
BEGIN
UPDATE batch_process SET updated_at = current_timestamp
WHERE batch_id = old.batch_id;
END;
"""
)
# Create the `batch_session` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_session (
batch_id TEXT NOT NULL,
session_id TEXT NOT NULL,
state TEXT NOT NULL,
batch_index NUMBER NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between batch_process and batch_session using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (batch_id,session_id),
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
);
"""
)
# Add index for batch id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
"""
)
# Add index for batch id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
AFTER UPDATE
ON batch_session FOR EACH ROW
BEGIN
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE batch_id = old.batch_id AND session_id = old.session_id;
END;
"""
)
def delete(self, batch_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
except Exception as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
finally:
self._lock.release()
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run)
VALUES (?, ?, ?, ?, ?);
""",
(
batch_process.batch_id,
batch_process.batch.json(),
batch_process.graph.json(),
batch_process.current_batch_index,
batch_process.current_run,
),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessSaveException from e
finally:
self._lock.release()
return self.get(batch_process.batch_id)
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
"""Deserializes a batch session."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
batch_raw = session_dict.get("batch", "unknown")
graph_raw = session_dict.get("graph", "unknown")
current_batch_index = session_dict.get("current_batch_index", 0)
current_run = session_dict.get("current_run", 0)
canceled = session_dict.get("canceled", 0)
return BatchProcess(
batch_id=batch_id,
batch=parse_raw_as(Batch, batch_raw),
graph=parse_raw_as(Graph, graph_raw),
current_batch_index=current_batch_index,
current_run=current_run,
canceled=canceled == 1,
)
def get(
self,
batch_id: str,
) -> BatchProcess:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchProcessNotFoundException
return self._deserialize_batch_process(dict(result))
def get_all(
self,
) -> list[BatchProcess]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_process
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
return list()
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
def get_incomplete(
self,
) -> list[BatchProcess]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT bp.*
FROM batch_process bp
WHERE bp.batch_id IN
(
SELECT batch_id
FROM batch_session bs
WHERE state IN ('uninitialized', 'in_progress')
);
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
return list()
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
def start(
self,
batch_id: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
UPDATE batch_process
SET canceled = 0
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
def cancel(
self,
batch_id: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
UPDATE batch_process
SET canceled = 1
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
def create_session(
self,
session: BatchSession,
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
VALUES (?, ?, ?, ?);
""",
(session.batch_id, session.session_id, session.state, session.batch_index),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session_by_session_id(session.session_id)
def create_sessions(
self,
sessions: list[BatchSession],
) -> list[BatchSession]:
try:
self._lock.acquire()
session_data = [(session.batch_id, session.session_id, session.state) for session in sessions]
self._cursor.executemany(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
VALUES (?, ?, ?, ?);
""",
session_data,
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_sessions_by_session_ids([session.session_id for session in sessions])
def get_session_by_session_id(self, session_id: str) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE session_id= ?;
""",
(session_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
return self._deserialize_batch_session(dict(result))
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
"""Deserializes a batch session."""
return BatchSession.parse_obj(session_dict)
def get_next_session(self, batch_id: str) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = 'uninitialized';
""",
(batch_id,),
)
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
session = self._deserialize_batch_session(dict(result))
return session
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ?;
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
try:
self._lock.acquire()
placeholders = ",".join("?" * len(session_ids))
self._cursor.execute(
f"""--sql
SELECT * FROM batch_session
WHERE session_id
IN ({placeholders})
""",
tuple(session_ids),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
try:
self._lock.acquire()
# Change the state of a batch session
if changes.state is not None:
self._cursor.execute(
"""--sql
UPDATE batch_session
SET state = ?
WHERE batch_id = ? AND session_id = ?;
""",
(changes.state, batch_id, session_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session_by_session_id(session_id)