feat(api): move batches to own router

This commit is contained in:
psychedelicious 2023-09-06 16:34:49 +10:00
parent 3e7dadd7b3
commit d22c4734ee
3 changed files with 109 additions and 97 deletions

View File

@ -0,0 +1,106 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from fastapi import Body, HTTPException, Path, Response
from fastapi.routing import APIRouter
from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException
# Importing * is bad karma but needed here for node detection
from ...invocations import * # noqa: F401 F403
from ...services.batch_manager import Batch, BatchProcessResponse
from ...services.graph import Graph
from ..dependencies import ApiDependencies
batches_router = APIRouter(prefix="/v1/batches", tags=["sessions"])
@batches_router.post(
"/",
operation_id="create_batch",
responses={
200: {"model": BatchProcessResponse},
400: {"description": "Invalid json"},
},
)
async def create_batch(
graph: Graph = Body(description="The graph to initialize the session with"),
batch: Batch = Body(description="Batch config to apply to the given graph"),
) -> BatchProcessResponse:
"""Creates a batch process"""
return ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
@batches_router.put(
"/b/{batch_process_id}/invoke",
operation_id="start_batch",
responses={
202: {"description": "Batch process started"},
404: {"description": "Batch session not found"},
},
)
async def start_batch(
batch_process_id: str = Path(description="ID of Batch to start"),
) -> Response:
"""Executes a batch process"""
try:
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
return Response(status_code=202)
except BatchSessionNotFoundException:
raise HTTPException(status_code=404, detail="Batch session not found")
@batches_router.delete(
"/b/{batch_process_id}",
operation_id="cancel_batch",
responses={202: {"description": "The batch is canceled"}},
)
async def cancel_batch(
batch_process_id: str = Path(description="The id of the batch process to cancel"),
) -> Response:
"""Cancels a batch process"""
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
return Response(status_code=202)
@batches_router.get(
"/incomplete",
operation_id="list_incomplete_batches",
responses={200: {"model": list[BatchProcessResponse]}},
)
async def list_incomplete_batches() -> list[BatchProcessResponse]:
"""Lists incomplete batch processes"""
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
@batches_router.get(
"/",
operation_id="list_batches",
responses={200: {"model": list[BatchProcessResponse]}},
)
async def list_batches() -> list[BatchProcessResponse]:
"""Lists all batch processes"""
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
@batches_router.get(
"/b/{batch_process_id}",
operation_id="get_batch",
responses={200: {"model": BatchProcessResponse}},
)
async def get_batch(
batch_process_id: str = Path(description="The id of the batch process to get"),
) -> BatchProcessResponse:
"""Gets a Batch Process"""
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
@batches_router.get(
"/b/{batch_process_id}/sessions",
operation_id="get_batch_sessions",
responses={200: {"model": list[BatchSession]}},
)
async def get_batch_sessions(
batch_process_id: str = Path(description="The id of the batch process to get"),
) -> list[BatchSession]:
"""Gets a list of batch sessions for a given batch process"""
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)

View File

@ -6,12 +6,9 @@ from fastapi import Body, HTTPException, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic.fields import Field from pydantic.fields import Field
from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException
# Importing * is bad karma but needed here for node detection # Importing * is bad karma but needed here for node detection
from ...invocations import * # noqa: F401 F403 from ...invocations import * # noqa: F401 F403
from ...invocations.baseinvocation import BaseInvocation from ...invocations.baseinvocation import BaseInvocation
from ...services.batch_manager import Batch, BatchProcessResponse
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
from ...services.item_storage import PaginatedResults from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -35,99 +32,6 @@ async def create_session(
return session return session
@session_router.post(
"/batch",
operation_id="create_batch",
responses={
200: {"model": BatchProcessResponse},
400: {"description": "Invalid json"},
},
)
async def create_batch(
graph: Graph = Body(description="The graph to initialize the session with"),
batch: Batch = Body(description="Batch config to apply to the given graph"),
) -> BatchProcessResponse:
"""Creates a batch process"""
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
return batch_process_res
@session_router.put(
"/batch/{batch_process_id}/invoke",
operation_id="start_batch",
responses={
202: {"description": "Batch process started"},
404: {"description": "Batch session not found"},
},
)
async def start_batch(
batch_process_id: str = Path(description="ID of Batch to start"),
) -> Response:
"""Executes a batch process"""
try:
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
return Response(status_code=202)
except BatchSessionNotFoundException:
raise HTTPException(status_code=404, detail="Batch session not found")
@session_router.delete(
"/batch/{batch_process_id}",
operation_id="cancel_batch",
responses={202: {"description": "The batch is canceled"}},
)
async def cancel_batch(
batch_process_id: str = Path(description="The id of the batch process to cancel"),
) -> Response:
"""Cancels a batch process"""
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
return Response(status_code=202)
@session_router.get(
"/batch/incomplete",
operation_id="list_incomplete_batches",
responses={200: {"model": list[BatchProcessResponse]}},
)
async def list_incomplete_batches() -> list[BatchProcessResponse]:
"""Lists incomplete batch processes"""
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
@session_router.get(
"/batch",
operation_id="list_batches",
responses={200: {"model": list[BatchProcessResponse]}},
)
async def list_batches() -> list[BatchProcessResponse]:
"""Lists all batch processes"""
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
@session_router.get(
"/batch/{batch_process_id}",
operation_id="get_batch",
responses={200: {"model": BatchProcessResponse}},
)
async def get_batch(
batch_process_id: str = Path(description="The id of the batch process to get"),
) -> BatchProcessResponse:
"""Gets a Batch Process"""
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
@session_router.get(
"/batch/{batch_process_id}/sessions",
operation_id="get_batch_sessions",
responses={200: {"model": list[BatchSession]}},
)
async def get_batch_sessions(
batch_process_id: str = Path(description="The id of the batch process to get"),
) -> list[BatchSession]:
"""Gets a list of batch sessions for a given batch process"""
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)
@session_router.get( @session_router.get(
"/", "/",
operation_id="list_sessions", operation_id="list_sessions",

View File

@ -24,7 +24,7 @@ import invokeai.frontend.web as web_dir
import mimetypes import mimetypes
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info from .api.routers import sessions, batches, models, images, boards, board_images, app_info
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
@ -90,6 +90,8 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api") app.include_router(sessions.session_router, prefix="/api")
app.include_router(batches.batches_router, prefix="/api")
app.include_router(models.models_router, prefix="/api") app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api") app.include_router(images.images_router, prefix="/api")