mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Testing out Spencer's batch data structure
This commit is contained in:
@ -4,13 +4,11 @@ import uuid
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
import json
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -18,7 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat
|
||||
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat, validator
|
||||
|
||||
invocations = BaseInvocation.get_invocations()
|
||||
InvocationsUnion = Union[invocations] # type: ignore
|
||||
@ -26,9 +24,42 @@ InvocationsUnion = Union[invocations] # type: ignore
|
||||
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
|
||||
|
||||
|
||||
class BatchData(BaseModel):
|
||||
node_id: str
|
||||
field_name: str
|
||||
items: list[BatchDataType]
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
|
||||
node_id: str = Field(description="ID of the node to batch")
|
||||
data: list[list[BatchData]]
|
||||
|
||||
@validator("data")
|
||||
def validate_len(cls, v: list[list[BatchData]]):
|
||||
for batch_data in v:
|
||||
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
|
||||
raise ValueError("Zipped batch items must have all have same length")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_types(cls, v: list[list[BatchData]]):
|
||||
for batch_data in v:
|
||||
for datum in batch_data:
|
||||
for item in datum.items:
|
||||
if not all(isinstance(item, type(i)) for i in datum.items):
|
||||
raise TypeError("All items in a batch must have have same type")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
|
||||
paths: set[tuple[str, str]] = set()
|
||||
count: int = 0
|
||||
for batch_data in v:
|
||||
for datum in batch_data:
|
||||
paths.add((datum.node_id, datum.field_name))
|
||||
count += 1
|
||||
if len(paths) != count:
|
||||
raise ValueError("Each batch data must have unique node_id and field_name")
|
||||
return v
|
||||
|
||||
|
||||
class BatchSession(BaseModel):
|
||||
@ -46,10 +77,7 @@ def uuid_string():
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
|
||||
batches: List[Batch] = Field(
|
||||
description="List of batch configs to apply to this session",
|
||||
default_factory=list,
|
||||
)
|
||||
batch: Batch = Field(description="List of batch configs to apply to this session")
|
||||
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False)
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
@ -310,13 +338,12 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
batches = [batch.json() for batch in batch_process.batches]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
||||
VALUES (?, ?, ?);
|
||||
""",
|
||||
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
|
||||
(batch_process.batch_id, batch_process.batch.json(), batch_process.graph.json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
@ -332,13 +359,11 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
batches_raw = session_dict.get("batches", "unknown")
|
||||
batch_raw = session_dict.get("batches", "unknown")
|
||||
graph_raw = session_dict.get("graph", "unknown")
|
||||
canceled = session_dict.get("canceled", 0)
|
||||
batches = json.loads(batches_raw)
|
||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||
return BatchProcess(
|
||||
batch_id=batch_id, batches=batches, graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
|
||||
batch_id=batch_id, batch=parse_raw_as(Batch, batch_raw), graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
|
||||
)
|
||||
|
||||
def get(
|
||||
|
Reference in New Issue
Block a user