mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(batches): defer ges *and* batch session creation until execution time
This commit is contained in:
parent
88ae19a768
commit
1652143671
@ -1,22 +1,24 @@
|
|||||||
import networkx as nx
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from pydantic import BaseModel, Field
|
from typing import Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import networkx as nx
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event
|
from fastapi_events.typing import Event
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.services.batch_manager_storage import (
|
||||||
|
Batch,
|
||||||
|
BatchProcess,
|
||||||
|
BatchProcessStorageBase,
|
||||||
|
BatchSession,
|
||||||
|
BatchSessionChanges,
|
||||||
|
BatchSessionNotFoundException,
|
||||||
|
)
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.batch_manager_storage import (
|
|
||||||
BatchProcessStorageBase,
|
|
||||||
BatchSessionNotFoundException,
|
|
||||||
Batch,
|
|
||||||
BatchProcess,
|
|
||||||
BatchSession,
|
|
||||||
BatchSessionChanges,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchProcessResponse(BaseModel):
|
class BatchProcessResponse(BaseModel):
|
||||||
@ -107,7 +109,9 @@ class BatchManager(BatchManagerBase):
|
|||||||
if not batch_process.canceled:
|
if not batch_process.canceled:
|
||||||
self.run_batch_process(batch_process.batch_id)
|
self.run_batch_process(batch_process.batch_id)
|
||||||
|
|
||||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: tuple[int]) -> GraphExecutionState:
|
def _create_graph_execution_state(
|
||||||
|
self, batch_process: BatchProcess, batch_indices: tuple[int, ...]
|
||||||
|
) -> GraphExecutionState:
|
||||||
graph = batch_process.graph.copy(deep=True)
|
graph = batch_process.graph.copy(deep=True)
|
||||||
batch = batch_process.batch
|
batch = batch_process.batch
|
||||||
g = graph.nx_graph_flat()
|
g = graph.nx_graph_flat()
|
||||||
@ -129,12 +133,31 @@ class BatchManager(BatchManagerBase):
|
|||||||
|
|
||||||
def run_batch_process(self, batch_id: str) -> None:
|
def run_batch_process(self, batch_id: str) -> None:
|
||||||
self.__batch_process_storage.start(batch_id)
|
self.__batch_process_storage.start(batch_id)
|
||||||
try:
|
|
||||||
next_session = self.__batch_process_storage.get_next_session(batch_id)
|
|
||||||
except BatchSessionNotFoundException:
|
|
||||||
return
|
|
||||||
batch_process = self.__batch_process_storage.get(batch_id)
|
batch_process = self.__batch_process_storage.get(batch_id)
|
||||||
ges = self._create_batch_session(batch_process=batch_process, batch_indices=tuple(next_session.batch_index))
|
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||||
|
if next_batch_index is None:
|
||||||
|
# finished with current run
|
||||||
|
if batch_process.current_run >= (batch_process.batch.runs - 1):
|
||||||
|
# finished with all runs
|
||||||
|
return
|
||||||
|
batch_process.current_batch_index = 0
|
||||||
|
batch_process.current_run += 1
|
||||||
|
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||||
|
if next_batch_index is None:
|
||||||
|
# shouldn't happen; satisfy types
|
||||||
|
return
|
||||||
|
# remember to increment the batch index
|
||||||
|
batch_process.current_batch_index += 1
|
||||||
|
self.__batch_process_storage.save(batch_process)
|
||||||
|
ges = self._create_graph_execution_state(batch_process=batch_process, batch_indices=next_batch_index)
|
||||||
|
next_session = self.__batch_process_storage.create_session(
|
||||||
|
BatchSession(
|
||||||
|
batch_id=batch_id,
|
||||||
|
session_id=str(uuid4()),
|
||||||
|
state="uninitialized",
|
||||||
|
batch_index=batch_process.current_batch_index,
|
||||||
|
)
|
||||||
|
)
|
||||||
ges.id = next_session.session_id
|
ges.id = next_session.session_id
|
||||||
self.__invoker.services.graph_execution_manager.set(ges)
|
self.__invoker.services.graph_execution_manager.set(ges)
|
||||||
self.__batch_process_storage.update_session_state(
|
self.__batch_process_storage.update_session_state(
|
||||||
@ -150,26 +173,11 @@ class BatchManager(BatchManagerBase):
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
)
|
)
|
||||||
batch_process = self.__batch_process_storage.save(batch_process)
|
batch_process = self.__batch_process_storage.save(batch_process)
|
||||||
sessions = self._create_sessions(batch_process)
|
|
||||||
return BatchProcessResponse(
|
return BatchProcessResponse(
|
||||||
batch_id=batch_process.batch_id,
|
batch_id=batch_process.batch_id,
|
||||||
session_ids=[session.session_id for session in sessions],
|
session_ids=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
|
||||||
batch_indices = list()
|
|
||||||
sessions_to_create: list[BatchSession] = list()
|
|
||||||
for batchdata in batch_process.batch.data:
|
|
||||||
batch_indices.append(list(range(len(batchdata[0].items))))
|
|
||||||
all_batch_indices = product(*batch_indices)
|
|
||||||
for bi in all_batch_indices:
|
|
||||||
for _ in range(batch_process.batch.runs):
|
|
||||||
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
|
|
||||||
if not sessions_to_create:
|
|
||||||
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
|
|
||||||
created_sessions = self.__batch_process_storage.create_sessions(sessions_to_create)
|
|
||||||
return created_sessions
|
|
||||||
|
|
||||||
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||||
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
|
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
|
||||||
|
|
||||||
@ -204,3 +212,12 @@ class BatchManager(BatchManagerBase):
|
|||||||
|
|
||||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||||
self.__batch_process_storage.cancel(batch_process_id)
|
self.__batch_process_storage.cancel(batch_process_id)
|
||||||
|
|
||||||
|
def _get_batch_index_tuple(self, batch_process: BatchProcess) -> Optional[tuple[int, ...]]:
|
||||||
|
batch_indices = list()
|
||||||
|
for batchdata in batch_process.batch.data:
|
||||||
|
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||||
|
try:
|
||||||
|
return list(product(*batch_indices))[batch_process.current_batch_index]
|
||||||
|
except IndexError:
|
||||||
|
return None
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
@ -85,19 +84,15 @@ class BatchSession(BaseModel):
|
|||||||
session_id: str = Field(
|
session_id: str = Field(
|
||||||
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
|
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
|
||||||
)
|
)
|
||||||
batch_index: list[int] = Field(description="The index of the batch to be run in this BatchSession.")
|
batch_index: int = Field(description="The index of this batch session in its parent batch process")
|
||||||
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
|
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
|
||||||
|
|
||||||
@validator("batch_index", pre=True)
|
|
||||||
def parse_str_to_list(cls, v: Union[str, list[int]]):
|
|
||||||
if isinstance(v, str):
|
|
||||||
return json.loads(v)
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class BatchProcess(BaseModel):
|
class BatchProcess(BaseModel):
|
||||||
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
||||||
batch: Batch = Field(description="The Batch to apply to this session.")
|
batch: Batch = Field(description="The Batch to apply to this session.")
|
||||||
|
current_batch_index: int = Field(default=0, description="The last executed batch index")
|
||||||
|
current_run: int = Field(default=0, description="The current run of the batch")
|
||||||
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
|
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
|
||||||
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
|
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
|
||||||
|
|
||||||
@ -279,8 +274,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
"""--sql
|
"""--sql
|
||||||
CREATE TABLE IF NOT EXISTS batch_process (
|
CREATE TABLE IF NOT EXISTS batch_process (
|
||||||
batch_id TEXT NOT NULL PRIMARY KEY,
|
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||||
batches TEXT NOT NULL,
|
batch TEXT NOT NULL,
|
||||||
graph TEXT NOT NULL,
|
graph TEXT NOT NULL,
|
||||||
|
current_batch_index NUMBER NOT NULL,
|
||||||
|
current_run NUMBER NOT NULL,
|
||||||
canceled BOOLEAN NOT NULL DEFAULT(0),
|
canceled BOOLEAN NOT NULL DEFAULT(0),
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
-- Updated via trigger
|
-- Updated via trigger
|
||||||
@ -316,8 +313,8 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
CREATE TABLE IF NOT EXISTS batch_session (
|
CREATE TABLE IF NOT EXISTS batch_session (
|
||||||
batch_id TEXT NOT NULL,
|
batch_id TEXT NOT NULL,
|
||||||
session_id TEXT NOT NULL,
|
session_id TEXT NOT NULL,
|
||||||
batch_index TEXT NOT NULL,
|
|
||||||
state TEXT NOT NULL,
|
state TEXT NOT NULL,
|
||||||
|
batch_index NUMBER NOT NULL,
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
-- updated via trigger
|
-- updated via trigger
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
@ -386,10 +383,16 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run)
|
||||||
VALUES (?, ?, ?);
|
VALUES (?, ?, ?, ?, ?);
|
||||||
""",
|
""",
|
||||||
(batch_process.batch_id, batch_process.batch.json(), batch_process.graph.json()),
|
(
|
||||||
|
batch_process.batch_id,
|
||||||
|
batch_process.batch.json(),
|
||||||
|
batch_process.graph.json(),
|
||||||
|
batch_process.current_batch_index,
|
||||||
|
batch_process.current_run,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
@ -405,13 +408,17 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||||
|
|
||||||
batch_id = session_dict.get("batch_id", "unknown")
|
batch_id = session_dict.get("batch_id", "unknown")
|
||||||
batch_raw = session_dict.get("batches", "unknown")
|
batch_raw = session_dict.get("batch", "unknown")
|
||||||
graph_raw = session_dict.get("graph", "unknown")
|
graph_raw = session_dict.get("graph", "unknown")
|
||||||
|
current_batch_index = session_dict.get("current_batch_index", 0)
|
||||||
|
current_run = session_dict.get("current_run", 0)
|
||||||
canceled = session_dict.get("canceled", 0)
|
canceled = session_dict.get("canceled", 0)
|
||||||
return BatchProcess(
|
return BatchProcess(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
batch=parse_raw_as(Batch, batch_raw),
|
batch=parse_raw_as(Batch, batch_raw),
|
||||||
graph=parse_raw_as(Graph, graph_raw),
|
graph=parse_raw_as(Graph, graph_raw),
|
||||||
|
current_batch_index=current_batch_index,
|
||||||
|
current_run=current_run,
|
||||||
canceled=canceled == 1,
|
canceled=canceled == 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -543,7 +550,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||||
VALUES (?, ?, ?, ?);
|
VALUES (?, ?, ?, ?);
|
||||||
""",
|
""",
|
||||||
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)),
|
(session.batch_id, session.session_id, session.state, session.batch_index),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
@ -559,10 +566,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
) -> list[BatchSession]:
|
) -> list[BatchSession]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
session_data = [
|
session_data = [(session.batch_id, session.session_id, session.state) for session in sessions]
|
||||||
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index))
|
|
||||||
for session in sessions
|
|
||||||
]
|
|
||||||
self._cursor.executemany(
|
self._cursor.executemany(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||||
|
Loading…
Reference in New Issue
Block a user