all files migrated; tweaks needed

This commit is contained in:
Lincoln Stein
2023-03-03 00:02:15 -05:00
parent 3f0b0f3250
commit 6a990565ff
496 changed files with 276 additions and 934 deletions

View File

@ -0,0 +1,80 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from argparse import Namespace
import os
from ..services.processor import DefaultInvocationProcessor
from ..services.graph import GraphExecutionState
from ..services.sqlite import SqliteItemStorage
from ...globals import Globals
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
from ..services.generate_initializer import get_generate
from .events import FastAPIEventService
# TODO: is there a better way to achieve this?
def check_internet()->bool:
'''
Return true if the internet is reachable.
It does this by pinging huggingface.co.
'''
import urllib.request
host = 'http://huggingface.co'
try:
urllib.request.urlopen(host,timeout=1)
return True
except:
return False
class ApiDependencies:
"""Contains and initializes all dependencies for the API"""
invoker: Invoker = None
@staticmethod
def initialize(
args,
config,
event_handler_id: int
):
Globals.try_patchmatch = args.patchmatch
Globals.always_use_cpu = args.always_use_cpu
Globals.internet_available = args.internet_available and check_internet()
Globals.disable_xformers = not args.xformers
Globals.ckpt_convert = args.ckpt_convert
# TODO: Use a logger
print(f'>> Internet connectivity is {Globals.internet_available}')
generate = get_generate(args, config)
events = FastAPIEventService(event_handler_id)
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../outputs'))
images = DiskImageStorage(output_folder)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db')
services = InvocationServices(
generate = generate,
events = events,
images = images,
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
ApiDependencies.invoker = Invoker(services)
@staticmethod
def shutdown():
if ApiDependencies.invoker:
ApiDependencies.invoker.stop()

View File

@ -0,0 +1,54 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from queue import Empty, Queue
from typing import Any
from fastapi_events.dispatcher import dispatch
from ..services.events import EventServiceBase
import threading
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
__stop_event: threading.Event
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self.__queue = Queue()
self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put(dict(
event_name = event_name,
payload = payload
))
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self.__queue.get(block = False)
if not event: # Probably stopping
continue
dispatch(
event.get('event_name'),
payload = event.get('payload'),
middleware_id = self.event_handler_id)
except Empty:
await asyncio.sleep(0.001)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -0,0 +1,57 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from fastapi import Path, UploadFile, Request
from fastapi.routing import APIRouter
from fastapi.responses import FileResponse, Response
from PIL import Image
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
images_router = APIRouter(
prefix = '/v1/images',
tags = ['images']
)
@images_router.get('/{image_type}/{image_name}',
operation_id = 'get_image'
)
async def get_image(
image_type: ImageType = Path(description = "The type of image to get"),
image_name: str = Path(description = "The name of the image to get")
):
"""Gets a result"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename)
@images_router.post('/uploads/',
operation_id = 'upload_image',
responses = {
201: {'description': 'The image was uploaded successfully'},
404: {'description': 'Session not found'}
})
async def upload_image(
file: UploadFile,
request: Request
):
if not file.content_type.startswith('image'):
return Response(status_code = 415)
contents = await file.read()
try:
im = Image.open(contents)
except:
# Error opening the image
return Response(status_code = 415)
filename = f'{str(int(datetime.now(timezone.utc).timestamp()))}.png'
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
return Response(
status_code=201,
headers = {
'Location': request.url_for('get_image', image_type=ImageType.UPLOAD, image_name=filename)
}
)

View File

@ -0,0 +1,232 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import List, Optional, Union, Annotated
from fastapi import Query, Path, Body
from fastapi.routing import APIRouter
from fastapi.responses import Response
from pydantic.fields import Field
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
from ...invocations import *
session_router = APIRouter(
prefix = '/v1/sessions',
tags = ['sessions']
)
@session_router.post('/',
operation_id = 'create_session',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid json'}
})
async def create_session(
graph: Optional[Graph] = Body(default = None, description = "The graph to initialize the session with")
) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.create_execution_state(graph)
return session
@session_router.get('/',
operation_id = 'list_sessions',
responses = {
200: {"model": PaginatedResults[GraphExecutionState]}
})
async def list_sessions(
page: int = Query(default = 0, description = "The page of results to get"),
per_page: int = Query(default = 10, description = "The number of results per page"),
query: str = Query(default = '', description = "The query string to search for")
) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching"""
if filter == '':
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
else:
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
return result
@session_router.get('/{session_id}',
operation_id = 'get_session',
responses = {
200: {"model": GraphExecutionState},
404: {'description': 'Session not found'}
})
async def get_session(
session_id: str = Path(description = "The id of the session to get")
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
else:
return session
@session_router.post('/{session_id}/nodes',
operation_id = 'add_node',
responses = {
200: {"model": str},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def add_node(
session_id: str = Path(description = "The id of the session"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
) -> str:
"""Adds a node to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.add_node(node)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session.id
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.put('/{session_id}/nodes/{node_path}',
operation_id = 'update_node',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def update_node(
session_id: str = Path(description = "The id of the session"),
node_path: str = Path(description = "The path to the node in the graph"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.update_node(node_path, node)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.delete('/{session_id}/nodes/{node_path}',
operation_id = 'delete_node',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def delete_node(
session_id: str = Path(description = "The id of the session"),
node_path: str = Path(description = "The path to the node to delete")
) -> GraphExecutionState:
"""Deletes a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.delete_node(node_path)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.post('/{session_id}/edges',
operation_id = 'add_edge',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def add_edge(
session_id: str = Path(description = "The id of the session"),
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
) -> GraphExecutionState:
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.add_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
# TODO: the edge being in the path here is really ugly, find a better solution
@session_router.delete('/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}',
operation_id = 'delete_edge',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def delete_edge(
session_id: str = Path(description = "The id of the session"),
from_node_id: str = Path(description = "The id of the node the edge is coming from"),
from_field: str = Path(description = "The field of the node the edge is coming from"),
to_node_id: str = Path(description = "The id of the node the edge is going to"),
to_field: str = Path(description = "The field of the node the edge is going to")
) -> GraphExecutionState:
"""Deletes an edge from the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
session.delete_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.put('/{session_id}/invoke',
operation_id = 'invoke_session',
responses = {
200: {"model": None},
202: {'description': 'The invocation is queued'},
400: {'description': 'The session has no invocations ready to invoke'},
404: {'description': 'Session not found'}
})
async def invoke_session(
session_id: str = Path(description = "The id of the session to invoke"),
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
) -> None:
"""Invokes a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
if session.is_complete():
return Response(status_code = 400)
ApiDependencies.invoker.invoke(session, invoke_all = all)
return Response(status_code=202)

View File

@ -0,0 +1,36 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from fastapi import FastAPI
from fastapi_socketio import SocketManager
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from ..services.events import EventServiceBase
class SocketIO:
__sio: SocketManager
def __init__(self, app: FastAPI):
self.__sio = SocketManager(app = app)
self.__sio.on('subscribe', handler=self._handle_sub)
self.__sio.on('unsubscribe', handler=self._handle_unsub)
local_handler.register(
event_name = EventServiceBase.session_event,
_func=self._handle_session_event
)
async def _handle_session_event(self, event: Event):
await self.__sio.emit(
event = event[1]['event'],
data = event[1]['data'],
room = event[1]['data']['graph_execution_state_id']
)
async def _handle_sub(self, sid, data, *args, **kwargs):
if 'session' in data:
self.__sio.enter_room(sid, data['session'])
# @app.sio.on('unsubscribe')
async def _handle_unsub(self, sid, data, *args, **kwargs):
if 'session' in data:
self.__sio.leave_room(sid, data['session'])