InvokeAI/invokeai/app/api/routers/sessions.py
2023-10-17 14:59:25 +11:00

277 lines
10 KiB
Python

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from fastapi import HTTPException, Path
from fastapi.routing import APIRouter
from ...services.shared.graph import GraphExecutionState
from ..dependencies import ApiDependencies
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
# @session_router.post(
# "/",
# operation_id="create_session",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid json"},
# },
# deprecated=True,
# )
# async def create_session(
# queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
# graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
# ) -> GraphExecutionState:
# """Creates a new session, optionally initializing it with an invocation graph"""
# session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
# return session
# @session_router.get(
# "/",
# operation_id="list_sessions",
# responses={200: {"model": PaginatedResults[GraphExecutionState]}},
# deprecated=True,
# )
# async def list_sessions(
# page: int = Query(default=0, description="The page of results to get"),
# per_page: int = Query(default=10, description="The number of results per page"),
# query: str = Query(default="", description="The query string to search for"),
# ) -> PaginatedResults[GraphExecutionState]:
# """Gets a list of sessions, optionally searching"""
# if query == "":
# result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
# else:
# result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
# return result
@session_router.get(
"/{session_id}",
operation_id="get_session",
responses={
200: {"model": GraphExecutionState},
404: {"description": "Session not found"},
},
)
async def get_session(
session_id: str = Path(description="The id of the session to get"),
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
else:
return session
# @session_router.post(
# "/{session_id}/nodes",
# operation_id="add_node",
# responses={
# 200: {"model": str},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def add_node(
# session_id: str = Path(description="The id of the session"),
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
# description="The node to add"
# ),
# ) -> str:
# """Adds a node to the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.add_node(node)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session.id
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.put(
# "/{session_id}/nodes/{node_path}",
# operation_id="update_node",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def update_node(
# session_id: str = Path(description="The id of the session"),
# node_path: str = Path(description="The path to the node in the graph"),
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
# description="The new node"
# ),
# ) -> GraphExecutionState:
# """Updates a node in the graph and removes all linked edges"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.update_node(node_path, node)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.delete(
# "/{session_id}/nodes/{node_path}",
# operation_id="delete_node",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def delete_node(
# session_id: str = Path(description="The id of the session"),
# node_path: str = Path(description="The path to the node to delete"),
# ) -> GraphExecutionState:
# """Deletes a node in the graph and removes all linked edges"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.delete_node(node_path)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.post(
# "/{session_id}/edges",
# operation_id="add_edge",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def add_edge(
# session_id: str = Path(description="The id of the session"),
# edge: Edge = Body(description="The edge to add"),
# ) -> GraphExecutionState:
# """Adds an edge to the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.add_edge(edge)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# # TODO: the edge being in the path here is really ugly, find a better solution
# @session_router.delete(
# "/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
# operation_id="delete_edge",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def delete_edge(
# session_id: str = Path(description="The id of the session"),
# from_node_id: str = Path(description="The id of the node the edge is coming from"),
# from_field: str = Path(description="The field of the node the edge is coming from"),
# to_node_id: str = Path(description="The id of the node the edge is going to"),
# to_field: str = Path(description="The field of the node the edge is going to"),
# ) -> GraphExecutionState:
# """Deletes an edge from the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# edge = Edge(
# source=EdgeConnection(node_id=from_node_id, field=from_field),
# destination=EdgeConnection(node_id=to_node_id, field=to_field),
# )
# session.delete_edge(edge)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.put(
# "/{session_id}/invoke",
# operation_id="invoke_session",
# responses={
# 200: {"model": None},
# 202: {"description": "The invocation is queued"},
# 400: {"description": "The session has no invocations ready to invoke"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def invoke_session(
# queue_id: str = Query(description="The id of the queue to associate the session with"),
# session_id: str = Path(description="The id of the session to invoke"),
# all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
# ) -> Response:
# """Invokes a session"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# if session.is_complete():
# raise HTTPException(status_code=400)
# ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
# return Response(status_code=202)
# @session_router.delete(
# "/{session_id}/invoke",
# operation_id="cancel_session_invoke",
# responses={202: {"description": "The invocation is canceled"}},
# deprecated=True,
# )
# async def cancel_session_invoke(
# session_id: str = Path(description="The id of the session to cancel"),
# ) -> Response:
# """Invokes a session"""
# ApiDependencies.invoker.cancel(session_id)
# return Response(status_code=202)