fix(api): fix start_batch route responses

This commit is contained in:
psychedelicious 2023-08-17 11:51:14 +10:00
parent f7277a8b21
commit f246b236dd

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Annotated, List, Optional, Union from typing import Annotated, Optional, Union
from fastapi import Body, HTTPException, Path, Query, Response from fastapi import Body, HTTPException, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
@ -8,14 +8,8 @@ from pydantic.fields import Field
from ...invocations import * from ...invocations import *
from ...invocations.baseinvocation import BaseInvocation from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import (
Edge,
EdgeConnection,
Graph,
GraphExecutionState,
NodeAlreadyExecutedError,
)
from ...services.batch_manager import Batch, BatchProcessResponse from ...services.batch_manager import Batch, BatchProcessResponse
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
from ...services.item_storage import PaginatedResults from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -47,7 +41,7 @@ async def create_session(
}, },
) )
async def create_batch( async def create_batch(
graph: Optional[Graph] = Body(description="The graph to initialize the session with"), graph: Graph = Body(description="The graph to initialize the session with"),
batch: Batch = Body(description="Batch config to apply to the given graph"), batch: Batch = Body(description="Batch config to apply to the given graph"),
) -> BatchProcessResponse: ) -> BatchProcessResponse:
"""Creates and starts a new new batch process""" """Creates and starts a new new batch process"""
@ -59,15 +53,14 @@ async def create_batch(
"/batch/{batch_process_id}/invoke", "/batch/{batch_process_id}/invoke",
operation_id="start_batch", operation_id="start_batch",
responses={ responses={
200: {"model": BatchProcessResponse}, 202: {"description": "Batch process started"},
400: {"description": "Invalid json"}, 400: {"description": "Invalid json"},
}, },
) )
async def start_batch( async def start_batch(
batch_process_id: str = Path(description="ID of Batch to start"), batch_process_id: str = Path(description="ID of Batch to start"),
) -> BatchProcessResponse: ) -> Response:
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id) ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
return Response(status_code=202) return Response(status_code=202)