Add a few more endpoints for managing batches

This commit is contained in:
Brandon Rising 2023-08-18 15:38:16 -04:00
parent 0282f46c71
commit 3e26214b83
3 changed files with 186 additions and 2 deletions

View File

@ -6,7 +6,7 @@ from fastapi import Body, HTTPException, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic.fields import Field
from invokeai.app.services.batch_manager_storage import BatchSessionNotFoundException
from invokeai.app.services.batch_manager_storage import BatchProcess, BatchSession, BatchSessionNotFoundException
from ...invocations import *
from ...invocations.baseinvocation import BaseInvocation
@ -82,6 +82,50 @@ async def cancel_batch(
return Response(status_code=202)
@session_router.get(
"/batch/incomplete",
operation_id="list_incomplete_batches",
responses={200: {"model": list[BatchProcessResponse]}},
)
async def list_incomplete_batches() -> list[BatchProcessResponse]:
"""Gets a list of sessions, optionally searching"""
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
@session_router.get(
"/batch",
operation_id="list__batches",
responses={200: {"model": list[BatchProcessResponse]}},
)
async def list_batches() -> list[BatchProcessResponse]:
"""Gets a list of sessions, optionally searching"""
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
@session_router.get(
"/batch/{batch_process_id}",
operation_id="get_batch",
responses={200: {"model": BatchProcess}},
)
async def get_batch(
batch_process_id: str = Path(description="The id of the batch process to get"),
) -> BatchProcess:
"""Gets a Batch Process"""
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
@session_router.get(
"/batch/{batch_process_id}/sessions",
operation_id="get_batch",
responses={200: {"model": list[BatchSession]}},
)
async def get_batch_sessions(
batch_process_id: str = Path(description="The id of the batch process to get"),
) -> list[BatchSession]:
"""Gets a list of BatchSessions Batch Process"""
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)
@session_router.get(
"/",
operation_id="list_sessions",

View File

@ -43,6 +43,22 @@ class BatchManagerBase(ABC):
def cancel_batch_process(self, batch_process_id: str) -> None:
pass
@abstractmethod
def get_batch(self, batch_id: str) -> BatchProcessResponse:
pass
@abstractmethod
def get_batch_processes(self) -> list[BatchProcessResponse]:
pass
@abstractmethod
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
pass
@abstractmethod
def get_sessions(self, batch_id: str) -> list[BatchSession]:
pass
class BatchManager(BatchManagerBase):
"""Responsible for managing currently running and scheduled batch jobs"""
@ -145,5 +161,37 @@ class BatchManager(BatchManagerBase):
sessions.append(self.__batch_process_storage.create_session(batch_session))
return sessions
def get_sessions(self, batch_id: str) -> list[BatchSession]:
return self.__batch_process_storage.get_sessions(batch_id)
def get_batch(self, batch_id: str) -> BatchProcess:
return self.__batch_process_storage.get(batch_id)
def get_batch_processes(self) -> list[BatchProcessResponse]:
bps = self.__batch_process_storage.get_all()
res = list()
for bp in bps:
sessions = self.__batch_process_storage.get_sessions(bp.batch_id)
res.append(
BatchProcessResponse(
batch_id=bp.batch_id,
session_ids=[session.session_id for session in sessions],
)
)
return res
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
bps = self.__batch_process_storage.get_incomplete()
res = list()
for bp in bps:
sessions = self.__batch_process_storage.get_sessions(bp.batch_id)
res.append(
BatchProcessResponse(
batch_id=bp.batch_id,
session_ids=[session.session_id for session in sessions],
)
)
return res
def cancel_batch_process(self, batch_process_id: str) -> None:
self.__batch_process_storage.cancel(batch_process_id)

View File

@ -6,7 +6,7 @@ from typing import List, Literal, Optional, Union, cast
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.graph import Graph
@ -160,6 +160,20 @@ class BatchProcessStorageBase(ABC):
"""Gets a BatchProcess record."""
pass
@abstractmethod
def get_all(
self,
) -> list[BatchProcess]:
"""Gets a BatchProcess record."""
pass
@abstractmethod
def get_incomplete(
self,
) -> list[BatchProcess]:
"""Gets a BatchProcess record."""
pass
@abstractmethod
def start(
self,
@ -189,6 +203,11 @@ class BatchProcessStorageBase(ABC):
"""Gets a BatchSession by session_id"""
pass
@abstractmethod
def get_sessions(self, batch_id: str) -> List[BatchSession]:
"""Gets all BatchSession's for a given BatchProcess id."""
pass
@abstractmethod
def get_created_session(self, batch_id: str) -> BatchSession:
"""Gets the latest BatchSession with state `created`, for a given BatchProcess id."""
@ -400,6 +419,56 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
raise BatchProcessNotFoundException
return self._deserialize_batch_process(dict(result))
def get_all(
self,
) -> list[BatchProcess]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_process
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
return list()
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
def get_incomplete(
self,
) -> list[BatchProcess]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT bp.*
FROM batch_process bp
WHERE bp.batch_id IN
(
SELECT batch_id
FROM batch_session bs
WHERE state = 'created'
);
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
return list()
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
def start(
self,
batch_id: str,
@ -536,6 +605,29 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def get_sessions(self, batch_id: str) -> List[BatchSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ?;
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def update_session_state(
self,
batch_id: str,