InvokeAI/ldm/invoke/app/api_app.py
Kyle Schouviller 34e3aa1f88 parent 9eed1919c2
author Kyle Schouviller <kyle0654@hotmail.com> 1669872800 -0800
committer Kyle Schouviller <kyle0654@hotmail.com> 1676240900 -0800

Adding base node architecture

Fix type annotation errors

Runs and generates, but breaks in saving session

Fix default model value setting. Fix deprecation warning.

Fixed node api

Adding markdown docs

Simplifying Generate construction in apps

[nodes] A few minor changes (#2510)

* Pin api-related requirements

* Remove confusing extra CORS origins list

* Adds response models for HTTP 200

[nodes] Adding graph_execution_state to soon replace session. Adding tests with pytest.

Minor typing fixes

[nodes] Fix some small output query hookups

[node] Fixing some additional typing issues

[nodes] Move and expand graph code. Add base item storage and sqlite implementation.

Update startup to match new code

[nodes] Add callbacks to item storage

[nodes] Adding an InvocationContext object to use for invocations to provide easier extensibility

[nodes] New execution model that handles iteration

[nodes] Fixing the CLI

[nodes] Adding a note to the CLI

[nodes] Split processing thread into separate service

[node] Add error message on node processing failure

Removing old files and duplicated packages

Adding python-multipart
2023-02-24 18:57:02 -08:00

165 lines
4.8 KiB
Python

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from inspect import signature
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from fastapi.staticfiles import StaticFiles
from fastapi_events.middleware import EventHandlerASGIMiddleware
from fastapi_events.handlers.local import local_handler
from fastapi.middleware.cors import CORSMiddleware
from pydantic.schema import schema
import uvicorn
from .api.sockets import SocketIO
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .api.routers import images, sessions
from .api.dependencies import ApiDependencies
from ..args import Args
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(
title = "Invoke AI",
docs_url = None,
redoc_url = None
)
# Add event handler
event_handler_id: int = id(app)
app.add_middleware(
EventHandlerASGIMiddleware,
handlers = [local_handler], # TODO: consider doing this in services to support different configurations
middleware_id = event_handler_id)
# Add CORS
# TODO: use configuration for this
origins = []
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
socket_io = SocketIO(app)
config = {}
# Add startup event to load dependencies
@app.on_event('startup')
async def startup_event():
args = Args()
config = args.parse_args()
ApiDependencies.initialize(
args = args,
config = config,
event_handler_id = event_handler_id
)
# Shut down threads
@app.on_event('shutdown')
async def shutdown_event():
ApiDependencies.shutdown()
# Include all routers
# TODO: REMOVE
# app.include_router(
# invocation.invocation_router,
# prefix = '/api')
app.include_router(
sessions.session_router,
prefix = '/api'
)
app.include_router(
images.images_router,
prefix = '/api'
)
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title = app.title,
description = "An API for invoking AI image operations",
version = "1.0.0",
routes = app.routes
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = dict()
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
for schema_key, output_schema in output_schemas['definitions'].items():
openapi_schema["components"]["schemas"][schema_key] = output_schema
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema['title']
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__
output_type = signature(invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' }
if 'additionalProperties' not in invoker_schema:
invoker_schema['additionalProperties'] = {}
invoker_schema['additionalProperties']['outputs'] = outputs_ref
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
# Override API doc favicons
app.mount('/static', StaticFiles(directory='static/dream_web'), name='static')
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=app.title,
swagger_favicon_url="/static/favicon.ico"
)
@app.get("/redoc", include_in_schema=False)
def overridden_redoc():
return get_redoc_html(
openapi_url=app.openapi_url,
title=app.title,
redoc_favicon_url="/static/favicon.ico"
)
def invoke_api():
# Start our own event loop for eventing usage
# TODO: determine if there's a better way to do this
loop = asyncio.new_event_loop()
config = uvicorn.Config(
app = app,
host = "0.0.0.0",
port = 9090,
loop = loop)
# Use access_log to turn off logging
server = uvicorn.Server(config)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
invoke_api()