diff --git a/invokeai/app/api/routers/batches.py b/invokeai/app/api/routers/batches.py new file mode 100644 index 0000000000..1b989a46f4 --- /dev/null +++ b/invokeai/app/api/routers/batches.py @@ -0,0 +1,106 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from fastapi import Body, HTTPException, Path, Response +from fastapi.routing import APIRouter + +from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException + +# Importing * is bad karma but needed here for node detection +from ...invocations import * # noqa: F401 F403 +from ...services.batch_manager import Batch, BatchProcessResponse +from ...services.graph import Graph +from ..dependencies import ApiDependencies + +batches_router = APIRouter(prefix="/v1/batches", tags=["sessions"]) + + +@batches_router.post( + "/", + operation_id="create_batch", + responses={ + 200: {"model": BatchProcessResponse}, + 400: {"description": "Invalid json"}, + }, +) +async def create_batch( + graph: Graph = Body(description="The graph to initialize the session with"), + batch: Batch = Body(description="Batch config to apply to the given graph"), +) -> BatchProcessResponse: + """Creates a batch process""" + return ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph) + + +@batches_router.put( + "/b/{batch_process_id}/invoke", + operation_id="start_batch", + responses={ + 202: {"description": "Batch process started"}, + 404: {"description": "Batch session not found"}, + }, +) +async def start_batch( + batch_process_id: str = Path(description="ID of Batch to start"), +) -> Response: + """Executes a batch process""" + try: + ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id) + return Response(status_code=202) + except BatchSessionNotFoundException: + raise HTTPException(status_code=404, detail="Batch session not found") + + +@batches_router.delete( + "/b/{batch_process_id}", + operation_id="cancel_batch", + responses={202: {"description": "The batch is canceled"}}, +) +async def cancel_batch( + batch_process_id: str = Path(description="The id of the batch process to cancel"), +) -> Response: + """Cancels a batch process""" + ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id) + return Response(status_code=202) + + +@batches_router.get( + "/incomplete", + operation_id="list_incomplete_batches", + responses={200: {"model": list[BatchProcessResponse]}}, +) +async def list_incomplete_batches() -> list[BatchProcessResponse]: + """Lists incomplete batch processes""" + return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes() + + +@batches_router.get( + "/", + operation_id="list_batches", + responses={200: {"model": list[BatchProcessResponse]}}, +) +async def list_batches() -> list[BatchProcessResponse]: + """Lists all batch processes""" + return ApiDependencies.invoker.services.batch_manager.get_batch_processes() + + +@batches_router.get( + "/b/{batch_process_id}", + operation_id="get_batch", + responses={200: {"model": BatchProcessResponse}}, +) +async def get_batch( + batch_process_id: str = Path(description="The id of the batch process to get"), +) -> BatchProcessResponse: + """Gets a Batch Process""" + return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id) + + +@batches_router.get( + "/b/{batch_process_id}/sessions", + operation_id="get_batch_sessions", + responses={200: {"model": list[BatchSession]}}, +) +async def get_batch_sessions( + batch_process_id: str = Path(description="The id of the batch process to get"), +) -> list[BatchSession]: + """Gets a list of batch sessions for a given batch process""" + return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id) diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index 4adfff7ee7..e950e46e9f 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -6,12 +6,9 @@ from fastapi import Body, HTTPException, Path, Query, Response from fastapi.routing import APIRouter from pydantic.fields import Field -from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException - # Importing * is bad karma but needed here for node detection from ...invocations import * # noqa: F401 F403 from ...invocations.baseinvocation import BaseInvocation -from ...services.batch_manager import Batch, BatchProcessResponse from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError from ...services.item_storage import PaginatedResults from ..dependencies import ApiDependencies @@ -35,99 +32,6 @@ async def create_session( return session -@session_router.post( - "/batch", - operation_id="create_batch", - responses={ - 200: {"model": BatchProcessResponse}, - 400: {"description": "Invalid json"}, - }, -) -async def create_batch( - graph: Graph = Body(description="The graph to initialize the session with"), - batch: Batch = Body(description="Batch config to apply to the given graph"), -) -> BatchProcessResponse: - """Creates a batch process""" - batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph) - return batch_process_res - - -@session_router.put( - "/batch/{batch_process_id}/invoke", - operation_id="start_batch", - responses={ - 202: {"description": "Batch process started"}, - 404: {"description": "Batch session not found"}, - }, -) -async def start_batch( - batch_process_id: str = Path(description="ID of Batch to start"), -) -> Response: - """Executes a batch process""" - try: - ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id) - return Response(status_code=202) - except BatchSessionNotFoundException: - raise HTTPException(status_code=404, detail="Batch session not found") - - -@session_router.delete( - "/batch/{batch_process_id}", - operation_id="cancel_batch", - responses={202: {"description": "The batch is canceled"}}, -) -async def cancel_batch( - batch_process_id: str = Path(description="The id of the batch process to cancel"), -) -> Response: - """Cancels a batch process""" - ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id) - return Response(status_code=202) - - -@session_router.get( - "/batch/incomplete", - operation_id="list_incomplete_batches", - responses={200: {"model": list[BatchProcessResponse]}}, -) -async def list_incomplete_batches() -> list[BatchProcessResponse]: - """Lists incomplete batch processes""" - return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes() - - -@session_router.get( - "/batch", - operation_id="list_batches", - responses={200: {"model": list[BatchProcessResponse]}}, -) -async def list_batches() -> list[BatchProcessResponse]: - """Lists all batch processes""" - return ApiDependencies.invoker.services.batch_manager.get_batch_processes() - - -@session_router.get( - "/batch/{batch_process_id}", - operation_id="get_batch", - responses={200: {"model": BatchProcessResponse}}, -) -async def get_batch( - batch_process_id: str = Path(description="The id of the batch process to get"), -) -> BatchProcessResponse: - """Gets a Batch Process""" - return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id) - - -@session_router.get( - "/batch/{batch_process_id}/sessions", - operation_id="get_batch_sessions", - responses={200: {"model": list[BatchSession]}}, -) -async def get_batch_sessions( - batch_process_id: str = Path(description="The id of the batch process to get"), -) -> list[BatchSession]: - """Gets a list of batch sessions for a given batch process""" - return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id) - - @session_router.get( "/", operation_id="list_sessions", diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 902af0c02c..213c9dde5d 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -24,7 +24,7 @@ import invokeai.frontend.web as web_dir import mimetypes from .api.dependencies import ApiDependencies -from .api.routers import sessions, models, images, boards, board_images, app_info +from .api.routers import sessions, batches, models, images, boards, board_images, app_info from .api.sockets import SocketIO from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase @@ -90,6 +90,8 @@ async def shutdown_event(): app.include_router(sessions.session_router, prefix="/api") +app.include_router(batches.batches_router, prefix="/api") + app.include_router(models.models_router, prefix="/api") app.include_router(images.images_router, prefix="/api")