Testing out Spencer's batch data structure

This commit is contained in:
Brandon Rising
2023-08-16 15:21:11 -04:00
parent d6a5c2dbe3
commit 796ff34c8a
4 changed files with 59 additions and 32 deletions

View File

@ -4,13 +4,11 @@ import uuid
import sqlite3
import threading
from typing import (
Any,
List,
Literal,
Optional,
Union,
)
import json
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -18,7 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.services.graph import Graph
from invokeai.app.invocations.primitives import ImageField
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat, validator
invocations = BaseInvocation.get_invocations()
InvocationsUnion = Union[invocations] # type: ignore
@ -26,9 +24,42 @@ InvocationsUnion = Union[invocations] # type: ignore
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
class BatchData(BaseModel):
node_id: str
field_name: str
items: list[BatchDataType]
class Batch(BaseModel):
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
node_id: str = Field(description="ID of the node to batch")
data: list[list[BatchData]]
@validator("data")
def validate_len(cls, v: list[list[BatchData]]):
for batch_data in v:
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
raise ValueError("Zipped batch items must have all have same length")
return v
@validator("data")
def validate_types(cls, v: list[list[BatchData]]):
for batch_data in v:
for datum in batch_data:
for item in datum.items:
if not all(isinstance(item, type(i)) for i in datum.items):
raise TypeError("All items in a batch must have have same type")
return v
@validator("data")
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
paths: set[tuple[str, str]] = set()
count: int = 0
for batch_data in v:
for datum in batch_data:
paths.add((datum.node_id, datum.field_name))
count += 1
if len(paths) != count:
raise ValueError("Each batch data must have unique node_id and field_name")
return v
class BatchSession(BaseModel):
@ -46,10 +77,7 @@ def uuid_string():
class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
batches: List[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
batch: Batch = Field(description="List of batch configs to apply to this session")
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False)
graph: Graph = Field(description="The graph being executed")
@ -310,13 +338,12 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
) -> BatchProcess:
try:
self._lock.acquire()
batches = [batch.json() for batch in batch_process.batches]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
VALUES (?, ?, ?);
""",
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
(batch_process.batch_id, batch_process.batch.json(), batch_process.graph.json()),
)
self._conn.commit()
except sqlite3.Error as e:
@ -332,13 +359,11 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
batches_raw = session_dict.get("batches", "unknown")
batch_raw = session_dict.get("batches", "unknown")
graph_raw = session_dict.get("graph", "unknown")
canceled = session_dict.get("canceled", 0)
batches = json.loads(batches_raw)
batches = [parse_raw_as(Batch, batch) for batch in batches]
return BatchProcess(
batch_id=batch_id, batches=batches, graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
batch_id=batch_id, batch=parse_raw_as(Batch, batch_raw), graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
)
def get(