mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Break apart create/start logic
This commit is contained in:
parent
6cb90e01de
commit
15e7ca1baa
@ -52,10 +52,25 @@ async def create_batch(
|
|||||||
) -> BatchProcessResponse:
|
) -> BatchProcessResponse:
|
||||||
"""Creates and starts a new new batch process"""
|
"""Creates and starts a new new batch process"""
|
||||||
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
||||||
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
|
||||||
return batch_process_res
|
return batch_process_res
|
||||||
|
|
||||||
|
|
||||||
|
@session_router.put(
|
||||||
|
"/batch/{batch_process_id}/invoke",
|
||||||
|
operation_id="start_batch",
|
||||||
|
responses={
|
||||||
|
200: {"model": BatchProcessResponse},
|
||||||
|
400: {"description": "Invalid json"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def start_batch(
|
||||||
|
batch_process_id: str = Path(description="ID of Batch to start"),
|
||||||
|
) -> BatchProcessResponse:
|
||||||
|
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
|
||||||
|
|
||||||
|
return Response(status_code=202)
|
||||||
|
|
||||||
|
|
||||||
@session_router.delete(
|
@session_router.delete(
|
||||||
"/batch/{batch_process_id}",
|
"/batch/{batch_process_id}",
|
||||||
operation_id="cancel_batch",
|
operation_id="cancel_batch",
|
||||||
|
@ -47,7 +47,6 @@ class BatchManager(BatchManagerBase):
|
|||||||
"""Responsible for managing currently running and scheduled batch jobs"""
|
"""Responsible for managing currently running and scheduled batch jobs"""
|
||||||
|
|
||||||
__invoker: Invoker
|
__invoker: Invoker
|
||||||
__batches: list[BatchProcess]
|
|
||||||
__batch_process_storage: BatchProcessStorageBase
|
__batch_process_storage: BatchProcessStorageBase
|
||||||
|
|
||||||
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
|
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
|
||||||
@ -55,9 +54,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
self.__batch_process_storage = batch_process_storage
|
self.__batch_process_storage = batch_process_storage
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
def start(self, invoker: Invoker) -> None:
|
||||||
# if we do want multithreading at some point, we could make this configurable
|
|
||||||
self.__invoker = invoker
|
self.__invoker = invoker
|
||||||
self.__batches = list()
|
|
||||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
||||||
|
|
||||||
async def on_event(self, event: Event):
|
async def on_event(self, event: Event):
|
||||||
@ -87,7 +84,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
self.run_batch_process(batch_process.batch_id)
|
self.run_batch_process(batch_process.batch_id)
|
||||||
|
|
||||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
||||||
graph = copy.deepcopy(batch_process.graph)
|
graph = batch_process.graph.copy(deep=True)
|
||||||
batches = batch_process.batches
|
batches = batch_process.batches
|
||||||
g = graph.nx_graph_flat()
|
g = graph.nx_graph_flat()
|
||||||
sorted_nodes = nx.topological_sort(g)
|
sorted_nodes = nx.topological_sort(g)
|
||||||
@ -104,6 +101,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
return GraphExecutionState(graph=graph)
|
return GraphExecutionState(graph=graph)
|
||||||
|
|
||||||
def run_batch_process(self, batch_id: str):
|
def run_batch_process(self, batch_id: str):
|
||||||
|
self.__batch_process_storage.start(batch_id)
|
||||||
try:
|
try:
|
||||||
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
||||||
except BatchSessionNotFoundException:
|
except BatchSessionNotFoundException:
|
||||||
|
@ -126,6 +126,14 @@ class BatchProcessStorageBase(ABC):
|
|||||||
"""Gets a Batch Process record."""
|
"""Gets a Batch Process record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def start(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
):
|
||||||
|
"""Start Batch Process record."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel(
|
def cancel(
|
||||||
self,
|
self,
|
||||||
@ -360,6 +368,27 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
raise BatchProcessNotFoundException
|
raise BatchProcessNotFoundException
|
||||||
return self._deserialize_batch_process(dict(result))
|
return self._deserialize_batch_process(dict(result))
|
||||||
|
|
||||||
|
def start(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE batch_process
|
||||||
|
SET canceled = 0
|
||||||
|
WHERE batch_id = ?;
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
def cancel(
|
def cancel(
|
||||||
self,
|
self,
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
|
Loading…
Reference in New Issue
Block a user