InvokeAI/invokeai/app/api/routers/sessions.py

233 lines
9.0 KiB
Python
Raw Normal View History

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import List, Optional, Union, Annotated
from fastapi import Query, Path, Body
from fastapi.routing import APIRouter
from fastapi.responses import Response
from pydantic.fields import Field
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
from ...invocations import *
session_router = APIRouter(
prefix = '/v1/sessions',
tags = ['sessions']
)
@session_router.post('/',
operation_id = 'create_session',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid json'}
})
async def create_session(
graph: Optional[Graph] = Body(default = None, description = "The graph to initialize the session with")
) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.create_execution_state(graph)
return session
@session_router.get('/',
operation_id = 'list_sessions',
responses = {
200: {"model": PaginatedResults[GraphExecutionState]}
})
async def list_sessions(
page: int = Query(default = 0, description = "The page of results to get"),
per_page: int = Query(default = 10, description = "The number of results per page"),
query: str = Query(default = '', description = "The query string to search for")
) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching"""
if filter == '':
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
else:
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
return result
@session_router.get('/{session_id}',
operation_id = 'get_session',
responses = {
200: {"model": GraphExecutionState},
404: {'description': 'Session not found'}
})
async def get_session(
session_id: str = Path(description = "The id of the session to get")
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
else:
return session
@session_router.post('/{session_id}/nodes',
operation_id = 'add_node',
responses = {
200: {"model": str},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def add_node(
session_id: str = Path(description = "The id of the session"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
) -> str:
"""Adds a node to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.add_node(node)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session.id
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.put('/{session_id}/nodes/{node_path}',
operation_id = 'update_node',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def update_node(
session_id: str = Path(description = "The id of the session"),
node_path: str = Path(description = "The path to the node in the graph"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.update_node(node_path, node)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.delete('/{session_id}/nodes/{node_path}',
operation_id = 'delete_node',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def delete_node(
session_id: str = Path(description = "The id of the session"),
node_path: str = Path(description = "The path to the node to delete")
) -> GraphExecutionState:
"""Deletes a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.delete_node(node_path)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.post('/{session_id}/edges',
operation_id = 'add_edge',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def add_edge(
session_id: str = Path(description = "The id of the session"),
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
) -> GraphExecutionState:
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.add_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
# TODO: the edge being in the path here is really ugly, find a better solution
@session_router.delete('/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}',
operation_id = 'delete_edge',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def delete_edge(
session_id: str = Path(description = "The id of the session"),
from_node_id: str = Path(description = "The id of the node the edge is coming from"),
from_field: str = Path(description = "The field of the node the edge is coming from"),
to_node_id: str = Path(description = "The id of the node the edge is going to"),
to_field: str = Path(description = "The field of the node the edge is going to")
) -> GraphExecutionState:
"""Deletes an edge from the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
session.delete_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.put('/{session_id}/invoke',
operation_id = 'invoke_session',
responses = {
200: {"model": None},
202: {'description': 'The invocation is queued'},
400: {'description': 'The session has no invocations ready to invoke'},
404: {'description': 'Session not found'}
})
async def invoke_session(
session_id: str = Path(description = "The id of the session to invoke"),
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
) -> None:
"""Invokes a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
if session.is_complete():
return Response(status_code = 400)
ApiDependencies.invoker.invoke(session, invoke_all = all)
return Response(status_code=202)