Testing out Spencer's batch data structure

This commit is contained in:
Brandon Rising 2023-08-16 15:21:11 -04:00
parent d6a5c2dbe3
commit 796ff34c8a
4 changed files with 59 additions and 32 deletions

View File

@ -48,10 +48,10 @@ async def create_session(
)
async def create_batch(
graph: Optional[Graph] = Body(description="The graph to initialize the session with"),
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
batch: Batch = Body(description="Batch config to apply to the given graph"),
) -> BatchProcessResponse:
"""Creates and starts a new new batch process"""
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
return batch_process_res

View File

@ -31,7 +31,7 @@ class BatchManagerBase(ABC):
pass
@abstractmethod
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
pass
@abstractmethod
@ -85,17 +85,20 @@ class BatchManager(BatchManagerBase):
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
graph = batch_process.graph.copy(deep=True)
batches = batch_process.batches
batch = batch_process.batch
g = graph.nx_graph_flat()
sorted_nodes = nx.topological_sort(g)
for npath in sorted_nodes:
node = graph.get_node(npath)
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
if batch:
batch_index = batch_indices[index]
datum = batch.data[batch_index]
for key in datum:
node.__dict__[key] = datum[key]
for index, bdl in enumerate(batch.data):
relavent_bd = [bd for bd in bdl if bd.node_id in node.id]
if not relavent_bd:
continue
for bd in relavent_bd:
batch_index = batch_indices[index]
datum = bd.items[batch_index]
key = bd.field_name
node.__dict__[key] = datum
graph.update_node(npath, node)
return GraphExecutionState(graph=graph)
@ -110,13 +113,13 @@ class BatchManager(BatchManagerBase):
self.__invoker.invoke(ges, invoke_all=True)
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
# TODO: Check that the node_ids in the batches are unique
# TODO: Check that the node_ids in the batch are unique
# TODO: Validate data types are correct for each batch data
return True
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
batch_process = BatchProcess(
batches=batches,
batch=batch,
graph=graph,
)
if not self._valid_batch_config(batch_process):
@ -131,8 +134,8 @@ class BatchManager(BatchManagerBase):
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
batch_indices = list()
sessions = list()
for batch in batch_process.batches:
batch_indices.append(list(range(len(batch.data))))
for batchdata in batch_process.batch.data:
batch_indices.append(list(range(len(batchdata[0].items))))
all_batch_indices = product(*batch_indices)
for bi in all_batch_indices:
ges = self._create_batch_session(batch_process, bi)

View File

@ -4,13 +4,11 @@ import uuid
import sqlite3
import threading
from typing import (
Any,
List,
Literal,
Optional,
Union,
)
import json
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -18,7 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.services.graph import Graph
from invokeai.app.invocations.primitives import ImageField
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat, validator
invocations = BaseInvocation.get_invocations()
InvocationsUnion = Union[invocations] # type: ignore
@ -26,9 +24,42 @@ InvocationsUnion = Union[invocations] # type: ignore
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
class BatchData(BaseModel):
node_id: str
field_name: str
items: list[BatchDataType]
class Batch(BaseModel):
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
node_id: str = Field(description="ID of the node to batch")
data: list[list[BatchData]]
@validator("data")
def validate_len(cls, v: list[list[BatchData]]):
for batch_data in v:
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
raise ValueError("Zipped batch items must have all have same length")
return v
@validator("data")
def validate_types(cls, v: list[list[BatchData]]):
for batch_data in v:
for datum in batch_data:
for item in datum.items:
if not all(isinstance(item, type(i)) for i in datum.items):
raise TypeError("All items in a batch must have have same type")
return v
@validator("data")
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
paths: set[tuple[str, str]] = set()
count: int = 0
for batch_data in v:
for datum in batch_data:
paths.add((datum.node_id, datum.field_name))
count += 1
if len(paths) != count:
raise ValueError("Each batch data must have unique node_id and field_name")
return v
class BatchSession(BaseModel):
@ -46,10 +77,7 @@ def uuid_string():
class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
batches: List[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
batch: Batch = Field(description="List of batch configs to apply to this session")
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False)
graph: Graph = Field(description="The graph being executed")
@ -310,13 +338,12 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
) -> BatchProcess:
try:
self._lock.acquire()
batches = [batch.json() for batch in batch_process.batches]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
VALUES (?, ?, ?);
""",
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
(batch_process.batch_id, batch_process.batch.json(), batch_process.graph.json()),
)
self._conn.commit()
except sqlite3.Error as e:
@ -332,13 +359,11 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
batches_raw = session_dict.get("batches", "unknown")
batch_raw = session_dict.get("batches", "unknown")
graph_raw = session_dict.get("graph", "unknown")
canceled = session_dict.get("canceled", 0)
batches = json.loads(batches_raw)
batches = [parse_raw_as(Batch, batch) for batch in batches]
return BatchProcess(
batch_id=batch_id, batches=batches, graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
batch_id=batch_id, batch=parse_raw_as(Batch, batch_raw), graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
)
def get(

View File

@ -45,8 +45,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
parsed = parse_raw_as(item_type, item)
return parsed
return parse_raw_as(item_type, item)
def set(self, item: T):
try: