mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(batches): defer ges creation until execution
This improves the overall responsiveness of the system substantially, but does make each iteration *slightly* slower, distributing the up-front cost across the batch. Two main changes: 1. Create BatchSessions immediately, but do not create a whole graph execution state until the batch is executed. BatchSessions are created with a `session_id` that does not exist in sessions database. The default state is changed to `"uninitialized"` to better represent this. Results: Time to create 5000 batches reduced from over 30s to 2.5s 2. Use `executemany()` to retrieve lists of created sessions. Results: time to create 5000 batches reduced from 2.5s to under 0.5s Other changes: - set BatchSession state to `"in_progress"` just before `invoke()` is called - rename a few methods to accomodate the new behaviour - remove unused `BatchProcessStorage.get_created_sessions()` method
This commit is contained in:
parent
50816432dc
commit
88ae19a768
@ -94,7 +94,7 @@ class BatchManager(BatchManagerBase):
|
||||
async def _process(self, event: Event, err: bool) -> None:
|
||||
data = event[1]["data"]
|
||||
try:
|
||||
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
|
||||
batch_session = self.__batch_process_storage.get_session_by_session_id(data["graph_execution_state_id"])
|
||||
except BatchSessionNotFoundException:
|
||||
return None
|
||||
changes = BatchSessionChanges(state="error" if err else "completed")
|
||||
@ -130,10 +130,18 @@ class BatchManager(BatchManagerBase):
|
||||
def run_batch_process(self, batch_id: str) -> None:
|
||||
self.__batch_process_storage.start(batch_id)
|
||||
try:
|
||||
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
||||
next_session = self.__batch_process_storage.get_next_session(batch_id)
|
||||
except BatchSessionNotFoundException:
|
||||
return
|
||||
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
|
||||
batch_process = self.__batch_process_storage.get(batch_id)
|
||||
ges = self._create_batch_session(batch_process=batch_process, batch_indices=tuple(next_session.batch_index))
|
||||
ges.id = next_session.session_id
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
self.__batch_process_storage.update_session_state(
|
||||
batch_id=next_session.batch_id,
|
||||
session_id=next_session.session_id,
|
||||
changes=BatchSessionChanges(state="in_progress"),
|
||||
)
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
|
||||
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||
@ -150,25 +158,20 @@ class BatchManager(BatchManagerBase):
|
||||
|
||||
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
||||
batch_indices = list()
|
||||
sessions = list()
|
||||
sessions_to_create: list[BatchSession] = list()
|
||||
for batchdata in batch_process.batch.data:
|
||||
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||
all_batch_indices = product(*batch_indices)
|
||||
for bi in all_batch_indices:
|
||||
for _ in range(batch_process.batch.runs):
|
||||
ges = self._create_batch_session(batch_process, bi)
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
|
||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||
if not sessions:
|
||||
ges = GraphExecutionState(graph=batch_process.graph)
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
|
||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||
return sessions
|
||||
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
|
||||
if not sessions_to_create:
|
||||
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
|
||||
created_sessions = self.__batch_process_storage.create_sessions(sessions_to_create)
|
||||
return created_sessions
|
||||
|
||||
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||
return self.__batch_process_storage.get_sessions(batch_id)
|
||||
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
|
||||
|
||||
def get_batch(self, batch_id: str) -> BatchProcess:
|
||||
return self.__batch_process_storage.get(batch_id)
|
||||
@ -177,7 +180,7 @@ class BatchManager(BatchManagerBase):
|
||||
bps = self.__batch_process_storage.get_all()
|
||||
res = list()
|
||||
for bp in bps:
|
||||
sessions = self.__batch_process_storage.get_sessions(bp.batch_id)
|
||||
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
|
||||
res.append(
|
||||
BatchProcessResponse(
|
||||
batch_id=bp.batch_id,
|
||||
@ -190,7 +193,7 @@ class BatchManager(BatchManagerBase):
|
||||
bps = self.__batch_process_storage.get_incomplete()
|
||||
res = list()
|
||||
for bp in bps:
|
||||
sessions = self.__batch_process_storage.get_sessions(bp.batch_id)
|
||||
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
|
||||
res.append(
|
||||
BatchProcessResponse(
|
||||
batch_id=bp.batch_id,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
@ -71,17 +72,29 @@ class Batch(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
class BatchSession(BaseModel):
|
||||
batch_id: str = Field(description="The Batch to which this BatchSession is attached.")
|
||||
session_id: str = Field(description="The Session to which this BatchSession is attached.")
|
||||
state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
|
||||
|
||||
|
||||
def uuid_string():
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
|
||||
BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"]
|
||||
|
||||
|
||||
class BatchSession(BaseModel):
|
||||
batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.")
|
||||
session_id: str = Field(
|
||||
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
|
||||
)
|
||||
batch_index: list[int] = Field(description="The index of the batch to be run in this BatchSession.")
|
||||
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
|
||||
|
||||
@validator("batch_index", pre=True)
|
||||
def parse_str_to_list(cls, v: Union[str, list[int]]):
|
||||
if isinstance(v, str):
|
||||
return json.loads(v)
|
||||
return v
|
||||
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
||||
batch: Batch = Field(description="The Batch to apply to this session.")
|
||||
@ -90,7 +103,7 @@ class BatchProcess(BaseModel):
|
||||
|
||||
|
||||
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
||||
state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
|
||||
state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession")
|
||||
|
||||
|
||||
class BatchProcessNotFoundException(Exception):
|
||||
@ -198,23 +211,31 @@ class BatchProcessStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_session(self, session_id: str) -> BatchSession:
|
||||
def create_sessions(
|
||||
self,
|
||||
sessions: list[BatchSession],
|
||||
) -> list[BatchSession]:
|
||||
"""Creates many BatchSessions attached to a BatchProcess."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||
"""Gets a BatchSession by session_id"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||
"""Gets all BatchSession's for a given BatchProcess id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_created_session(self, batch_id: str) -> BatchSession:
|
||||
"""Gets the latest BatchSession with state `created`, for a given BatchProcess id."""
|
||||
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||
"""Gets all BatchSession's for a given list of session ids."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||
"""Gets all BatchSession's with state `created`, for a given BatchProcess id."""
|
||||
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||
"""Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -295,6 +316,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
CREATE TABLE IF NOT EXISTS batch_session (
|
||||
batch_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
batch_index TEXT NOT NULL,
|
||||
state TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
@ -453,7 +475,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
(
|
||||
SELECT batch_id
|
||||
FROM batch_session bs
|
||||
WHERE state = 'created'
|
||||
WHERE state IN ('uninitialized', 'in_progress')
|
||||
);
|
||||
"""
|
||||
)
|
||||
@ -518,10 +540,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
|
||||
VALUES (?, ?, ?);
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(session.batch_id, session.session_id, session.state),
|
||||
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
@ -529,9 +551,34 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session(session.session_id)
|
||||
return self.get_session_by_session_id(session.session_id)
|
||||
|
||||
def get_session(self, session_id: str) -> BatchSession:
|
||||
def create_sessions(
|
||||
self,
|
||||
sessions: list[BatchSession],
|
||||
) -> list[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
session_data = [
|
||||
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index))
|
||||
for session in sessions
|
||||
]
|
||||
self._cursor.executemany(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
session_data,
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_sessions_by_session_ids([session.session_id for session in sessions])
|
||||
|
||||
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -558,14 +605,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
|
||||
return BatchSession.parse_obj(session_dict)
|
||||
|
||||
def get_created_session(self, batch_id: str) -> BatchSession:
|
||||
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ? AND state = 'created';
|
||||
WHERE batch_id = ? AND state = 'uninitialized';
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
@ -581,14 +628,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
session = self._deserialize_batch_session(dict(result))
|
||||
return session
|
||||
|
||||
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ? AND state = created;
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
@ -604,16 +651,17 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
def get_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
placeholders = ",".join("?" * len(session_ids))
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ?;
|
||||
f"""--sql
|
||||
SELECT * FROM batch_session
|
||||
WHERE session_id
|
||||
IN ({placeholders})
|
||||
""",
|
||||
(batch_id,),
|
||||
tuple(session_ids),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
@ -652,4 +700,4 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session(session_id)
|
||||
return self.get_session_by_session_id(session_id)
|
||||
|
Loading…
Reference in New Issue
Block a user