feat(backend): surface BatchSessionNodeFoundException

Catch this exception in the router and return an appropriate `HTTPException`.
This commit is contained in:
psychedelicious 2023-08-17 12:45:32 +10:00
parent e16b5f7cdc
commit 7e4beab4ff
2 changed files with 9 additions and 7 deletions

View File

@ -6,6 +6,8 @@ from fastapi import Body, HTTPException, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic.fields import Field from pydantic.fields import Field
from invokeai.app.services.batch_manager_storage import BatchSessionNotFoundException
from ...invocations import * from ...invocations import *
from ...invocations.baseinvocation import BaseInvocation from ...invocations.baseinvocation import BaseInvocation
from ...services.batch_manager import Batch, BatchProcessResponse from ...services.batch_manager import Batch, BatchProcessResponse
@ -54,14 +56,17 @@ async def create_batch(
operation_id="start_batch", operation_id="start_batch",
responses={ responses={
202: {"description": "Batch process started"}, 202: {"description": "Batch process started"},
400: {"description": "Invalid json"}, 404: {"description": "Batch session not found"},
}, },
) )
async def start_batch( async def start_batch(
batch_process_id: str = Path(description="ID of Batch to start"), batch_process_id: str = Path(description="ID of Batch to start"),
) -> Response: ) -> Response:
try:
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id) ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
return Response(status_code=202) return Response(status_code=202)
except BatchSessionNotFoundException:
raise HTTPException(status_code=404, detail="Batch session not found")
@session_router.delete( @session_router.delete(

View File

@ -106,10 +106,7 @@ class BatchManager(BatchManagerBase):
def run_batch_process(self, batch_id: str) -> None: def run_batch_process(self, batch_id: str) -> None:
self.__batch_process_storage.start(batch_id) self.__batch_process_storage.start(batch_id)
try:
created_session = self.__batch_process_storage.get_created_session(batch_id) created_session = self.__batch_process_storage.get_created_session(batch_id)
except BatchSessionNotFoundException:
return
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id) ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
self.__invoker.invoke(ges, invoke_all=True) self.__invoker.invoke(ges, invoke_all=True)