mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
34e3aa1f88
author Kyle Schouviller <kyle0654@hotmail.com> 1669872800 -0800 committer Kyle Schouviller <kyle0654@hotmail.com> 1676240900 -0800 Adding base node architecture Fix type annotation errors Runs and generates, but breaks in saving session Fix default model value setting. Fix deprecation warning. Fixed node api Adding markdown docs Simplifying Generate construction in apps [nodes] A few minor changes (#2510) * Pin api-related requirements * Remove confusing extra CORS origins list * Adds response models for HTTP 200 [nodes] Adding graph_execution_state to soon replace session. Adding tests with pytest. Minor typing fixes [nodes] Fix some small output query hookups [node] Fixing some additional typing issues [nodes] Move and expand graph code. Add base item storage and sqlite implementation. Update startup to match new code [nodes] Add callbacks to item storage [nodes] Adding an InvocationContext object to use for invocations to provide easier extensibility [nodes] New execution model that handles iteration [nodes] Fixing the CLI [nodes] Adding a note to the CLI [nodes] Split processing thread into separate service [node] Add error message on node processing failure Removing old files and duplicated packages Adding python-multipart
165 lines
4.8 KiB
Python
165 lines
4.8 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
import asyncio
|
|
from inspect import signature
|
|
from fastapi import FastAPI
|
|
from fastapi.openapi.utils import get_openapi
|
|
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
|
from fastapi_events.handlers.local import local_handler
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic.schema import schema
|
|
import uvicorn
|
|
from .api.sockets import SocketIO
|
|
from .invocations import *
|
|
from .invocations.baseinvocation import BaseInvocation
|
|
from .api.routers import images, sessions
|
|
from .api.dependencies import ApiDependencies
|
|
from ..args import Args
|
|
|
|
# Create the app
|
|
# TODO: create this all in a method so configuration/etc. can be passed in?
|
|
app = FastAPI(
|
|
title = "Invoke AI",
|
|
docs_url = None,
|
|
redoc_url = None
|
|
)
|
|
|
|
# Add event handler
|
|
event_handler_id: int = id(app)
|
|
app.add_middleware(
|
|
EventHandlerASGIMiddleware,
|
|
handlers = [local_handler], # TODO: consider doing this in services to support different configurations
|
|
middleware_id = event_handler_id)
|
|
|
|
# Add CORS
|
|
# TODO: use configuration for this
|
|
origins = []
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
socket_io = SocketIO(app)
|
|
|
|
config = {}
|
|
|
|
# Add startup event to load dependencies
|
|
@app.on_event('startup')
|
|
async def startup_event():
|
|
args = Args()
|
|
config = args.parse_args()
|
|
|
|
ApiDependencies.initialize(
|
|
args = args,
|
|
config = config,
|
|
event_handler_id = event_handler_id
|
|
)
|
|
|
|
# Shut down threads
|
|
@app.on_event('shutdown')
|
|
async def shutdown_event():
|
|
ApiDependencies.shutdown()
|
|
|
|
# Include all routers
|
|
# TODO: REMOVE
|
|
# app.include_router(
|
|
# invocation.invocation_router,
|
|
# prefix = '/api')
|
|
|
|
app.include_router(
|
|
sessions.session_router,
|
|
prefix = '/api'
|
|
)
|
|
|
|
app.include_router(
|
|
images.images_router,
|
|
prefix = '/api'
|
|
)
|
|
|
|
# Build a custom OpenAPI to include all outputs
|
|
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
|
def custom_openapi():
|
|
if app.openapi_schema:
|
|
return app.openapi_schema
|
|
openapi_schema = get_openapi(
|
|
title = app.title,
|
|
description = "An API for invoking AI image operations",
|
|
version = "1.0.0",
|
|
routes = app.routes
|
|
)
|
|
|
|
# Add all outputs
|
|
all_invocations = BaseInvocation.get_invocations()
|
|
output_types = set()
|
|
output_type_titles = dict()
|
|
for invoker in all_invocations:
|
|
output_type = signature(invoker.invoke).return_annotation
|
|
output_types.add(output_type)
|
|
|
|
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
|
for schema_key, output_schema in output_schemas['definitions'].items():
|
|
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
|
|
|
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
|
# This could break in some cases, figure out a better way to do it
|
|
output_type_titles[schema_key] = output_schema['title']
|
|
|
|
# Add a reference to the output type to additionalProperties of the invoker schema
|
|
for invoker in all_invocations:
|
|
invoker_name = invoker.__name__
|
|
output_type = signature(invoker.invoke).return_annotation
|
|
output_type_title = output_type_titles[output_type.__name__]
|
|
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
|
outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' }
|
|
if 'additionalProperties' not in invoker_schema:
|
|
invoker_schema['additionalProperties'] = {}
|
|
|
|
invoker_schema['additionalProperties']['outputs'] = outputs_ref
|
|
|
|
app.openapi_schema = openapi_schema
|
|
return app.openapi_schema
|
|
|
|
app.openapi = custom_openapi
|
|
|
|
# Override API doc favicons
|
|
app.mount('/static', StaticFiles(directory='static/dream_web'), name='static')
|
|
|
|
@app.get("/docs", include_in_schema=False)
|
|
def overridden_swagger():
|
|
return get_swagger_ui_html(
|
|
openapi_url=app.openapi_url,
|
|
title=app.title,
|
|
swagger_favicon_url="/static/favicon.ico"
|
|
)
|
|
|
|
@app.get("/redoc", include_in_schema=False)
|
|
def overridden_redoc():
|
|
return get_redoc_html(
|
|
openapi_url=app.openapi_url,
|
|
title=app.title,
|
|
redoc_favicon_url="/static/favicon.ico"
|
|
)
|
|
|
|
def invoke_api():
|
|
# Start our own event loop for eventing usage
|
|
# TODO: determine if there's a better way to do this
|
|
loop = asyncio.new_event_loop()
|
|
config = uvicorn.Config(
|
|
app = app,
|
|
host = "0.0.0.0",
|
|
port = 9090,
|
|
loop = loop)
|
|
# Use access_log to turn off logging
|
|
|
|
server = uvicorn.Server(config)
|
|
loop.run_until_complete(server.serve())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
invoke_api()
|