From 1f355d581015e3732ad670850299250feb9fdfef Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 17 Aug 2023 12:31:51 +1000 Subject: [PATCH] feat(backend): update `batch_manager_storage.py` docstrings --- .../app/services/batch_manager_storage.py | 85 ++++++++++--------- 1 file changed, 43 insertions(+), 42 deletions(-) diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index 490ac0a1bc..76c1664536 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -1,35 +1,40 @@ -from abc import ABC, abstractmethod -from typing import cast -import uuid import sqlite3 import threading -from typing import ( - List, - Literal, - Optional, - Union, -) +import uuid +from abc import ABC, abstractmethod +from typing import List, Literal, Optional, Union, cast -from invokeai.app.invocations.baseinvocation import ( - BaseInvocation, -) -from invokeai.app.services.graph import Graph +from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator + +from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.primitives import ImageField - -from pydantic import BaseModel, Field, Extra, parse_raw_as, StrictStr, StrictInt, StrictFloat, validator - +from invokeai.app.services.graph import Graph BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField] class BatchData(BaseModel): - node_id: str - field_name: str - items: list[BatchDataType] + """ + A batch data collection. + """ + + node_id: str = Field(description="The node into which this batch data collection will be substituted.") + field_name: str = Field(description="The field into which this batch data collection will be substituted.") + items: list[BatchDataType] = Field( + default_factory=list, description="The list of items to substitute into the node/field." + ) class Batch(BaseModel): - data: list[list[BatchData]] + """ + A batch, consisting of a list of a list of batch data collections. + + First, each inner list[BatchData] is zipped into a single batch data collection. + + Then, the final batch collection is created by taking the Cartesian product of all batch data collections. + """ + + data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.") @validator("data") def validate_len(cls, v: list[list[BatchData]]): @@ -61,11 +66,9 @@ class Batch(BaseModel): class BatchSession(BaseModel): - batch_id: str = Field(description="Identifier for which batch this Index belongs to") - session_id: str = Field(description="Session ID Created for this Batch Index") - state: Literal["created", "completed", "inprogress", "error"] = Field( - description="Is this session created, completed, in progress, or errored?" - ) + batch_id: str = Field(description="The Batch to which this BatchSession is attached.") + session_id: str = Field(description="The Session to which this BatchSession is attached.") + state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession") def uuid_string(): @@ -74,16 +77,14 @@ def uuid_string(): class BatchProcess(BaseModel): - batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch") - batch: Batch = Field(description="List of batch configs to apply to this session") - canceled: bool = Field(description="Flag for saying whether or not to run sessions from this batch", default=False) - graph: Graph = Field(description="The graph being executed") + batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.") + batch: Batch = Field(description="The Batch to apply to this session.") + canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False) + graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.") class BatchSessionChanges(BaseModel, extra=Extra.forbid): - state: Literal["created", "completed", "inprogress", "error"] = Field( - description="Is this session created, completed, in progress, or errored?" - ) + state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession") class BatchProcessNotFoundException(Exception): @@ -133,7 +134,7 @@ class BatchProcessStorageBase(ABC): @abstractmethod def delete(self, batch_id: str) -> None: - """Deletes a Batch Process record.""" + """Deletes a BatchProcess record.""" pass @abstractmethod @@ -141,7 +142,7 @@ class BatchProcessStorageBase(ABC): self, batch_process: BatchProcess, ) -> BatchProcess: - """Saves a Batch Process record.""" + """Saves a BatchProcess record.""" pass @abstractmethod @@ -149,7 +150,7 @@ class BatchProcessStorageBase(ABC): self, batch_id: str, ) -> BatchProcess: - """Gets a Batch Process record.""" + """Gets a BatchProcess record.""" pass @abstractmethod @@ -157,7 +158,7 @@ class BatchProcessStorageBase(ABC): self, batch_id: str, ): - """Start Batch Process record.""" + """Starts a BatchProcess record by marking its `canceled` attribute to False.""" pass @abstractmethod @@ -165,7 +166,7 @@ class BatchProcessStorageBase(ABC): self, batch_id: str, ): - """Cancel Batch Process record.""" + """Cancel BatchProcess record by setting its `canceled` attribute to True.""" pass @abstractmethod @@ -173,22 +174,22 @@ class BatchProcessStorageBase(ABC): self, session: BatchSession, ) -> BatchSession: - """Creates a Batch Session attached to a Batch Process.""" + """Creates a BatchSession attached to a BatchProcess.""" pass @abstractmethod def get_session(self, session_id: str) -> BatchSession: - """Gets session by session_id""" + """Gets a BatchSession by session_id""" pass @abstractmethod def get_created_session(self, batch_id: str) -> BatchSession: - """Gets all created Batch Sessions for a given Batch Process id.""" + """Gets the latest BatchSession with state `created`, for a given BatchProcess id.""" pass @abstractmethod def get_created_sessions(self, batch_id: str) -> List[BatchSession]: - """Gets all created Batch Sessions for a given Batch Process id.""" + """Gets all BatchSession's with state `created`, for a given BatchProcess id.""" pass @abstractmethod @@ -198,7 +199,7 @@ class BatchProcessStorageBase(ABC): session_id: str, changes: BatchSessionChanges, ) -> BatchSession: - """Updates the state of a Batch Session record.""" + """Updates the state of a BatchSession record.""" pass