mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Testing out Spencer's batch data structure
This commit is contained in:
parent
d6a5c2dbe3
commit
796ff34c8a
@ -48,10 +48,10 @@ async def create_session(
|
||||
)
|
||||
async def create_batch(
|
||||
graph: Optional[Graph] = Body(description="The graph to initialize the session with"),
|
||||
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
|
||||
batch: Batch = Body(description="Batch config to apply to the given graph"),
|
||||
) -> BatchProcessResponse:
|
||||
"""Creates and starts a new new batch process"""
|
||||
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
||||
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
|
||||
return batch_process_res
|
||||
|
||||
|
||||
|
@ -31,7 +31,7 @@ class BatchManagerBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
|
||||
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -85,17 +85,20 @@ class BatchManager(BatchManagerBase):
|
||||
|
||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
||||
graph = batch_process.graph.copy(deep=True)
|
||||
batches = batch_process.batches
|
||||
batch = batch_process.batch
|
||||
g = graph.nx_graph_flat()
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
for npath in sorted_nodes:
|
||||
node = graph.get_node(npath)
|
||||
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
|
||||
if batch:
|
||||
batch_index = batch_indices[index]
|
||||
datum = batch.data[batch_index]
|
||||
for key in datum:
|
||||
node.__dict__[key] = datum[key]
|
||||
for index, bdl in enumerate(batch.data):
|
||||
relavent_bd = [bd for bd in bdl if bd.node_id in node.id]
|
||||
if not relavent_bd:
|
||||
continue
|
||||
for bd in relavent_bd:
|
||||
batch_index = batch_indices[index]
|
||||
datum = bd.items[batch_index]
|
||||
key = bd.field_name
|
||||
node.__dict__[key] = datum
|
||||
graph.update_node(npath, node)
|
||||
|
||||
return GraphExecutionState(graph=graph)
|
||||
@ -110,13 +113,13 @@ class BatchManager(BatchManagerBase):
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
|
||||
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
||||
# TODO: Check that the node_ids in the batches are unique
|
||||
# TODO: Check that the node_ids in the batch are unique
|
||||
# TODO: Validate data types are correct for each batch data
|
||||
return True
|
||||
|
||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
|
||||
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||
batch_process = BatchProcess(
|
||||
batches=batches,
|
||||
batch=batch,
|
||||
graph=graph,
|
||||
)
|
||||
if not self._valid_batch_config(batch_process):
|
||||
@ -131,8 +134,8 @@ class BatchManager(BatchManagerBase):
|
||||
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
||||
batch_indices = list()
|
||||
sessions = list()
|
||||
for batch in batch_process.batches:
|
||||
batch_indices.append(list(range(len(batch.data))))
|
||||
for batchdata in batch_process.batch.data:
|
||||
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||
all_batch_indices = product(*batch_indices)
|
||||
for bi in all_batch_indices:
|
||||
ges = self._create_batch_session(batch_process, bi)
|
||||
|
@ -4,13 +4,11 @@ import uuid
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
import json
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -18,7 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat
|
||||
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat, validator
|
||||
|
||||
invocations = BaseInvocation.get_invocations()
|
||||
InvocationsUnion = Union[invocations] # type: ignore
|
||||
@ -26,9 +24,42 @@ InvocationsUnion = Union[invocations] # type: ignore
|
||||
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
|
||||
|
||||
|
||||
class BatchData(BaseModel):
|
||||
node_id: str
|
||||
field_name: str
|
||||
items: list[BatchDataType]
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
|
||||
node_id: str = Field(description="ID of the node to batch")
|
||||
data: list[list[BatchData]]
|
||||
|
||||
@validator("data")
|
||||
def validate_len(cls, v: list[list[BatchData]]):
|
||||
for batch_data in v:
|
||||
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
|
||||
raise ValueError("Zipped batch items must have all have same length")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_types(cls, v: list[list[BatchData]]):
|
||||
for batch_data in v:
|
||||
for datum in batch_data:
|
||||
for item in datum.items:
|
||||
if not all(isinstance(item, type(i)) for i in datum.items):
|
||||
raise TypeError("All items in a batch must have have same type")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
|
||||
paths: set[tuple[str, str]] = set()
|
||||
count: int = 0
|
||||
for batch_data in v:
|
||||
for datum in batch_data:
|
||||
paths.add((datum.node_id, datum.field_name))
|
||||
count += 1
|
||||
if len(paths) != count:
|
||||
raise ValueError("Each batch data must have unique node_id and field_name")
|
||||
return v
|
||||
|
||||
|
||||
class BatchSession(BaseModel):
|
||||
@ -46,10 +77,7 @@ def uuid_string():
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
|
||||
batches: List[Batch] = Field(
|
||||
description="List of batch configs to apply to this session",
|
||||
default_factory=list,
|
||||
)
|
||||
batch: Batch = Field(description="List of batch configs to apply to this session")
|
||||
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False)
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
@ -310,13 +338,12 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
batches = [batch.json() for batch in batch_process.batches]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
||||
VALUES (?, ?, ?);
|
||||
""",
|
||||
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
|
||||
(batch_process.batch_id, batch_process.batch.json(), batch_process.graph.json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
@ -332,13 +359,11 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
batches_raw = session_dict.get("batches", "unknown")
|
||||
batch_raw = session_dict.get("batches", "unknown")
|
||||
graph_raw = session_dict.get("graph", "unknown")
|
||||
canceled = session_dict.get("canceled", 0)
|
||||
batches = json.loads(batches_raw)
|
||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||
return BatchProcess(
|
||||
batch_id=batch_id, batches=batches, graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
|
||||
batch_id=batch_id, batch=parse_raw_as(Batch, batch_raw), graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
|
||||
)
|
||||
|
||||
def get(
|
||||
|
@ -45,8 +45,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
parsed = parse_raw_as(item_type, item)
|
||||
return parsed
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user