mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(batches): defer ges creation until execution
This improves the overall responsiveness of the system substantially, but does make each iteration *slightly* slower, distributing the up-front cost across the batch. Two main changes: 1. Create BatchSessions immediately, but do not create a whole graph execution state until the batch is executed. BatchSessions are created with a `session_id` that does not exist in sessions database. The default state is changed to `"uninitialized"` to better represent this. Results: Time to create 5000 batches reduced from over 30s to 2.5s 2. Use `executemany()` to retrieve lists of created sessions. Results: time to create 5000 batches reduced from 2.5s to under 0.5s Other changes: - set BatchSession state to `"in_progress"` just before `invoke()` is called - rename a few methods to accomodate the new behaviour - remove unused `BatchProcessStorage.get_created_sessions()` method
This commit is contained in:
parent
50816432dc
commit
88ae19a768
@ -94,7 +94,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
async def _process(self, event: Event, err: bool) -> None:
|
async def _process(self, event: Event, err: bool) -> None:
|
||||||
data = event[1]["data"]
|
data = event[1]["data"]
|
||||||
try:
|
try:
|
||||||
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
|
batch_session = self.__batch_process_storage.get_session_by_session_id(data["graph_execution_state_id"])
|
||||||
except BatchSessionNotFoundException:
|
except BatchSessionNotFoundException:
|
||||||
return None
|
return None
|
||||||
changes = BatchSessionChanges(state="error" if err else "completed")
|
changes = BatchSessionChanges(state="error" if err else "completed")
|
||||||
@ -130,10 +130,18 @@ class BatchManager(BatchManagerBase):
|
|||||||
def run_batch_process(self, batch_id: str) -> None:
|
def run_batch_process(self, batch_id: str) -> None:
|
||||||
self.__batch_process_storage.start(batch_id)
|
self.__batch_process_storage.start(batch_id)
|
||||||
try:
|
try:
|
||||||
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
next_session = self.__batch_process_storage.get_next_session(batch_id)
|
||||||
except BatchSessionNotFoundException:
|
except BatchSessionNotFoundException:
|
||||||
return
|
return
|
||||||
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
|
batch_process = self.__batch_process_storage.get(batch_id)
|
||||||
|
ges = self._create_batch_session(batch_process=batch_process, batch_indices=tuple(next_session.batch_index))
|
||||||
|
ges.id = next_session.session_id
|
||||||
|
self.__invoker.services.graph_execution_manager.set(ges)
|
||||||
|
self.__batch_process_storage.update_session_state(
|
||||||
|
batch_id=next_session.batch_id,
|
||||||
|
session_id=next_session.session_id,
|
||||||
|
changes=BatchSessionChanges(state="in_progress"),
|
||||||
|
)
|
||||||
self.__invoker.invoke(ges, invoke_all=True)
|
self.__invoker.invoke(ges, invoke_all=True)
|
||||||
|
|
||||||
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||||
@ -150,25 +158,20 @@ class BatchManager(BatchManagerBase):
|
|||||||
|
|
||||||
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
||||||
batch_indices = list()
|
batch_indices = list()
|
||||||
sessions = list()
|
sessions_to_create: list[BatchSession] = list()
|
||||||
for batchdata in batch_process.batch.data:
|
for batchdata in batch_process.batch.data:
|
||||||
batch_indices.append(list(range(len(batchdata[0].items))))
|
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||||
all_batch_indices = product(*batch_indices)
|
all_batch_indices = product(*batch_indices)
|
||||||
for bi in all_batch_indices:
|
for bi in all_batch_indices:
|
||||||
for _ in range(batch_process.batch.runs):
|
for _ in range(batch_process.batch.runs):
|
||||||
ges = self._create_batch_session(batch_process, bi)
|
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
|
||||||
self.__invoker.services.graph_execution_manager.set(ges)
|
if not sessions_to_create:
|
||||||
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
|
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
|
||||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
created_sessions = self.__batch_process_storage.create_sessions(sessions_to_create)
|
||||||
if not sessions:
|
return created_sessions
|
||||||
ges = GraphExecutionState(graph=batch_process.graph)
|
|
||||||
self.__invoker.services.graph_execution_manager.set(ges)
|
|
||||||
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
|
|
||||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
|
||||||
return sessions
|
|
||||||
|
|
||||||
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||||
return self.__batch_process_storage.get_sessions(batch_id)
|
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
|
||||||
|
|
||||||
def get_batch(self, batch_id: str) -> BatchProcess:
|
def get_batch(self, batch_id: str) -> BatchProcess:
|
||||||
return self.__batch_process_storage.get(batch_id)
|
return self.__batch_process_storage.get(batch_id)
|
||||||
@ -177,7 +180,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
bps = self.__batch_process_storage.get_all()
|
bps = self.__batch_process_storage.get_all()
|
||||||
res = list()
|
res = list()
|
||||||
for bp in bps:
|
for bp in bps:
|
||||||
sessions = self.__batch_process_storage.get_sessions(bp.batch_id)
|
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
|
||||||
res.append(
|
res.append(
|
||||||
BatchProcessResponse(
|
BatchProcessResponse(
|
||||||
batch_id=bp.batch_id,
|
batch_id=bp.batch_id,
|
||||||
@ -190,7 +193,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
bps = self.__batch_process_storage.get_incomplete()
|
bps = self.__batch_process_storage.get_incomplete()
|
||||||
res = list()
|
res = list()
|
||||||
for bp in bps:
|
for bp in bps:
|
||||||
sessions = self.__batch_process_storage.get_sessions(bp.batch_id)
|
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
|
||||||
res.append(
|
res.append(
|
||||||
BatchProcessResponse(
|
BatchProcessResponse(
|
||||||
batch_id=bp.batch_id,
|
batch_id=bp.batch_id,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
@ -71,17 +72,29 @@ class Batch(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
class BatchSession(BaseModel):
|
|
||||||
batch_id: str = Field(description="The Batch to which this BatchSession is attached.")
|
|
||||||
session_id: str = Field(description="The Session to which this BatchSession is attached.")
|
|
||||||
state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
|
|
||||||
|
|
||||||
|
|
||||||
def uuid_string():
|
def uuid_string():
|
||||||
res = uuid.uuid4()
|
res = uuid.uuid4()
|
||||||
return str(res)
|
return str(res)
|
||||||
|
|
||||||
|
|
||||||
|
BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSession(BaseModel):
|
||||||
|
batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.")
|
||||||
|
session_id: str = Field(
|
||||||
|
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
|
||||||
|
)
|
||||||
|
batch_index: list[int] = Field(description="The index of the batch to be run in this BatchSession.")
|
||||||
|
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
|
||||||
|
|
||||||
|
@validator("batch_index", pre=True)
|
||||||
|
def parse_str_to_list(cls, v: Union[str, list[int]]):
|
||||||
|
if isinstance(v, str):
|
||||||
|
return json.loads(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class BatchProcess(BaseModel):
|
class BatchProcess(BaseModel):
|
||||||
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
||||||
batch: Batch = Field(description="The Batch to apply to this session.")
|
batch: Batch = Field(description="The Batch to apply to this session.")
|
||||||
@ -90,7 +103,7 @@ class BatchProcess(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
||||||
state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
|
state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession")
|
||||||
|
|
||||||
|
|
||||||
class BatchProcessNotFoundException(Exception):
|
class BatchProcessNotFoundException(Exception):
|
||||||
@ -198,23 +211,31 @@ class BatchProcessStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_session(self, session_id: str) -> BatchSession:
|
def create_sessions(
|
||||||
|
self,
|
||||||
|
sessions: list[BatchSession],
|
||||||
|
) -> list[BatchSession]:
|
||||||
|
"""Creates many BatchSessions attached to a BatchProcess."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||||
"""Gets a BatchSession by session_id"""
|
"""Gets a BatchSession by session_id"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_sessions(self, batch_id: str) -> List[BatchSession]:
|
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||||
"""Gets all BatchSession's for a given BatchProcess id."""
|
"""Gets all BatchSession's for a given BatchProcess id."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_created_session(self, batch_id: str) -> BatchSession:
|
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||||
"""Gets the latest BatchSession with state `created`, for a given BatchProcess id."""
|
"""Gets all BatchSession's for a given list of session ids."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||||
"""Gets all BatchSession's with state `created`, for a given BatchProcess id."""
|
"""Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -295,6 +316,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
CREATE TABLE IF NOT EXISTS batch_session (
|
CREATE TABLE IF NOT EXISTS batch_session (
|
||||||
batch_id TEXT NOT NULL,
|
batch_id TEXT NOT NULL,
|
||||||
session_id TEXT NOT NULL,
|
session_id TEXT NOT NULL,
|
||||||
|
batch_index TEXT NOT NULL,
|
||||||
state TEXT NOT NULL,
|
state TEXT NOT NULL,
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
-- updated via trigger
|
-- updated via trigger
|
||||||
@ -453,7 +475,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
(
|
(
|
||||||
SELECT batch_id
|
SELECT batch_id
|
||||||
FROM batch_session bs
|
FROM batch_session bs
|
||||||
WHERE state = 'created'
|
WHERE state IN ('uninitialized', 'in_progress')
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@ -518,10 +540,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
|
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||||
VALUES (?, ?, ?);
|
VALUES (?, ?, ?, ?);
|
||||||
""",
|
""",
|
||||||
(session.batch_id, session.session_id, session.state),
|
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
@ -529,9 +551,34 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
raise BatchSessionSaveException from e
|
raise BatchSessionSaveException from e
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
return self.get_session(session.session_id)
|
return self.get_session_by_session_id(session.session_id)
|
||||||
|
|
||||||
def get_session(self, session_id: str) -> BatchSession:
|
def create_sessions(
|
||||||
|
self,
|
||||||
|
sessions: list[BatchSession],
|
||||||
|
) -> list[BatchSession]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
session_data = [
|
||||||
|
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index))
|
||||||
|
for session in sessions
|
||||||
|
]
|
||||||
|
self._cursor.executemany(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||||
|
VALUES (?, ?, ?, ?);
|
||||||
|
""",
|
||||||
|
session_data,
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get_sessions_by_session_ids([session.session_id for session in sessions])
|
||||||
|
|
||||||
|
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -558,14 +605,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
|
|
||||||
return BatchSession.parse_obj(session_dict)
|
return BatchSession.parse_obj(session_dict)
|
||||||
|
|
||||||
def get_created_session(self, batch_id: str) -> BatchSession:
|
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM batch_session
|
FROM batch_session
|
||||||
WHERE batch_id = ? AND state = 'created';
|
WHERE batch_id = ? AND state = 'uninitialized';
|
||||||
""",
|
""",
|
||||||
(batch_id,),
|
(batch_id,),
|
||||||
)
|
)
|
||||||
@ -581,14 +628,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
session = self._deserialize_batch_session(dict(result))
|
session = self._deserialize_batch_session(dict(result))
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM batch_session
|
FROM batch_session
|
||||||
WHERE batch_id = ? AND state = created;
|
WHERE batch_id = ?;
|
||||||
""",
|
""",
|
||||||
(batch_id,),
|
(batch_id,),
|
||||||
)
|
)
|
||||||
@ -604,16 +651,17 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
def get_sessions(self, batch_id: str) -> List[BatchSession]:
|
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
placeholders = ",".join("?" * len(session_ids))
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
f"""--sql
|
||||||
SELECT *
|
SELECT * FROM batch_session
|
||||||
FROM batch_session
|
WHERE session_id
|
||||||
WHERE batch_id = ?;
|
IN ({placeholders})
|
||||||
""",
|
""",
|
||||||
(batch_id,),
|
tuple(session_ids),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
@ -652,4 +700,4 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
raise BatchSessionSaveException from e
|
raise BatchSessionSaveException from e
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
return self.get_session(session_id)
|
return self.get_session_by_session_id(session_id)
|
||||||
|
Loading…
Reference in New Issue
Block a user