mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(batches): defer ges *and* batch session creation until execution time
This commit is contained in:
parent
88ae19a768
commit
1652143671
@ -1,22 +1,24 @@
|
||||
import networkx as nx
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import product
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import networkx as nx
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.batch_manager_storage import (
|
||||
Batch,
|
||||
BatchProcess,
|
||||
BatchProcessStorageBase,
|
||||
BatchSession,
|
||||
BatchSessionChanges,
|
||||
BatchSessionNotFoundException,
|
||||
)
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.batch_manager_storage import (
|
||||
BatchProcessStorageBase,
|
||||
BatchSessionNotFoundException,
|
||||
Batch,
|
||||
BatchProcess,
|
||||
BatchSession,
|
||||
BatchSessionChanges,
|
||||
)
|
||||
|
||||
|
||||
class BatchProcessResponse(BaseModel):
|
||||
@ -107,7 +109,9 @@ class BatchManager(BatchManagerBase):
|
||||
if not batch_process.canceled:
|
||||
self.run_batch_process(batch_process.batch_id)
|
||||
|
||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: tuple[int]) -> GraphExecutionState:
|
||||
def _create_graph_execution_state(
|
||||
self, batch_process: BatchProcess, batch_indices: tuple[int, ...]
|
||||
) -> GraphExecutionState:
|
||||
graph = batch_process.graph.copy(deep=True)
|
||||
batch = batch_process.batch
|
||||
g = graph.nx_graph_flat()
|
||||
@ -129,12 +133,31 @@ class BatchManager(BatchManagerBase):
|
||||
|
||||
def run_batch_process(self, batch_id: str) -> None:
|
||||
self.__batch_process_storage.start(batch_id)
|
||||
try:
|
||||
next_session = self.__batch_process_storage.get_next_session(batch_id)
|
||||
except BatchSessionNotFoundException:
|
||||
return
|
||||
batch_process = self.__batch_process_storage.get(batch_id)
|
||||
ges = self._create_batch_session(batch_process=batch_process, batch_indices=tuple(next_session.batch_index))
|
||||
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||
if next_batch_index is None:
|
||||
# finished with current run
|
||||
if batch_process.current_run >= (batch_process.batch.runs - 1):
|
||||
# finished with all runs
|
||||
return
|
||||
batch_process.current_batch_index = 0
|
||||
batch_process.current_run += 1
|
||||
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||
if next_batch_index is None:
|
||||
# shouldn't happen; satisfy types
|
||||
return
|
||||
# remember to increment the batch index
|
||||
batch_process.current_batch_index += 1
|
||||
self.__batch_process_storage.save(batch_process)
|
||||
ges = self._create_graph_execution_state(batch_process=batch_process, batch_indices=next_batch_index)
|
||||
next_session = self.__batch_process_storage.create_session(
|
||||
BatchSession(
|
||||
batch_id=batch_id,
|
||||
session_id=str(uuid4()),
|
||||
state="uninitialized",
|
||||
batch_index=batch_process.current_batch_index,
|
||||
)
|
||||
)
|
||||
ges.id = next_session.session_id
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
self.__batch_process_storage.update_session_state(
|
||||
@ -150,26 +173,11 @@ class BatchManager(BatchManagerBase):
|
||||
graph=graph,
|
||||
)
|
||||
batch_process = self.__batch_process_storage.save(batch_process)
|
||||
sessions = self._create_sessions(batch_process)
|
||||
return BatchProcessResponse(
|
||||
batch_id=batch_process.batch_id,
|
||||
session_ids=[session.session_id for session in sessions],
|
||||
session_ids=[],
|
||||
)
|
||||
|
||||
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
||||
batch_indices = list()
|
||||
sessions_to_create: list[BatchSession] = list()
|
||||
for batchdata in batch_process.batch.data:
|
||||
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||
all_batch_indices = product(*batch_indices)
|
||||
for bi in all_batch_indices:
|
||||
for _ in range(batch_process.batch.runs):
|
||||
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
|
||||
if not sessions_to_create:
|
||||
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
|
||||
created_sessions = self.__batch_process_storage.create_sessions(sessions_to_create)
|
||||
return created_sessions
|
||||
|
||||
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
|
||||
|
||||
@ -204,3 +212,12 @@ class BatchManager(BatchManagerBase):
|
||||
|
||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||
self.__batch_process_storage.cancel(batch_process_id)
|
||||
|
||||
def _get_batch_index_tuple(self, batch_process: BatchProcess) -> Optional[tuple[int, ...]]:
|
||||
batch_indices = list()
|
||||
for batchdata in batch_process.batch.data:
|
||||
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||
try:
|
||||
return list(product(*batch_indices))[batch_process.current_batch_index]
|
||||
except IndexError:
|
||||
return None
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
@ -85,19 +84,15 @@ class BatchSession(BaseModel):
|
||||
session_id: str = Field(
|
||||
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
|
||||
)
|
||||
batch_index: list[int] = Field(description="The index of the batch to be run in this BatchSession.")
|
||||
batch_index: int = Field(description="The index of this batch session in its parent batch process")
|
||||
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
|
||||
|
||||
@validator("batch_index", pre=True)
|
||||
def parse_str_to_list(cls, v: Union[str, list[int]]):
|
||||
if isinstance(v, str):
|
||||
return json.loads(v)
|
||||
return v
|
||||
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
||||
batch: Batch = Field(description="The Batch to apply to this session.")
|
||||
current_batch_index: int = Field(default=0, description="The last executed batch index")
|
||||
current_run: int = Field(default=0, description="The current run of the batch")
|
||||
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
|
||||
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
|
||||
|
||||
@ -279,8 +274,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS batch_process (
|
||||
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||
batches TEXT NOT NULL,
|
||||
batch TEXT NOT NULL,
|
||||
graph TEXT NOT NULL,
|
||||
current_batch_index NUMBER NOT NULL,
|
||||
current_run NUMBER NOT NULL,
|
||||
canceled BOOLEAN NOT NULL DEFAULT(0),
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
@ -316,8 +313,8 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
CREATE TABLE IF NOT EXISTS batch_session (
|
||||
batch_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
batch_index TEXT NOT NULL,
|
||||
state TEXT NOT NULL,
|
||||
batch_index NUMBER NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
@ -386,10 +383,16 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
||||
VALUES (?, ?, ?);
|
||||
INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run)
|
||||
VALUES (?, ?, ?, ?, ?);
|
||||
""",
|
||||
(batch_process.batch_id, batch_process.batch.json(), batch_process.graph.json()),
|
||||
(
|
||||
batch_process.batch_id,
|
||||
batch_process.batch.json(),
|
||||
batch_process.graph.json(),
|
||||
batch_process.current_batch_index,
|
||||
batch_process.current_run,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
@ -405,13 +408,17 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
batch_raw = session_dict.get("batches", "unknown")
|
||||
batch_raw = session_dict.get("batch", "unknown")
|
||||
graph_raw = session_dict.get("graph", "unknown")
|
||||
current_batch_index = session_dict.get("current_batch_index", 0)
|
||||
current_run = session_dict.get("current_run", 0)
|
||||
canceled = session_dict.get("canceled", 0)
|
||||
return BatchProcess(
|
||||
batch_id=batch_id,
|
||||
batch=parse_raw_as(Batch, batch_raw),
|
||||
graph=parse_raw_as(Graph, graph_raw),
|
||||
current_batch_index=current_batch_index,
|
||||
current_run=current_run,
|
||||
canceled=canceled == 1,
|
||||
)
|
||||
|
||||
@ -543,7 +550,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)),
|
||||
(session.batch_id, session.session_id, session.state, session.batch_index),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
@ -559,10 +566,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
) -> list[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
session_data = [
|
||||
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index))
|
||||
for session in sessions
|
||||
]
|
||||
session_data = [(session.batch_id, session.session_id, session.state) for session in sessions]
|
||||
self._cursor.executemany(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||
|
Loading…
Reference in New Issue
Block a user