diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index c158dae5d6..e231bb63a4 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -6,7 +6,7 @@ from fastapi import Body, HTTPException, Path, Query, Response from fastapi.routing import APIRouter from pydantic.fields import Field -from invokeai.app.services.batch_manager_storage import BatchSessionNotFoundException +from invokeai.app.services.batch_manager_storage import BatchProcess, BatchSession, BatchSessionNotFoundException from ...invocations import * from ...invocations.baseinvocation import BaseInvocation @@ -82,6 +82,50 @@ async def cancel_batch( return Response(status_code=202) +@session_router.get( + "/batch/incomplete", + operation_id="list_incomplete_batches", + responses={200: {"model": list[BatchProcessResponse]}}, +) +async def list_incomplete_batches() -> list[BatchProcessResponse]: + """Gets a list of sessions, optionally searching""" + return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes() + + +@session_router.get( + "/batch", + operation_id="list__batches", + responses={200: {"model": list[BatchProcessResponse]}}, +) +async def list_batches() -> list[BatchProcessResponse]: + """Gets a list of sessions, optionally searching""" + return ApiDependencies.invoker.services.batch_manager.get_batch_processes() + + +@session_router.get( + "/batch/{batch_process_id}", + operation_id="get_batch", + responses={200: {"model": BatchProcess}}, +) +async def get_batch( + batch_process_id: str = Path(description="The id of the batch process to get"), +) -> BatchProcess: + """Gets a Batch Process""" + return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id) + + +@session_router.get( + "/batch/{batch_process_id}/sessions", + operation_id="get_batch", + responses={200: {"model": list[BatchSession]}}, +) +async def get_batch_sessions( + batch_process_id: str = Path(description="The id of the batch process to get"), +) -> list[BatchSession]: + """Gets a list of BatchSessions Batch Process""" + return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id) + + @session_router.get( "/", operation_id="list_sessions", diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index 016f1a8ba6..d92aeeae18 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -43,6 +43,22 @@ class BatchManagerBase(ABC): def cancel_batch_process(self, batch_process_id: str) -> None: pass + @abstractmethod + def get_batch(self, batch_id: str) -> BatchProcessResponse: + pass + + @abstractmethod + def get_batch_processes(self) -> list[BatchProcessResponse]: + pass + + @abstractmethod + def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]: + pass + + @abstractmethod + def get_sessions(self, batch_id: str) -> list[BatchSession]: + pass + class BatchManager(BatchManagerBase): """Responsible for managing currently running and scheduled batch jobs""" @@ -145,5 +161,37 @@ class BatchManager(BatchManagerBase): sessions.append(self.__batch_process_storage.create_session(batch_session)) return sessions + def get_sessions(self, batch_id: str) -> list[BatchSession]: + return self.__batch_process_storage.get_sessions(batch_id) + + def get_batch(self, batch_id: str) -> BatchProcess: + return self.__batch_process_storage.get(batch_id) + + def get_batch_processes(self) -> list[BatchProcessResponse]: + bps = self.__batch_process_storage.get_all() + res = list() + for bp in bps: + sessions = self.__batch_process_storage.get_sessions(bp.batch_id) + res.append( + BatchProcessResponse( + batch_id=bp.batch_id, + session_ids=[session.session_id for session in sessions], + ) + ) + return res + + def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]: + bps = self.__batch_process_storage.get_incomplete() + res = list() + for bp in bps: + sessions = self.__batch_process_storage.get_sessions(bp.batch_id) + res.append( + BatchProcessResponse( + batch_id=bp.batch_id, + session_ids=[session.session_id for session in sessions], + ) + ) + return res + def cancel_batch_process(self, batch_process_id: str) -> None: self.__batch_process_storage.cancel(batch_process_id) diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index 1e0f0916be..704521b9aa 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -6,7 +6,7 @@ from typing import List, Literal, Optional, Union, cast from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator -from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.invocations.primitives import ImageField from invokeai.app.services.graph import Graph @@ -160,6 +160,20 @@ class BatchProcessStorageBase(ABC): """Gets a BatchProcess record.""" pass + @abstractmethod + def get_all( + self, + ) -> list[BatchProcess]: + """Gets a BatchProcess record.""" + pass + + @abstractmethod + def get_incomplete( + self, + ) -> list[BatchProcess]: + """Gets a BatchProcess record.""" + pass + @abstractmethod def start( self, @@ -189,6 +203,11 @@ class BatchProcessStorageBase(ABC): """Gets a BatchSession by session_id""" pass + @abstractmethod + def get_sessions(self, batch_id: str) -> List[BatchSession]: + """Gets all BatchSession's for a given BatchProcess id.""" + pass + @abstractmethod def get_created_session(self, batch_id: str) -> BatchSession: """Gets the latest BatchSession with state `created`, for a given BatchProcess id.""" @@ -400,6 +419,56 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): raise BatchProcessNotFoundException return self._deserialize_batch_process(dict(result)) + def get_all( + self, + ) -> list[BatchProcess]: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT * + FROM batch_process + """ + ) + + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + except sqlite3.Error as e: + self._conn.rollback() + raise BatchProcessNotFoundException from e + finally: + self._lock.release() + if result is None: + return list() + return list(map(lambda r: self._deserialize_batch_process(dict(r)), result)) + + def get_incomplete( + self, + ) -> list[BatchProcess]: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT bp.* + FROM batch_process bp + WHERE bp.batch_id IN + ( + SELECT batch_id + FROM batch_session bs + WHERE state = 'created' + ); + """ + ) + + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + except sqlite3.Error as e: + self._conn.rollback() + raise BatchProcessNotFoundException from e + finally: + self._lock.release() + if result is None: + return list() + return list(map(lambda r: self._deserialize_batch_process(dict(r)), result)) + def start( self, batch_id: str, @@ -536,6 +605,29 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result)) return sessions + def get_sessions(self, batch_id: str) -> List[BatchSession]: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT * + FROM batch_session + WHERE batch_id = ?; + """, + (batch_id,), + ) + + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + except sqlite3.Error as e: + self._conn.rollback() + raise BatchSessionNotFoundException from e + finally: + self._lock.release() + if result is None: + raise BatchSessionNotFoundException + sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result)) + return sessions + def update_session_state( self, batch_id: str,