feat(batches): defer ges *and* batch session creation until execution time

This commit is contained in:
psychedelicious 2023-08-22 00:54:17 +10:00
parent 88ae19a768
commit 1652143671
2 changed files with 73 additions and 52 deletions

View File

@ -1,22 +1,24 @@
import networkx as nx
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import product from itertools import product
from pydantic import BaseModel, Field from typing import Optional
from uuid import uuid4
import networkx as nx
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event from fastapi_events.typing import Event
from pydantic import BaseModel, Field
from invokeai.app.services.batch_manager_storage import (
Batch,
BatchProcess,
BatchProcessStorageBase,
BatchSession,
BatchSessionChanges,
BatchSessionNotFoundException,
)
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph, GraphExecutionState from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.batch_manager_storage import (
BatchProcessStorageBase,
BatchSessionNotFoundException,
Batch,
BatchProcess,
BatchSession,
BatchSessionChanges,
)
class BatchProcessResponse(BaseModel): class BatchProcessResponse(BaseModel):
@ -107,7 +109,9 @@ class BatchManager(BatchManagerBase):
if not batch_process.canceled: if not batch_process.canceled:
self.run_batch_process(batch_process.batch_id) self.run_batch_process(batch_process.batch_id)
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: tuple[int]) -> GraphExecutionState: def _create_graph_execution_state(
self, batch_process: BatchProcess, batch_indices: tuple[int, ...]
) -> GraphExecutionState:
graph = batch_process.graph.copy(deep=True) graph = batch_process.graph.copy(deep=True)
batch = batch_process.batch batch = batch_process.batch
g = graph.nx_graph_flat() g = graph.nx_graph_flat()
@ -129,12 +133,31 @@ class BatchManager(BatchManagerBase):
def run_batch_process(self, batch_id: str) -> None: def run_batch_process(self, batch_id: str) -> None:
self.__batch_process_storage.start(batch_id) self.__batch_process_storage.start(batch_id)
try:
next_session = self.__batch_process_storage.get_next_session(batch_id)
except BatchSessionNotFoundException:
return
batch_process = self.__batch_process_storage.get(batch_id) batch_process = self.__batch_process_storage.get(batch_id)
ges = self._create_batch_session(batch_process=batch_process, batch_indices=tuple(next_session.batch_index)) next_batch_index = self._get_batch_index_tuple(batch_process)
if next_batch_index is None:
# finished with current run
if batch_process.current_run >= (batch_process.batch.runs - 1):
# finished with all runs
return
batch_process.current_batch_index = 0
batch_process.current_run += 1
next_batch_index = self._get_batch_index_tuple(batch_process)
if next_batch_index is None:
# shouldn't happen; satisfy types
return
# remember to increment the batch index
batch_process.current_batch_index += 1
self.__batch_process_storage.save(batch_process)
ges = self._create_graph_execution_state(batch_process=batch_process, batch_indices=next_batch_index)
next_session = self.__batch_process_storage.create_session(
BatchSession(
batch_id=batch_id,
session_id=str(uuid4()),
state="uninitialized",
batch_index=batch_process.current_batch_index,
)
)
ges.id = next_session.session_id ges.id = next_session.session_id
self.__invoker.services.graph_execution_manager.set(ges) self.__invoker.services.graph_execution_manager.set(ges)
self.__batch_process_storage.update_session_state( self.__batch_process_storage.update_session_state(
@ -150,26 +173,11 @@ class BatchManager(BatchManagerBase):
graph=graph, graph=graph,
) )
batch_process = self.__batch_process_storage.save(batch_process) batch_process = self.__batch_process_storage.save(batch_process)
sessions = self._create_sessions(batch_process)
return BatchProcessResponse( return BatchProcessResponse(
batch_id=batch_process.batch_id, batch_id=batch_process.batch_id,
session_ids=[session.session_id for session in sessions], session_ids=[],
) )
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
batch_indices = list()
sessions_to_create: list[BatchSession] = list()
for batchdata in batch_process.batch.data:
batch_indices.append(list(range(len(batchdata[0].items))))
all_batch_indices = product(*batch_indices)
for bi in all_batch_indices:
for _ in range(batch_process.batch.runs):
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
if not sessions_to_create:
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
created_sessions = self.__batch_process_storage.create_sessions(sessions_to_create)
return created_sessions
def get_sessions(self, batch_id: str) -> list[BatchSession]: def get_sessions(self, batch_id: str) -> list[BatchSession]:
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id) return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
@ -204,3 +212,12 @@ class BatchManager(BatchManagerBase):
def cancel_batch_process(self, batch_process_id: str) -> None: def cancel_batch_process(self, batch_process_id: str) -> None:
self.__batch_process_storage.cancel(batch_process_id) self.__batch_process_storage.cancel(batch_process_id)
def _get_batch_index_tuple(self, batch_process: BatchProcess) -> Optional[tuple[int, ...]]:
batch_indices = list()
for batchdata in batch_process.batch.data:
batch_indices.append(list(range(len(batchdata[0].items))))
try:
return list(product(*batch_indices))[batch_process.current_batch_index]
except IndexError:
return None

