Add a few more endpoints for managing batches

This commit is contained in:
Brandon Rising
2023-08-18 15:38:16 -04:00
parent 0282f46c71
commit 3e26214b83
3 changed files with 186 additions and 2 deletions

View File

@ -6,7 +6,7 @@ from typing import List, Literal, Optional, Union, cast
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.graph import Graph
@ -160,6 +160,20 @@ class BatchProcessStorageBase(ABC):
"""Gets a BatchProcess record."""
pass
@abstractmethod
def get_all(
self,
) -> list[BatchProcess]:
"""Gets a BatchProcess record."""
pass
@abstractmethod
def get_incomplete(
self,
) -> list[BatchProcess]:
"""Gets a BatchProcess record."""
pass
@abstractmethod
def start(
self,
@ -189,6 +203,11 @@ class BatchProcessStorageBase(ABC):
"""Gets a BatchSession by session_id"""
pass
@abstractmethod
def get_sessions(self, batch_id: str) -> List[BatchSession]:
"""Gets all BatchSession's for a given BatchProcess id."""
pass
@abstractmethod
def get_created_session(self, batch_id: str) -> BatchSession:
"""Gets the latest BatchSession with state `created`, for a given BatchProcess id."""
@ -400,6 +419,56 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
raise BatchProcessNotFoundException
return self._deserialize_batch_process(dict(result))
def get_all(
self,
) -> list[BatchProcess]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_process
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
return list()
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
def get_incomplete(
self,
) -> list[BatchProcess]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT bp.*
FROM batch_process bp
WHERE bp.batch_id IN
(
SELECT batch_id
FROM batch_session bs
WHERE state = 'created'
);
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
return list()
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
def start(
self,
batch_id: str,
@ -536,6 +605,29 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def get_sessions(self, batch_id: str) -> List[BatchSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ?;
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def update_session_state(
self,
batch_id: str,