InvokeAI/invokeai/app/api_app.py

241 lines
8.0 KiB
Python
Raw Normal View History

# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import asyncio
from inspect import signature
2023-03-03 06:02:00 +00:00
import logging
2023-03-03 06:02:00 +00:00
import uvicorn
import socket
from fastapi import FastAPI
2023-03-03 06:02:00 +00:00
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
2023-03-03 06:02:00 +00:00
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pathlib import Path
from pydantic.schema import schema
2023-03-03 06:02:00 +00:00
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
2023-07-27 14:54:01 +00:00
from invokeai.version.invokeai_version import __version__
import invokeai.frontend.web as web_dir
import mimetypes
2023-05-26 02:01:48 +00:00
2023-03-03 06:02:00 +00:00
from .api.dependencies import ApiDependencies
2023-07-08 09:31:17 +00:00
from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
2023-07-27 14:54:01 +00:00
import torch
2023-08-17 22:45:25 +00:00
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
2023-07-27 14:54:01 +00:00
if torch.backends.mps.is_available():
2023-08-17 22:45:25 +00:00
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)
# fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
2023-07-27 14:54:01 +00:00
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
2023-03-03 06:02:00 +00:00
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
# Add event handler
event_handler_id: int = id(app)
app.add_middleware(
EventHandlerASGIMiddleware,
2023-07-27 14:54:01 +00:00
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
2023-03-03 06:02:00 +00:00
middleware_id=event_handler_id,
)
socket_io = SocketIO(app)
2023-07-27 14:54:01 +00:00
# Add startup event to load dependencies
2023-03-03 06:02:00 +00:00
@app.on_event("startup")
async def startup_event():
app.add_middleware(
CORSMiddleware,
allow_origins=app_config.allow_origins,
allow_credentials=app_config.allow_credentials,
allow_methods=app_config.allow_methods,
allow_headers=app_config.allow_headers,
)
2023-07-27 14:54:01 +00:00
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
2023-03-03 06:02:00 +00:00
# Shut down threads
2023-03-03 06:02:00 +00:00
@app.on_event("shutdown")
async def shutdown_event():
ApiDependencies.shutdown()
2023-03-03 06:02:00 +00:00
# Include all routers
# TODO: REMOVE
# app.include_router(
# invocation.invocation_router,
# prefix = '/api')
2023-03-03 06:02:00 +00:00
app.include_router(sessions.session_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
2023-05-21 13:03:52 +00:00
app.include_router(images.images_router, prefix="/api")
2023-06-14 18:20:23 +00:00
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
2023-06-13 18:51:20 +00:00
2023-07-27 14:54:01 +00:00
app.include_router(app_info.app_router, prefix="/api")
2023-07-08 09:31:17 +00:00
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
2023-03-03 06:02:00 +00:00
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = dict()
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
2023-03-03 06:02:00 +00:00
for schema_key, output_schema in output_schemas["definitions"].items():
output_schema["class"] = "output"
openapi_schema["components"]["schemas"][schema_key] = output_schema
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
2023-03-03 06:02:00 +00:00
output_type_titles[schema_key] = output_schema["title"]
# Add Node Editor UI helper schemas
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__
output_type = signature(invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
2023-03-03 06:02:00 +00:00
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation"
from invokeai.backend.model_management.models import get_model_config_enums
2023-07-27 14:54:01 +00:00
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__
if name in openapi_schema["components"]["schemas"]:
# print(f"Config with name {name} already defined")
continue
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)
app.openapi_schema = openapi_schema
return app.openapi_schema
2023-03-03 06:02:00 +00:00
app.openapi = custom_openapi
# Override API doc favicons
2023-07-27 14:54:01 +00:00
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
2023-03-03 06:02:00 +00:00
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
2023-03-03 06:02:00 +00:00
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=app.title,
2023-03-03 06:02:00 +00:00
swagger_favicon_url="/static/favicon.ico",
)
2023-03-03 06:02:00 +00:00
@app.get("/redoc", include_in_schema=False)
def overridden_redoc():
2023-03-03 06:02:00 +00:00
return get_redoc_html(
openapi_url=app.openapi_url,
title=app.title,
2023-03-03 06:02:00 +00:00
redoc_favicon_url="/static/favicon.ico",
)
2023-05-10 06:18:06 +00:00
# Must mount *after* the other routes else it borks em
2023-07-27 14:54:01 +00:00
app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0], "dist"), html=True), name="ui")
2023-03-03 06:02:00 +00:00
def invoke_api():
def find_port(port: int):
"""Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(("localhost", port)) == 0:
return find_port(port=port + 1)
else:
return port
2023-07-27 14:54:01 +00:00
from invokeai.backend.install.check_root import check_invokeai_root
2023-07-27 14:54:01 +00:00
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
2023-07-27 14:54:01 +00:00
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")
2023-07-28 00:55:27 +00:00
# Start our own event loop for eventing usage
loop = asyncio.new_event_loop()
config = uvicorn.Config(
app=app,
host=app_config.host,
port=port,
loop=loop,
log_level=app_config.log_level,
)
server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
for logname in ["uvicorn.access", "uvicorn"]:
2023-08-17 22:45:25 +00:00
log = logging.getLogger(logname)
log.handlers.clear()
2023-07-27 15:59:29 +00:00
for ch in logger.handlers:
2023-08-17 22:45:25 +00:00
log.addHandler(ch)
2023-07-28 00:55:27 +00:00
loop.run_until_complete(server.serve())
2023-07-27 14:54:01 +00:00
if __name__ == "__main__":
2023-08-17 22:45:25 +00:00
if app_config.version:
print(f"InvokeAI version {__version__}")
else:
invoke_api()