mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(api): restore get_session route
This commit is contained in:
parent
2c39557dc9
commit
685cda89ff
@ -1,57 +1,50 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Annotated, Optional, Union
|
|
||||||
|
|
||||||
from fastapi import Body, HTTPException, Path, Query, Response
|
from fastapi import HTTPException, Path
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic.fields import Field
|
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from ...services.shared.graph import GraphExecutionState
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
|
||||||
from ...invocations import * # noqa: F401 F403
|
|
||||||
from ...invocations.baseinvocation import BaseInvocation
|
|
||||||
from ...services.shared.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||||
|
|
||||||
|
|
||||||
@session_router.post(
|
# @session_router.post(
|
||||||
"/",
|
# "/",
|
||||||
operation_id="create_session",
|
# operation_id="create_session",
|
||||||
responses={
|
# responses={
|
||||||
200: {"model": GraphExecutionState},
|
# 200: {"model": GraphExecutionState},
|
||||||
400: {"description": "Invalid json"},
|
# 400: {"description": "Invalid json"},
|
||||||
},
|
# },
|
||||||
deprecated=True,
|
# deprecated=True,
|
||||||
)
|
# )
|
||||||
async def create_session(
|
# async def create_session(
|
||||||
queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
|
# queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
|
||||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
# graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||||
) -> GraphExecutionState:
|
# ) -> GraphExecutionState:
|
||||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
# """Creates a new session, optionally initializing it with an invocation graph"""
|
||||||
session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
|
# session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
|
||||||
return session
|
# return session
|
||||||
|
|
||||||
|
|
||||||
@session_router.get(
|
# @session_router.get(
|
||||||
"/",
|
# "/",
|
||||||
operation_id="list_sessions",
|
# operation_id="list_sessions",
|
||||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
# responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||||
deprecated=True,
|
# deprecated=True,
|
||||||
)
|
# )
|
||||||
async def list_sessions(
|
# async def list_sessions(
|
||||||
page: int = Query(default=0, description="The page of results to get"),
|
# page: int = Query(default=0, description="The page of results to get"),
|
||||||
per_page: int = Query(default=10, description="The number of results per page"),
|
# per_page: int = Query(default=10, description="The number of results per page"),
|
||||||
query: str = Query(default="", description="The query string to search for"),
|
# query: str = Query(default="", description="The query string to search for"),
|
||||||
) -> PaginatedResults[GraphExecutionState]:
|
# ) -> PaginatedResults[GraphExecutionState]:
|
||||||
"""Gets a list of sessions, optionally searching"""
|
# """Gets a list of sessions, optionally searching"""
|
||||||
if query == "":
|
# if query == "":
|
||||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
# result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||||
else:
|
# else:
|
||||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
# result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||||
return result
|
# return result
|
||||||
|
|
||||||
|
|
||||||
@session_router.get(
|
@session_router.get(
|
||||||
@ -61,7 +54,6 @@ async def list_sessions(
|
|||||||
200: {"model": GraphExecutionState},
|
200: {"model": GraphExecutionState},
|
||||||
404: {"description": "Session not found"},
|
404: {"description": "Session not found"},
|
||||||
},
|
},
|
||||||
deprecated=True,
|
|
||||||
)
|
)
|
||||||
async def get_session(
|
async def get_session(
|
||||||
session_id: str = Path(description="The id of the session to get"),
|
session_id: str = Path(description="The id of the session to get"),
|
||||||
@ -74,211 +66,211 @@ async def get_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
@session_router.post(
|
# @session_router.post(
|
||||||
"/{session_id}/nodes",
|
# "/{session_id}/nodes",
|
||||||
operation_id="add_node",
|
# operation_id="add_node",
|
||||||
responses={
|
# responses={
|
||||||
200: {"model": str},
|
# 200: {"model": str},
|
||||||
400: {"description": "Invalid node or link"},
|
# 400: {"description": "Invalid node or link"},
|
||||||
404: {"description": "Session not found"},
|
# 404: {"description": "Session not found"},
|
||||||
},
|
# },
|
||||||
deprecated=True,
|
# deprecated=True,
|
||||||
)
|
# )
|
||||||
async def add_node(
|
# async def add_node(
|
||||||
session_id: str = Path(description="The id of the session"),
|
# session_id: str = Path(description="The id of the session"),
|
||||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||||
description="The node to add"
|
# description="The node to add"
|
||||||
),
|
# ),
|
||||||
) -> str:
|
# ) -> str:
|
||||||
"""Adds a node to the graph"""
|
# """Adds a node to the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
# if session is None:
|
||||||
raise HTTPException(status_code=404)
|
# raise HTTPException(status_code=404)
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
session.add_node(node)
|
# session.add_node(node)
|
||||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||||
session
|
# session
|
||||||
) # TODO: can this be done automatically, or add node through an API?
|
# ) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session.id
|
# return session.id
|
||||||
except NodeAlreadyExecutedError:
|
# except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
# raise HTTPException(status_code=400)
|
||||||
except IndexError:
|
# except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
# raise HTTPException(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.put(
|
# @session_router.put(
|
||||||
"/{session_id}/nodes/{node_path}",
|
# "/{session_id}/nodes/{node_path}",
|
||||||
operation_id="update_node",
|
# operation_id="update_node",
|
||||||
responses={
|
# responses={
|
||||||
200: {"model": GraphExecutionState},
|
# 200: {"model": GraphExecutionState},
|
||||||
400: {"description": "Invalid node or link"},
|
# 400: {"description": "Invalid node or link"},
|
||||||
404: {"description": "Session not found"},
|
# 404: {"description": "Session not found"},
|
||||||
},
|
# },
|
||||||
deprecated=True,
|
# deprecated=True,
|
||||||
)
|
# )
|
||||||
async def update_node(
|
# async def update_node(
|
||||||
session_id: str = Path(description="The id of the session"),
|
# session_id: str = Path(description="The id of the session"),
|
||||||
node_path: str = Path(description="The path to the node in the graph"),
|
# node_path: str = Path(description="The path to the node in the graph"),
|
||||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||||
description="The new node"
|
# description="The new node"
|
||||||
),
|
# ),
|
||||||
) -> GraphExecutionState:
|
# ) -> GraphExecutionState:
|
||||||
"""Updates a node in the graph and removes all linked edges"""
|
# """Updates a node in the graph and removes all linked edges"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
# if session is None:
|
||||||
raise HTTPException(status_code=404)
|
# raise HTTPException(status_code=404)
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
session.update_node(node_path, node)
|
# session.update_node(node_path, node)
|
||||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||||
session
|
# session
|
||||||
) # TODO: can this be done automatically, or add node through an API?
|
# ) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
# return session
|
||||||
except NodeAlreadyExecutedError:
|
# except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
# raise HTTPException(status_code=400)
|
||||||
except IndexError:
|
# except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
# raise HTTPException(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.delete(
|
# @session_router.delete(
|
||||||
"/{session_id}/nodes/{node_path}",
|
# "/{session_id}/nodes/{node_path}",
|
||||||
operation_id="delete_node",
|
# operation_id="delete_node",
|
||||||
responses={
|
# responses={
|
||||||
200: {"model": GraphExecutionState},
|
# 200: {"model": GraphExecutionState},
|
||||||
400: {"description": "Invalid node or link"},
|
# 400: {"description": "Invalid node or link"},
|
||||||
404: {"description": "Session not found"},
|
# 404: {"description": "Session not found"},
|
||||||
},
|
# },
|
||||||
deprecated=True,
|
# deprecated=True,
|
||||||
)
|
# )
|
||||||
async def delete_node(
|
# async def delete_node(
|
||||||
session_id: str = Path(description="The id of the session"),
|
# session_id: str = Path(description="The id of the session"),
|
||||||
node_path: str = Path(description="The path to the node to delete"),
|
# node_path: str = Path(description="The path to the node to delete"),
|
||||||
) -> GraphExecutionState:
|
# ) -> GraphExecutionState:
|
||||||
"""Deletes a node in the graph and removes all linked edges"""
|
# """Deletes a node in the graph and removes all linked edges"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
# if session is None:
|
||||||
raise HTTPException(status_code=404)
|
# raise HTTPException(status_code=404)
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
session.delete_node(node_path)
|
# session.delete_node(node_path)
|
||||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||||
session
|
# session
|
||||||
) # TODO: can this be done automatically, or add node through an API?
|
# ) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
# return session
|
||||||
except NodeAlreadyExecutedError:
|
# except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
# raise HTTPException(status_code=400)
|
||||||
except IndexError:
|
# except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
# raise HTTPException(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.post(
|
# @session_router.post(
|
||||||
"/{session_id}/edges",
|
# "/{session_id}/edges",
|
||||||
operation_id="add_edge",
|
# operation_id="add_edge",
|
||||||
responses={
|
# responses={
|
||||||
200: {"model": GraphExecutionState},
|
# 200: {"model": GraphExecutionState},
|
||||||
400: {"description": "Invalid node or link"},
|
# 400: {"description": "Invalid node or link"},
|
||||||
404: {"description": "Session not found"},
|
# 404: {"description": "Session not found"},
|
||||||
},
|
# },
|
||||||
deprecated=True,
|
# deprecated=True,
|
||||||
)
|
# )
|
||||||
async def add_edge(
|
# async def add_edge(
|
||||||
session_id: str = Path(description="The id of the session"),
|
# session_id: str = Path(description="The id of the session"),
|
||||||
edge: Edge = Body(description="The edge to add"),
|
# edge: Edge = Body(description="The edge to add"),
|
||||||
) -> GraphExecutionState:
|
# ) -> GraphExecutionState:
|
||||||
"""Adds an edge to the graph"""
|
# """Adds an edge to the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
# if session is None:
|
||||||
raise HTTPException(status_code=404)
|
# raise HTTPException(status_code=404)
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
session.add_edge(edge)
|
# session.add_edge(edge)
|
||||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||||
session
|
# session
|
||||||
) # TODO: can this be done automatically, or add node through an API?
|
# ) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
# return session
|
||||||
except NodeAlreadyExecutedError:
|
# except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
# raise HTTPException(status_code=400)
|
||||||
except IndexError:
|
# except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
# raise HTTPException(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
# # TODO: the edge being in the path here is really ugly, find a better solution
|
||||||
@session_router.delete(
|
# @session_router.delete(
|
||||||
"/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
|
# "/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
|
||||||
operation_id="delete_edge",
|
# operation_id="delete_edge",
|
||||||
responses={
|
# responses={
|
||||||
200: {"model": GraphExecutionState},
|
# 200: {"model": GraphExecutionState},
|
||||||
400: {"description": "Invalid node or link"},
|
# 400: {"description": "Invalid node or link"},
|
||||||
404: {"description": "Session not found"},
|
# 404: {"description": "Session not found"},
|
||||||
},
|
# },
|
||||||
deprecated=True,
|
# deprecated=True,
|
||||||
)
|
# )
|
||||||
async def delete_edge(
|
# async def delete_edge(
|
||||||
session_id: str = Path(description="The id of the session"),
|
# session_id: str = Path(description="The id of the session"),
|
||||||
from_node_id: str = Path(description="The id of the node the edge is coming from"),
|
# from_node_id: str = Path(description="The id of the node the edge is coming from"),
|
||||||
from_field: str = Path(description="The field of the node the edge is coming from"),
|
# from_field: str = Path(description="The field of the node the edge is coming from"),
|
||||||
to_node_id: str = Path(description="The id of the node the edge is going to"),
|
# to_node_id: str = Path(description="The id of the node the edge is going to"),
|
||||||
to_field: str = Path(description="The field of the node the edge is going to"),
|
# to_field: str = Path(description="The field of the node the edge is going to"),
|
||||||
) -> GraphExecutionState:
|
# ) -> GraphExecutionState:
|
||||||
"""Deletes an edge from the graph"""
|
# """Deletes an edge from the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
# if session is None:
|
||||||
raise HTTPException(status_code=404)
|
# raise HTTPException(status_code=404)
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
edge = Edge(
|
# edge = Edge(
|
||||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
# source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||||
destination=EdgeConnection(node_id=to_node_id, field=to_field),
|
# destination=EdgeConnection(node_id=to_node_id, field=to_field),
|
||||||
)
|
# )
|
||||||
session.delete_edge(edge)
|
# session.delete_edge(edge)
|
||||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||||
session
|
# session
|
||||||
) # TODO: can this be done automatically, or add node through an API?
|
# ) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
# return session
|
||||||
except NodeAlreadyExecutedError:
|
# except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
# raise HTTPException(status_code=400)
|
||||||
except IndexError:
|
# except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
# raise HTTPException(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.put(
|
# @session_router.put(
|
||||||
"/{session_id}/invoke",
|
# "/{session_id}/invoke",
|
||||||
operation_id="invoke_session",
|
# operation_id="invoke_session",
|
||||||
responses={
|
# responses={
|
||||||
200: {"model": None},
|
# 200: {"model": None},
|
||||||
202: {"description": "The invocation is queued"},
|
# 202: {"description": "The invocation is queued"},
|
||||||
400: {"description": "The session has no invocations ready to invoke"},
|
# 400: {"description": "The session has no invocations ready to invoke"},
|
||||||
404: {"description": "Session not found"},
|
# 404: {"description": "Session not found"},
|
||||||
},
|
# },
|
||||||
deprecated=True,
|
# deprecated=True,
|
||||||
)
|
# )
|
||||||
async def invoke_session(
|
# async def invoke_session(
|
||||||
queue_id: str = Query(description="The id of the queue to associate the session with"),
|
# queue_id: str = Query(description="The id of the queue to associate the session with"),
|
||||||
session_id: str = Path(description="The id of the session to invoke"),
|
# session_id: str = Path(description="The id of the session to invoke"),
|
||||||
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
# all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||||
) -> Response:
|
# ) -> Response:
|
||||||
"""Invokes a session"""
|
# """Invokes a session"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
# if session is None:
|
||||||
raise HTTPException(status_code=404)
|
# raise HTTPException(status_code=404)
|
||||||
|
|
||||||
if session.is_complete():
|
# if session.is_complete():
|
||||||
raise HTTPException(status_code=400)
|
# raise HTTPException(status_code=400)
|
||||||
|
|
||||||
ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
|
# ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
|
||||||
return Response(status_code=202)
|
# return Response(status_code=202)
|
||||||
|
|
||||||
|
|
||||||
@session_router.delete(
|
# @session_router.delete(
|
||||||
"/{session_id}/invoke",
|
# "/{session_id}/invoke",
|
||||||
operation_id="cancel_session_invoke",
|
# operation_id="cancel_session_invoke",
|
||||||
responses={202: {"description": "The invocation is canceled"}},
|
# responses={202: {"description": "The invocation is canceled"}},
|
||||||
deprecated=True,
|
# deprecated=True,
|
||||||
)
|
# )
|
||||||
async def cancel_session_invoke(
|
# async def cancel_session_invoke(
|
||||||
session_id: str = Path(description="The id of the session to cancel"),
|
# session_id: str = Path(description="The id of the session to cancel"),
|
||||||
) -> Response:
|
# ) -> Response:
|
||||||
"""Invokes a session"""
|
# """Invokes a session"""
|
||||||
ApiDependencies.invoker.cancel(session_id)
|
# ApiDependencies.invoker.cancel(session_id)
|
||||||
return Response(status_code=202)
|
# return Response(status_code=202)
|
||||||
|
@ -31,7 +31,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, utilities
|
from .api.routers import app_info, board_images, boards, images, models, sessions, session_queue, utilities
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ async def shutdown_event():
|
|||||||
|
|
||||||
|
|
||||||
# Include all routers
|
# Include all routers
|
||||||
# app.include_router(sessions.session_router, prefix="/api")
|
app.include_router(sessions.session_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(utilities.utilities_router, prefix="/api")
|
app.include_router(utilities.utilities_router, prefix="/api")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user