mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
all vestiges of ldm.invoke removed
This commit is contained in:
@ -2,36 +2,37 @@
|
||||
|
||||
import asyncio
|
||||
from inspect import signature
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
from ..args import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .api.routers import images, sessions
|
||||
from .api.dependencies import ApiDependencies
|
||||
from ..args import Args
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(
|
||||
title = "Invoke AI",
|
||||
docs_url = None,
|
||||
redoc_url = None
|
||||
)
|
||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers = [local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id = event_handler_id)
|
||||
handlers=[
|
||||
local_handler
|
||||
], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
@ -48,38 +49,34 @@ socket_io = SocketIO(app)
|
||||
|
||||
config = {}
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event('startup')
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
args = Args()
|
||||
config = args.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
args = args,
|
||||
config = config,
|
||||
event_handler_id = event_handler_id
|
||||
args=args, config=config, event_handler_id=event_handler_id
|
||||
)
|
||||
|
||||
|
||||
# Shut down threads
|
||||
@app.on_event('shutdown')
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Include all routers
|
||||
# TODO: REMOVE
|
||||
# app.include_router(
|
||||
# invocation.invocation_router,
|
||||
# prefix = '/api')
|
||||
|
||||
app.include_router(
|
||||
sessions.session_router,
|
||||
prefix = '/api'
|
||||
)
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(
|
||||
images.images_router,
|
||||
prefix = '/api'
|
||||
)
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
@ -87,10 +84,10 @@ def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title = app.title,
|
||||
description = "An API for invoking AI image operations",
|
||||
version = "1.0.0",
|
||||
routes = app.routes
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
@ -102,12 +99,12 @@ def custom_openapi():
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas['definitions'].items():
|
||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema['title']
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
@ -115,47 +112,47 @@ def custom_openapi():
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' }
|
||||
if 'additionalProperties' not in invoker_schema:
|
||||
invoker_schema['additionalProperties'] = {}
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
if "additionalProperties" not in invoker_schema:
|
||||
invoker_schema["additionalProperties"] = {}
|
||||
|
||||
invoker_schema["additionalProperties"]["outputs"] = outputs_ref
|
||||
|
||||
invoker_schema['additionalProperties']['outputs'] = outputs_ref
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount('/static', StaticFiles(directory='static/dream_web'), name='static')
|
||||
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
return get_swagger_ui_html(
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
swagger_favicon_url="/static/favicon.ico"
|
||||
swagger_favicon_url="/static/favicon.ico",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/redoc", include_in_schema=False)
|
||||
def overridden_redoc():
|
||||
return get_redoc_html(
|
||||
return get_redoc_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
redoc_favicon_url="/static/favicon.ico"
|
||||
redoc_favicon_url="/static/favicon.ico",
|
||||
)
|
||||
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(
|
||||
app = app,
|
||||
host = "0.0.0.0",
|
||||
port = 9090,
|
||||
loop = loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
Reference in New Issue
Block a user