InvokeAI/invokeai/app/api/routers/sessions.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

368 lines
13 KiB
Python
Raw Normal View History

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Annotated, Optional, Union
2023-03-03 06:02:00 +00:00
2023-04-24 13:55:45 +00:00
from fastapi import Body, HTTPException, Path, Query, Response
2023-03-03 06:02:00 +00:00
from fastapi.routing import APIRouter
from pydantic.fields import Field
from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException
2023-08-17 22:45:25 +00:00
# Importing * is bad karma but needed here for node detection
from ...invocations import * # noqa: F401 F403
2023-03-03 06:02:00 +00:00
from ...invocations.baseinvocation import BaseInvocation
2023-08-11 15:45:27 +00:00
from ...services.batch_manager import Batch, BatchProcessResponse
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
2023-03-03 06:02:00 +00:00
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
2023-03-03 06:02:00 +00:00
@session_router.post(
"/",
operation_id="create_session",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid json"},
},
)
async def create_session(
2023-03-03 06:02:00 +00:00
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph"""
2023-07-31 19:45:35 +00:00
session = ApiDependencies.invoker.create_execution_state(graph)
return session
2023-07-31 19:45:35 +00:00
@session_router.post(
"/batch",
operation_id="create_batch",
responses={
2023-08-11 15:45:27 +00:00
200: {"model": BatchProcessResponse},
2023-07-31 19:45:35 +00:00
400: {"description": "Invalid json"},
},
)
async def create_batch(
graph: Graph = Body(description="The graph to initialize the session with"),
batch: Batch = Body(description="Batch config to apply to the given graph"),
2023-08-11 15:45:27 +00:00
) -> BatchProcessResponse:
2023-08-21 09:51:16 +00:00
"""Creates a batch process"""
batch_process_res = ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
2023-08-11 15:45:27 +00:00
return batch_process_res
2023-08-15 20:28:47 +00:00
@session_router.put(
"/batch/{batch_process_id}/invoke",
operation_id="start_batch",
responses={
202: {"description": "Batch process started"},
404: {"description": "Batch session not found"},
2023-08-15 20:28:47 +00:00
},
)
async def start_batch(
batch_process_id: str = Path(description="ID of Batch to start"),
) -> Response:
2023-08-21 09:51:16 +00:00
"""Executes a batch process"""
try:
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
return Response(status_code=202)
except BatchSessionNotFoundException:
raise HTTPException(status_code=404, detail="Batch session not found")
2023-08-15 20:28:47 +00:00
2023-07-31 20:05:27 +00:00
@session_router.delete(
2023-08-11 19:52:49 +00:00
"/batch/{batch_process_id}",
2023-07-31 20:05:27 +00:00
operation_id="cancel_batch",
responses={202: {"description": "The batch is canceled"}},
)
async def cancel_batch(
batch_process_id: str = Path(description="The id of the batch process to cancel"),
) -> Response:
2023-08-21 09:51:16 +00:00
"""Cancels a batch process"""
2023-07-31 20:05:27 +00:00
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
return Response(status_code=202)
@session_router.get(
"/batch/incomplete",
operation_id="list_incomplete_batches",
responses={200: {"model": list[BatchProcessResponse]}},
)
async def list_incomplete_batches() -> list[BatchProcessResponse]:
2023-08-21 09:51:16 +00:00
"""Lists incomplete batch processes"""
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
@session_router.get(
"/batch",
2023-08-29 14:32:22 +00:00
operation_id="list_batches",
responses={200: {"model": list[BatchProcessResponse]}},
)
async def list_batches() -> list[BatchProcessResponse]:
2023-08-21 09:51:16 +00:00
"""Lists all batch processes"""
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
@session_router.get(
"/batch/{batch_process_id}",
operation_id="get_batch",
responses={200: {"model": BatchProcessResponse}},
)
async def get_batch(
batch_process_id: str = Path(description="The id of the batch process to get"),
) -> BatchProcessResponse:
"""Gets a Batch Process"""
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
@session_router.get(
"/batch/{batch_process_id}/sessions",
2023-09-05 06:06:35 +00:00
operation_id="get_batch_sessions",
responses={200: {"model": list[BatchSession]}},
)
async def get_batch_sessions(
batch_process_id: str = Path(description="The id of the batch process to get"),
) -> list[BatchSession]:
2023-08-21 09:51:16 +00:00
"""Gets a list of batch sessions for a given batch process"""
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)
2023-03-03 06:02:00 +00:00
@session_router.get(
"/",
operation_id="list_sessions",
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
)
async def list_sessions(
2023-03-03 06:02:00 +00:00
page: int = Query(default=0, description="The page of results to get"),
per_page: int = Query(default=10, description="The number of results per page"),
query: str = Query(default="", description="The query string to search for"),
) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching"""
if query == "":
2023-03-03 06:02:00 +00:00
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
else:
2023-03-03 06:02:00 +00:00
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
return result
2023-03-03 06:02:00 +00:00
@session_router.get(
"/{session_id}",
operation_id="get_session",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
404: {"description": "Session not found"},
},
)
async def get_session(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session to get"),
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
else:
return session
2023-03-03 06:02:00 +00:00
@session_router.post(
"/{session_id}/nodes",
operation_id="add_node",
responses={
200: {"model": str},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def add_node(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session"),
2023-03-15 06:09:30 +00:00
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
2023-03-03 06:02:00 +00:00
description="The node to add"
),
) -> str:
"""Adds a node to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
try:
session.add_node(node)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session.id
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
2023-03-03 06:02:00 +00:00
2023-03-03 06:02:00 +00:00
@session_router.put(
"/{session_id}/nodes/{node_path}",
operation_id="update_node",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def update_node(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node in the graph"),
2023-03-15 06:09:30 +00:00
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
2023-03-03 06:02:00 +00:00
description="The new node"
),
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
try:
session.update_node(node_path, node)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
2023-03-03 06:02:00 +00:00
@session_router.delete(
"/{session_id}/nodes/{node_path}",
operation_id="delete_node",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def delete_node(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node to delete"),
) -> GraphExecutionState:
"""Deletes a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
try:
session.delete_node(node_path)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
2023-03-03 06:02:00 +00:00
@session_router.post(
"/{session_id}/edges",
operation_id="add_edge",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def add_edge(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session"),
2023-03-15 06:09:30 +00:00
edge: Edge = Body(description="The edge to add"),
) -> GraphExecutionState:
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
try:
session.add_edge(edge)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
# TODO: the edge being in the path here is really ugly, find a better solution
2023-03-03 06:02:00 +00:00
@session_router.delete(
"/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
operation_id="delete_edge",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def delete_edge(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session"),
from_node_id: str = Path(description="The id of the node the edge is coming from"),
from_field: str = Path(description="The field of the node the edge is coming from"),
to_node_id: str = Path(description="The id of the node the edge is going to"),
to_field: str = Path(description="The field of the node the edge is going to"),
) -> GraphExecutionState:
"""Deletes an edge from the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
try:
2023-03-15 06:09:30 +00:00
edge = Edge(
source=EdgeConnection(node_id=from_node_id, field=from_field),
destination=EdgeConnection(node_id=to_node_id, field=to_field),
2023-03-03 06:02:00 +00:00
)
session.delete_edge(edge)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
2023-03-03 06:02:00 +00:00
@session_router.put(
"/{session_id}/invoke",
operation_id="invoke_session",
responses={
200: {"model": None},
2023-03-03 06:02:00 +00:00
202: {"description": "The invocation is queued"},
400: {"description": "The session has no invocations ready to invoke"},
404: {"description": "Session not found"},
},
)
async def invoke_session(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session to invoke"),
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
2023-04-24 13:55:45 +00:00
) -> Response:
"""Invokes a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
2023-03-03 06:02:00 +00:00
if session.is_complete():
raise HTTPException(status_code=400)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.invoke(session, invoke_all=all)
2023-04-24 13:55:45 +00:00
return Response(status_code=202)
2023-03-17 03:05:36 +00:00
@session_router.delete(
"/{session_id}/invoke",
operation_id="cancel_session_invoke",
responses={202: {"description": "The invocation is canceled"}},
)
async def cancel_session_invoke(
session_id: str = Path(description="The id of the session to cancel"),
2023-04-24 13:55:45 +00:00
) -> Response:
2023-03-17 03:05:36 +00:00
"""Invokes a session"""
ApiDependencies.invoker.cancel(session_id)
2023-04-24 13:55:45 +00:00
return Response(status_code=202)