mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Testing out Spencer's batch data structure
This commit is contained in:
parent
d6a5c2dbe3
commit
796ff34c8a
@ -48,10 +48,10 @@ async def create_session(
|
|||||||
)
|
)
|
||||||
async def create_batch(
|
async def create_batch(
|
||||||
graph: Optional[Graph] = Body(description="The graph to initialize the session with"),
|
graph: Optional[Graph] = Body(description="The graph to initialize the session with"),
|
||||||
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
|
batch: Batch = Body(description="Batch config to apply to the given graph"),
|
||||||
) -> BatchProcessResponse:
|
) -> BatchProcessResponse:
|
||||||
"""Creates and starts a new new batch process"""
|
"""Creates and starts a new new batch process"""
|
||||||
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
|
||||||
return batch_process_res
|
return batch_process_res
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class BatchManagerBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
|
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -85,17 +85,20 @@ class BatchManager(BatchManagerBase):
|
|||||||
|
|
||||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
||||||
graph = batch_process.graph.copy(deep=True)
|
graph = batch_process.graph.copy(deep=True)
|
||||||
batches = batch_process.batches
|
batch = batch_process.batch
|
||||||
g = graph.nx_graph_flat()
|
g = graph.nx_graph_flat()
|
||||||
sorted_nodes = nx.topological_sort(g)
|
sorted_nodes = nx.topological_sort(g)
|
||||||
for npath in sorted_nodes:
|
for npath in sorted_nodes:
|
||||||
node = graph.get_node(npath)
|
node = graph.get_node(npath)
|
||||||
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
|
for index, bdl in enumerate(batch.data):
|
||||||
if batch:
|
relavent_bd = [bd for bd in bdl if bd.node_id in node.id]
|
||||||
batch_index = batch_indices[index]
|
if not relavent_bd:
|
||||||
datum = batch.data[batch_index]
|
continue
|
||||||
for key in datum:
|
for bd in relavent_bd:
|
||||||
node.__dict__[key] = datum[key]
|
batch_index = batch_indices[index]
|
||||||
|
datum = bd.items[batch_index]
|
||||||
|
key = bd.field_name
|
||||||
|
node.__dict__[key] = datum
|
||||||
graph.update_node(npath, node)
|
graph.update_node(npath, node)
|
||||||
|
|
||||||
return GraphExecutionState(graph=graph)
|
return GraphExecutionState(graph=graph)
|
||||||
@ -110,13 +113,13 @@ class BatchManager(BatchManagerBase):
|
|||||||
self.__invoker.invoke(ges, invoke_all=True)
|
self.__invoker.invoke(ges, invoke_all=True)
|
||||||
|
|
||||||
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
||||||
# TODO: Check that the node_ids in the batches are unique
|
# TODO: Check that the node_ids in the batch are unique
|
||||||
# TODO: Validate data types are correct for each batch data
|
# TODO: Validate data types are correct for each batch data
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
|
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||||
batch_process = BatchProcess(
|
batch_process = BatchProcess(
|
||||||
batches=batches,
|
batch=batch,
|
||||||
graph=graph,
|
graph=graph,
|
||||||
)
|
)
|
||||||
if not self._valid_batch_config(batch_process):
|
if not self._valid_batch_config(batch_process):
|
||||||
@ -131,8 +134,8 @@ class BatchManager(BatchManagerBase):
|
|||||||
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
||||||
batch_indices = list()
|
batch_indices = list()
|
||||||
sessions = list()
|
sessions = list()
|
||||||
for batch in batch_process.batches:
|
for batchdata in batch_process.batch.data:
|
||||||
batch_indices.append(list(range(len(batch.data))))
|
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||||
all_batch_indices = product(*batch_indices)
|
all_batch_indices = product(*batch_indices)
|
||||||
for bi in all_batch_indices:
|
for bi in all_batch_indices:
|
||||||
ges = self._create_batch_session(batch_process, bi)
|
ges = self._create_batch_session(batch_process, bi)
|
||||||
|
@ -4,13 +4,11 @@ import uuid
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
import json
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -18,7 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
from invokeai.app.services.graph import Graph
|
from invokeai.app.services.graph import Graph
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat
|
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat, validator
|
||||||
|
|
||||||
invocations = BaseInvocation.get_invocations()
|
invocations = BaseInvocation.get_invocations()
|
||||||
InvocationsUnion = Union[invocations] # type: ignore
|
InvocationsUnion = Union[invocations] # type: ignore
|
||||||
@ -26,9 +24,42 @@ InvocationsUnion = Union[invocations] # type: ignore
|
|||||||
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
|
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchData(BaseModel):
|
||||||
|
node_id: str
|
||||||
|
field_name: str
|
||||||
|
items: list[BatchDataType]
|
||||||
|
|
||||||
|
|
||||||
class Batch(BaseModel):
|
class Batch(BaseModel):
|
||||||
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
|
data: list[list[BatchData]]
|
||||||
node_id: str = Field(description="ID of the node to batch")
|
|
||||||
|
@validator("data")
|
||||||
|
def validate_len(cls, v: list[list[BatchData]]):
|
||||||
|
for batch_data in v:
|
||||||
|
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
|
||||||
|
raise ValueError("Zipped batch items must have all have same length")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("data")
|
||||||
|
def validate_types(cls, v: list[list[BatchData]]):
|
||||||
|
for batch_data in v:
|
||||||
|
for datum in batch_data:
|
||||||
|
for item in datum.items:
|
||||||
|
if not all(isinstance(item, type(i)) for i in datum.items):
|
||||||
|
raise TypeError("All items in a batch must have have same type")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("data")
|
||||||
|
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
|
||||||
|
paths: set[tuple[str, str]] = set()
|
||||||
|
count: int = 0
|
||||||
|
for batch_data in v:
|
||||||
|
for datum in batch_data:
|
||||||
|
paths.add((datum.node_id, datum.field_name))
|
||||||
|
count += 1
|
||||||
|
if len(paths) != count:
|
||||||
|
raise ValueError("Each batch data must have unique node_id and field_name")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class BatchSession(BaseModel):
|
class BatchSession(BaseModel):
|
||||||
@ -46,10 +77,7 @@ def uuid_string():
|
|||||||
|
|
||||||
class BatchProcess(BaseModel):
|
class BatchProcess(BaseModel):
|
||||||
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
|
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
|
||||||
batches: List[Batch] = Field(
|
batch: Batch = Field(description="List of batch configs to apply to this session")
|
||||||
description="List of batch configs to apply to this session",
|
|
||||||
default_factory=list,
|
|
||||||
)
|
|
||||||
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False)
|
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False)
|
||||||
graph: Graph = Field(description="The graph being executed")
|
graph: Graph = Field(description="The graph being executed")
|
||||||
|
|
||||||
@ -310,13 +338,12 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
) -> BatchProcess:
|
) -> BatchProcess:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
batches = [batch.json() for batch in batch_process.batches]
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
||||||
VALUES (?, ?, ?);
|
VALUES (?, ?, ?);
|
||||||
""",
|
""",
|
||||||
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
|
(batch_process.batch_id, batch_process.batch.json(), batch_process.graph.json()),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
@ -332,13 +359,11 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||||
|
|
||||||
batch_id = session_dict.get("batch_id", "unknown")
|
batch_id = session_dict.get("batch_id", "unknown")
|
||||||
batches_raw = session_dict.get("batches", "unknown")
|
batch_raw = session_dict.get("batches", "unknown")
|
||||||
graph_raw = session_dict.get("graph", "unknown")
|
graph_raw = session_dict.get("graph", "unknown")
|
||||||
canceled = session_dict.get("canceled", 0)
|
canceled = session_dict.get("canceled", 0)
|
||||||
batches = json.loads(batches_raw)
|
|
||||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
|
||||||
return BatchProcess(
|
return BatchProcess(
|
||||||
batch_id=batch_id, batches=batches, graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
|
batch_id=batch_id, batch=parse_raw_as(Batch, batch_raw), graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
|
@ -45,8 +45,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
item_type = get_args(self.__orig_class__)[0]
|
item_type = get_args(self.__orig_class__)[0]
|
||||||
parsed = parse_raw_as(item_type, item)
|
return parse_raw_as(item_type, item)
|
||||||
return parsed
|
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user