mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add a few more endpoints for managing batches
This commit is contained in:
parent
0282f46c71
commit
3e26214b83
@ -6,7 +6,7 @@ from fastapi import Body, HTTPException, Path, Query, Response
|
|||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
from invokeai.app.services.batch_manager_storage import BatchSessionNotFoundException
|
from invokeai.app.services.batch_manager_storage import BatchProcess, BatchSession, BatchSessionNotFoundException
|
||||||
|
|
||||||
from ...invocations import *
|
from ...invocations import *
|
||||||
from ...invocations.baseinvocation import BaseInvocation
|
from ...invocations.baseinvocation import BaseInvocation
|
||||||
@ -82,6 +82,50 @@ async def cancel_batch(
|
|||||||
return Response(status_code=202)
|
return Response(status_code=202)
|
||||||
|
|
||||||
|
|
||||||
|
@session_router.get(
|
||||||
|
"/batch/incomplete",
|
||||||
|
operation_id="list_incomplete_batches",
|
||||||
|
responses={200: {"model": list[BatchProcessResponse]}},
|
||||||
|
)
|
||||||
|
async def list_incomplete_batches() -> list[BatchProcessResponse]:
|
||||||
|
"""Gets a list of sessions, optionally searching"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
|
||||||
|
|
||||||
|
|
||||||
|
@session_router.get(
|
||||||
|
"/batch",
|
||||||
|
operation_id="list__batches",
|
||||||
|
responses={200: {"model": list[BatchProcessResponse]}},
|
||||||
|
)
|
||||||
|
async def list_batches() -> list[BatchProcessResponse]:
|
||||||
|
"""Gets a list of sessions, optionally searching"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
|
||||||
|
|
||||||
|
|
||||||
|
@session_router.get(
|
||||||
|
"/batch/{batch_process_id}",
|
||||||
|
operation_id="get_batch",
|
||||||
|
responses={200: {"model": BatchProcess}},
|
||||||
|
)
|
||||||
|
async def get_batch(
|
||||||
|
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||||
|
) -> BatchProcess:
|
||||||
|
"""Gets a Batch Process"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
|
||||||
|
|
||||||
|
|
||||||
|
@session_router.get(
|
||||||
|
"/batch/{batch_process_id}/sessions",
|
||||||
|
operation_id="get_batch",
|
||||||
|
responses={200: {"model": list[BatchSession]}},
|
||||||
|
)
|
||||||
|
async def get_batch_sessions(
|
||||||
|
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||||
|
) -> list[BatchSession]:
|
||||||
|
"""Gets a list of BatchSessions Batch Process"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)
|
||||||
|
|
||||||
|
|
||||||
@session_router.get(
|
@session_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_sessions",
|
operation_id="list_sessions",
|
||||||
|
@ -43,6 +43,22 @@ class BatchManagerBase(ABC):
|
|||||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_batch(self, batch_id: str) -> BatchProcessResponse:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BatchManager(BatchManagerBase):
|
class BatchManager(BatchManagerBase):
|
||||||
"""Responsible for managing currently running and scheduled batch jobs"""
|
"""Responsible for managing currently running and scheduled batch jobs"""
|
||||||
@ -145,5 +161,37 @@ class BatchManager(BatchManagerBase):
|
|||||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
|
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||||
|
return self.__batch_process_storage.get_sessions(batch_id)
|
||||||
|
|
||||||
|
def get_batch(self, batch_id: str) -> BatchProcess:
|
||||||
|
return self.__batch_process_storage.get(batch_id)
|
||||||
|
|
||||||
|
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
||||||
|
bps = self.__batch_process_storage.get_all()
|
||||||
|
res = list()
|
||||||
|
for bp in bps:
|
||||||
|
sessions = self.__batch_process_storage.get_sessions(bp.batch_id)
|
||||||
|
res.append(
|
||||||
|
BatchProcessResponse(
|
||||||
|
batch_id=bp.batch_id,
|
||||||
|
session_ids=[session.session_id for session in sessions],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
||||||
|
bps = self.__batch_process_storage.get_incomplete()
|
||||||
|
res = list()
|
||||||
|
for bp in bps:
|
||||||
|
sessions = self.__batch_process_storage.get_sessions(bp.batch_id)
|
||||||
|
res.append(
|
||||||
|
BatchProcessResponse(
|
||||||
|
batch_id=bp.batch_id,
|
||||||
|
session_ids=[session.session_id for session in sessions],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||||
self.__batch_process_storage.cancel(batch_process_id)
|
self.__batch_process_storage.cancel(batch_process_id)
|
||||||
|
@ -6,7 +6,7 @@ from typing import List, Literal, Optional, Union, cast
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
|
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.services.graph import Graph
|
from invokeai.app.services.graph import Graph
|
||||||
|
|
||||||
@ -160,6 +160,20 @@ class BatchProcessStorageBase(ABC):
|
|||||||
"""Gets a BatchProcess record."""
|
"""Gets a BatchProcess record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BatchProcess]:
|
||||||
|
"""Gets a BatchProcess record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_incomplete(
|
||||||
|
self,
|
||||||
|
) -> list[BatchProcess]:
|
||||||
|
"""Gets a BatchProcess record."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def start(
|
def start(
|
||||||
self,
|
self,
|
||||||
@ -189,6 +203,11 @@ class BatchProcessStorageBase(ABC):
|
|||||||
"""Gets a BatchSession by session_id"""
|
"""Gets a BatchSession by session_id"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||||
|
"""Gets all BatchSession's for a given BatchProcess id."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_created_session(self, batch_id: str) -> BatchSession:
|
def get_created_session(self, batch_id: str) -> BatchSession:
|
||||||
"""Gets the latest BatchSession with state `created`, for a given BatchProcess id."""
|
"""Gets the latest BatchSession with state `created`, for a given BatchProcess id."""
|
||||||
@ -400,6 +419,56 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
raise BatchProcessNotFoundException
|
raise BatchProcessNotFoundException
|
||||||
return self._deserialize_batch_process(dict(result))
|
return self._deserialize_batch_process(dict(result))
|
||||||
|
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BatchProcess]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM batch_process
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
return list()
|
||||||
|
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||||
|
|
||||||
|
def get_incomplete(
|
||||||
|
self,
|
||||||
|
) -> list[BatchProcess]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT bp.*
|
||||||
|
FROM batch_process bp
|
||||||
|
WHERE bp.batch_id IN
|
||||||
|
(
|
||||||
|
SELECT batch_id
|
||||||
|
FROM batch_session bs
|
||||||
|
WHERE state = 'created'
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
return list()
|
||||||
|
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||||
|
|
||||||
def start(
|
def start(
|
||||||
self,
|
self,
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
@ -536,6 +605,29 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
|
def get_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM batch_session
|
||||||
|
WHERE batch_id = ?;
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise BatchSessionNotFoundException
|
||||||
|
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||||
|
return sessions
|
||||||
|
|
||||||
def update_session_state(
|
def update_session_state(
|
||||||
self,
|
self,
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
|
Loading…
Reference in New Issue
Block a user