feat(api): restore get_session route

This commit is contained in:
psychedelicious 2023-10-17 14:28:39 +11:00
parent 2c39557dc9
commit 685cda89ff
2 changed files with 226 additions and 234 deletions

View File

@ -1,57 +1,50 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Annotated, Optional, Union
from fastapi import Body, HTTPException, Path, Query, Response from fastapi import HTTPException, Path
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic.fields import Field
from invokeai.app.services.shared.pagination import PaginatedResults from ...services.shared.graph import GraphExecutionState
# Importing * is bad karma but needed here for node detection
from ...invocations import * # noqa: F401 F403
from ...invocations.baseinvocation import BaseInvocation
from ...services.shared.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"]) session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
@session_router.post( # @session_router.post(
"/", # "/",
operation_id="create_session", # operation_id="create_session",
responses={ # responses={
200: {"model": GraphExecutionState}, # 200: {"model": GraphExecutionState},
400: {"description": "Invalid json"}, # 400: {"description": "Invalid json"},
}, # },
deprecated=True, # deprecated=True,
) # )
async def create_session( # async def create_session(
queue_id: str = Query(default="", description="The id of the queue to associate the session with"), # queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"), # graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
) -> GraphExecutionState: # ) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph""" # """Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph) # session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
return session # return session
@session_router.get( # @session_router.get(
"/", # "/",
operation_id="list_sessions", # operation_id="list_sessions",
responses={200: {"model": PaginatedResults[GraphExecutionState]}}, # responses={200: {"model": PaginatedResults[GraphExecutionState]}},
deprecated=True, # deprecated=True,
) # )
async def list_sessions( # async def list_sessions(
page: int = Query(default=0, description="The page of results to get"), # page: int = Query(default=0, description="The page of results to get"),
per_page: int = Query(default=10, description="The number of results per page"), # per_page: int = Query(default=10, description="The number of results per page"),
query: str = Query(default="", description="The query string to search for"), # query: str = Query(default="", description="The query string to search for"),
) -> PaginatedResults[GraphExecutionState]: # ) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching""" # """Gets a list of sessions, optionally searching"""
if query == "": # if query == "":
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page) # result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
else: # else:
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page) # result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
return result # return result
@session_router.get( @session_router.get(
@ -61,7 +54,6 @@ async def list_sessions(
200: {"model": GraphExecutionState}, 200: {"model": GraphExecutionState},
404: {"description": "Session not found"}, 404: {"description": "Session not found"},
}, },
deprecated=True,
) )
async def get_session( async def get_session(
session_id: str = Path(description="The id of the session to get"), session_id: str = Path(description="The id of the session to get"),
@ -74,211 +66,211 @@ async def get_session(
return session return session
@session_router.post( # @session_router.post(
"/{session_id}/nodes", # "/{session_id}/nodes",
operation_id="add_node", # operation_id="add_node",
responses={ # responses={
200: {"model": str}, # 200: {"model": str},
400: {"description": "Invalid node or link"}, # 400: {"description": "Invalid node or link"},
404: {"description": "Session not found"}, # 404: {"description": "Session not found"},
}, # },
deprecated=True, # deprecated=True,
) # )
async def add_node( # async def add_node(
session_id: str = Path(description="The id of the session"), # session_id: str = Path(description="The id of the session"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore # node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
description="The node to add" # description="The node to add"
), # ),
) -> str: # ) -> str:
"""Adds a node to the graph""" # """Adds a node to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) # session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: # if session is None:
raise HTTPException(status_code=404) # raise HTTPException(status_code=404)
try: # try:
session.add_node(node) # session.add_node(node)
ApiDependencies.invoker.services.graph_execution_manager.set( # ApiDependencies.invoker.services.graph_execution_manager.set(
session # session
) # TODO: can this be done automatically, or add node through an API? # ) # TODO: can this be done automatically, or add node through an API?
return session.id # return session.id
except NodeAlreadyExecutedError: # except NodeAlreadyExecutedError:
raise HTTPException(status_code=400) # raise HTTPException(status_code=400)
except IndexError: # except IndexError:
raise HTTPException(status_code=400) # raise HTTPException(status_code=400)
@session_router.put( # @session_router.put(
"/{session_id}/nodes/{node_path}", # "/{session_id}/nodes/{node_path}",
operation_id="update_node", # operation_id="update_node",
responses={ # responses={
200: {"model": GraphExecutionState}, # 200: {"model": GraphExecutionState},
400: {"description": "Invalid node or link"}, # 400: {"description": "Invalid node or link"},
404: {"description": "Session not found"}, # 404: {"description": "Session not found"},
}, # },
deprecated=True, # deprecated=True,
) # )
async def update_node( # async def update_node(
session_id: str = Path(description="The id of the session"), # session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node in the graph"), # node_path: str = Path(description="The path to the node in the graph"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore # node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
description="The new node" # description="The new node"
), # ),
) -> GraphExecutionState: # ) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges""" # """Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) # session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: # if session is None:
raise HTTPException(status_code=404) # raise HTTPException(status_code=404)
try: # try:
session.update_node(node_path, node) # session.update_node(node_path, node)
ApiDependencies.invoker.services.graph_execution_manager.set( # ApiDependencies.invoker.services.graph_execution_manager.set(
session # session
) # TODO: can this be done automatically, or add node through an API? # ) # TODO: can this be done automatically, or add node through an API?
return session # return session
except NodeAlreadyExecutedError: # except NodeAlreadyExecutedError:
raise HTTPException(status_code=400) # raise HTTPException(status_code=400)
except IndexError: # except IndexError:
raise HTTPException(status_code=400) # raise HTTPException(status_code=400)
@session_router.delete( # @session_router.delete(
"/{session_id}/nodes/{node_path}", # "/{session_id}/nodes/{node_path}",
operation_id="delete_node", # operation_id="delete_node",
responses={ # responses={
200: {"model": GraphExecutionState}, # 200: {"model": GraphExecutionState},
400: {"description": "Invalid node or link"}, # 400: {"description": "Invalid node or link"},
404: {"description": "Session not found"}, # 404: {"description": "Session not found"},
}, # },
deprecated=True, # deprecated=True,
) # )
async def delete_node( # async def delete_node(
session_id: str = Path(description="The id of the session"), # session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node to delete"), # node_path: str = Path(description="The path to the node to delete"),
) -> GraphExecutionState: # ) -> GraphExecutionState:
"""Deletes a node in the graph and removes all linked edges""" # """Deletes a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) # session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: # if session is None:
raise HTTPException(status_code=404) # raise HTTPException(status_code=404)
try: # try:
session.delete_node(node_path) # session.delete_node(node_path)
ApiDependencies.invoker.services.graph_execution_manager.set( # ApiDependencies.invoker.services.graph_execution_manager.set(
session # session
) # TODO: can this be done automatically, or add node through an API? # ) # TODO: can this be done automatically, or add node through an API?
return session # return session
except NodeAlreadyExecutedError: # except NodeAlreadyExecutedError:
raise HTTPException(status_code=400) # raise HTTPException(status_code=400)
except IndexError: # except IndexError:
raise HTTPException(status_code=400) # raise HTTPException(status_code=400)
@session_router.post( # @session_router.post(
"/{session_id}/edges", # "/{session_id}/edges",
operation_id="add_edge", # operation_id="add_edge",
responses={ # responses={
200: {"model": GraphExecutionState}, # 200: {"model": GraphExecutionState},
400: {"description": "Invalid node or link"}, # 400: {"description": "Invalid node or link"},
404: {"description": "Session not found"}, # 404: {"description": "Session not found"},
}, # },
deprecated=True, # deprecated=True,
) # )
async def add_edge( # async def add_edge(
session_id: str = Path(description="The id of the session"), # session_id: str = Path(description="The id of the session"),
edge: Edge = Body(description="The edge to add"), # edge: Edge = Body(description="The edge to add"),
) -> GraphExecutionState: # ) -> GraphExecutionState:
"""Adds an edge to the graph""" # """Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) # session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: # if session is None:
raise HTTPException(status_code=404) # raise HTTPException(status_code=404)
try: # try:
session.add_edge(edge) # session.add_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set( # ApiDependencies.invoker.services.graph_execution_manager.set(
session # session
) # TODO: can this be done automatically, or add node through an API? # ) # TODO: can this be done automatically, or add node through an API?
return session # return session
except NodeAlreadyExecutedError: # except NodeAlreadyExecutedError:
raise HTTPException(status_code=400) # raise HTTPException(status_code=400)
except IndexError: # except IndexError:
raise HTTPException(status_code=400) # raise HTTPException(status_code=400)
# TODO: the edge being in the path here is really ugly, find a better solution # # TODO: the edge being in the path here is really ugly, find a better solution
@session_router.delete( # @session_router.delete(
"/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}", # "/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
operation_id="delete_edge", # operation_id="delete_edge",
responses={ # responses={
200: {"model": GraphExecutionState}, # 200: {"model": GraphExecutionState},
400: {"description": "Invalid node or link"}, # 400: {"description": "Invalid node or link"},
404: {"description": "Session not found"}, # 404: {"description": "Session not found"},
}, # },
deprecated=True, # deprecated=True,
) # )
async def delete_edge( # async def delete_edge(
session_id: str = Path(description="The id of the session"), # session_id: str = Path(description="The id of the session"),
from_node_id: str = Path(description="The id of the node the edge is coming from"), # from_node_id: str = Path(description="The id of the node the edge is coming from"),
from_field: str = Path(description="The field of the node the edge is coming from"), # from_field: str = Path(description="The field of the node the edge is coming from"),
to_node_id: str = Path(description="The id of the node the edge is going to"), # to_node_id: str = Path(description="The id of the node the edge is going to"),
to_field: str = Path(description="The field of the node the edge is going to"), # to_field: str = Path(description="The field of the node the edge is going to"),
) -> GraphExecutionState: # ) -> GraphExecutionState:
"""Deletes an edge from the graph""" # """Deletes an edge from the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) # session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: # if session is None:
raise HTTPException(status_code=404) # raise HTTPException(status_code=404)
try: # try:
edge = Edge( # edge = Edge(
source=EdgeConnection(node_id=from_node_id, field=from_field), # source=EdgeConnection(node_id=from_node_id, field=from_field),
destination=EdgeConnection(node_id=to_node_id, field=to_field), # destination=EdgeConnection(node_id=to_node_id, field=to_field),
) # )
session.delete_edge(edge) # session.delete_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set( # ApiDependencies.invoker.services.graph_execution_manager.set(
session # session
) # TODO: can this be done automatically, or add node through an API? # ) # TODO: can this be done automatically, or add node through an API?
return session # return session
except NodeAlreadyExecutedError: # except NodeAlreadyExecutedError:
raise HTTPException(status_code=400) # raise HTTPException(status_code=400)
except IndexError: # except IndexError:
raise HTTPException(status_code=400) # raise HTTPException(status_code=400)
@session_router.put( # @session_router.put(
"/{session_id}/invoke", # "/{session_id}/invoke",
operation_id="invoke_session", # operation_id="invoke_session",
responses={ # responses={
200: {"model": None}, # 200: {"model": None},
202: {"description": "The invocation is queued"}, # 202: {"description": "The invocation is queued"},
400: {"description": "The session has no invocations ready to invoke"}, # 400: {"description": "The session has no invocations ready to invoke"},
404: {"description": "Session not found"}, # 404: {"description": "Session not found"},
}, # },
deprecated=True, # deprecated=True,
) # )
async def invoke_session( # async def invoke_session(
queue_id: str = Query(description="The id of the queue to associate the session with"), # queue_id: str = Query(description="The id of the queue to associate the session with"),
session_id: str = Path(description="The id of the session to invoke"), # session_id: str = Path(description="The id of the session to invoke"),
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"), # all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
) -> Response: # ) -> Response:
"""Invokes a session""" # """Invokes a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) # session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None: # if session is None:
raise HTTPException(status_code=404) # raise HTTPException(status_code=404)
if session.is_complete(): # if session.is_complete():
raise HTTPException(status_code=400) # raise HTTPException(status_code=400)
ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all) # ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
return Response(status_code=202) # return Response(status_code=202)
@session_router.delete( # @session_router.delete(
"/{session_id}/invoke", # "/{session_id}/invoke",
operation_id="cancel_session_invoke", # operation_id="cancel_session_invoke",
responses={202: {"description": "The invocation is canceled"}}, # responses={202: {"description": "The invocation is canceled"}},
deprecated=True, # deprecated=True,
) # )
async def cancel_session_invoke( # async def cancel_session_invoke(
session_id: str = Path(description="The id of the session to cancel"), # session_id: str = Path(description="The id of the session to cancel"),
) -> Response: # ) -> Response:
"""Invokes a session""" # """Invokes a session"""
ApiDependencies.invoker.cancel(session_id) # ApiDependencies.invoker.cancel(session_id)
return Response(status_code=202) # return Response(status_code=202)

View File

@ -31,7 +31,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from ..backend.util.logging import InvokeAILogger from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
from .api.routers import app_info, board_images, boards, images, models, session_queue, utilities from .api.routers import app_info, board_images, boards, images, models, sessions, session_queue, utilities
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
@ -85,7 +85,7 @@ async def shutdown_event():
# Include all routers # Include all routers
# app.include_router(sessions.session_router, prefix="/api") app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api")