mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
all files migrated; tweaks needed
This commit is contained in:
80
invokeai/app/api/dependencies.py
Normal file
80
invokeai/app/api/dependencies.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from argparse import Namespace
|
||||
import os
|
||||
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
|
||||
from ...globals import Globals
|
||||
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.generate_initializer import get_generate
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
# TODO: is there a better way to achieve this?
|
||||
def check_internet()->bool:
|
||||
'''
|
||||
Return true if the internet is reachable.
|
||||
It does this by pinging huggingface.co.
|
||||
'''
|
||||
import urllib.request
|
||||
host = 'http://huggingface.co'
|
||||
try:
|
||||
urllib.request.urlopen(host,timeout=1)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(
|
||||
args,
|
||||
config,
|
||||
event_handler_id: int
|
||||
):
|
||||
Globals.try_patchmatch = args.patchmatch
|
||||
Globals.always_use_cpu = args.always_use_cpu
|
||||
Globals.internet_available = args.internet_available and check_internet()
|
||||
Globals.disable_xformers = not args.xformers
|
||||
Globals.ckpt_convert = args.ckpt_convert
|
||||
|
||||
# TODO: Use a logger
|
||||
print(f'>> Internet connectivity is {Globals.internet_available}')
|
||||
|
||||
generate = get_generate(args, config)
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../outputs'))
|
||||
|
||||
images = DiskImageStorage(output_folder)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, 'invokeai.db')
|
||||
|
||||
services = InvocationServices(
|
||||
generate = generate,
|
||||
events = events,
|
||||
images = images,
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
def shutdown():
|
||||
if ApiDependencies.invoker:
|
||||
ApiDependencies.invoker.stop()
|
54
invokeai/app/api/events.py
Normal file
54
invokeai/app/api/events.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
from ..services.events import EventServiceBase
|
||||
import threading
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
event_handler_id: int
|
||||
__queue: Queue
|
||||
__stop_event: threading.Event
|
||||
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self.__queue = Queue()
|
||||
self.__stop_event = threading.Event()
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self.__stop_event.set()
|
||||
self.__queue.put(None)
|
||||
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.__queue.put(dict(
|
||||
event_name = event_name,
|
||||
payload = payload
|
||||
))
|
||||
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self.__queue.get(block = False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
|
||||
dispatch(
|
||||
event.get('event_name'),
|
||||
payload = event.get('payload'),
|
||||
middleware_id = self.event_handler_id)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.001)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
57
invokeai/app/api/routers/images.py
Normal file
57
invokeai/app/api/routers/images.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import Path, UploadFile, Request
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from PIL import Image
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(
|
||||
prefix = '/v1/images',
|
||||
tags = ['images']
|
||||
)
|
||||
|
||||
|
||||
@images_router.get('/{image_type}/{image_name}',
|
||||
operation_id = 'get_image'
|
||||
)
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description = "The type of image to get"),
|
||||
image_name: str = Path(description = "The name of the image to get")
|
||||
):
|
||||
"""Gets a result"""
|
||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||
return FileResponse(filename)
|
||||
|
||||
@images_router.post('/uploads/',
|
||||
operation_id = 'upload_image',
|
||||
responses = {
|
||||
201: {'description': 'The image was uploaded successfully'},
|
||||
404: {'description': 'Session not found'}
|
||||
})
|
||||
async def upload_image(
|
||||
file: UploadFile,
|
||||
request: Request
|
||||
):
|
||||
if not file.content_type.startswith('image'):
|
||||
return Response(status_code = 415)
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
im = Image.open(contents)
|
||||
except:
|
||||
# Error opening the image
|
||||
return Response(status_code = 415)
|
||||
|
||||
filename = f'{str(int(datetime.now(timezone.utc).timestamp()))}.png'
|
||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||
|
||||
return Response(
|
||||
status_code=201,
|
||||
headers = {
|
||||
'Location': request.url_for('get_image', image_type=ImageType.UPLOAD, image_name=filename)
|
||||
}
|
||||
)
|
232
invokeai/app/api/routers/sessions.py
Normal file
232
invokeai/app/api/routers/sessions.py
Normal file
@ -0,0 +1,232 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import List, Optional, Union, Annotated
|
||||
from fastapi import Query, Path, Body
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import Response
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||
from ...invocations import *
|
||||
|
||||
session_router = APIRouter(
|
||||
prefix = '/v1/sessions',
|
||||
tags = ['sessions']
|
||||
)
|
||||
|
||||
|
||||
@session_router.post('/',
|
||||
operation_id = 'create_session',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid json'}
|
||||
})
|
||||
async def create_session(
|
||||
graph: Optional[Graph] = Body(default = None, description = "The graph to initialize the session with")
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||
return session
|
||||
|
||||
|
||||
@session_router.get('/',
|
||||
operation_id = 'list_sessions',
|
||||
responses = {
|
||||
200: {"model": PaginatedResults[GraphExecutionState]}
|
||||
})
|
||||
async def list_sessions(
|
||||
page: int = Query(default = 0, description = "The page of results to get"),
|
||||
per_page: int = Query(default = 10, description = "The number of results per page"),
|
||||
query: str = Query(default = '', description = "The query string to search for")
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if filter == '':
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
else:
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
return result
|
||||
|
||||
|
||||
@session_router.get('/{session_id}',
|
||||
operation_id = 'get_session',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
404: {'description': 'Session not found'}
|
||||
})
|
||||
async def get_session(
|
||||
session_id: str = Path(description = "The id of the session to get")
|
||||
) -> GraphExecutionState:
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
else:
|
||||
return session
|
||||
|
||||
|
||||
@session_router.post('/{session_id}/nodes',
|
||||
operation_id = 'add_node',
|
||||
responses = {
|
||||
200: {"model": str},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.add_node(node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session.id
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.put('/{session_id}/nodes/{node_path}',
|
||||
operation_id = 'update_node',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def update_node(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
node_path: str = Path(description = "The path to the node in the graph"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.update_node(node_path, node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.delete('/{session_id}/nodes/{node_path}',
|
||||
operation_id = 'delete_node',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def delete_node(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
node_path: str = Path(description = "The path to the node to delete")
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.delete_node(node_path)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.post('/{session_id}/edges',
|
||||
operation_id = 'add_edge',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
|
||||
) -> GraphExecutionState:
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.add_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||
@session_router.delete('/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}',
|
||||
operation_id = 'delete_edge',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def delete_edge(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
from_node_id: str = Path(description = "The id of the node the edge is coming from"),
|
||||
from_field: str = Path(description = "The field of the node the edge is coming from"),
|
||||
to_node_id: str = Path(description = "The id of the node the edge is going to"),
|
||||
to_field: str = Path(description = "The field of the node the edge is going to")
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes an edge from the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.put('/{session_id}/invoke',
|
||||
operation_id = 'invoke_session',
|
||||
responses = {
|
||||
200: {"model": None},
|
||||
202: {'description': 'The invocation is queued'},
|
||||
400: {'description': 'The session has no invocations ready to invoke'},
|
||||
404: {'description': 'Session not found'}
|
||||
})
|
||||
async def invoke_session(
|
||||
session_id: str = Path(description = "The id of the session to invoke"),
|
||||
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
|
||||
) -> None:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
if session.is_complete():
|
||||
return Response(status_code = 400)
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all = all)
|
||||
return Response(status_code=202)
|
36
invokeai/app/api/sockets.py
Normal file
36
invokeai/app/api/sockets.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi_socketio import SocketManager
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from ..services.events import EventServiceBase
|
||||
|
||||
class SocketIO:
|
||||
__sio: SocketManager
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app = app)
|
||||
self.__sio.on('subscribe', handler=self._handle_sub)
|
||||
self.__sio.on('unsubscribe', handler=self._handle_unsub)
|
||||
|
||||
local_handler.register(
|
||||
event_name = EventServiceBase.session_event,
|
||||
_func=self._handle_session_event
|
||||
)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event = event[1]['event'],
|
||||
data = event[1]['data'],
|
||||
room = event[1]['data']['graph_execution_state_id']
|
||||
)
|
||||
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
if 'session' in data:
|
||||
self.__sio.enter_room(sid, data['session'])
|
||||
|
||||
# @app.sio.on('unsubscribe')
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
if 'session' in data:
|
||||
self.__sio.leave_room(sid, data['session'])
|
Reference in New Issue
Block a user