feat(backend): update batch_manager_storage.py docstrings

This commit is contained in:
psychedelicious 2023-08-17 12:31:51 +10:00
parent df7370f9d9
commit 1f355d5810

View File

@ -1,35 +1,40 @@
from abc import ABC, abstractmethod
from typing import cast
import uuid
import sqlite3 import sqlite3
import threading import threading
from typing import ( import uuid
List, from abc import ABC, abstractmethod
Literal, from typing import List, Literal, Optional, Union, cast
Optional,
Union,
)
from invokeai.app.invocations.baseinvocation import ( from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
BaseInvocation,
) from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.graph import Graph
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.graph import Graph
from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat, validator
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField] BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
class BatchData(BaseModel): class BatchData(BaseModel):
node_id: str """
field_name: str A batch data collection.
items: list[BatchDataType] """
node_id: str = Field(description="The node into which this batch data collection will be substituted.")
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
items: list[BatchDataType] = Field(
default_factory=list, description="The list of items to substitute into the node/field."
)
class Batch(BaseModel): class Batch(BaseModel):
data: list[list[BatchData]] """
A batch, consisting of a list of a list of batch data collections.
First, each inner list[BatchData] is zipped into a single batch data collection.
Then, the final batch collection is created by taking the Cartesian product of all batch data collections.
"""
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
@validator("data") @validator("data")
def validate_len(cls, v: list[list[BatchData]]): def validate_len(cls, v: list[list[BatchData]]):
@ -61,11 +66,9 @@ class Batch(BaseModel):
class BatchSession(BaseModel): class BatchSession(BaseModel):
batch_id: str = Field(description="Identifier for which batch this Index belongs to") batch_id: str = Field(description="The Batch to which this BatchSession is attached.")
session_id: str = Field(description="Session ID Created for this Batch Index") session_id: str = Field(description="The Session to which this BatchSession is attached.")
state: Literal["created", "completed", "inprogress", "error"] = Field( state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
description="Is this session created, completed, in progress, or errored?"
)
def uuid_string(): def uuid_string():
@ -74,16 +77,14 @@ def uuid_string():
class BatchProcess(BaseModel): class BatchProcess(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch") batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
batch: Batch = Field(description="List of batch configs to apply to this session") batch: Batch = Field(description="The Batch to apply to this session.")
canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False) canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
graph: Graph = Field(description="The graph being executed") graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
class BatchSessionChanges(BaseModel, extra=Extra.forbid): class BatchSessionChanges(BaseModel, extra=Extra.forbid):
state: Literal["created", "completed", "inprogress", "error"] = Field( state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
description="Is this session created, completed, in progress, or errored?"
)
class BatchProcessNotFoundException(Exception): class BatchProcessNotFoundException(Exception):
@ -133,7 +134,7 @@ class BatchProcessStorageBase(ABC):
@abstractmethod @abstractmethod
def delete(self, batch_id: str) -> None: def delete(self, batch_id: str) -> None:
"""Deletes a Batch Process record.""" """Deletes a BatchProcess record."""
pass pass
@abstractmethod @abstractmethod
@ -141,7 +142,7 @@ class BatchProcessStorageBase(ABC):
self, self,
batch_process: BatchProcess, batch_process: BatchProcess,
) -> BatchProcess: ) -> BatchProcess:
"""Saves a Batch Process record.""" """Saves a BatchProcess record."""
pass pass
@abstractmethod @abstractmethod
@ -149,7 +150,7 @@ class BatchProcessStorageBase(ABC):
self, self,
batch_id: str, batch_id: str,
) -> BatchProcess: ) -> BatchProcess:
"""Gets a Batch Process record.""" """Gets a BatchProcess record."""
pass pass
@abstractmethod @abstractmethod
@ -157,7 +158,7 @@ class BatchProcessStorageBase(ABC):
self, self,
batch_id: str, batch_id: str,
): ):
"""Start Batch Process record.""" """Starts a BatchProcess record by marking its `canceled` attribute to False."""
pass pass
@abstractmethod @abstractmethod
@ -165,7 +166,7 @@ class BatchProcessStorageBase(ABC):
self, self,
batch_id: str, batch_id: str,
): ):
"""Cancel Batch Process record.""" """Cancel BatchProcess record by setting its `canceled` attribute to True."""
pass pass
@abstractmethod @abstractmethod
@ -173,22 +174,22 @@ class BatchProcessStorageBase(ABC):
self, self,
session: BatchSession, session: BatchSession,
) -> BatchSession: ) -> BatchSession:
"""Creates a Batch Session attached to a Batch Process.""" """Creates a BatchSession attached to a BatchProcess."""
pass pass
@abstractmethod @abstractmethod
def get_session(self, session_id: str) -> BatchSession: def get_session(self, session_id: str) -> BatchSession:
"""Gets session by session_id""" """Gets a BatchSession by session_id"""
pass pass
@abstractmethod @abstractmethod
def get_created_session(self, batch_id: str) -> BatchSession: def get_created_session(self, batch_id: str) -> BatchSession:
"""Gets all created Batch Sessions for a given Batch Process id.""" """Gets the latest BatchSession with state `created`, for a given BatchProcess id."""
pass pass
@abstractmethod @abstractmethod
def get_created_sessions(self, batch_id: str) -> List[BatchSession]: def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
"""Gets all created Batch Sessions for a given Batch Process id.""" """Gets all BatchSession's with state `created`, for a given BatchProcess id."""
pass pass
@abstractmethod @abstractmethod
@ -198,7 +199,7 @@ class BatchProcessStorageBase(ABC):
session_id: str, session_id: str,
changes: BatchSessionChanges, changes: BatchSessionChanges,
) -> BatchSession: ) -> BatchSession:
"""Updates the state of a Batch Session record.""" """Updates the state of a BatchSession record."""
pass pass