InvokeAI/invokeai/app/api/routers/sessions.py

308 lines
11 KiB
Python
Raw Normal View History

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
2023-03-03 06:02:00 +00:00
from typing import Annotated, List, Optional, Union
2023-04-24 13:55:45 +00:00
from fastapi import Body, HTTPException, Path, Query, Response
2023-03-03 06:02:00 +00:00
from fastapi.routing import APIRouter
from pydantic.fields import Field
2023-03-03 06:02:00 +00:00
from ...invocations import *
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import (
2023-03-15 06:09:30 +00:00
Edge,
2023-03-03 06:02:00 +00:00
EdgeConnection,
Graph,
GraphExecutionState,
NodeAlreadyExecutedError,
)
2023-07-31 20:05:27 +00:00
from ...services.batch_manager import Batch, BatchProcess
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
2023-03-03 06:02:00 +00:00
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
2023-03-03 06:02:00 +00:00
@session_router.post(
"/",
operation_id="create_session",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid json"},
},
)
async def create_session(
2023-07-27 14:54:01 +00:00
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph"""
2023-07-31 19:45:35 +00:00
session = ApiDependencies.invoker.create_execution_state(graph)
return session
2023-07-31 19:45:35 +00:00
@session_router.post(
"/batch",
operation_id="create_batch",
responses={
2023-07-31 20:05:27 +00:00
200: {"model": BatchProcess},
2023-07-31 19:45:35 +00:00
400: {"description": "Invalid json"},
},
)
async def create_batch(
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
batches: list[Batch] = Body(description="Batch config to apply to the given graph")
2023-07-31 20:05:27 +00:00
) -> BatchProcess:
"""Creates and starts a new new batch process"""
2023-07-31 19:45:35 +00:00
session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph)
return session
2023-07-31 20:05:27 +00:00
@session_router.delete(
"{batch_process_id}/batch",
operation_id="cancel_batch",
responses={202: {"description": "The batch is canceled"}},
)
async def cancel_batch(
batch_process_id: str = Path(description="The id of the batch process to cancel"),
) -> Response:
"""Creates and starts a new new batch process"""
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
return Response(status_code=202)
2023-03-03 06:02:00 +00:00
@session_router.get(
"/",
operation_id="list_sessions",
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
)
async def list_sessions(
2023-03-03 06:02:00 +00:00
page: int = Query(default=0, description="The page of results to get"),
per_page: int = Query(default=10, description="The number of results per page"),
query: str = Query(default="", description="The query string to search for"),
) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching"""
if query == "":
2023-07-27 14:54:01 +00:00
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
else:
2023-07-27 14:54:01 +00:00
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
return result
2023-03-03 06:02:00 +00:00
@session_router.get(
"/{session_id}",
operation_id="get_session",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
404: {"description": "Session not found"},
},
)
async def get_session(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session to get"),
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
else:
return session
2023-03-03 06:02:00 +00:00
@session_router.post(
"/{session_id}/nodes",
operation_id="add_node",
responses={
200: {"model": str},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def add_node(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session"),
2023-07-27 14:54:01 +00:00
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
description="The node to add"
),
) -> str:
"""Adds a node to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
try:
session.add_node(node)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session.id
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
2023-03-03 06:02:00 +00:00
2023-03-03 06:02:00 +00:00
@session_router.put(
"/{session_id}/nodes/{node_path}",
operation_id="update_node",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def update_node(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node in the graph"),
2023-07-27 14:54:01 +00:00
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
description="The new node"
),
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
try:
session.update_node(node_path, node)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
2023-03-03 06:02:00 +00:00
@session_router.delete(
"/{session_id}/nodes/{node_path}",
operation_id="delete_node",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def delete_node(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node to delete"),
) -> GraphExecutionState:
"""Deletes a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
try:
session.delete_node(node_path)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
2023-03-03 06:02:00 +00:00
@session_router.post(
"/{session_id}/edges",
operation_id="add_edge",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def add_edge(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session"),
2023-03-15 06:09:30 +00:00
edge: Edge = Body(description="The edge to add"),
) -> GraphExecutionState:
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
try:
session.add_edge(edge)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
# TODO: the edge being in the path here is really ugly, find a better solution
2023-03-03 06:02:00 +00:00
@session_router.delete(
"/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
operation_id="delete_edge",
responses={
200: {"model": GraphExecutionState},
2023-03-03 06:02:00 +00:00
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def delete_edge(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session"),
from_node_id: str = Path(description="The id of the node the edge is coming from"),
from_field: str = Path(description="The field of the node the edge is coming from"),
to_node_id: str = Path(description="The id of the node the edge is going to"),
to_field: str = Path(description="The field of the node the edge is going to"),
) -> GraphExecutionState:
"""Deletes an edge from the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
try:
2023-03-15 06:09:30 +00:00
edge = Edge(
source=EdgeConnection(node_id=from_node_id, field=from_field),
2023-07-27 14:54:01 +00:00
destination=EdgeConnection(node_id=to_node_id, field=to_field),
2023-03-03 06:02:00 +00:00
)
session.delete_edge(edge)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
2023-03-03 06:02:00 +00:00
@session_router.put(
"/{session_id}/invoke",
operation_id="invoke_session",
responses={
200: {"model": None},
2023-03-03 06:02:00 +00:00
202: {"description": "The invocation is queued"},
400: {"description": "The session has no invocations ready to invoke"},
404: {"description": "Session not found"},
},
)
async def invoke_session(
2023-03-03 06:02:00 +00:00
session_id: str = Path(description="The id of the session to invoke"),
2023-07-27 14:54:01 +00:00
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
2023-04-24 13:55:45 +00:00
) -> Response:
"""Invokes a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
2023-03-03 06:02:00 +00:00
if session.is_complete():
raise HTTPException(status_code=400)
2023-03-03 06:02:00 +00:00
ApiDependencies.invoker.invoke(session, invoke_all=all)
2023-04-24 13:55:45 +00:00
return Response(status_code=202)
2023-03-17 03:05:36 +00:00
@session_router.delete(
"/{session_id}/invoke",
operation_id="cancel_session_invoke",
2023-07-27 14:54:01 +00:00
responses={202: {"description": "The invocation is canceled"}},
2023-03-17 03:05:36 +00:00
)
async def cancel_session_invoke(
session_id: str = Path(description="The id of the session to cancel"),
2023-04-24 13:55:45 +00:00
) -> Response:
2023-03-17 03:05:36 +00:00
"""Invokes a session"""
ApiDependencies.invoker.cancel(session_id)
2023-04-24 13:55:45 +00:00
return Response(status_code=202)