mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(api): move batches to own router
This commit is contained in:
parent
3e7dadd7b3
commit
d22c4734ee
106
invokeai/app/api/routers/batches.py
Normal file
106
invokeai/app/api/routers/batches.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from fastapi import Body, HTTPException, Path, Response
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
|
|
||||||
|
from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException
|
||||||
|
|
||||||
|
# Importing * is bad karma but needed here for node detection
|
||||||
|
from ...invocations import * # noqa: F401 F403
|
||||||
|
from ...services.batch_manager import Batch, BatchProcessResponse
|
||||||
|
from ...services.graph import Graph
|
||||||
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
|
batches_router = APIRouter(prefix="/v1/batches", tags=["sessions"])
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.post(
|
||||||
|
"/",
|
||||||
|
operation_id="create_batch",
|
||||||
|
responses={
|
||||||
|
200: {"model": BatchProcessResponse},
|
||||||
|
400: {"description": "Invalid json"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def create_batch(
|
||||||
|
graph: Graph = Body(description="The graph to initialize the session with"),
|
||||||
|
batch: Batch = Body(description="Batch config to apply to the given graph"),
|
||||||
|
) -> BatchProcessResponse:
|
||||||
|
"""Creates a batch process"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.put(
|
||||||
|
"/b/{batch_process_id}/invoke",
|
||||||
|
operation_id="start_batch",
|
||||||
|
responses={
|
||||||
|
202: {"description": "Batch process started"},
|
||||||
|
404: {"description": "Batch session not found"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def start_batch(
|
||||||
|
batch_process_id: str = Path(description="ID of Batch to start"),
|
||||||
|
) -> Response:
|
||||||
|
"""Executes a batch process"""
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
|
||||||
|
return Response(status_code=202)
|
||||||
|
except BatchSessionNotFoundException:
|
||||||
|
raise HTTPException(status_code=404, detail="Batch session not found")
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.delete(
|
||||||
|
"/b/{batch_process_id}",
|
||||||
|
operation_id="cancel_batch",
|
||||||
|
responses={202: {"description": "The batch is canceled"}},
|
||||||
|
)
|
||||||
|
async def cancel_batch(
|
||||||
|
batch_process_id: str = Path(description="The id of the batch process to cancel"),
|
||||||
|
) -> Response:
|
||||||
|
"""Cancels a batch process"""
|
||||||
|
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
|
||||||
|
return Response(status_code=202)
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.get(
|
||||||
|
"/incomplete",
|
||||||
|
operation_id="list_incomplete_batches",
|
||||||
|
responses={200: {"model": list[BatchProcessResponse]}},
|
||||||
|
)
|
||||||
|
async def list_incomplete_batches() -> list[BatchProcessResponse]:
|
||||||
|
"""Lists incomplete batch processes"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.get(
|
||||||
|
"/",
|
||||||
|
operation_id="list_batches",
|
||||||
|
responses={200: {"model": list[BatchProcessResponse]}},
|
||||||
|
)
|
||||||
|
async def list_batches() -> list[BatchProcessResponse]:
|
||||||
|
"""Lists all batch processes"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.get(
|
||||||
|
"/b/{batch_process_id}",
|
||||||
|
operation_id="get_batch",
|
||||||
|
responses={200: {"model": BatchProcessResponse}},
|
||||||
|
)
|
||||||
|
async def get_batch(
|
||||||
|
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||||
|
) -> BatchProcessResponse:
|
||||||
|
"""Gets a Batch Process"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.get(
|
||||||
|
"/b/{batch_process_id}/sessions",
|
||||||
|
operation_id="get_batch_sessions",
|
||||||
|
responses={200: {"model": list[BatchSession]}},
|
||||||
|
)
|
||||||
|
async def get_batch_sessions(
|
||||||
|
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||||
|
) -> list[BatchSession]:
|
||||||
|
"""Gets a list of batch sessions for a given batch process"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)
|
@ -6,12 +6,9 @@ from fastapi import Body, HTTPException, Path, Query, Response
|
|||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException
|
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from ...invocations import * # noqa: F401 F403
|
from ...invocations import * # noqa: F401 F403
|
||||||
from ...invocations.baseinvocation import BaseInvocation
|
from ...invocations.baseinvocation import BaseInvocation
|
||||||
from ...services.batch_manager import Batch, BatchProcessResponse
|
|
||||||
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||||
from ...services.item_storage import PaginatedResults
|
from ...services.item_storage import PaginatedResults
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@ -35,99 +32,6 @@ async def create_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
@session_router.post(
|
|
||||||
"/batch",
|
|
||||||
operation_id="create_batch",
|
|
||||||
responses={
|
|
||||||
200: {"model": BatchProcessResponse},
|
|
||||||
400: {"description": "Invalid json"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def create_batch(
|
|
||||||
graph: Graph = Body(description="The graph to initialize the session with"),
|
|
||||||
batch: Batch = Body(description="Batch config to apply to the given graph"),
|
|
||||||
) -> BatchProcessResponse:
|
|
||||||
"""Creates a batch process"""
|
|
||||||
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
|
|
||||||
return batch_process_res
|
|
||||||
|
|
||||||
|
|
||||||
@session_router.put(
|
|
||||||
"/batch/{batch_process_id}/invoke",
|
|
||||||
operation_id="start_batch",
|
|
||||||
responses={
|
|
||||||
202: {"description": "Batch process started"},
|
|
||||||
404: {"description": "Batch session not found"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def start_batch(
|
|
||||||
batch_process_id: str = Path(description="ID of Batch to start"),
|
|
||||||
) -> Response:
|
|
||||||
"""Executes a batch process"""
|
|
||||||
try:
|
|
||||||
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
|
|
||||||
return Response(status_code=202)
|
|
||||||
except BatchSessionNotFoundException:
|
|
||||||
raise HTTPException(status_code=404, detail="Batch session not found")
|
|
||||||
|
|
||||||
|
|
||||||
@session_router.delete(
|
|
||||||
"/batch/{batch_process_id}",
|
|
||||||
operation_id="cancel_batch",
|
|
||||||
responses={202: {"description": "The batch is canceled"}},
|
|
||||||
)
|
|
||||||
async def cancel_batch(
|
|
||||||
batch_process_id: str = Path(description="The id of the batch process to cancel"),
|
|
||||||
) -> Response:
|
|
||||||
"""Cancels a batch process"""
|
|
||||||
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
|
|
||||||
return Response(status_code=202)
|
|
||||||
|
|
||||||
|
|
||||||
@session_router.get(
|
|
||||||
"/batch/incomplete",
|
|
||||||
operation_id="list_incomplete_batches",
|
|
||||||
responses={200: {"model": list[BatchProcessResponse]}},
|
|
||||||
)
|
|
||||||
async def list_incomplete_batches() -> list[BatchProcessResponse]:
|
|
||||||
"""Lists incomplete batch processes"""
|
|
||||||
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
|
|
||||||
|
|
||||||
|
|
||||||
@session_router.get(
|
|
||||||
"/batch",
|
|
||||||
operation_id="list_batches",
|
|
||||||
responses={200: {"model": list[BatchProcessResponse]}},
|
|
||||||
)
|
|
||||||
async def list_batches() -> list[BatchProcessResponse]:
|
|
||||||
"""Lists all batch processes"""
|
|
||||||
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
|
|
||||||
|
|
||||||
|
|
||||||
@session_router.get(
|
|
||||||
"/batch/{batch_process_id}",
|
|
||||||
operation_id="get_batch",
|
|
||||||
responses={200: {"model": BatchProcessResponse}},
|
|
||||||
)
|
|
||||||
async def get_batch(
|
|
||||||
batch_process_id: str = Path(description="The id of the batch process to get"),
|
|
||||||
) -> BatchProcessResponse:
|
|
||||||
"""Gets a Batch Process"""
|
|
||||||
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
|
|
||||||
|
|
||||||
|
|
||||||
@session_router.get(
|
|
||||||
"/batch/{batch_process_id}/sessions",
|
|
||||||
operation_id="get_batch_sessions",
|
|
||||||
responses={200: {"model": list[BatchSession]}},
|
|
||||||
)
|
|
||||||
async def get_batch_sessions(
|
|
||||||
batch_process_id: str = Path(description="The id of the batch process to get"),
|
|
||||||
) -> list[BatchSession]:
|
|
||||||
"""Gets a list of batch sessions for a given batch process"""
|
|
||||||
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)
|
|
||||||
|
|
||||||
|
|
||||||
@session_router.get(
|
@session_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_sessions",
|
operation_id="list_sessions",
|
||||||
|
@ -24,7 +24,7 @@ import invokeai.frontend.web as web_dir
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
|
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
from .api.routers import sessions, batches, models, images, boards, board_images, app_info
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||||
|
|
||||||
@ -90,6 +90,8 @@ async def shutdown_event():
|
|||||||
|
|
||||||
app.include_router(sessions.session_router, prefix="/api")
|
app.include_router(sessions.session_router, prefix="/api")
|
||||||
|
|
||||||
|
app.include_router(batches.batches_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(models.models_router, prefix="/api")
|
app.include_router(models.models_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(images.images_router, prefix="/api")
|
app.include_router(images.images_router, prefix="/api")
|
||||||
|
Loading…
Reference in New Issue
Block a user