import asyncio import mimetypes import os import signal import socket import sys from contextlib import asynccontextmanager from inspect import signature from pathlib import Path from typing import Any import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware from pydantic.json_schema import models_json_schema from torch.backends.mps import is_available as is_mps_available # for PyCharm: # noinspection PyUnresolvedReferences import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.frontend.web as web_dir from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.config.config_default import get_config from invokeai.app.services.session_processor.session_processor_common import ProgressImage from ..backend.util.logging import InvokeAILogger from .api.dependencies import ApiDependencies from .api.routers import ( app_info, board_images, boards, download_queue, images, model_manager, session_queue, utilities, workflows, ) from .api.sockets import SocketIO from .invocations.baseinvocation import ( BaseInvocation, UIConfigBase, ) from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra app_config = get_config() if is_mps_available(): import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) logger = InvokeAILogger.get_logger(config=app_config) # fix for windows mimetypes registry entries being borked # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("text/css", ".css") @asynccontextmanager async def lifespan(app: FastAPI): # Add startup event to load dependencies ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger) yield # Shut down threads ApiDependencies.shutdown() # Create the app # TODO: create this all in a method so configuration/etc. can be passed in? app = FastAPI( title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False, lifespan=lifespan, ) # Add event handler event_handler_id: int = id(app) app.add_middleware( EventHandlerASGIMiddleware, handlers=[local_handler], # TODO: consider doing this in services to support different configurations middleware_id=event_handler_id, ) socket_io = SocketIO(app) app.add_middleware( CORSMiddleware, allow_origins=app_config.allow_origins, allow_credentials=app_config.allow_credentials, allow_methods=app_config.allow_methods, allow_headers=app_config.allow_headers, ) app.add_middleware(GZipMiddleware, minimum_size=1000) # Include all routers app.include_router(utilities.utilities_router, prefix="/api") app.include_router(model_manager.model_manager_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") app.include_router(board_images.board_images_router, prefix="/api") app.include_router(app_info.app_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api") # Build a custom OpenAPI to include all outputs # TODO: can outputs be included on metadata of invocation schemas somehow? def custom_openapi() -> dict[str, Any]: if app.openapi_schema: return app.openapi_schema openapi_schema = get_openapi( title=app.title, description="An API for invoking AI image operations", version="1.0.0", routes=app.routes, separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/ ) # Add all outputs all_invocations = BaseInvocation.get_invocations() output_types = set() output_type_titles = {} for invoker in all_invocations: output_type = signature(invoker.invoke).return_annotation output_types.add(output_type) output_schemas = models_json_schema( models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}" ) for schema_key, output_schema in output_schemas[1]["$defs"].items(): # TODO: note that we assume the schema_key here is the TYPE.__name__ # This could break in some cases, figure out a better way to do it output_type_titles[schema_key] = output_schema["title"] openapi_schema["components"]["schemas"][schema_key] = output_schema openapi_schema["components"]["schemas"][schema_key]["class"] = "output" # Some models don't end up in the schemas as standalone definitions additional_schemas = models_json_schema( [ (UIConfigBase, "serialization"), (InputFieldJSONSchemaExtra, "serialization"), (OutputFieldJSONSchemaExtra, "serialization"), (ModelIdentifierField, "serialization"), (ProgressImage, "serialization"), ], ref_template="#/components/schemas/{model}", ) for schema_key, schema_json in additional_schemas[1]["$defs"].items(): openapi_schema["components"]["schemas"][schema_key] = schema_json # Add a reference to the output type to additionalProperties of the invoker schema for invoker in all_invocations: invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute output_type = signature(obj=invoker.invoke).return_annotation output_type_title = output_type_titles[output_type.__name__] invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"] outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} invoker_schema["output"] = outputs_ref invoker_schema["class"] = "invocation" # This code no longer seems to be necessary? # Leave it here just in case # # from invokeai.backend.model_manager import get_model_config_formats # formats = get_model_config_formats() # for model_config_name, enum_set in formats.items(): # if model_config_name in openapi_schema["components"]["schemas"]: # # print(f"Config with name {name} already defined") # continue # openapi_schema["components"]["schemas"][model_config_name] = { # "title": model_config_name, # "description": "An enumeration.", # "type": "string", # "enum": [v.value for v in enum_set], # } app.openapi_schema = openapi_schema return app.openapi_schema app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment @app.get("/docs", include_in_schema=False) def overridden_swagger() -> HTMLResponse: return get_swagger_ui_html( openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string title=f"{app.title} - Swagger UI", swagger_favicon_url="static/docs/invoke-favicon-docs.svg", ) @app.get("/redoc", include_in_schema=False) def overridden_redoc() -> HTMLResponse: return get_redoc_html( openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string title=f"{app.title} - Redoc", redoc_favicon_url="static/docs/invoke-favicon-docs.svg", ) web_root_path = Path(list(web_dir.__path__)[0]) try: app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui") except RuntimeError: logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount") app.mount( "/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static" ) # docs favicon is in here def invoke_api() -> None: class InterruptWatcher: def __init__(self): self.child = os.fork() if self.child == 0: return else: self.watch() def watch(self) -> None: try: os.wait() except KeyboardInterrupt: self.kill() sys.exit() def kill(self) -> None: try: os.kill(self.child, signal.SIGKILL) except OSError: pass def find_port(port: int) -> int: """Find a port not in use starting at given port""" # Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon! # https://github.com/WaylonWalker with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: if s.connect_ex(("localhost", port)) == 0: return find_port(port=port + 1) else: return port InterruptWatcher() if app_config.dev_reload: try: import jurigged except ImportError as e: logger.error( 'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.', exc_info=e, ) else: jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info) port = find_port(app_config.port) if port != app_config.port: logger.warn(f"Port {app_config.port} in use, using port {port}") # Start our own event loop for eventing usage loop = asyncio.new_event_loop() config = uvicorn.Config( app=app, host=app_config.host, port=port, loop="asyncio", log_level=app_config.log_level, ssl_certfile=app_config.ssl_certfile, ssl_keyfile=app_config.ssl_keyfile, ) server = uvicorn.Server(config) # replace uvicorn's loggers with InvokeAI's for consistent appearance for logname in ["uvicorn.access", "uvicorn"]: log = InvokeAILogger.get_logger(logname) log.handlers.clear() for ch in logger.handlers: log.addHandler(ch) loop.run_until_complete(server.serve()) if __name__ == "__main__": invoke_api()