mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(backend): update batch_manager_storage.py
docstrings
This commit is contained in:
parent
df7370f9d9
commit
1f355d5810
@ -1,35 +1,40 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import cast
|
|
||||||
import uuid
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from typing import (
|
import uuid
|
||||||
List,
|
from abc import ABC, abstractmethod
|
||||||
Literal,
|
from typing import List, Literal, Optional, Union, cast
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
|
||||||
BaseInvocation,
|
|
||||||
)
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.services.graph import Graph
|
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
|
from invokeai.app.services.graph import Graph
|
||||||
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat, validator
|
|
||||||
|
|
||||||
|
|
||||||
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
|
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
|
||||||
|
|
||||||
|
|
||||||
class BatchData(BaseModel):
|
class BatchData(BaseModel):
|
||||||
node_id: str
|
"""
|
||||||
field_name: str
|
A batch data collection.
|
||||||
items: list[BatchDataType]
|
"""
|
||||||
|
|
||||||
|
node_id: str = Field(description="The node into which this batch data collection will be substituted.")
|
||||||
|
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
||||||
|
items: list[BatchDataType] = Field(
|
||||||
|
default_factory=list, description="The list of items to substitute into the node/field."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Batch(BaseModel):
|
class Batch(BaseModel):
|
||||||
data: list[list[BatchData]]
|
"""
|
||||||
|
A batch, consisting of a list of a list of batch data collections.
|
||||||
|
|
||||||
|
First, each inner list[BatchData] is zipped into a single batch data collection.
|
||||||
|
|
||||||
|
Then, the final batch collection is created by taking the Cartesian product of all batch data collections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
|
||||||
|
|
||||||
@validator("data")
|
@validator("data")
|
||||||
def validate_len(cls, v: list[list[BatchData]]):
|
def validate_len(cls, v: list[list[BatchData]]):
|
||||||
@ -61,11 +66,9 @@ class Batch(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class BatchSession(BaseModel):
|
class BatchSession(BaseModel):
|
||||||
batch_id: str = Field(description="Identifier for which batch this Index belongs to")
|
batch_id: str = Field(description="The Batch to which this BatchSession is attached.")
|
||||||
session_id: str = Field(description="Session ID Created for this Batch Index")
|
session_id: str = Field(description="The Session to which this BatchSession is attached.")
|
||||||
state: Literal["created", "completed", "inprogress", "error"] = Field(
|
state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
|
||||||
description="Is this session created, completed, in progress, or errored?"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def uuid_string():
|
def uuid_string():
|
||||||
@ -74,16 +77,14 @@ def uuid_string():
|
|||||||
|
|
||||||
|
|
||||||
class BatchProcess(BaseModel):
|
class BatchProcess(BaseModel):
|
||||||
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch")
|
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
||||||
batch: Batch = Field(description="List of batch configs to apply to this session")
|
batch: Batch = Field(description="The Batch to apply to this session.")
|
||||||
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False)
|
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
|
||||||
graph: Graph = Field(description="The graph being executed")
|
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
|
||||||
|
|
||||||
|
|
||||||
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
||||||
state: Literal["created", "completed", "inprogress", "error"] = Field(
|
state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
|
||||||
description="Is this session created, completed, in progress, or errored?"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchProcessNotFoundException(Exception):
|
class BatchProcessNotFoundException(Exception):
|
||||||
@ -157,7 +158,7 @@ class BatchProcessStorageBase(ABC):
|
|||||||
self,
|
self,
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
):
|
):
|
||||||
"""Start Batch Process record."""
|
"""Starts a BatchProcess record by marking its `canceled` attribute to False."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -165,7 +166,7 @@ class BatchProcessStorageBase(ABC):
|
|||||||
self,
|
self,
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
):
|
):
|
||||||
"""Cancel Batch Process record."""
|
"""Cancel BatchProcess record by setting its `canceled` attribute to True."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -178,17 +179,17 @@ class BatchProcessStorageBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_session(self, session_id: str) -> BatchSession:
|
def get_session(self, session_id: str) -> BatchSession:
|
||||||
"""Gets session by session_id"""
|
"""Gets a BatchSession by session_id"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_created_session(self, batch_id: str) -> BatchSession:
|
def get_created_session(self, batch_id: str) -> BatchSession:
|
||||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
"""Gets the latest BatchSession with state `created`, for a given BatchProcess id."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
"""Gets all BatchSession's with state `created`, for a given BatchProcess id."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
Loading…
x
Reference in New Issue
Block a user