From f8d8b16267eb42f359f7c816fe7b89590a5f2f7d Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Mon, 14 Aug 2023 11:01:31 -0400 Subject: [PATCH] Run python black --- invokeai/app/services/batch_manager.py | 19 +++---- .../app/services/batch_manager_storage.py | 57 ++++++------------- invokeai/app/services/sqlite.py | 2 +- tests/nodes/test_invoker.py | 11 ++-- 4 files changed, 30 insertions(+), 59 deletions(-) diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index ce00fca708..1a67916cf7 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -19,6 +19,7 @@ from invokeai.app.services.batch_manager_storage import ( BatchSessionChanges, ) + class BatchProcessResponse(BaseModel): batch_id: str = Field(description="ID for the batch") session_ids: list[str] = Field(description="List of session IDs created for this batch") @@ -75,9 +76,7 @@ class BatchManager(BatchManagerBase): batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"]) if not batch_session: return - updateSession = BatchSessionChanges( - state='error' if err else 'completed' - ) + updateSession = BatchSessionChanges(state="error" if err else "completed") batch_session = self.__batch_process_storage.update_session_state( batch_session.batch_id, batch_session.session_id, @@ -105,18 +104,18 @@ class BatchManager(BatchManagerBase): return GraphExecutionState(graph=graph) def run_batch_process(self, batch_id: str): - try: + try: created_session = self.__batch_process_storage.get_created_session(batch_id) except BatchSessionNotFoundException: return ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id) self.__invoker.invoke(ges, invoke_all=True) - + def _valid_batch_config(self, batch_process: BatchProcess) -> bool: # TODO: Check that the node_ids in the batches are unique # TODO: Validate data types are correct for each batch data return True - + def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse: batch_process = BatchProcess( batches=batches, @@ -130,7 +129,7 @@ class BatchManager(BatchManagerBase): batch_id=batch_process.batch_id, session_ids=[session.session_id for session in sessions], ) - + def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]: batch_indices = list() sessions = list() @@ -140,11 +139,7 @@ class BatchManager(BatchManagerBase): for bi in all_batch_indices: ges = self._create_batch_session(batch_process, bi) self.__invoker.services.graph_execution_manager.set(ges) - batch_session = BatchSession( - batch_id=batch_process.batch_id, - session_id=ges.id, - state="created" - ) + batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created") sessions.append(self.__batch_process_storage.create_session(batch_session)) return sessions diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index 025ee4a338..799be67e66 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -25,6 +25,7 @@ InvocationsUnion = Union[invocations] # type: ignore BatchDataType = Union[str, int, float, ImageField] + class Batch(BaseModel): data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value") node_id: str = Field(description="ID of the node to batch") @@ -42,6 +43,7 @@ def uuid_string(): res = uuid.uuid4() return str(res) + class BatchProcess(BaseModel): batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch") batches: List[Batch] = Field( @@ -134,33 +136,24 @@ class BatchProcessStorageBase(ABC): @abstractmethod def create_session( - self, - session: BatchSession, + self, + session: BatchSession, ) -> BatchSession: """Creates a Batch Session attached to a Batch Process.""" pass - + @abstractmethod - def get_session( - self, - session_id: str - ) -> BatchSession: + def get_session(self, session_id: str) -> BatchSession: """Gets session by session_id""" pass @abstractmethod - def get_created_session( - self, - batch_id: str - ) -> BatchSession: + def get_created_session(self, batch_id: str) -> BatchSession: """Gets all created Batch Sessions for a given Batch Process id.""" pass @abstractmethod - def get_created_sessions( - self, - batch_id: str - ) -> List[BatchSession]: + def get_created_sessions(self, batch_id: str) -> List[BatchSession]: """Gets all created Batch Sessions for a given Batch Process id.""" pass @@ -339,10 +332,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): batches = json.loads(batches_raw) batches = [parse_raw_as(Batch, batch) for batch in batches] return BatchProcess( - batch_id=batch_id, - batches=batches, - graph=parse_raw_as(Graph, graph_raw), - canceled = canceled == 1 + batch_id=batch_id, batches=batches, graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1 ) def get( @@ -357,7 +347,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): FROM batch_process WHERE batch_id = ?; """, - (batch_id,) + (batch_id,), ) result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) @@ -370,7 +360,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): raise BatchProcessNotFoundException return self._deserialize_batch_process(dict(result)) - def cancel( self, batch_id: str, @@ -393,8 +382,8 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): self._lock.release() def create_session( - self, - session: BatchSession, + self, + session: BatchSession, ) -> BatchSession: try: self._lock.acquire() @@ -413,11 +402,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): self._lock.release() return self.get_session(session.session_id) - - def get_session( - self, - session_id: str - ) -> BatchSession: + def get_session(self, session_id: str) -> BatchSession: try: self._lock.acquire() self._cursor.execute( @@ -454,11 +439,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): state=state, ) - - def get_created_session( - self, - batch_id: str - ) -> BatchSession: + def get_created_session(self, batch_id: str) -> BatchSession: try: self._lock.acquire() self._cursor.execute( @@ -481,11 +462,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): session = self._deserialize_batch_session(dict(result)) return session - - def get_created_sessions( - self, - batch_id: str - ) -> List[BatchSession]: + def get_created_sessions(self, batch_id: str) -> List[BatchSession]: try: self._lock.acquire() self._cursor.execute( @@ -497,7 +474,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): (batch_id,), ) - result = cast(list[sqlite3.Row], self._cursor.fetchall()) except sqlite3.Error as e: self._conn.rollback() @@ -508,7 +484,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): raise BatchSessionNotFoundException sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result)) return sessions - def update_session_state( self, @@ -536,4 +511,4 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): raise BatchSessionSaveException from e finally: self._lock.release() - return self.get_session(session_id) \ No newline at end of file + return self.get_session(session_id) diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index 1b70a98929..f016995e7b 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -29,7 +29,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._conn = sqlite3.connect( self._filename, check_same_thread=False ) # TODO: figure out a better threading solution - self._conn.execute('pragma journal_mode=wal') + self._conn.execute("pragma journal_mode=wal") self._cursor = self._conn.cursor() self._create_table() diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 9051fcb403..7e7a226023 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -39,7 +39,7 @@ def simple_batches(): batches = [ Batch( node_id=1, - data= [ + data=[ { "prompt": "Tomato sushi", }, @@ -55,11 +55,11 @@ def simple_batches(): { "prompt": "Tea sushi", }, - ] + ], ), Batch( node_id="2", - data= [ + data=[ { "prompt2": "Ume sushi", }, @@ -75,8 +75,8 @@ def simple_batches(): { "prompt2": "Cha sushi", }, - ] - ) + ], + ), ] return batches @@ -185,6 +185,7 @@ def test_handles_errors(mock_invoker: Invoker): assert all((i in g.errors for i in g.source_prepared_mapping["1"])) + def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches): batch_process_res = mock_invoker.services.batch_manager.create_batch_process( batches=simple_batches,