mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add a few more endpoints for managing batches
This commit is contained in:
@ -6,7 +6,7 @@ from typing import List, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.graph import Graph
|
||||
|
||||
@ -160,6 +160,20 @@ class BatchProcessStorageBase(ABC):
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_incomplete(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(
|
||||
self,
|
||||
@ -189,6 +203,11 @@ class BatchProcessStorageBase(ABC):
|
||||
"""Gets a BatchSession by session_id"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||
"""Gets all BatchSession's for a given BatchProcess id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_created_session(self, batch_id: str) -> BatchSession:
|
||||
"""Gets the latest BatchSession with state `created`, for a given BatchProcess id."""
|
||||
@ -400,6 +419,56 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
raise BatchProcessNotFoundException
|
||||
return self._deserialize_batch_process(dict(result))
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_process
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return list()
|
||||
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||
|
||||
def get_incomplete(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT bp.*
|
||||
FROM batch_process bp
|
||||
WHERE bp.batch_id IN
|
||||
(
|
||||
SELECT batch_id
|
||||
FROM batch_session bs
|
||||
WHERE state = 'created'
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return list()
|
||||
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||
|
||||
def start(
|
||||
self,
|
||||
batch_id: str,
|
||||
@ -536,6 +605,29 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
def get_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
|
Reference in New Issue
Block a user