View File

@ -1,4 +1,3 @@
import json
import sqlite3 import sqlite3
import threading import threading
import uuid import uuid
@ -85,19 +84,15 @@ class BatchSession(BaseModel):
session_id: str = Field( session_id: str = Field(
default_factory=uuid_string, description="The Session to which this BatchSession is attached." default_factory=uuid_string, description="The Session to which this BatchSession is attached."
) )
batch_index: list[int] = Field(description="The index of the batch to be run in this BatchSession.") batch_index: int = Field(description="The index of this batch session in its parent batch process")
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession") state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
@validator("batch_index", pre=True)
def parse_str_to_list(cls, v: Union[str, list[int]]):
if isinstance(v, str):
return json.loads(v)
return v
class BatchProcess(BaseModel): class BatchProcess(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.") batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
batch: Batch = Field(description="The Batch to apply to this session.") batch: Batch = Field(description="The Batch to apply to this session.")
current_batch_index: int = Field(default=0, description="The last executed batch index")
current_run: int = Field(default=0, description="The current run of the batch")
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False) canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.") graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
@ -279,8 +274,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
"""--sql """--sql
CREATE TABLE IF NOT EXISTS batch_process ( CREATE TABLE IF NOT EXISTS batch_process (
batch_id TEXT NOT NULL PRIMARY KEY, batch_id TEXT NOT NULL PRIMARY KEY,
batches TEXT NOT NULL, batch TEXT NOT NULL,
graph TEXT NOT NULL, graph TEXT NOT NULL,
current_batch_index NUMBER NOT NULL,
current_run NUMBER NOT NULL,
canceled BOOLEAN NOT NULL DEFAULT(0), canceled BOOLEAN NOT NULL DEFAULT(0),
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger -- Updated via trigger
@ -316,8 +313,8 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
CREATE TABLE IF NOT EXISTS batch_session ( CREATE TABLE IF NOT EXISTS batch_session (
batch_id TEXT NOT NULL, batch_id TEXT NOT NULL,
session_id TEXT NOT NULL, session_id TEXT NOT NULL,
batch_index TEXT NOT NULL,
state TEXT NOT NULL, state TEXT NOT NULL,
batch_index NUMBER NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger -- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
@ -386,10 +383,16 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph) INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run)
VALUES (?, ?, ?); VALUES (?, ?, ?, ?, ?);
""", """,
(batch_process.batch_id, batch_process.batch.json(), batch_process.graph.json()), (
batch_process.batch_id,
batch_process.batch.json(),
batch_process.graph.json(),
batch_process.current_batch_index,
batch_process.current_run,
),
) )
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
@ -405,13 +408,17 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
# Retrieve all the values, setting "reasonable" defaults if they are not present. # Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown") batch_id = session_dict.get("batch_id", "unknown")
batch_raw = session_dict.get("batches", "unknown") batch_raw = session_dict.get("batch", "unknown")
graph_raw = session_dict.get("graph", "unknown") graph_raw = session_dict.get("graph", "unknown")
current_batch_index = session_dict.get("current_batch_index", 0)
current_run = session_dict.get("current_run", 0)
canceled = session_dict.get("canceled", 0) canceled = session_dict.get("canceled", 0)
return BatchProcess( return BatchProcess(
batch_id=batch_id, batch_id=batch_id,
batch=parse_raw_as(Batch, batch_raw), batch=parse_raw_as(Batch, batch_raw),
graph=parse_raw_as(Graph, graph_raw), graph=parse_raw_as(Graph, graph_raw),
current_batch_index=current_batch_index,
current_run=current_run,
canceled=canceled == 1, canceled=canceled == 1,
) )
@ -543,7 +550,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index) INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
VALUES (?, ?, ?, ?); VALUES (?, ?, ?, ?);
""", """,
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)), (session.batch_id, session.session_id, session.state, session.batch_index),
) )
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
@ -559,10 +566,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
) -> list[BatchSession]: ) -> list[BatchSession]:
try: try:
self._lock.acquire() self._lock.acquire()
session_data = [ session_data = [(session.batch_id, session.session_id, session.state) for session in sessions]
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index))
for session in sessions
]
self._cursor.executemany( self._cursor.executemany(
"""--sql """--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index) INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)